In [2]:
from collections import Counter
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torch.utils.data
import math
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [4]:
# Paths to the data files from the Cornell Movie-Dialogs Corpus
corpus_movie_conv = 'cornell movie-dialogs corpus/movie_conversations.txt'
corpus_movie_lines = 'cornell movie-dialogs corpus/movie_lines.txt'

# Maximum length of sentences to consider
max_len = 25

In [5]:
# Open the file containing movie conversations
with open(corpus_movie_conv, 'r') as c:
    # Read all lines from the file and store them in a list
    conv = c.readlines()

In [6]:
# Open the file containing individual movie lines
with open(corpus_movie_lines, 'r') as l:
    # Read all lines from the file and store them in a list
    lines = l.readlines()

In [7]:
# Initialize an empty dictionary to store movie lines
lines_dic = {}

# Iterate over each line in the lines list
for line in lines:
    # Split the line into components separated by " +++$+++ "
    objects = line.split(" +++$+++ ")
    
    # Map the line's ID (first object) to its text (last object) in the dictionary
    lines_dic[objects[0]] = objects[-1]

In [8]:
def remove_punc(string):
    # Define a string of punctuation characters to be removed
    punctuations = '''!()-[]{};:'"\,<>./?@#$%^&*_~'''
    
    # Initialize an empty string to store the result
    no_punct = ""
    
    # Iterate over each character in the input string
    for char in string:
        # Check if the character is not a punctuation
        if char not in punctuations:
            # Add the character to the result string if it's not punctuation (space is also a character)
            no_punct = no_punct + char  
    
    # Convert the result string to lowercase and return it
    return no_punct.lower()

In [9]:
# Initialize an empty list to store question-answer pairs
pairs = []

# Iterate over each conversation in the conv list
for con in conv:
    
    # Extract the list of line IDs for the conversation
    ids = eval(con.split(" +++$+++ ")[-1])
    
    # Iterate over the line IDs to create QA pairs
    for i in range(len(ids)):
        # Initialize a list to hold a single QA pair
        qa_pairs = []
        
        # Skip the last line since it won't have a following line to pair with
        if i==len(ids)-1:
            break
        
        # Process and truncate the first line of the pair
        first = remove_punc(lines_dic[ids[i]].strip()) 
        
        # Process and truncate the second line of the pair
        second = remove_punc(lines_dic[ids[i+1]].strip())
        
        # Add the processed lines to the qa_pairs list
        qa_pairs.append(first.split()[:max_len])
        qa_pairs.append(second.split()[:max_len])
        
        # Add the QA pair to the pairs list
        pairs.append(qa_pairs)

In [10]:
# Initialize a Counter object to keep track of word frequencies
word_freq = Counter()

# Iterate over each question-answer pair in the pairs list
for pair in pairs:
    # Update the word frequency counter with words from the question (first part of the pair)
    word_freq.update(pair[0])
    
    # Update the word frequency counter with words from the answer (second part of the pair)
    word_freq.update(pair[1])

In [11]:
# Set the minimum frequency threshold for words to be included in the vocabulary
min_word_freq = 5

# Filter out words that occur less frequently than the minimum threshold
words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq]

# Create a word map (dictionary) where each word is assigned a unique integer value
word_map = {k: v + 1 for v, k in enumerate(words)}

# Add a special token '<unk>' for unknown words (words not in the word_map)
word_map['<unk>'] = len(word_map) + 1

# Add special tokens for the start and end of a sentence
word_map['<start>'] = len(word_map) + 1
word_map['<end>'] = len(word_map) + 1

# Add a special token '<pad>' for padding, assigned the integer value 0
word_map['<pad>'] = 0

In [12]:
# Print the total number of words (unique entries) in the word_map
print("Total words are: {}".format(len(word_map)))

Total words are: 18243


In [13]:
# Open a file named 'WORDMAP_corpus.json' in write mode
with open('WORDMAP_corpus.json', 'w') as j:
    # Dump the word_map dictionary into the file as JSON
    json.dump(word_map, j)

In [14]:
def encode_question(words, word_map):
    # Encode each word in the question using the word_map. Use '<unk>' token for unknown words.
    # Add padding to the encoded question to ensure it has a consistent length of max_len.
    # The number of padding tokens added is max_len minus the length of the question.
    enc_c = [word_map.get(word, word_map['<unk>']) for word in words] + [word_map['<pad>']] * (max_len - len(words))
    return enc_c

