In [None]:
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
from transformers import BertTokenizer, BertModel
import torch.optim.lr_scheduler as lr_scheduler
from pytorch_model_summary import summary
from torch.utils.data import Dataset
from torch import optim
import torch.nn as nn
import torch

import numpy as np
import unicodedata
import random
import math
import re

In [None]:
device = "cuda"

# Process Data

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_encoder = BertModel.from_pretrained('bert-base-uncased')

In [None]:
for param in bert_encoder.parameters():
    param.requires_grad = True

bert_encoder.pooler.dense.weight.requires_grad = True
bert_encoder.pooler.dense.bias.requires_grad = True

# for name, param in bert_encoder.named_parameters():
#     print(f"{name}: {param.requires_grad}")

In [None]:
tokenized = tokenizer.encode_plus("hello my name is nate",
                                  max_length=20,
                                  pad_to_max_length=True,
                                  return_attention_mask=True,
                                  return_tensors="pt")
with torch.no_grad():
    encodings = bert_encoder(**tokenized)
    
last_hidden_states = encodings.last_hidden_state
bert_encodings = last_hidden_states.mean(dim=1).squeeze().numpy()

In [None]:
SOS_token = 101
EOS_token = 102

In [None]:
def encodeString(text, tokenizer, encoder):
    indexed = tokenizer.encode_plus(text,
                                    max_length=20,
                                    pad_to_max_length=True,
                                    return_attention_mask=True,
                                    return_tensors="pt")
    attention_mask = indexed["attention_mask"]
    with torch.no_grad():
        encodings = encoder(**indexed)
        
    last_hidden_states = encodings.last_hidden_state
    return last_hidden_states, attention_mask

In [None]:
def readLangs():
    print("Reading lines...")

    lines = []
    counter = 0

    with open("/media/nathanmon/389E28739E282BB6/Users/Natha/Datasets/MyJarvisConversation/conversation.txt", "r") as f:
        for line in f.readlines():
            if line[0] == "U":
                lines.append("")
                lines[counter] += line[6:] + "/t"
            elif line[0] == "J":
                line = line.replace("/u", "/u ")
                lines[counter] += line[8:]
                counter += 1

    return lines

In [None]:
MAX_LENGTH = 20

def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and \
        len(p[1].split(' ')) < MAX_LENGTH


def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

In [None]:
def prepareData():
    pairs = readLangs()
    pairs = filterPairs(pairs)
    for i, pair in enumerate(pairs):
        pairs[i] = pair.lower().replace("\n", "").split("/t")
    return pairs

In [None]:
pairs = prepareData()
print(random.choice(pairs))

In [None]:
keywords = ["/u shopping", "/u todolist", "/u wiki", "/u volume", "/a/"]
filenames = ["shopping_items", "todo_list_items", "wiki_queries", "volumes", "apps"]
augments = {"shopping_items": [], "todo_list_items": [],
            "wiki_queries": [], "volumes": [], "apps": []}

for keyword, filename in zip(keywords, filenames):
    with open(f"/media/nathanmon/389E28739E282BB6/Users/Natha/Datasets/MyJarvisConversation/{filename}.txt", "r") as f:
        for line in f.readlines():
            augments[filename].append(line.replace("\n", "").strip())

