In [1]:
from collections import Counter
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torch.utils.data
import math
import torch.nn.functional as F

In [2]:
import os
!mkdir -p data
os.chdir("/content/data")

In [3]:
!wget http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip
!unzip -o cornell_movie_dialogs_corpus.zip

# os.chdir("/content")

--2021-02-16 10:02:18--  http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip
Resolving www.cs.cornell.edu (www.cs.cornell.edu)... 132.236.207.36
Connecting to www.cs.cornell.edu (www.cs.cornell.edu)|132.236.207.36|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9916637 (9.5M) [application/zip]
Saving to: ‘cornell_movie_dialogs_corpus.zip’


2021-02-16 10:02:19 (12.4 MB/s) - ‘cornell_movie_dialogs_corpus.zip’ saved [9916637/9916637]

Archive:  cornell_movie_dialogs_corpus.zip
   creating: cornell movie-dialogs corpus/
  inflating: cornell movie-dialogs corpus/.DS_Store  
   creating: __MACOSX/
   creating: __MACOSX/cornell movie-dialogs corpus/
  inflating: __MACOSX/cornell movie-dialogs corpus/._.DS_Store  
  inflating: cornell movie-dialogs corpus/chameleons.pdf  
  inflating: __MACOSX/cornell movie-dialogs corpus/._chameleons.pdf  
  inflating: cornell movie-dialogs corpus/movie_characters_metadata.txt  
  inflating: cornell movie-dialog

In [4]:
corpus_movie_conv = 'cornell movie-dialogs corpus/movie_conversations.txt'
corpus_movie_lines = 'cornell movie-dialogs corpus/movie_lines.txt'
max_len = 25

In [5]:
with open(corpus_movie_conv, 'r',
           encoding='utf-8',
          errors='ignore') as c:
    conv = c.readlines()

In [6]:
with open(corpus_movie_lines, 'r',
          encoding='utf-8',
          errors='ignore') as l:
    lines = l.readlines()

In [7]:
lines_dic = {}
for line in lines:
    objects = line.split(" +++$+++ ")
    lines_dic[objects[0]] = objects[-1]

In [8]:
def remove_punc(string):
    punctuations = '''!()-[]{};:'"\,<>./?@#$%^&*_~'''
    no_punct = ""
    for char in string:
        if char not in punctuations:
            no_punct = no_punct + char  # space is also a character
    return no_punct.lower()

In [9]:
pairs = []
for con in conv:
    ids = eval(con.split(" +++$+++ ")[-1])
    for i in range(len(ids)):
        qa_pairs = []
        
        if i==len(ids)-1:
            break
        
        first = remove_punc(lines_dic[ids[i]].strip())      
        second = remove_punc(lines_dic[ids[i+1]].strip())
        qa_pairs.append(first.split()[:max_len])
        qa_pairs.append(second.split()[:max_len])
        pairs.append(qa_pairs)

In [10]:
word_freq = Counter()
for pair in pairs:
    word_freq.update(pair[0])
    word_freq.update(pair[1])

In [11]:
min_word_freq = 5
words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq]
word_map = {k: v + 1 for v, k in enumerate(words)}
word_map['<unk>'] = len(word_map) + 1
word_map['<start>'] = len(word_map) + 1
word_map['<end>'] = len(word_map) + 1
word_map['<pad>'] = 0

In [12]:
print("Total words are: {}".format(len(word_map)))

Total words are: 18190


In [13]:
with open('WORDMAP_corpus.json', 'w') as j:
    json.dump(word_map, j)

In [14]:
def encode_question(words, word_map):
    enc_c = [word_map.get(word, word_map['<unk>']) for word in words] + [word_map['<pad>']] * (max_len - len(words))
    return enc_c

In [15]:
def encode_reply(words, word_map):
    enc_c = [word_map['<start>']] + [word_map.get(word, word_map['<unk>']) for word in words] + \
    [word_map['<end>']] + [word_map['<pad>']] * (max_len - len(words))
    return enc_c

