We will build a chatbot using transformers from scratch using the Pytorch library.
Pytorch already has a transformer class as well as transformer encoders and decoders. However here we will code it from scratch to help understand the concept.

We will be using the Cornell Movie Dialogs Corpus. 

# Imports

In [1]:
from collections import Counter
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torch.utils.data
import math
import torch.nn.functional as F

In [2]:
corpus_movie_conv='data/cornell movie-dialogs corpus/movie_conversations.txt'
corpus_movie_lines = 'data/cornell movie-dialogs corpus/movie_lines.txt'

# Dataset Preprocessing
1. Load the data in batches to feed into the network. To store this data we need to find max length and pad the shorter sentences.

2. These datasets tell us the conversations that happen, it includes which lines are in a conversation, the character that said each line, etc. Extract the line number and utterance. 

3. Format the conversations into a dictonary.

4. Group conversations.


In [3]:
max_len=25

# conversations
with open(corpus_movie_conv,'r') as c:
    conv = c.readlines()

# lines
with open(corpus_movie_lines,'r') as l:
    lines = l.readlines()

In [4]:
#conv
#lines

lines[0]
# we want to extract line id and the saying 
lines[0].split('+++$+++')

['L1045 ', ' u0 ', ' m0 ', ' BIANCA ', ' They do not!\n']

In [5]:
lines_dict={}
for line in lines:
    objects = line.split(' +++$+++ ')
    lines_dict[objects[0]] = objects[-1]

In [6]:
def remove_punc(string):
    '''Remove punctuation to make it easier for NN'''
    punctuations='''!()-[]{};:'"\,<>./?@#$%^&*_`~'''
    no_punc=''
    for char in string:
        if char not in punctuations:
            no_punc+=char
    return no_punc.lower()

In [7]:
pairs = []
for con in conv:
    ids=eval(con.split(' +++$+++ ')[-1])
    # note eval converts string to list and strip function removes extra spacing in a string
    for i in range(len(ids)):
        qa_pairs = []
        
        if i ==len(ids)-1:
            break
            
        first =remove_punc(lines_dict[ids[i]].strip())
        second =remove_punc(lines_dict[ids[i+1]].strip())
        
        # Create a 2D list
        # Append and split strings because you want to process words one at a time in the transformer
        # Trim to maximum length so we can assemble in a matrix
        qa_pairs.append(first.split()[:max_len])
        qa_pairs.append(second.split()[:max_len])
        
        # take the whole 2D list and append to the pairs list.
        pairs.append(qa_pairs)
        
        
# eval(conv[0].split(' +++$+++ ')[-1])
# lines_dict['L194'].strip().split()
# len(pairs)
# pairs[9]

In [8]:
# Find word frequency for all conversation pairs.
word_freq = Counter()
for pair in pairs:
    # this only updates unique words
    word_freq.update(pair[0])
    word_freq.update(pair[1])
    
# add minimum word frequency so vocabulary doesn't get out of hand. 
min_word_freq = 5
words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq]

# construct word to index dictionary and pytorch will map it to a OHE which will be mapped to embedding
word_map = {k:v+1 for v,k in enumerate(words)}

# add tokens
# unknown will be for all words with a frequency of 5 or less
word_map['<unk>'] = len(word_map) + 1
word_map['<start>'] = len(word_map) + 1
word_map['<end>'] = len(word_map) + 1
# index for padding will be zero. Therefore when checking if words in a sentence !=0 the non padded words will be 1 and padded
# words will be 0
word_map['<pad>'] = 0

In [9]:
print('Total words are {}'.format(len(word_map)))

Total words are 18243


In [10]:
# Dump so we dont need to run the code again we can directly load it
with open('WORDMAP_corpus.json','w') as j:
    json.dump(word_map,j)

In [11]:
# we cant provide words to NN so we will provide indices
# note the get function will get the value of a key and if its not present it will get value of unknown token
# add padding for sententences less than max_len

def encode_question(words, word_map):
    enc_c = [word_map.get(word,word_map['<unk>']) for word in words] + [word_map['<pad>']] * (max_len - len(words))
    return enc_c

# pairs[0][0]
# encode_question(pairs[0][0],word_map)