In [None]:
class TokenizedDataset(Dataset):
    def __init__(self, pairs, augments,
                 batch_size=16, max_length=150):
        self.pairs = pairs
        self.augments = augments
        self.batch_size = batch_size
        self.max_length = max_length
        
        self.numbers = ["zero", "one", "two", "three", "four", 
           "five", "six", "seven", "eight", "nine",
           "ten", "eleven", "twelve", "thirteen",
           "fourteen", "fifteen", "sixteen", 
           "seventeen", "eighteen", "nineteen",
           "twenty", "thirty", "forty", "fifty",
           "sixty", "seventy", "eighty", "ninety",
           "hundred", "thousand", "million", "billion",
           "trillion", "quadrillion", "quintillion", "mute", "?"]

    def __len__(self):
        return len(self.pairs)

    def sentence2num(self, sentence):
        words = sentence.split(" ")
        filtered = []
        for word in words:
            if word.lower() in self.numbers:
                filtered.append(word)

        return " ".join(filtered)
    
    def find_tgt(self, response, loc="'"):
        lower = response.index(loc) + len(loc)
        upper = response[lower:].index(loc) + lower
        return response[lower:upper]
    
    def augment(self, inp, tgt):
        keywords = ["/u shopping", "/u todolist", "/u wiki", "/u volume", "/a"]
        filenames = ["shopping_items", "todo_list_items", "wiki_queries", "volumes", "apps"]

        for keyword, filename in zip(keywords, filenames):
            if keyword in tgt or keyword in inp:
                if keyword == "/u volume":
                    prev_item = self.sentence2num(self.find_tgt(tgt))
                elif keyword == "/a":
                    prev_item = self.find_tgt(inp, "/a")
                else:
                    prev_item = self.find_tgt(tgt)

                if keyword != "/uvolume" or (prev_item != "?"and prev_item != "Mute"):
                    replacement = random.choice(self.augments[filename])
                    inp = inp.replace(prev_item, replacement)
                    if keyword == "/a":
                        prev_item = self.find_tgt(tgt)
                    tgt = tgt.replace(prev_item, replacement)
                    return inp.replace("/a", ""), tgt.replace("/a", "")
        return inp, tgt

    def getitem(self, idx, augment=False):
        inps_tokenized, inps_types, inps_masked, targs_in, targs_out = [], [], [], [], []
        start_idx = idx*self.batch_size
        for (inp, tgt) in self.pairs[start_idx:start_idx+batch_size]:
            if augment:
                inp, tgt = self.augment(inp, tgt)
            
            inp_encoded = tokenizer.encode_plus(inp,
                                                max_length=20,
                                                pad_to_max_length=True,
                                                return_attention_mask=True,
                                                return_tensors="pt")
            inps_tokenized.append(inp_encoded['input_ids'][0].tolist())
            inps_types.append(inp_encoded["token_type_ids"][0].tolist())
            inps_masked.append(inp_encoded['attention_mask'][0].tolist())

            tgt_encoded = tokenizer.encode_plus(tgt,
                                                max_length=20,
                                                pad_to_max_length=True,
                                                return_attention_mask=True,
                                                return_tensors="pt")
            tgt_tokenized = tgt_encoded['input_ids'][0]
    #         tgt_masked = tgt_encoded['attention_mask'][0]

            targ_in = tgt_tokenized[:-1].tolist()
            targ_out = tgt_tokenized[1:].tolist()
            targs_in.append(targ_in)
            targs_out.append(targ_out)
            
        inps_tokenized = torch.tensor(inps_tokenized).to(device)
        inps_types = torch.tensor(inps_types).to(device)
        inps_masked = torch.tensor(inps_masked).to(device)
        targs_in = torch.tensor(targs_in).to(device)
        targs_out = torch.tensor(targs_out).to(device)
            
        return inps_tokenized, inps_types, inps_masked, targs_in, targs_out

In [None]:
batch_size = 16
dataloader = TokenizedDataset(pairs, augments, batch_size=batch_size)

In [None]:
pair = random.choice(pairs)
dataloader.augment(pair[0], pair[1])

In [None]:
pair

# Create Model

### Positional Encoding

In [None]:
def positional_encoding(length, depth):
    depth = depth/2

    positions = torch.unsqueeze(torch.arange(length), 1)
    depths = torch.unsqueeze(torch.arange(depth), 0)/depth
#     positions = torch.arange(length)[:, np.newaxis]     # (seq, 1)
#     depths = torch.arange(depth)[np.newaxis, :]/depth   # (1, depth)

    angle_rates = 1 / (10000**depths)         # (1, depth)
    angle_rads = positions * angle_rates      # (pos, depth)

    pos_encoding = torch.cat(
      [torch.sin(angle_rads), torch.cos(angle_rads)],
      axis=-1) 

    return pos_encoding.to(device, dtype=torch.float32)

In [None]:
class PositionalEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
        # The positional encoding is used to introduce sequence to a sentence by causing words near 
        # eachother to have similar vectors
        self.pos_encoding = positional_encoding(length=2048, depth=d_model)

    def compute_mask(self, *args, **kwargs):
        return self.embedding.compute_mask(*args, **kwargs)

    def forward(self, x):
        length = np.shape(x)[1]
        x = self.embedding(x)
        # This factor sets the relative scale of the embedding and positonal_encoding.
        x *= math.sqrt(torch.tensor(self.d_model).type(torch.float32))
        x = x + torch.unsqueeze(self.pos_encoding, 0)[:, :length]
        return x

### Attention

In [None]:
class BaseAttention(nn.Module):
    def __init__(self, d_model, **kwargs):
        super().__init__()
        self.num_heads = kwargs.get('num_heads')
        self.mha = nn.MultiheadAttention(**kwargs)
        self.layernorm = nn.LayerNorm(d_model)