In [15]:
def encode_reply(words, word_map):
    # Start the encoding with the '<start>' token
    # Encode each word in the reply using the word_map. Use '<unk>' token for unknown words.
    # Append the '<end>' token to signify the end of the reply
    # Add padding to the encoded reply to ensure it has a consistent length of max_len.
    # The number of padding tokens added is max_len minus the length of the reply (including start and end tokens).
    enc_c = [word_map['<start>']] + [word_map.get(word, word_map['<unk>']) for word in words] + \
    [word_map['<end>']] + [word_map['<pad>']] * (max_len - len(words))
    
    return enc_c

In [16]:
# Initialize an empty list to store the encoded question-answer pairs
pairs_encoded = []

# Iterate over each question-answer pair in the pairs list
for pair in pairs:
    # Encode the question part of the pair
    qus = encode_question(pair[0], word_map)
    
    # Encode the answer part of the pair
    ans = encode_reply(pair[1], word_map)
    
    # Append the encoded question and answer as a pair to the pairs_encoded list
    pairs_encoded.append([qus, ans])

In [17]:
# Open a file named 'pairs_encoded.json' in write mode
with open('pairs_encoded.json', 'w') as p:
    # Serialize and write the pairs_encoded list to the file as JSON
    json.dump(pairs_encoded, p)

In [18]:
# Reverse the word_map to create a reverse mapping from integers to words
# rev_word_map = {v: k for k, v in word_map.items()}

# Convert the first question (index 0) of the second pair (index 1) in pairs_encoded back to words
# ' '.join([rev_word_map[v] for v in pairs_encoded[1][0]])

In [19]:
class Dataset(Dataset):

    def __init__(self):
        
        # Load the encoded question-reply pairs from the JSON file
        self.pairs = json.load(open('pairs_encoded.json'))
        
        # Store the size of the dataset
        self.dataset_size = len(self.pairs)

    def __getitem__(self, i):
        
        # Retrieve the i-th pair from the dataset
        # Convert the question and reply from lists of integers to PyTorch tensors
        question = torch.LongTensor(self.pairs[i][0])
        reply = torch.LongTensor(self.pairs[i][1])
        
        # Return the question and reply tensors
        return question, reply

    def __len__(self):
        
        # Return the total number of pairs in the dataset
        return self.dataset_size

In [20]:
# Create a DataLoader for the training dataset
train_loader = DataLoader(Dataset(),
                          batch_size=100,   # Set the batch size to 100
                          shuffle=True,     # Enable shuffling to randomize the order of the data
                          pin_memory=True)  # Pin memory for faster data transfer to CUDA-enabled GPUs

In [22]:
# Fetch the first batch of question-reply pairs from the train_loader
# question, reply = next(iter(train_loader))

In [24]:
def create_masks(question, reply_input, reply_target):
    
    def subsequent_mask(size):
        # Create a mask for subsequent positions (upper triangular matrix)
        mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
        return mask.unsqueeze(0)
    
    # Create a mask for the question, where non-zero elements are True
    question_mask = question!=0
    question_mask = question_mask.to(device)
    question_mask = question_mask.unsqueeze(1).unsqueeze(1)  # Reshape for compatibility (batch_size, 1, 1, max_words)
    
    # Create a mask for the reply input, where non-zero elements are True
    reply_input_mask = reply_input!=0
    reply_input_mask = reply_input_mask.unsqueeze(1)  # Add dimension for batch compatibility (batch_size, 1, max_words)
    
    # Apply subsequent mask to the reply input mask
    reply_input_mask = reply_input_mask & subsequent_mask(reply_input.size(-1)).type_as(reply_input_mask.data) 
    reply_input_mask = reply_input_mask.unsqueeze(1)  # Reshape for compatibility (batch_size, 1, max_words, max_words)
    
    # Create a mask for the reply target, where non-zero elements are True
    reply_target_mask = reply_target!=0   # (batch_size, max_words)
    
    return question_mask, reply_input_mask, reply_target_mask

In [25]:
"Modified to implement Universal Transformer:"