In [16]:
pairs_encoded = []
for pair in pairs:
    qus = encode_question(pair[0], word_map)
    ans = encode_reply(pair[1], word_map)
    pairs_encoded.append([qus, ans])

In [17]:
with open('pairs_encoded.json', 'w') as p:
    json.dump(pairs_encoded, p)

In [18]:
# rev_word_map = {v: k for k, v in word_map.items()}
# ' '.join([rev_word_map[v] for v in pairs_encoded[1][0]])

In [19]:
class Dataset(Dataset):

    def __init__(self):

        self.pairs = json.load(open('pairs_encoded.json'))
        self.dataset_size = len(self.pairs)

    def __getitem__(self, i):
        
        question = torch.LongTensor(self.pairs[i][0])
        reply = torch.LongTensor(self.pairs[i][1])
            
        return question, reply

    def __len__(self):
        return self.dataset_size

In [20]:
train_loader = torch.utils.data.DataLoader(Dataset(),
                                           batch_size = 100, 
                                           shuffle=True, 
                                           pin_memory=True)

In [21]:
# question, reply = next(iter(train_loader))

In [22]:
def create_masks(question, reply_input, reply_target):
    
    def subsequent_mask(size):
        mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
        return mask.unsqueeze(0)
    
    question_mask = question!=0
    question_mask = question_mask.to(device)
    question_mask = question_mask.unsqueeze(1).unsqueeze(1)         # (batch_size, 1, 1, max_words)
     
    reply_input_mask = reply_input!=0
    reply_input_mask = reply_input_mask.unsqueeze(1)  # (batch_size, 1, max_words)
    reply_input_mask = reply_input_mask & subsequent_mask(reply_input.size(-1)).type_as(reply_input_mask.data) 
    reply_input_mask = reply_input_mask.unsqueeze(1) # (batch_size, 1, max_words, max_words)
    reply_target_mask = reply_target!=0              # (batch_size, max_words)
    
    return question_mask, reply_input_mask, reply_target_mask

In [23]:
class Embeddings(nn.Module):
    """
    Implements embeddings of the words and adds their positional encodings. 
    """
    def __init__(self, vocab_size, d_model, max_len = 50, num_layers = 6):
        super(Embeddings, self).__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(0.1)
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pe = self.create_positinal_encoding(max_len, self.d_model)     # (1, max_len, d_model)
        self.te = self.create_positinal_encoding(num_layers, self.d_model)  # (1, num_layers, d_model)
        self.dropout = nn.Dropout(0.1)
        
    def create_positinal_encoding(self, max_len, d_model):
        pe = torch.zeros(max_len, d_model).to(device)
        for pos in range(max_len):   # for each position of the word
            for i in range(0, d_model, 2):   # for each dimension of the each position
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))
        pe = pe.unsqueeze(0)   # include the batch size
        return pe
        
    def forward(self, embedding, layer_idx):
        if layer_idx == 0:
            embedding = self.embed(embedding) * math.sqrt(self.d_model)
        embedding += self.pe[:, :embedding.size(1)]   # pe will automatically be expanded with the same batch size as encoded_words
        # embedding: (batch_size, max_len, d_model), te: (batch_size, 1, d_model)
        embedding += self.te[:, layer_idx, :].unsqueeze(1).repeat(1, embedding.size(1), 1)
        embedding = self.dropout(embedding)
        return embedding