In [12]:
# note we need to include start and end token when decoding.
def encode_reply(words, word_map):
    enc_c = [word_map['<start>']] + [word_map.get(word,word_map['<unk>']) for word in words] + [word_map['<end>']] + ([word_map['<pad>']] * (max_len - len(words)))
    return enc_c

In [13]:
# now we will encode the pairs
pairs_encoded=[]

for pair in pairs:
    ques = encode_question(pair[0],word_map)
    ans = encode_reply(pair[1],word_map)
    
    #2D list of encoded pairs
    pairs_encoded.append([ques, ans])

In [14]:
# save pairs encoded.
with open('pairs_encoded.json','w') as w:
    json.dump(pairs_encoded,w)
# now we dont need to run the code again we can directly load it

In [15]:
# Create dataset that will inherit from pytorch Dataset class so it has all of its functions and attributes

class Dataset(Dataset):
    
    def __init__(self,path='pairs_encoded.json'):
        # define loading function
        self.pairs = json.load(open(path))
        self.dataset_size = len(self.pairs)
        
        
    def __getitem__(self,i):
        '''We will override the getitem function from dataset class. This function is what retreives a sample from the 
        dataset. This will retrieve one element and loop for self.dataset_size.
        Note: needs to be long tensor because these are discrete integer values.'''
        
        question = torch.LongTensor(self.pairs[i][0])
        reply = torch.LongTensor(self.pairs[i][1])
        
        # return one pair, later in dataloader these will be returned in batches. 
        return question, reply
    def __len__(self):
        return self.dataset_size

In [16]:
# Create dataloader to load data in batches. 
train_loader = torch.utils.data.DataLoader(Dataset(), 
                                          batch_size = 100,
                                          shuffle = True,
                                         pin_memory = True)

In [17]:
# lets visualize a batch of samples. 
question, reply = next(iter(train_loader))

In [18]:
question.shape
# 100 samples, 25 words is maxlen

torch.Size([100, 25])

In [19]:
reply.shape
# 27 also includes the start and end token so 2 extra words 

torch.Size([100, 27])

# Build Architecture

In [20]:
# Create a function to create input and target masks for the decoder.

def create_masks(question, reply_input, reply_target):
    
    #input example:
    #sentence: <start> I slept last night <end>
    # reply input: <start> I slept last night
    # reply target: I slept last night <end>
    
    
    def subsequent_mask(size):
        '''creates the whole matrix mask for one sentence'''
        # takes in size of mask
        # note when we use masks we want them to be integers
        #pytorch has a function that does the transpose of masking already
        # mask needs to be integer type not float.
        mask = torch.triu(torch.ones(size,size)).transpose(0,1).type(dtype=torch.uint8)
        # since 4D tensors we need to unsqueezed in 0th dimention to add a dimention
        return mask.unsqueeze(0)
    
    # We need a question mask so words are 1 and padded elements are 0.
    # Also the question mask is just one line (batchsize=1 when you do that). try typing question[0]
    question_mask = (question!=0).to(device)
    #unsqueeze because we are using 4D tensor so mask should also be 4D, need to add 2Ds here
    question_mask = question_mask.unsqueeze(1).unsqueeze(1)  #(batchsize,1,1,maxwords)
    
    # do same for the reply input mask. This is to put a 1 where there are words and 0 where there is padding.
    reply_input_mask = reply_input!=0
    # Just like question mask again, it needs to be unsqueezed twice to make it a 4D mask, however second time needs 
    # to be done after its been combined with subsequent mask because now decoder masks both padded AND future words
    # we want decoder to only attend to words already generated
    reply_input_mask=reply_input_mask.unsqueeze(1)   #(batchsize,1,maxwords)
    reply_input_mask=reply_input_mask & subsequent_mask(reply_input.size(-1)).type_as(reply_input_mask.data)  #-1 is maxwords
    # that gave (batchsize,maxwords,maxwords)
    reply_input_mask=reply_input_mask.unsqueeze(1) 
    
    # this doesnt need to be unsqueezed because not used in model just loss 
    reply_target_mask=reply_target!=0
    
    return question_mask, reply_input_mask, reply_target_mask
    