class Embeddings(nn.Module):
    """
    Implements embeddings of the words and adds their positional encodings. 
    """
    
    def __init__(self, vocab_size, d_model, max_len=50, num_layers=6):
        super(Embeddings, self).__init__()
        
        self.d_model = d_model  # Model dimension
        self.dropout = nn.Dropout(0.1)  # Dropout layer for regularization
        self.embed = nn.Embedding(vocab_size, d_model)  # Word embedding layer

        # Create positional encodings for word positions
        self.pe = self.create_positinal_encoding(max_len, self.d_model)  # (1, max_len, d_model)

        # Create positional encodings for layer indices
        self.te = self.create_positinal_encoding(num_layers, self.d_model)  # (1, num_layers, d_model)

        self.dropout = nn.Dropout(0.1)  
        
    def create_positinal_encoding(self, max_len, d_model):
        # Function to create positional encodings
        pe = torch.zeros(max_len, d_model).to(device)
        
        for pos in range(max_len):  # for each position of the word
            for i in range(0, d_model, 2):  # for each dimension of the each position
                
                # Apply sine and cosine functions for positional encoding
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))
                
        pe = pe.unsqueeze(0)  # Add a batch dimension
        
        return pe
        
        
    def forward(self, embedding, layer_idx):
        # Forward pass of the embedding layer
        
        if layer_idx == 0:
            # Apply word embedding only for the first layer
            embedding = self.embed(embedding) * math.sqrt(self.d_model)

        # Add positional encoding for word positions
        # pe will automatically be expanded with the same batch size as encoded_words
        embedding += self.pe[:, :embedding.size(1)]  

        # Add positional encoding for the layer index
        # embedding: (batch_size, max_len, d_model), te: (batch_size, 1, d_model)
        embedding += self.te[:, layer_idx, :].unsqueeze(1).repeat(1, embedding.size(1), 1)

        # Apply dropout for regularization
        embedding = self.dropout(embedding)
        
        return embedding


In [None]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self, heads, d_model):
        
        super(MultiHeadAttention, self).__init__()
        assert d_model % heads == 0 # Ensure d_model is divisible by the number of heads
        self.d_k = d_model // heads # Dimension of each head
        self.heads = heads      # Number of attention heads
        self.dropout = nn.Dropout(0.1)
        
        # Linear layers for transforming query, key, and value
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        
        # Linear layer for concatenating outputs
        self.concat = nn.Linear(d_model, d_model)
        
        
    def forward(self, query, key, value, mask):
        """
        query, key, value of shape: (batch_size, max_len, 512)
        mask of shape: (batch_size, 1, 1, max_words)
        """
        
        # (batch_size, max_len, 512)
        query = self.query(query)
        key = self.key(key)        
        value = self.value(value)   
        
        # Split and transform the query, key, and value
        # (batch_size, max_len, 512) --> (batch_size, max_len, h, d_k) --> (batch_size, h, max_len, d_k)
        query = query.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)   
        key = key.view(key.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)  
        value = value.view(value.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)  
        
        # Compute the attention scores
        # (batch_size, h, max_len, d_k) matmul (batch_size, h, d_k, max_len) --> (batch_size, h, max_len, max_len)
        scores = torch.matmul(query, key.permute(0,1,3,2)) / math.sqrt(query.size(-1))
        # Apply the mask
        scores = scores.masked_fill(mask == 0, -1e9)    # (batch_size, h, max_len, max_len)
        # Apply softmax to get attention weights
        weights = F.softmax(scores, dim = -1)           # (batch_size, h, max_len, max_len)
        weights = self.dropout(weights)
        
        # Apply the attention weights to the value
        # (batch_size, h, max_len, max_len) matmul (batch_size, h, max_len, d_k) --> (batch_size, h, max_len, d_k)
        context = torch.matmul(weights, value)
        
        # Concatenate the heads and apply the final linear layer
        # (batch_size, h, max_len, d_k) --> (batch_size, max_len, h, d_k) --> (batch_size, max_len, h * d_k)
        context = context.permute(0,2,1,3).contiguous().view(context.shape[0], -1, self.heads * self.d_k)
        # (batch_size, max_len, h * d_k)
        interacted = self.concat(context)
        
        
        return interacted 

In [None]:
class FeedForward(nn.Module):

    def __init__(self, d_model, middle_dim = 2048):
        super(FeedForward, self).__init__()
        
        # First fully connected layer from d_model to middle_dim
        self.fc1 = nn.Linear(d_model, middle_dim)
        
        # Second fully connected layer from middle_dim back to d_model
        self.fc2 = nn.Linear(middle_dim, d_model)
        
        # Dropout layer for regularization
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        # Apply dropout to the output of the first layer then...
        # apply ReLU to the first layer
        out = F.relu(self.fc1(x))
        
        # Apply dropout to the output of the second layer
        out = self.fc2(self.dropout(out))
        
        return out