In [24]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self, heads, d_model):
        
        super(MultiHeadAttention, self).__init__()
        assert d_model % heads == 0
        self.d_k = d_model // heads
        self.heads = heads
        self.dropout = nn.Dropout(0.1)
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        self.concat = nn.Linear(d_model, d_model)
        
    def forward(self, query, key, value, mask):
        """
        query, key, value of shape: (batch_size, max_len, 512)
        mask of shape: (batch_size, 1, 1, max_words)
        """
        # (batch_size, max_len, 512)
        query = self.query(query)
        key = self.key(key)        
        value = self.value(value)   
        
        # (batch_size, max_len, 512) --> (batch_size, max_len, h, d_k) --> (batch_size, h, max_len, d_k)
        query = query.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)   
        key = key.view(key.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)  
        value = value.view(value.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)  
        
        # (batch_size, h, max_len, d_k) matmul (batch_size, h, d_k, max_len) --> (batch_size, h, max_len, max_len)
        scores = torch.matmul(query, key.permute(0,1,3,2)) / math.sqrt(query.size(-1))
        scores = scores.masked_fill(mask == 0, -1e9)    # (batch_size, h, max_len, max_len)
        weights = F.softmax(scores, dim = -1)           # (batch_size, h, max_len, max_len)
        weights = self.dropout(weights)
        # (batch_size, h, max_len, max_len) matmul (batch_size, h, max_len, d_k) --> (batch_size, h, max_len, d_k)
        context = torch.matmul(weights, value)
        # (batch_size, h, max_len, d_k) --> (batch_size, max_len, h, d_k) --> (batch_size, max_len, h * d_k)
        context = context.permute(0,2,1,3).contiguous().view(context.shape[0], -1, self.heads * self.d_k)
        # (batch_size, max_len, h * d_k)
        interacted = self.concat(context)
        return interacted 

In [25]:
class FeedForward(nn.Module):

    def __init__(self, d_model, middle_dim = 2048):
        super(FeedForward, self).__init__()
        
        self.fc1 = nn.Linear(d_model, middle_dim)
        self.fc2 = nn.Linear(middle_dim, d_model)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = self.fc2(self.dropout(out))
        return out

In [26]:
class EncoderLayer(nn.Module):

    def __init__(self, d_model, heads):
        super(EncoderLayer, self).__init__()
        self.layernorm = nn.LayerNorm(d_model)
        self.self_multihead = MultiHeadAttention(heads, d_model)
        self.feed_forward = FeedForward(d_model)
        self.dropout = nn.Dropout(0.1)

    def forward(self, embeddings, mask):
        interacted = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, mask))
        interacted = self.layernorm(interacted + embeddings)
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        encoded = self.layernorm(feed_forward_out + interacted)
        return encoded

In [27]:
class DecoderLayer(nn.Module):
    
    def __init__(self, d_model, heads):
        super(DecoderLayer, self).__init__()
        self.layernorm = nn.LayerNorm(d_model)
        self.self_multihead = MultiHeadAttention(heads, d_model)
        self.src_multihead = MultiHeadAttention(heads, d_model)
        self.feed_forward = FeedForward(d_model)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, embeddings, encoded, src_mask, target_mask):
        query = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, target_mask))
        query = self.layernorm(query + embeddings)
        interacted = self.dropout(self.src_multihead(query, encoded, encoded, src_mask))
        interacted = self.layernorm(interacted + query)
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        decoded = self.layernorm(feed_forward_out + interacted)
        return decoded

In [28]:
class Transformer(nn.Module):
    
    def __init__(self, d_model, heads, num_layers, word_map):
        super(Transformer, self).__init__()
        
        self.d_model = d_model
        self.num_layers = num_layers
        self.vocab_size = len(word_map)
        self.embed = Embeddings(self.vocab_size, d_model, num_layers = num_layers)
        self.encoder = EncoderLayer(d_model, heads) 
        self.decoder = DecoderLayer(d_model, heads)
        self.logit = nn.Linear(d_model, self.vocab_size)
        
    def encode(self, src_embeddings, src_mask):
        for i in range(self.num_layers):
            src_embeddings = self.embed(src_embeddings, i)
            src_embeddings = self.encoder(src_embeddings, src_mask)
        return src_embeddings
    
    def decode(self, tgt_embeddings, target_mask, src_embeddings, src_mask):
        for i in range(self.num_layers):
            tgt_embeddings = self.embed(tgt_embeddings, i)
            tgt_embeddings = self.decoder(tgt_embeddings, src_embeddings, src_mask, target_mask)
        return tgt_embeddings
        
    def forward(self, src_words, src_mask, target_words, target_mask):
        encoded = self.encode(src_words, src_mask)
        decoded = self.decode(target_words, target_mask, encoded, src_mask)
        out = F.log_softmax(self.logit(decoded), dim = 2)
        return out