In [27]:
'''responsible for word embeddings for input sequence and positional encoding'''
class Embeddings(nn.Module):

    def __init__(self,vocab_size,d_model,max_len=50):
        #d_model is dimentionality of embeddings, and I know max len is 25 but we will trim this later
        super(Embeddings, self).__init__()
        self.d_model=d_model
        # we will use dropout as regularization, we will follow hyperparameters in the paper
        self.dropout=nn.Dropout(0.1)
        self.embed = nn.Embedding(vocab_size,d_model)  # input is size of the dictionary of embeddings and output is size of
                                                        # the desired embedding vectors
        
        # positional encoding (pe) is fixed matrix with no params to upadate so no backprop through this matrix
        self.pe = self.create_positional_encoding(max_len,d_model)
        
    def create_positional_encoding(self,max_len,d_model):
        # recall max_len is x axis, and d_model is how many curves you create to make positional vector
        # recall must be same shape as embedding so can be added together.
        pe = torch.zeros(max_len,d_model).to(device)
        
        for pos in range(max_len):
            for i in range(0,d_model,2):
                # for each position we include sin and cos elements up until total dimantions of model
                pe[pos,i] = math.sin(pos/(10000 **((2*i)/d_model)))
                pe[pos,i+1] = math.cos(pos/(10000 **((2*(i+1))/d_model)))
                # pos will go to maximum length
                
        # we are working in batches so unsqueeze
        # 1 will be automatically expanded with same batch size as encoded words in forward function
        pe = pe.unsqueeze(0)  # 0 because batch size is in first dimention (1,max_len,d_model)
        
        return pe
    
    def forward(self,encoded_words):
        '''Tell pytorch how to run the class'''
        
        # recall we want to give more wieght to words rather than positional encodings so mult
        # by sqrt of dimentionality of model
        embeddings = self.embed(encoded_words) * math.sqrt(self.d_model) #(batchsize,max_words,d_model)
        
        #trim positional encoding by max words extracted from embeddings
#         max_words=embeddings.size[1]
        embeddings += self.pe[:,:embeddings.size(1)] # pe will automatically be expandied to batch size of embeddings matrix
        embeddings = self.dropout(embeddings)
        return embeddings 

In [28]:
# recall 3 kinds of attention in transformer`
# encoder self attention
# decoder masked self attention
# decoder source attention from encoder

class MultiHeadAttention(nn.Module):
    
    def __init__(self,heads,d_model):
        super(MultiHeadAttention,self).__init__()
        
        # need to ensure number of heads is compatible with dimentionality of embeddings 
        # assert - make sure statement is correct
        assert d_model % heads == 0
        
        # dimentionality of each head
        self.d_k = d_model // heads # division result with no remainder
        self.heads = heads
        self.dropout = nn.Dropout(0.1)
        
        # from embeddings we create three weights matricies 
        self.query = nn.Linear(d_model,d_model)
        self.key = nn.Linear(d_model,d_model)
        self.value = nn.Linear(d_model,d_model)
        
        # concat layer that allows concat head results to interact
        self.concat = nn.Linear(d_model,d_model)
        
    def forward(self,query,key,value,mask):
            '''# in self attention and masked self attention all 3 come from model input, in source attention the 
            # value comes from encoded representation from encoder 
            query, key and value: (batchsize,max_words,512)
            mask for masked self attention, mask for source attention, masked self multihead attention 
            in decoder (batchsize,1,1,max_words)'''
            
            query = self.query(query)  #(batchsize,maxwords,512)
            key = self.key(key)        #(batchsize,maxwords,512)
            value = self.value(value)  #(batchsize,maxwords,512)
            
            #reshape by number of heads
            # note cant do (-1)view trick in 1 step because it will reshape the max_words to number of heads (8), 
            # which you don't want
            # (batchsize,maxwords,512) --> (batchsize,maxwords,8,64) --> (batchsize,8,maxwords,64)
            query = query.view(query.shape[0],-1,self.heads,self.d_k).permute(0,2,1,3)  
            key = key.view(query.shape[0],-1,self.heads,self.d_k).permute(0,2,1,3)  
            value = value.view(query.shape[0],-1,self.heads,self.d_k).permute(0,2,1,3)  
            
            #dot Q and K
            # (batchsize,8,maxwords,64) dot (batchsize,8,maxwords,64)T --> (batchsize,8,maxwords,maxwords)
            # note in source attention first maxwords is max words from decoder and second is maxwords from encoder
            # in self attention theyre the same
            # when you do dot product or transpose on 4D matrix the first 2Ds stay the same. 
            scores = torch.matmul(query,key.permute(0,1,3,2)) / math.sqrt(query.size(-1))
            # mask the padded elements using masked_full method from pytorch. you can pass a mask
            # and specify a value you want replaced. we will replace with small value so its ignored
            # by softmax to prevent calculation of attention weights over padded elements
            scores = scores.masked_fill(mask ==0,1e-9)
            # pass through softmax over all words - always over -1 
            weights = F.softmax(scores,dim=-1)
            weights = self.dropout(weights)
            
            # dot product with value
            #(batchsize,8,maxwords,maxwords) dot (batchsize,8,maxwords,64)
            context = torch.matmul(weights,value)
            
            # concat context matrix into a vector for all heads together
            # (batchsize,8,maxwords,64) --> (batchsize,maxwords,8,64)--> (batchsize,maxwords,8*64)
            context = context.permute(0,2,1,3).reshape(context.shape[0],-1,self.heads*self.d_k)
            
            # run though linear interaction layer
            interacted=self.concat(context)
            return interacted  #(batchsize,maxwords,8*64)