In [None]:
class EncoderLayer(nn.Module):

    def __init__(self, d_model, heads):
        super(EncoderLayer, self).__init__()
        
        # Layer normalization
        self.layernorm = nn.LayerNorm(d_model)
        
        # Multi-head self-attention mechanism
        self.self_multihead = MultiHeadAttention(heads, d_model)
        
        # Position-wise feedforward network
        self.feed_forward = FeedForward(d_model)
        
        # Dropout layer for regularization
        self.dropout = nn.Dropout(0.1)

    def forward(self, embeddings, mask):
        # Apply multi-head self-attention and then dropout to the output of the self-attention
        interacted = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, mask))
        
        # Add the input (residual connection) and apply layer normalization
        interacted = self.layernorm(interacted + embeddings)
        
        # Apply the feedforward network and then dropout to the output of the feedforward network
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        
        # Add the output of the self-attention (residual connection) and apply layer normalization
        encoded = self.layernorm(feed_forward_out + interacted)
        
        return encoded

In [None]:
class DecoderLayer(nn.Module):
    
    def __init__(self, d_model, heads):
        super(DecoderLayer, self).__init__()
        
        # Layer normalization
        self.layernorm = nn.LayerNorm(d_model)
        
        # Multi-head self-attention mechanism for the decoder
        self.self_multihead = MultiHeadAttention(heads, d_model)
        
        # Multi-head attention mechanism between the decoder and the encoder
        self.src_multihead = MultiHeadAttention(heads, d_model)
        
        # Position-wise feedforward network
        self.feed_forward = FeedForward(d_model)
        
        # Dropout layer for regularization
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, embeddings, encoded, src_mask, target_mask):
        # Apply self-attention to the decoder embeddings and then apply dropout
        query = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, target_mask))
        
        # Add the input (residual connection) followed by layer normalization
        query = self.layernorm(query + embeddings)
        
        # Apply attention between the decoder (query) and the encoder (encoded) and then apply dropout
        interacted = self.dropout(self.src_multihead(query, encoded, encoded, src_mask))
        
        # Add the output of the previous self-attention (residual connection) followed by layer normalization
        interacted = self.layernorm(interacted + query)
        
        # Apply the feedforward network and then dropout
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        
        # Add the output of the previous attention (residual connection) followed by layer normalization
        decoded = self.layernorm(feed_forward_out + interacted)
        
        return decoded

In [None]:
"Modified to implement Universal Transformer:"

class Transformer(nn.Module):
    def __init__(self, d_model, heads, num_layers, word_map):
        super(Transformer, self).__init__()
        self.d_model = d_model  # Model dimension
        self.num_layers = num_layers  # Number of layers in both the encoder and decoder
        self.vocab_size = len(word_map)  # Vocabulary size

        # Embedding layer that includes word embeddings and positional encodings
        self.embed = Embeddings(self.vocab_size, d_model, num_layers=num_layers)

        # Single encoder and decoder layer that will be reused in each layer of the stack
        self.encoder = EncoderLayer(d_model, heads)
        self.decoder = DecoderLayer(d_model, heads)

        # Final linear layer that projects the decoder output to the vocabulary size
        self.logit = nn.Linear(d_model, self.vocab_size)
        
    def encode(self, src_embeddings, src_mask):
        # Encode the source sequence
        for i in range(self.num_layers):
            
            # Apply embeddings with positional encoding for each layer
            src_embeddings = self.embed(src_embeddings, i)
            
            # Pass through the encoder layer
            src_embeddings = self.encoder(src_embeddings, src_mask)
            
        return src_embeddings
    
    def decode(self, tgt_embeddings, target_mask, src_embeddings, src_mask):
        # Decode the target sequence
        for i in range(self.num_layers):
            
            # Apply embeddings with positional encoding for each layer
            tgt_embeddings = self.embed(tgt_embeddings, i)
            
            # Pass through the decoder layer
            tgt_embeddings = self.decoder(tgt_embeddings, src_embeddings, src_mask, target_mask)
            
        return tgt_embeddings
        
    def forward(self, src_words, src_mask, target_words, target_mask):
        # Forward pass of the Transformer model:
        
        # Encode the source words
        encoded = self.encode(src_words, src_mask)
        
        # Decode the target words
        decoded = self.decode(target_words, target_mask, encoded, src_mask)
        
        # Apply the final linear layer and log softmax
        out = F.log_softmax(self.logit(decoded), dim=2)
        
        return out