In [29]:
class AdamWarmup:
    
    def __init__(self, model_size, warmup_steps, optimizer):
        
        self.model_size = model_size
        self.warmup_steps = warmup_steps
        self.optimizer = optimizer
        self.current_step = 0
        self.lr = 0
        
    def get_lr(self):
        return self.model_size ** (-0.5) * min(self.current_step ** (-0.5), self.current_step * self.warmup_steps ** (-1.5))
        
    def step(self):
        # Increment the number of steps each time we call the step function
        self.current_step += 1
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        # update the learning rate
        self.lr = lr
        self.optimizer.step()       

In [30]:
# def maskNLLLoss(inp, target, mask):
#     nTotal = mask.sum()
#     crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
#     loss = crossEntropy.masked_select(mask).mean()
#     loss = loss.to(device)
#     return loss, nTotal.item()

The torch.gather function (or torch.Tensor.gather) is a multi-index selection method. Look at the following example from the official docs:

Let's start with going through the semantics of the different arguments: The first argument, input, is the source tensor that we want to select elements from. The second, dim, is the dimension (or axis in tensorflow/numpy) that we want to collect along. And finally, index are the indices to index input. As for the semantics of the operation, this is how the official docs explain it:

So let's go through the example.

the input tensor is [[1, 2], [3, 4]], and the dim argument is 1, i.e. we want to collect from the second dimension. The indices for the second dimension are given as [0, 0] and [1, 0].

As we "skip" the first dimension (the dimension we want to collect along is 1), the first dimension of the result is implicitly given as the first dimension of the index. That means that the indices hold the second dimension, or the column indices, but not the row indices. Those are given by the indices of the index tensor itself. For the example, this means that the output will have in its first row a selection of the elements of the input tensor's first row as well, as given by the first row of the index tensor's first row. As the column-indices are given by [0, 0], we therefore select the first element of the first row of the input twice, resulting in [1, 1]. Similarly, the elements of the second row of the result are a result of indexing the second row of the input tensor by the elements of the second row of the index tensor, resulting in [4, 3].
https://stackoverflow.com/questions/50999977/what-does-the-gather-function-do-in-pytorch-in-layman-terms

In [31]:
t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1,  1],
#        [ 4,  3]])

In [32]:
## ORGINAL CODE

## ORGINAL
# target shape: torch.Size([64])
# prediction shape: torch.Size([64, 7826])
# mask shape: torch.Size([64])

## THIS CODE
# target shape: torch.Size([2600])
# prediction shape: torch.Size([2600, 18190])
# mask shape: torch.Size([100, 26])

In [33]:
class maskNLLLoss(nn.Module):

    def __init__(self, size, smooth):
        super(maskNLLLoss, self).__init__()
        # self.criterion = nn.KLDivLoss(size_average=False, reduce=False)
        # self.confidence = 1.0 - smooth
        # self.smooth = smooth
        # self.size = size
        self.ls = nn.LogSoftmax(dim=1)
        self.nllloss = nn.NLLLoss()
        
    def forward(self, prediction, target, mask):
        """
        prediction of shape: (batch_size, max_words, vocab_size)
        target and mask of shape: (batch_size, max_words)
        """
        mask = mask.to(device)
        nTotal = mask.sum()
        
        target = target.contiguous().view(-1)   # (batch_size * max_words)
        # labels = prediction.data.clone()
        prediction = prediction.view(-1, prediction.size(-1))   # (batch_size * max_words, vocab_size)
        mask = mask.contiguous().view(-1)   # (batch_size * max_words)
        output = self.ls(prediction)
        loss = self.nllloss(output, target)
        # print(f"target shape: {target.shape}") ## target shape: torch.Size([2600])
        # print(f"prediction shape: {prediction.shape}") ## prediction shape: torch.Size([2600, 18190])
        # print(f"mask shape: {mask.shape}") ## mask shape: torch.Size([100, 26]). here 100 is batch
        # crossEntropy = -torch.log(torch.gather(prediction.data, 1, target.data.view(-1, 1)).squeeze(1))
        # crossEntropy.requres_grad = True
        # loss = crossEntropy.masked_select(mask).mean()
        # loss.requres_grad = True
        # print(f"Loss:{loss}")
        # loss = loss.to(device)
        
        # loss = Variable(loss, requires_grad = True)
        return loss