In [29]:
class FeedForward(nn.Module):
    def __init__(self,d_model,middle_dim=2048):
        super(FeedForward,self).__init__()
        
        self.fc1=nn.Linear(d_model,middle_dim)
        self.fc2=nn.Linear(middle_dim,d_model)
        self.dropout=nn.Dropout(0,1)
        
    def forward(self,x):
        #x is output of multihead attention
        out = F.relu(self.fc1(x))
        out = self.fc2(self.dropout(out))
        return out  #(batchsize,maxwords,8*64)

In [30]:
class EncoderLayer(nn.Module):
    
    def __init__(self,d_model,heads):
        super(EncoderLayer,self).__init__()
        
        #in encoder we only have self attention - not source attention
        # its also not masked
        self.self_multihead = MultiHeadAttention(heads,d_model)
        self.feed_forward = FeedForward(d_model)
        #layernorm is applied to output of attention and FF which both have dimantionality of 512
        self.layernorm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.1)
        
        
    def forward(self,embeddings, mask):
        '''mask is source mask which I believe stops the model from attending to padded tokens, we can see forward 
        function of multihead class takes query key and values but those are all just the embeddings in the 
        case of embedding'''
        interacted = self.self_multihead(embeddings,embeddings,embeddings,mask)
        interacted = self.dropout(interacted)
        #residual connection
        interacted = interacted + embeddings
        #layer norm
        interacted = self.layernorm(interacted)
        
        #FF
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        #residual connection
        feed_forward_out = feed_forward_out + interacted 
        #layer norm
        encoded = self.layernorm(feed_forward_out)
        
        # encoded representation of 1 block
        return encoded

In [146]:
#word embeddings needed for the masked self attention, encoded representation from encoder needed for
#source attention,source mask needed for encoded representation, target mask for 
#self attention so can attend to its own inpstart with self attention - get query from decoder
# we can see forward function of multihead class takes query key and values but those areall just 
#the embeddings

class DecoderLayer(nn.Module):
    # same as encoder but need to add source attention which attends to encoder outputs
    def __init__(self,d_model,heads):
        super(DecoderLayer, self).__init__()
        
        
        self.self_multihead = MultiHeadAttention(heads,d_model)
        self.src_multihead = MultiHeadAttention(heads,d_model)
        self.feed_forward = FeedForward(d_model)
        #layernorm is applied to output of attention and FF which both have dimantionality of 512
        self.layernorm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self,embeddings,encoded,src_mask, target_mask):
        query = self.self_multihead(embeddings,embeddings,embeddings,target_mask)
        query = self.dropout(query)
        
        
        #layer norm and residule connections
        query = self.layernorm(query + embeddings)
        
        # now we need to apply source multi head attention that attends to the encoder output
        # the query comes from self attention of decoder, keys and values come from encoded representation
        interacted = self.src_multihead(query,encoded,encoded,src_mask)
        interacted = self.dropout(interacted)
        
        #layer norm and residule connections
        interacted = self.layernorm(interacted + query)
        
        #feed forward
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        decoded = self.layernorm(feed_forward_out + interacted)
        
        return decoded 