In [None]:
class AdamWarmup:
    
    def __init__(self, model_size, warmup_steps, optimizer):
        
        # Initialization of the AdamWarmup class
        self.model_size = model_size # Model size parameter, typically the dimensionality of the embeddings
        self.warmup_steps = warmup_steps # Number of steps over which to warm up the learning rate
        self.optimizer = optimizer  # The optimizer to which this scheduler will be applied
        self.current_step = 0 # Initialize the current step count
        self.lr = 0     # Initialize the learning rate
        
    def get_lr(self):
        # Calculate the learning rate based on the current step
        return self.model_size ** (-0.5) * min(self.current_step ** (-0.5), self.current_step * self.warmup_steps ** (-1.5))
        
    def step(self):
        # Increment the number of steps each time we call the step function:
        
        # Increment the step count
        self.current_step += 1
        
        # Get the new learning rate
        lr = self.get_lr()
        
        # Update the learning rate for each parameter group in the optimizer
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
            
            
        # Update the class's learning rate attribute
        self.lr = lr
        
        # Perform an optimization step
        self.optimizer.step()       

In [None]:
class LossWithLS(nn.Module):

    def __init__(self, size, smooth):
        super(LossWithLS, self).__init__()
        
        # Kullback-Leibler divergence loss
        self.criterion = nn.KLDivLoss(size_average=False, reduce=False)
        
        # Confidence level for the true label
        self.confidence = 1.0 - smooth
        
        # Smoothing factor
        self.smooth = smooth 
        
        # Vocabulary size
        self.size = size
        
    def forward(self, prediction, target, mask):
        """
        prediction of shape: (batch_size, max_words, vocab_size)
        target and mask of shape: (batch_size, max_words)
        """
        # Flatten the prediction and target tensors
        prediction = prediction.view(-1, prediction.size(-1))   # Reshape to 2D (batch_size * max_words, vocab_size)
        target = target.contiguous().view(-1)   # (batch_size * max_words)
        
        # Convert the mask to float and flatten it
        mask = mask.float()
        mask = mask.view(-1) # Reshape mask to 1D (batch_size * max_words)
        
        # Create a tensor for smoothed labels
        labels = prediction.data.clone()   # Clone the prediction tensor
        labels.fill_(self.smooth / (self.size - 1)) # Fill with the smoothed value
        labels.scatter_(1, target.data.unsqueeze(1), self.confidence) # Assign confidence to the true label
        
        # Calculate the loss
        loss = self.criterion(prediction, labels)    # Compute KL divergence loss. (batch_size * max_words, vocab_size)
        loss = (loss.sum(1) * mask).sum() / mask.sum() # Apply the mask and average the loss
        
        return loss

In [None]:
# Setting hyperparameters and the device for training:

d_model = 512   # The dimensionality of the model's embeddings and hidden layers
heads = 8       # The number of attention heads in the multi-head attention layers
num_layers = 3  # The number of layers in both the encoder and decoder of the Transformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Use GPU if available, else CPU
epochs = 10     # The number of epochs for training

# Loading the word map from a JSON file
with open('WORDMAP_corpus.json', 'r') as j:
    word_map = json.load(j)  # Loading the word map from a JSON file

# Initializing the Transformer model
transformer = Transformer(d_model = d_model, heads = heads, num_layers = num_layers, word_map = word_map)
transformer = transformer.to(device)  # Move the model to the specified device (GPU/CPU)

# Setting up the Adam optimizer with specific hyperparameters
adam_optimizer = torch.optim.Adam(transformer.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)


# Initializing a custom learning rate scheduler with warmup
transformer_optimizer = AdamWarmup(model_size = d_model, warmup_steps = 4000, optimizer = adam_optimizer)

# Initializing the loss function with label smoothing
criterion = LossWithLS(len(word_map), 0.1)  # The smoothing factor is set to 0.1

In [None]:
def train(train_loader, transformer, criterion, epoch):
    
    # Set the transformer model to training mode
    transformer.train()
    
    # To accumulate the total loss
    sum_loss = 0
    
    # To count the total number of samples processed
    count = 0
    
    # Iterate over batches of data in the train_loader
    for i, (question, reply) in enumerate(train_loader):
        samples = question.shape[0]  # Number of samples in the current batch

        # Move the data to the specified device (GPU or CPU)
        question = question.to(device)
        reply = reply.to(device)

        # Prepare the input and target data for the transformer
        reply_input = reply[:, :-1] # Exclude the last token for input
        reply_target = reply[:, 1:] # Exclude the first token for target

        # Create masks for the question and reply input
        question_mask, reply_input_mask, reply_target_mask = create_masks(question, reply_input, reply_target)

        # Forward pass: compute the predicted output by the transformer
        out = transformer(question, question_mask, reply_input, reply_input_mask)

        # Compute the loss between the predicted output and the target
        loss = criterion(out, reply_target, reply_target_mask)
        
        # Backpropagation: compute the gradient of the loss with respect to the parameters
        transformer_optimizer.optimizer.zero_grad()  # Clear previous gradients
        loss.backward()  # Compute gradients
        transformer_optimizer.step() # Update parameters
        
        # Update the total loss and sample count
        sum_loss += loss.item() * samples
        count += samples
        
        # Print the average loss every 100 batches
        if i % 100 == 0:
            print("Epoch [{}][{}/{}]\tLoss: {:.3f}".format(epoch, i, len(train_loader), sum_loss/count))