In [None]:
class CrossAttention(BaseAttention):
    def forward(self, x, context):
        x_ = x.permute(1, 0, 2)
        context_ = context.permute(1, 0, 2)
        attn_output, attn_scores = self.mha(
            query=x_,
            key=context_,
            value=context_,
            need_weights=True)
        attn_output = attn_output.permute(1, 0, 2)
        attn_scores = attn_scores.permute(1, 0, 2)

        # Cache the attention scores for plotting later.
        self.last_attn_scores = attn_scores

        x =x + attn_output
        x = self.layernorm(x)

        return x
    
sample_ca = CrossAttention(d_model=256, embed_dim=128, num_heads=2, kdim=256)

In [None]:
class CausalSelfAttention(BaseAttention):
    def forward(self, x):
        x_ = x.permute(1, 0, 2)
        attention_mask = nn.Transformer.generate_square_subsequent_mask(x_.shape[0]).to(device)
        attention_mask = attention_mask.expand(x_.shape[1]*self.num_heads, -1, -1).to(device)
        
        attn_output = self.mha(
            query=x_,
            value=x_,
            key=x_,
            attn_mask=attention_mask,
            is_causal=True)[0]
        attn_output = attn_output.permute(1, 0, 2)
        x = x + attn_output
        x = self.layernorm(x)
        return x
    
sample_csa = CausalSelfAttention(d_model=256, embed_dim=128, 
                                 num_heads=2, kdim=256)

### Encoder

In [None]:
class FeedForward(nn.Module):
    def __init__(self, d_model, dff, dropout_rate=0.1):
        super().__init__()
        self.seq = nn.Sequential(
            nn.Linear(d_model, dff),
            nn.ReLU(),
            nn.Linear(dff, d_model),
            nn.Dropout(dropout_rate)
        ).to(device)
        self.layer_norm = nn.LayerNorm(d_model).to(device)
        
    def forward(self, x):
        x = x + self.seq(x)
        x = self.layer_norm(x)
        return x
    
sample_ffn = FeedForward(28, 512)

In [None]:
class Encoder(nn.Module):
    def __init__(self, *, emb_size, d_model, dff,
                   dropout_rate=0.1):
        super(Encoder, self).__init__()

        self.bert_encoder = bert_encoder
        
        self.ffn = FeedForward(emb_size, dff)
        self.linear = nn.Linear(emb_size, d_model)

    def forward(self, x):
        input_tensor, input_type, input_mask = x
        x = self.bert_encoder(input_tensor, input_type, input_mask).last_hidden_state
        x = self.ffn(x)
        x = self.linear(x)
        return x

### Decoder

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self,
                   *,
                   d_model,
                   num_heads,
                   dff,
                   dropout_rate=0.1):
        super(DecoderLayer, self).__init__()

        self.causal_self_attention = CausalSelfAttention(
            d_model=d_model,
            embed_dim=d_model,
            num_heads=num_heads,
            kdim=d_model,
            dropout=dropout_rate).to(device)
        
        self.cross_attention = CrossAttention(
            d_model=d_model,
            embed_dim=d_model,
            num_heads=num_heads,
            kdim=d_model,
            dropout=dropout_rate).to(device)

        self.ffn = FeedForward(d_model, dff)

    def forward(self, x, context):
        x = self.causal_self_attention(x=x)
        x = self.cross_attention(x=x, context=context)

        # Cache the last attention scores for plotting later
        self.last_attn_scores = self.cross_attention.last_attn_scores

        x = self.ffn(x)  # Shape `(batch_size, seq_len, d_model)`.
        return x