In [147]:
class Transformer(nn.Module):
    
    def __init__(self, d_model, heads, num_layers, word_map):
        super(Transformer,self).__init__()
        # num_layers is number of encoder layers, all have predefined architecture we specified but
        # will have different weights
        self.d_model = d_model
        self.vocab = len(word_map)
        self.embed = Embeddings(self.vocab, d_model)
        
        #assemble models in a list, this is list of 6 encoder layers
        # note underscore is more memory efficient, 
        self.encoder = nn.ModuleList([EncoderLayer(d_model, heads) for _ in range(num_layers)])
        self.decoder = nn.ModuleList([DecoderLayer(d_model, heads) for _ in range(num_layers)])
        
        #classification layer
        #input will be size d_model which is size of decoded that is returned in decoder layer
        self.logit = nn.Linear(d_model,self.vocab)
        
    def encode(self,src_words,src_mask):
        #src_words is question and src_mask is the question mask
        src_embeddings = self.embed(src_words)
        
        # for each encoder layer supply updated parameter inputs from output of previous layer
        for layer in self.encoder:
            src_embeddings = layer(src_embeddings, src_mask)
        return src_embeddings
    
    def decode(self,target_words,target_mask, src_embeddings,src_mask):
        # src mask prevents attention over padded words
        # target mask is mask for self attention to rpevent peeking
        
        # use same embeddings because question and reply share the same vocab - different for translation tasks
        tgt_embeddings = self.embed(target_words)
        
        # for each encoder layer supply updated parameter inputs from output of previous layer
        for layer in self.decoder:
            tgt_embeddings = layer(tgt_embeddings,src_embeddings, src_mask, target_mask)
        return tgt_embeddings
    
    def forward(self,src_words,src_mask,target_words,target_mask):
        encoded = self.encode(src_words,src_mask)
        decoded = self.decode(target_words,target_mask,encoded,src_mask)
        out = self.logit(decoded)
        # log softmax doesnt need to be used if we use nn.crossentorpyloss because it will automatically take log and softmax
        #output, but here were usng KL divergence loss so we need to specify it manually. 
        out = F.log_softmax(out)

        return out

In [148]:
class AdamWarmup:
    '''Create an instance of this to train/update weights'''
    def __init__(self,model_size,warmup_steps,optimizer):
        self.model_size = model_size
        self.warmup_steps = warmup_steps
        self.optimizer = optimizer
        self.current_step = 0
        self.lr = 0
        
    def get_lr(self):
        '''create adamwarmup equation'''
        return self.model_size ** (-0.5) * min(self.current_step ** (-0.5), self.current_step * self.warmup_steps ** (-1.5))
    
    def step(self):
        '''called each time you update lr and will update the weights
        recall the normal step function optimizer.step() which does w<--w-lr*gradient.'''
        
        self.current_step +=1
        lr=self.get_lr()
        
        for param_group in self.optimizer.param_groups:
            #params group is a dictionary that includes the lr which we will update
            param_group['lr'] = lr
        
        self.lr=lr
        
        #update weight
        self.optimizer.step()
            