In [34]:
class LossWithLS(nn.Module):

    def __init__(self, size, smooth):
        super(LossWithLS, self).__init__()
        self.criterion = nn.KLDivLoss(size_average=False, reduce=False)
        self.confidence = 1.0 - smooth
        self.smooth = smooth
        self.size = size
        
    def forward(self, prediction, target, mask):
        """
        prediction of shape: (batch_size, max_words, vocab_size)
        target and mask of shape: (batch_size, max_words)
        """
        prediction = prediction.view(-1, prediction.size(-1))   # (batch_size * max_words, vocab_size)
        target = target.contiguous().view(-1)   # (batch_size * max_words)
        mask = mask.float()
        mask = mask.view(-1)       # (batch_size * max_words)
        labels = prediction.data.clone()
        labels.fill_(self.smooth / (self.size - 1))
        labels.scatter_(1, target.data.unsqueeze(1), self.confidence)
        loss = self.criterion(prediction, labels)    # (batch_size * max_words, vocab_size)
        loss = (loss.sum(1) * mask).sum() / mask.sum()
        return loss

In [35]:
d_model = 512
heads = 8
num_layers = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 10

with open('WORDMAP_corpus.json', 'r') as j:
    word_map = json.load(j)
    
transformer = Transformer(d_model = d_model, heads = heads, num_layers = num_layers, word_map = word_map)
transformer = transformer.to(device)
adam_optimizer = torch.optim.Adam(transformer.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
transformer_optimizer = AdamWarmup(model_size = d_model, warmup_steps = 4000, optimizer = adam_optimizer)
# criterion = LossWithLS(len(word_map), 0.2)  ## changed loss function
criterion = maskNLLLoss(len(word_map), 0.2)


In [36]:
def train(train_loader, transformer, criterion, epoch):
    
    transformer.train()
    sum_loss = 0
    count = 0

    for i, (question, reply) in enumerate(train_loader):
        
        samples = question.shape[0]

        # Move to device
        question = question.to(device)
        reply = reply.to(device)

        # Prepare Target Data
        reply_input = reply[:, :-1]
        reply_target = reply[:, 1:]

        # Create mask and add dimensions
        question_mask, reply_input_mask, reply_target_mask = create_masks(question, reply_input, reply_target)

        # Get the transformer outputs
        out = transformer(question, question_mask, reply_input, reply_input_mask)

        # Compute the loss
        loss = criterion(out, reply_target, reply_target_mask)
        loss.requres_grad = True
        # Backprop
        transformer_optimizer.optimizer.zero_grad()
        loss.backward()
        transformer_optimizer.step()
        
        sum_loss += loss.item() * samples
        count += samples
        
        if i % 100 == 0:
            print("Epoch [{}][{}/{}]\tLoss: {:.3f}".format(epoch, i, len(train_loader), sum_loss/count))

In [37]:
def evaluate(transformer, question, question_mask, max_len, word_map):
    """
    Performs Greedy Decoding with a batch size of 1
    """
    rev_word_map = {v: k for k, v in word_map.items()}
    transformer.eval()
    start_token = word_map['<start>']
    encoded = transformer.encode(question, question_mask)
    words = torch.LongTensor([[start_token]]).to(device)
    
    for step in range(max_len - 1):
        size = words.shape[1]
        target_mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
        target_mask = target_mask.to(device).unsqueeze(0).unsqueeze(0)
        decoded = transformer.decode(words, target_mask, encoded, question_mask)
        predictions = transformer.logit(decoded[:, -1])
        _, next_word = torch.max(predictions, dim = 1)
        next_word = next_word.item()
        if next_word == word_map['<end>']:
            break
        words = torch.cat([words, torch.LongTensor([[next_word]]).to(device)], dim = 1)   # (1,step+2)
        
    # Construct Sentence
    if words.dim() == 2:
        words = words.squeeze(0)
        words = words.tolist()
        
    sen_idx = [w for w in words if w not in {word_map['<start>']}]
    sentence = ' '.join([rev_word_map[sen_idx[k]] for k in range(len(sen_idx))])
    
    return sentence

In [38]:
for epoch in range(epochs):
    
    train(train_loader, transformer, criterion, epoch)
    
    state = {'epoch': epoch, 'transformer': transformer, 'transformer_optimizer': transformer_optimizer}
    torch.save(state, 'checkpoint_' + str(epoch) + '.pth.tar')

Epoch [0][0/2217]	Loss: 10.009
Epoch [0][100/2217]	Loss: 7.575
Epoch [0][200/2217]	Loss: 5.512
Epoch [0][300/2217]	Loss: 4.579
Epoch [0][400/2217]	Loss: 4.046
Epoch [0][500/2217]	Loss: 3.707
Epoch [0][600/2217]	Loss: 3.472
Epoch [0][700/2217]	Loss: 3.302
Epoch [0][800/2217]	Loss: 3.169
Epoch [0][900/2217]	Loss: 3.059
Epoch [0][1000/2217]	Loss: 2.969
Epoch [0][1100/2217]	Loss: 2.893
Epoch [0][1200/2217]	Loss: 2.832
Epoch [0][1300/2217]	Loss: 2.775
Epoch [0][1400/2217]	Loss: 2.729
Epoch [0][1500/2217]	Loss: 2.686
Epoch [0][1600/2217]	Loss: 2.649
Epoch [0][1700/2217]	Loss: 2.616
Epoch [0][1800/2217]	Loss: 2.585
Epoch [0][1900/2217]	Loss: 2.557
Epoch [0][2000/2217]	Loss: 2.533
Epoch [0][2100/2217]	Loss: 2.510
Epoch [0][2200/2217]	Loss: 2.490
Epoch [1][0/2217]	Loss: 2.143
Epoch [1][100/2217]	Loss: 2.040
Epoch [1][200/2217]	Loss: 2.033
Epoch [1][300/2217]	Loss: 2.028
Epoch [1][400/2217]	Loss: 2.030
Epoch [1][500/2217]	Loss: 2.030
Epoch [1][600/2217]	Loss: 2.035
Epoch [1][700/2217]	Loss: 2.03

## ORGINAL
target shape: torch.Size([64])
prediction shape: torch.Size([64, 7826])
mask shape: torch.Size([64])

## THIS CODE


In [39]:
## ORGINAL CODE

## ORGINAL
# target shape: torch.Size([64])
# prediction shape: torch.Size([64, 7826])
# mask shape: torch.Size([64])

## THIS CODE
# target shape: torch.Size([2600])
# prediction shape: torch.Size([2600, 18190])
# mask shape: torch.Size([100, 26])

In [40]:
checkpoint = torch.load('checkpoint_0.pth.tar')
transformer = checkpoint['transformer']

In [41]:
while(1):
    question = input("Question: ") 
    if question == 'quit':
        break
    max_len = input("Maximum Reply Length: ")
    enc_qus = [word_map.get(word, word_map['<unk>']) for word in question.split()]
    question = torch.LongTensor(enc_qus).to(device).unsqueeze(0)
    question_mask = (question!=0).to(device).unsqueeze(1).unsqueeze(1)  
    sentence = evaluate(transformer, question, question_mask, int(max_len), word_map)
    print(sentence)

Question: How are you today?
Maximum Reply Length: 20
i dont know
Question: WHAT ARE YOU DOING?
Maximum Reply Length: 10
i dont know what about it
Question: When is the show?
Maximum Reply Length: 10
i dont know
Question: what is your name?
Maximum Reply Length: 20
i dont know
Question: how is weather today?
Maximum Reply Length: 10
i dont know
Question: how is the movie?
Maximum Reply Length: 10
i dont know


KeyboardInterrupt: ignored