In [None]:
def evaluate(transformer, question, question_mask, max_len, word_map):
    """
    Performs Greedy Decoding with a batch size of 1
    """
    
    # Reverse the word map to convert indices back to words
    rev_word_map = {v: k for k, v in word_map.items()}
    
    # Set the transformer model to evaluation mode
    transformer.eval()
    
    # Start token for decoding
    start_token = word_map['<start>']
    
    # Encode the input question
    encoded = transformer.encode(question, question_mask)
    
    # Initialize the sequence with the start token
    words = torch.LongTensor([[start_token]]).to(device)
    
    # Greedy decoding loop
    for step in range(max_len - 1):
        # Get the current sequence length
        size = words.shape[1]
        
        # Create a target mask for the current sequence
        target_mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
        target_mask = target_mask.to(device).unsqueeze(0).unsqueeze(0)
        
        # Decode the sequence so far to predict the next word
        decoded = transformer.decode(words, target_mask, encoded, question_mask)
        predictions = transformer.logit(decoded[:, -1])
        
        # Choose the word with the highest probability as the next word
        _, next_word = torch.max(predictions, dim = 1)
        next_word = next_word.item()
        
        # Stop if the end token is generated
        if next_word == word_map['<end>']:
            break
            
        # Append the next word to the sequence
        words = torch.cat([words, torch.LongTensor([[next_word]]).to(device)], dim = 1)   # (1,step+2)
        
    # Construct the output sentence
    if words.dim() == 2:
        words = words.squeeze(0)
        words = words.tolist()
    
    # Filter out the start token and convert indices to words
    sen_idx = [w for w in words if w not in {word_map['<start>']}]
    sentence = ' '.join([rev_word_map[sen_idx[k]] for k in range(len(sen_idx))])
    
    return sentence

In [None]:
for epoch in range(epochs):
    
    # Train the model for number of epochs
    train(train_loader, transformer, criterion, epoch)
    
    # Save the state of the model and optimizer after the epoch
    state = {'epoch': epoch, 'transformer': transformer, 'transformer_optimizer': transformer_optimizer}
    torch.save(state, 'checkpoint_' + str(epoch) + '.pth.tar')

In [None]:
# Load the saved checkpoint
checkpoint = torch.load('checkpoint.pth.tar')

# Retrieve the Transformer model from the checkpoint
transformer = checkpoint['transformer']

In [None]:
# Start an infinite loop for the interactive session:
while(1):
    # Take a question as input from the user
    question = input("Question: ") 
    
    # If the user types 'quit', exit the loop
    if question == 'quit':
        break
    
    # Take the maximum length for the reply as input from the user
    max_len = input("Maximum Reply Length: ")
    
    # Convert the question to a list of word indices
    enc_qus = [word_map.get(word, word_map['<unk>']) for word in question.split()]
    # The above line processes each word in the question. If a word is in the word_map, its index is used;
    # otherwise, the index for '<unk>' (unknown) is used.
    
    # Convert the list of word indices to a PyTorch tensor and move it to the specified device (GPU/CPU)
    # The unsqueeze(0) adds a batch dimension to the tensor, making it compatible with the model's input requirements.
    question = torch.LongTensor(enc_qus).to(device).unsqueeze(0)
    
    # Create a mask for the question tensor
    # This mask is used to ignore padding (zeros) in the question. The additional unsqueeze operations add necessary dimensions.
    question_mask = (question!=0).to(device).unsqueeze(1).unsqueeze(1) 
    
    # Generate a reply using the Transformer model
    # The evaluate function generates a reply based on the input question and the maximum reply length.
    sentence = evaluate(transformer, question, question_mask, int(max_len), word_map)
    
    # Print the generated reply
    print(sentence)