In [172]:
class LossWithLS(nn.Module):
    def __init__(self, size, smooth):
        #size is vocab size and smooth is smoothing param
        super(LossWithLS,self).__init__()
        self.criterion = nn.KLDivLoss(size_average = False, reduce = False)
        #size avg =true would average over all samples however we manually compute avg by number of words in the mask. In the
        # mask 1 means you have a word so we will sum to get num of words and divide by this
        # reduce =True would change the shape to make 1D vector which we dont want because we want to multiply predictions
        # with the mask later on and the mask has a specific shape we want to follow. 
        # to get average the words correspond to a value of 1 so we will sum up all of mask. 
        self.confidence = 1 - smooth
        self.smooth = smooth
        self.size = size
        
    def forward(self, prediction, target, mask):
        '''takes the prediction, target, and mask for the target 
         prediction comes out of logit layer in decoder --> (batchsize,max_words,vocab_size)
            target mask comes from create_mask function which output question mask, reply_input_mask (utilized 
            in transformer), and reply_target_mask (utilized in loss function)
            target and mask have shape (batch_size,max_words)'''
            
        # do reshaping because we are processing words in parallel rather than in sequence:
        #reshape prediction (batchsize,maxwords,vocabsize) --> (batchsize*maxwords,vocabsize)
        prediction = prediction.view(-1,prediction.size(-1))
        # reshape target and mask (batchsize,maxwords) --> (batchsize*maxwords)
        # note target is indicies of correct vocab so values go up to vocabsize
        # print(target.size())
        target = target.reshape(-1)
        mask = mask.float()
        # note target and mask have the same shape
        mask = mask.reshape(-1)
            
        #labels will be smoothed version of target variable
        # copy predictions to get label variable in proper shape which will be smoothed 
        labels = prediction.data.clone()
        #replace data
        labels.fill_(self.smooth / self.size-1)
        # place confidence value in index of correct class
        # scatter function will create labels in which we want to minimize KL div
        # input first dimention(vocab size),  the indicies of the target, and with smoothing
        # also note target needs to be same size as labels in scatter so unsqueeze
        labels.scatter(1,target.data.unsqueeze(1),self.confidence)
        # now labels have smoothed based on the true target to minimize KL divergence with prediction
        
        #calculate loss
        loss = self.criterion(prediction,labels)   #(batchsize*maxwords,vocabsize) shape is retained
        # because we specified False in instance of KL div
        
        # take sum of loss over dim=1(vocab) to get 1D tensor, mask is same shape as 1D tensor 
        # so can be multiplied to mask out the padded words. sum losses for non padded words. Get
        # average loss of non padded words by dividing by number of non padded words
        loss = (loss.sum(1) * mask).sum() / mask.sum()

        return loss

In [174]:
d_model=512
heads = 8
num_layers= 1  # in paper it is 6
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
epochs = 1  # in paper its 25

with open('WORDMAP_corpus.json','r') as j:
    word_map = json.load(j)
    
    
transformer = Transformer(d_model=d_model, heads = heads, num_layers = num_layers, word_map = word_map)
transformer.to(device)

# parameters is all weights, define lr as 0 because adam warmup will change it and set it
adam_optimizer = torch.optim.Adam(transformer.parameters(),lr=0, betas = (0.9, 0.98), eps = 1e-9)
transformer_optimizer = AdamWarmup(model_size = d_model, warmup_steps = 4000, optimizer = adam_optimizer)

#size is vocab size aka length of wordmap
criterion = LossWithLS(size = len(word_map), smooth = 0.3)

In [175]:
def train(train_loader, transformer, criterion, epoch):
    '''recieves train_laoder which is in charge of loading in batches, the model which is the transformer, 
    criterion which is loss function, and epoch number'''
    
    # specify training mode because we used batchnorm which acts differently in train and test time
    transformer.train()
    sum_loss = 0
    #take incriment according to number of samles
    count = 0
    
    # we can see from train loader function that it returns question and reply
    for i, (question, reply) in enumerate(train_loader):
        
        #get batch size
        samples = question.shape[0]
        
        #move to device
        question = question.to(device)
        reply = reply.to(device)
        
        # prepare encoder input, take whole batch and everything in maxwords except last word which is end token
        reply_input = reply[:,:-1]
        #dont take start token
        reply_target = reply[:,1:]
        #print(reply_target.size())
        #eg sentence: <start> I went home <end>
        # input: <start> I went home
        # target aka decoder prediction: I went home <end>
        
        # create masks
        question_mask, reply_input_mask, reply_target_mask = create_masks(question, reply_input, reply_target)
        
        # run inputs and masks through transformer
        out = transformer(question,question_mask,reply_input,reply_input_mask)
        
        #calculate loss
        loss= criterion(out,reply_target, reply_target_mask)
        
        #backprop
        # zero out gradient at each new batch
        transformer_optimizer.optimizer.zero_grad()
        #calcualate gradient
        loss.backward()
        #update weight
        transformer_optimizer.step()
        
        #calculate statistics
        sum_loss+=loss.item()*samples  # only get value not whole tensor. At last batch for an epoch the count = len(train_loader)*100
        count += samples
        if i% 100 == 0:
            print('Epoch [{}][{}]/[{}]\tLoss: {:.3f}'.format(epoch,i,len(train_loader),sum_loss/count))