In [None]:
class Decoder(nn.Module):
    def __init__(self, *, emb_size, num_layers, d_model, num_heads, dff, vocab_size,
                   dropout_rate=0.1):
        super(Decoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.linear = nn.Linear(emb_size, d_model)
        self.pos_embedding = PositionalEmbedding(vocab_size=vocab_size,
                                                 d_model=d_model).to("cuda")
        self.dropout = nn.Dropout(dropout_rate)
        self.dec_layers = [
            DecoderLayer(d_model=d_model, num_heads=num_heads,
                         dff=dff, dropout_rate=dropout_rate)
            for _ in range(num_layers)]
        self.dec_layers = nn.ModuleList(self.dec_layers)

        self.last_attn_scores = None
        
        self.final_layer = nn.Linear(d_model, vocab_size)

    def forward(self, x, context):
        x = self.pos_embedding(x)  # (batch_size, target_seq_len, d_model)

        x = self.dropout(x)

        for i in range(self.num_layers):
            x  = self.dec_layers[i](x, context)

        self.last_attn_scores = self.dec_layers[-1].last_attn_scores
        logits = self.final_layer(x)
        
        return logits

### Transformer

In [None]:
class Transformer(nn.Module):
    def __init__(self, *, bert_encoder, emb_size, num_layers, d_model, num_heads, 
                 dff, vocab_size, dropout_rate=0.1):
        super().__init__()
        self.encoder = Encoder(emb_size=emb_size, d_model=d_model, dff=dff, dropout_rate=dropout_rate)
        self.decoder = Decoder(emb_size=emb_size, num_layers=num_layers, d_model=d_model,
                               num_heads=num_heads, dff=dff, vocab_size=vocab_size,
                               dropout_rate=dropout_rate)

        self.final_layer = nn.Linear(d_model, vocab_size)

        
    def forward(self, input):
        input_tensor, input_type, input_mask, x = input
        context = self.encoder((input_tensor,
                               input_type,
                               input_mask)) # (batch_size, target_len, 768)
        
        x = self.decoder(x, context)  # (batch_size, target_len, d_model)


        # Return the final output and the attention weights.
        return x

In [None]:
try:
    del transformer
except:
    pass

num_layers = 5
emb_size = 768
d_model = 1024
dff = 2056
num_heads = 16
dropout_rate = 0.7

# transformer = Transformer(
#     bert_encoder=bert_encoder,
#     emb_size=emb_size,
#     num_layers=num_layers,
#     d_model=d_model,
#     num_heads=num_heads,
#     dff=dff,
#     vocab_size=len(tokenizer.get_vocab()),
#     dropout_rate=dropout_rate).to(device)

encoder = Encoder(emb_size=emb_size, d_model=d_model, dff=dff, dropout_rate=dropout_rate).to(device)
decoder = Decoder(emb_size=emb_size, num_layers=num_layers, d_model=d_model,
                               num_heads=num_heads, dff=dff, vocab_size=len(tokenizer.get_vocab()),
                               dropout_rate=dropout_rate).to(device)

In [None]:
encoding = tokenizer.encode_plus("hello there",
                                  max_length=20,
                                  pad_to_max_length=True,
                                  return_attention_mask=True,
                                  return_tensors="pt").to(device)
tokenized = encoding["input_ids"]
types = encoding["token_type_ids"]
mask = encoding["attention_mask"]
context = encoder((tokenized, types, mask))

print(summary(encoder, (tokenized, types, mask), show_input=True))
print(summary(decoder, torch.zeros((1, 20), dtype=torch.int32).to(device), context, show_input=True))

# Train Model

In [None]:
def masked_loss(label, pred):
    mask = label != 0

    loss_object = nn.CrossEntropyLoss(ignore_index=0)
    
    pred_flat = pred.view(-1, pred.size(-1))
    label_flat = label.view(-1)
    
    pred_masked = pred_flat[mask.view(-1)]
    label_masked = label_flat[mask.view(-1)]
    
    loss = loss_object(pred_masked, label_masked)
    return loss


def masked_accuracy(label, pred):
    pred = torch.argmax(pred, axis=2)
    label = label.to(pred.dtype)
    match = label == pred

    mask = label != 0

    match = match & mask

    match = match.to(torch.float32)
    mask = mask.to(torch.float32)
    return torch.sum(match)/torch.sum(mask)

In [None]:
def train_epoch(dataloader, encoder, decoder, encoder_optimizer, 
                decoder_optimizer, criterion, augment, train=True):
    global batch_size
    total_loss = 0
    for batch in range(len(dataloader) // batch_size):
        input_tensor, input_type, input_mask, tensor_in, tensor_out = dataloader.getitem(batch, augment)
        inputs = (input_tensor, input_type,
                  input_mask, tensor_in)

        encoder_optimizer.zero_grad()
        if decoder_optimizer is not None:
            decoder_optimizer.zero_grad()
        
        if decoder is not None:
            context = encoder((input_tensor,
                               input_type,
                               input_mask))
            logits = decoder(tensor_in, context)
        else:
            logits = encoder(inputs)
        loss = masked_loss(tensor_out, logits)
        
        if train:
            loss.backward()

            encoder_optimizer.step()
            if decoder_optimizer is not None:
                decoder_optimizer.step()
    
        total_loss += loss.item()

    return total_loss / (len(dataloader) / batch_size)

In [None]:
import time
import math

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

In [None]:
class SaveBestModel:
    """
    Class to save the best model while training. If the current epoch's 
    validation loss is less than the previous least less, then save the
    model state.
    """
    def __init__(
        self, best_valid_loss=float('inf')
    ):
        self.best_valid_loss = best_valid_loss
        
    def __call__(
        self, current_valid_loss, 
        epoch, encoder, decoder,
        encoder_optimizer, decoder_optimizer, criterion
    ):
        if current_valid_loss < self.best_valid_loss:
            self.best_valid_loss = current_valid_loss
            print(f"Best validation loss: {self.best_valid_loss}")
            print(f"Saving best model for epoch: {epoch+1}")
            torch.save({
                'epoch': epoch+1,
                'model_state_dict': encoder.state_dict(),
                'optimizer_state_dict': encoder_optimizer.state_dict(),
                'loss': criterion,
                }, 'checkpoints/best_encoder.pth')
            if decoder is not None:
                torch.save({
                    'epoch': epoch+1,
                    'model_state_dict': decoder.state_dict(),
                    'optimizer_state_dict': decoder_optimizer.state_dict(),
                    'loss': criterion,
                    }, 'checkpoints/best_decoder.pth')

In [None]:
# def train(train_dataloader, val_dataloader, transformer, n_epochs, learning_rate=0.001,
def train(train_dataloader, val_dataloader, encoder, decoder, n_epochs, augment=True,
          encoder_lr=1e-3, decoder_lr=1e-3, print_every=100, plot_every=100):
    start = time.time()
    global d_model
    print_train_loss_total = 0  # Reset every print_every
    plot_train_loss_total = 0  # Reset every plot_every
    plot_train_losses = []
    
    plot_encoder_lrs = []
    plot_decoder_lrs = []
    
    plot_val_loss_total = 0
    plot_val_losses = []
    print_val_loss_total = 0  # Reset every print_every

    save_best = SaveBestModel(best_valid_loss=.59)
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=encoder_lr,
                                   betas=(0.95, 0.9995), eps=1e-9)
    encoder_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(encoder_optimizer, "min", factor=0.05, patience=250)
    decoder_optimizer = None
    decoder_scheduler = None    
    if decoder is not None:
        decoder_optimizer = optim.Adam(decoder.parameters(), lr=decoder_lr,
                                       betas=(0.95, 0.9995), eps=1e-9)
        decoder_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(decoder_optimizer, "min", factor=0.05, patience=250)
        
    criterion = nn.NLLLoss()

    for epoch in range(1, n_epochs + 1):
#         train_loss = train_epoch(train_dataloader, transformer, optimizer, scheduler, criterion, augment)
        train_loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, 
                                 decoder_optimizer, criterion, augment)
        print_train_loss_total += train_loss
        plot_train_loss_total += train_loss
        
        # Evaluate validation dataloader