In [176]:
def evaluate(transformer, question, question_mask, max_len, word_map):
    '''Performs greedy search decoding with batch size of 1.'''
    # output of transformer is indicies so we want to do opposite of wordmap
    rev_word_map = {v:k for k,v in word_map.items()}
    
    transformer.eval()
    # create start token to tell decoder to start decoding, get from wordmap
    start_token = word_map['<start>']
    # get encoded representation
    encoded = transformer.encode(question,question_mask)
    words = torch.LongTensor([[start_token]]).to(device)  #must be 2D matrix so 2 brackets. word -->(1,1)
    
    #concat generated words frm decoder to next input of decoder
    # note first token is start token
    for step in range(max_len - 1):
        size = words.shape[0]  # in first iteration will be 1 then will increase
        
        #create target mask, cant use create mask funtion because size is varying eg size starts with just 1 word
        # so for each step impliment new target mask
        target_mask = torch.triu(torch.ones(size,size)).transpose(0,1).type(dtype=torch.uint8)
        # need to unsqueeze because we need it in 4D
        target_mask = target_mask.to(device).unsqueeze(0)
        
        # decode 
        decoded = transformer.decode(words,target_mask, encoded,question_mask)
        #output layer
        
        #decoded is shape (batchsize=1,maxwords=1,vocabsize)
        # take only last element, we only want prediction of last word so we want to take final decoded output
        predictions = transformer.logit(decoded[:,-1])
        # predictions is shape (1,vocabsize)
        #predict next word by taking max of output
        _, next_word = torch.max(predictions,dim=1)  # next_word -->(1,1)
        #since we called torch it returns a tensor, extract the value
        next_word = next_word.item()
        
        #check to see if were at the end
        if next_word == word_map['<end>']:
            break
            
        words = torch.cat([words,torch.LongTensor([[next_word]]).to(device)], dim=1) # shape is now (1,step+2) because on dim=1
        
        # at end we have rpedicted words concat together
        #turn into sentence by getting rid of the 0th dimention, also in target we dont want start token
    words = words.squeeze(0)
    words = words.tolist()
    sen_idx = [w for w in words if w not in {word_map['<start>']}]
    #convert indices to words #convert to string
    sentence = " ".join(rev_word_map[sen_idx[k]] for k in range(len(sen_idx)))
        
    return sentence 

In [177]:
#training
for epoch in range(epochs):
    train(train_loader, transformer, criterion, epoch)
    #save checkpoint(saving a dictionary) - model weights, optimizer, and model, and epoch we're at
    state = {'epoch':epoch, 'transformer':transformer, 
             'transformer_optimizer':transformer_optimizer}
    torch.save(state, 'checkpoint_'+str(epoch)+'.tar')

torch.Size([100, 26])


  out = F.log_softmax(out)


tensor(0., grad_fn=<DivBackward0>)
Epoch [0][0]/[2217]	Loss: 0.000
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])

torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0.

tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
Epoch [0][300]/[2217]	Loss: 0.000
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26

torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0.

tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward

torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0.

tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward

tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward

torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0.

tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
Epoch [0][1300]/[2217]	Loss: 0.000
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 2

torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0.

tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward

torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0.

tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward

torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0.

tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward0>)
torch.Size([100, 26])
tensor(0., grad_fn=<DivBackward

In [None]:
# function for evaluation on users own input
# load checkpoint so if you cleared the kernel and ran this it owuld load the checkpoints with the
# weights and transformer model
# usually include epoch number, we will train it for 1 epoch so put a zero
checkpoint = torch.load('checkpoint_0.tar')
#load model and its weights
transformer = checkpoint['transformer']

#perform eval- go until user quits
while(1):
    #get user input
    question = input("Question: ")
    if question == 'quit':
        break
    max_len = input("Enter max words to be generated: ")
    
    #encode questions using wordmap
    enc_question =[word_map.get(word,word_map['<unk>']) for word in question.split()] 
    #transform to long tensor and get index of each word
    question = torch.LongTensor(enc_question).to(device).unsqueeze(0)
    #create mask for all non padded words (aka not equal to zero)
    #unsqueeze twice because 4D
    question_mask = (question!=0).to(device).unsqueeze(1).unsqueeze(1)
    #evaluate function to return sentence as a string
    sentence = evaluate(transformer,question,question_mask,int(max_len),word_map)
    print(sentence)

Question: how are you
Enter max words to be generated: 3
ups utoldu
Question: what is the weather?
Enter max words to be generated: 5
fruitcake tales eliminate —