#         val_loss = train_epoch(val_dataloader, transformer, optimizer, scheduler, criterion, augment, train=False)
        val_loss = train_epoch(val_dataloader, encoder, decoder, encoder_optimizer,
                               decoder_optimizer, criterion, augment, train=False)
        print_val_loss_total += val_loss
        plot_val_loss_total += val_loss
        
        save_best(val_loss, epoch, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
        
        encoder_scheduler.step(val_loss)
        if decoder_scheduler is not None:
            decoder_scheduler.step(val_loss)

        if epoch % print_every == 0:
            print_train_loss_avg = print_train_loss_total / print_every
            print_train_loss_total = 0
            print_val_loss_avg = print_val_loss_total / print_every
            print_val_loss_total = 0
            print('%s (%d %d%%) %.4f %.4f %.7f %.7f' % (timeSince(start, epoch / n_epochs),
                epoch, epoch / n_epochs * 100, print_train_loss_avg, print_val_loss_avg,
                encoder_optimizer.param_groups[0]["lr"],
                decoder_optimizer.param_groups[0]["lr"] if decoder_optimizer is not None else encoder_optimizer.param_groups[0]["lr"]))
            plot_encoder_lrs.append(encoder_optimizer.param_groups[0]["lr"])
            if decoder is not None:
                plot_decoder_lrs.append(decoder_optimizer.param_groups[0]["lr"])

        if epoch % plot_every == 0:
            plot_train_loss_avg = plot_train_loss_total / plot_every
            plot_train_losses.append(plot_train_loss_avg)

            plot_train_loss_total = 0
            
            plot_val_loss_avg = plot_val_loss_total / plot_every
            plot_val_losses.append(plot_val_loss_avg)

            plot_val_loss_total = 0

    showPlot(plot_train_losses, "loss", plot_val_losses, "val_loss")
    if decoder is None:
        showPlot(plot_encoder_lrs, "encoder learning rate")
    else:
        showPlot(plot_encoder_lrs, "encoder learning_rate", plot_decoder_lrs, "decoder learning rate")
    return plot_train_losses

In [None]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np

def showPlot(points, points_name, points2=None, points2_name=None):
    plt.figure()
    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    loc = ticker.MultipleLocator(base=0.5)
    ax.yaxis.set_major_locator(loc)
    if points2 != None:
        plt.plot(np.arange(len(points)), points, points2)
        plt.legend([points_name, points2_name])
    else:
        plt.plot(points)
        plt.legend([points_name])

In [None]:
len(pairs)

In [None]:
torch.cuda.empty_cache()

In [None]:
batch_size = 32
train_dataloader = TokenizedDataset(pairs[:210], augments, 
                                   batch_size=batch_size)
val_dataloader = TokenizedDataset(pairs[210:], augments, 
                                   batch_size=batch_size)

history = train(train_dataloader, val_dataloader, encoder, decoder, 3000, encoder_lr=1e-5, decoder_lr=1e-4, print_every=5, plot_every=5)

In [None]:
encoder_ckpt, decoder_ckpt = torch.load("checkpoints/best_encoder.pth"), torch.load("checkpoints/best_decoder.pth")

In [None]:
encoder.load_state_dict(encoder_ckpt['model_state_dict'])
decoder.load_state_dict(decoder_ckpt['model_state_dict'])

In [None]:
def sentenceFromIndexes(encoded):
    words = []
    for word in encoded:
        words.append(list(tokenizer.get_vocab().keys())[word])
        
    return words

In [None]:
class Chatbot():
    def __init__(self, encoder, decoder):
#         self.transformer = transformer
        self.encoder = encoder
        self.decoder = decoder

    def __call__(self, sentence):
        inp_encoded = tokenizer.encode_plus(sentence,
                                            max_length=20,
                                            pad_to_max_length=True,
                                            return_attention_mask=True,
                                            return_tensors="pt")
        inp_tokenized = inp_encoded['input_ids']
        inp_types = inp_encoded["token_type_ids"]
        inp_masked = inp_encoded['attention_mask']
        
        encoder_input = sentence

        output_array = torch.tensor([[SOS_token]]) # tokens, batch size
        
        with torch.no_grad():
            for i in range(MAX_LENGTH):
                output = output_array.transpose(0, 1) # batch size, tokens
#                 predictions = self.transformer((inp_tokenized.to(device),
#                                                 inp_types.to(device),
#                                                 inp_masked.to(device),
#                                                 output.to(device))) # batch size, tokens, vocab size
                context = encoder((inp_tokenized.to(device),
                   inp_types.to(device),
                   inp_masked.to(device)))
                predictions = decoder(output.to(device), context)
                
                predictions = predictions[:, -1:, :] # batch_size, 1, vocab_size
                
                predicted_id = torch.argmax(predictions, -1)
                
                output_array = torch.cat((output_array.to(device),
                                          predicted_id), 0)
                
                if predicted_id[0] == torch.tensor([EOS_token]).to(device):
                    break

        output = torch.unsqueeze(torch.flatten(output_array), 0)
        
        tokens = sentenceFromIndexes(output[0].tolist())
        text = ' '.join(tokens)

        with torch.no_grad():
#             self.transformer((inp_tokenized.to(device),
#                               inp_types.to(device),
#                               inp_masked.to(device),
#                               output[:,:-1].to(device)))
            context = encoder((inp_tokenized.to(device),
                               inp_types.to(device),
                               inp_masked.to(device)))
            predictions = decoder(output.to(device), context)
            attention_weights = self.decoder.last_attn_scores

        return text, attention_weights.to(device)

In [None]:
# chatbot = Chatbot(transformer)
chatbot = Chatbot(encoder, decoder)

In [None]:
def print_translation(sentence, tokens):
    print(f'{"Input:":15s}: {sentence}')
    print(f'\n{"Prediction":15s}: {tokens}')

In [None]:
sentence = "What is the temperature today"

translated_text, attention_weights = chatbot(sentence)
print_translation(sentence, translated_text)