<a href="https://colab.research.google.com/github/BoraShruti/Chatbot/blob/main/chatbot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Chatbot Using Torchscript, Attention & Profiling

This notebook demonstrates how to build a sequence-to-sequence chatbot using PyTorch. The chatbot uses an encoder–decoder architecture with Luong attention and includes data preprocessing, vocabulary construction, model definition, training routines (with teacher forcing), evaluation functions, and checkpoint saving/loading.

The notebook is highly commented so that each part of the code is explained in detail.

Key technologies and concepts include:
- **PyTorch** for deep learning model construction and GPU/CPU computation.
- **GRU-based Encoder/Decoder** for sequence modeling.
- **Luong Attention** for weighted context during decoding.
- **Data preprocessing**: normalization, tokenization, padding, and masking.
- **Checkpointing** to save and resume model training.
- **Torch profiler** for performance insights (using `torch.profiler`).


## Section 1: Imports, Device Configuration, and Corpus Setup

> **Text Box:**
> This section loads all required libraries, sets up the computing device (GPU if available, otherwise CPU), and downloads the corpus using Convokit.
> Additional inline comments explain each import and the purpose of key lines.

In [None]:
# Import necessary libraries for deep learning, data processing, and file handling
import torch                 # Main PyTorch library for tensor operations and GPU acceleration
import torch.nn as nn        # Module for building neural networks
import torch.optim as optim  # Contains optimization algorithms like SGD and Adam
import torch.nn.functional as F  # Provides functions like softmax and loss functions
import csv                   # For CSV file read/write operations
import random                # Standard library for random number generation
import re                    # Regular expressions for string matching and manipulation
import os                    # Operating system interfaces for file system operations
import unicodedata           # For Unicode text normalization
import codecs                # For encoding and decoding operations
import json                  # To work with JSON files
import itertools             # Used for efficient looping operations, especially for batching
import math                  # Provides access to mathematical functions
from io import open          # Import open from io for file operations with encoding support
from torch.profiler import profile, record_function, ProfilerActivity  # Profiling tools to analyze performance

# Enables inline plotting for visualizations within the notebook
%matplotlib inline

# Select the computing device: if an accelerator (like a GPU) is available, use it; else, fall back to CPU
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

# Print a message signaling that corpus setup is beginning
print("Setting up corpus...")

# Install the Convokit library which is used to download and manage conversation corpora
!pip install convokit

# Import necessary components from the convokit library
from convokit import Corpus, download

# Define a custom directory where the corpus data will be stored
custom_path = "/content/drive/MyDrive/Data_cornell"

# Download the 'movie-corpus' to the specified custom path
corpus = Corpus(filename=download("movie-corpus", data_dir=custom_path))

# Confirm the corpus has been downloaded by printing the path
print("Corpus downloaded to:", custom_path)

## Section 2: Preprocessing the Corpus Data

> **Text Box:**
> This section preprocesses the raw corpus data. It includes functions to:
> 1. Print sample lines of the raw data file (to quickly verify its contents).
> 2. Load and parse conversation lines from a JSONL file. Each line represents an utterance in a conversation.
> 3. Extract sentence pairs (an input and its corresponding response) and write them to a formatted text file for subsequent training.
> Every function is extensively commented to clarify its purpose.

In [None]:
# Function to print the first 'n' lines from a file
def printLines(file, n=10):
    # Open the file in binary mode
    with open(file, 'rb') as datafile:
        # Read all lines from the file
        lines = datafile.readlines()
    # Print only the first 'n' lines
    for line in lines[:n]:
        print(line)

# Define the corpus name and construct the path to the corpus
corpus_name = "movie-corpus"
corpus = os.path.join("/content/drive/MyDrive/Data_cornell", corpus_name)

# Display sample lines from the raw corpus file to verify its content
printLines(os.path.join(corpus, "utterances.jsonl"))

# Function to load lines and conversations from a JSONL file
def loadLinesAndConversations(fileName):
    # Create two empty dictionaries: one for individual lines and one for grouped conversations
    lines = {}
    conversations = {}

    # Open the JSONL file using a specified encoding
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            # Each line in the file is parsed as a JSON object
            lineJson = json.loads(line)

            # Create a simplified line object with essential information
            lineObj = {
                "lineID": lineJson["id"],
                "characterID": lineJson["speaker"],
                "text": lineJson["text"]
            }

            # Save the line object in the lines dictionary
            lines[lineObj['lineID']] = lineObj

            # For grouping: if the conversation_id is new, create a new conversation object
            if lineJson["conversation_id"] not in conversations:
                convObj = {
                    "conversationID": lineJson["conversation_id"],
                    "movieID": lineJson["meta"]["movie_id"],
                    "lines": [lineObj]  # Initialize with the current line
                }
            else:
                # Otherwise, insert the line at the beginning of the existing conversation (for reverse chronological order)
                convObj = conversations[lineJson["conversation_id"]]
                convObj["lines"].insert(0, lineObj)
            # Save/update the conversation object in the conversations dictionary
            conversations[convObj["conversationID"]] = convObj

    # Return both dictionaries
    return lines, conversations

# Function to extract sentence pairs (input and target responses) from conversations
def extractSentencePairs(conversations):
    qa_pairs = []
    # Loop over each conversation in the dictionary
    for conversation in conversations.values():
        # Iterate over each pair of consecutive lines (ignores the last line as it has no reply)
        for i in range(len(conversation["lines"]) - 1):
            inputLine = conversation["lines"][i]["text"].strip()
            targetLine = conversation["lines"][i+1]["text"].strip()

            # Only add the pair if both input and target lines are non-empty
            if inputLine and targetLine:
                qa_pairs.append([inputLine, targetLine])
    return qa_pairs

# Process the corpus by parsing the JSONL file
print("\nProcessing corpus into lines and conversations...")
lines, conversations = loadLinesAndConversations(os.path.join(corpus, "utterances.jsonl"))

# Define the file path to save the newly formatted data
datafile = os.path.join(corpus, "formatted_movie_lines.txt")

# Define the delimiter to use when writing the new file (a tab in this case)
delimiter = str(codecs.decode('\t', "unicode_escape"))

# Write the extracted sentence pairs into a new CSV (tab delimited) file
print("\nWriting newly formatted file...")
with open(datafile, 'w', encoding='utf-8') as outputfile:
    writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
    for pair in extractSentencePairs(conversations):
        writer.writerow(pair)  # Each row represents an input-response pair

# Print sample lines from the newly formatted file to verify correctness
print("\nSample lines from file:")
printLines(datafile)

## Section 3: Vocabulary and Tokenization Setup

> **Text Box:**
> Here we define the vocabulary management and text normalization routines.
>
> **Special Tokens:**
> - PAD: Padding token to make sequences the same length
> - SOS: Start-of-Sentence token
> - EOS: End-of-Sentence token
>
> **Voc Class:** Maintains mappings between words and unique indices, counts words, and can trim rarely-used words.
>
> **Normalization Functions:** Help process the raw text into a consistent format (ASCII, lowercase, trimmed, and cleaned of non-letter characters).

In [None]:
# Define special token constants
PAD_token = 0  # Padding token
SOS_token = 1  # Start-of-sentence token
EOS_token = 2  # End-of-sentence token

# Class to build and maintain the vocabulary
class Voc:
    def __init__(self, name):
        # Name or tag for the vocabulary instance
        self.name = name
        self.trimmed = False  # Flag to indicate if trimming has been done
        self.word2index = {}  # Dictionary mapping words to indices
        self.word2count = {}  # Dictionary counting how often each word appears
        # Reverse mapping: indices to words, pre-populated with special tokens
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3  # Count of words; starts at 3 because of the three special tokens

    def addSentence(self, sentence):
        # Split the sentence into words and add each word to the vocabulary
        for word in sentence.split():
            self.addWord(word)

    def addWord(self, word):
        # If the word is new to the vocabulary, add it
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            # Otherwise, just update the word frequency
            self.word2count[word] += 1

    def trim(self, min_count):
        # Ensure that trimming happens only once
        if self.trimmed:
            return
        self.trimmed = True

        # Identify words that appear at least 'min_count' times
        keep_words = [k for k, v in self.word2count.items() if v >= min_count]
        print('keep_words {} / {} = {:.4f}'.format(len(keep_words), len(self.word2index), len(keep_words)/len(self.word2index)))

        # Reinitialize the dictionaries to only include the frequently occurring words
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3

        # Add back each word that met the minimum count threshold
        for word in keep_words:
            self.addWord(word)

MAX_LENGTH = 10  # Maximum sentence length to consider when filtering sentence pairs

# Function to convert Unicode string to ASCII
def unicodeToAscii(s):
    # Normalize the string into a canonical form and filter out combining characters
    return ''.join(c for c in unicodedata.normalize('NFD', s)
                   if unicodedata.category(c) != 'Mn')

# Function to normalize strings by converting to lowercase, trimming whitespace, and removing non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    # Add a space before punctuation for better tokenization
    s = re.sub(r"([.!?])", r" \1", s)
    # Replace any character that is not a letter or punctuation with a space
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    # Remove extra spaces
    s = re.sub(r"\s+", r" ", s).strip()
    return s

# Function to read the formatted data file and return the vocabulary and sentence pairs
def readVocs(datafile, corpus_name):
    print("Reading lines...")
    # Read entire file as a string and split into individual lines
    lines = open(datafile, encoding='utf-8').read().strip().split('\n')
    # For each line, split by tab and normalize both parts of the pair
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
    voc = Voc(corpus_name)  # Create a new vocabulary object
    return voc, pairs

# Filter function to keep only pairs where both sentences are shorter than MAX_LENGTH
def filterPair(p):
    return len(p[0].split()) < MAX_LENGTH and len(p[1].split()) < MAX_LENGTH

# Apply the filter to all sentence pairs
def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]


## Section 4: Data Preparation for Model Input

> **Text Box:**
> This section defines helper functions to convert sentences to numerical indices, pad sequences with the PAD token, and create masks so the model ignores padding during training.
> These functions ultimately generate batches of data to be fed into the model during training.
> Each function has extensive comments to clarify its purpose.

In [None]:
# Convert a sentence into a list of indices using the vocabulary mapping, and append the EOS token
def indexesFromSentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split()] + [EOS_token]

# Zero-pad a list of sequences so that they all have the same length
def zeroPadding(l, fillvalue=PAD_token):
    # itertools.zip_longest groups elements from each sequence; missing values are filled with 'fillvalue'
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

# Create a binary matrix which indicates non-PAD tokens (1 for real token, 0 for PAD)
def binaryMatrix(l, value=PAD_token):
    m = []
    for seq in l:
        # For each token in the sequence, check if it is not the padding token
        m.append([0 if token == PAD_token else 1 for token in seq])
    return m

# Prepare input variables (padded sequences and lengths) from a list of sentences
def inputVar(l, voc):
    # Convert each sentence into a sequence of indices
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    # Record the original lengths of each sequence
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    # Pad the sequences with the PAD token so that all sequences have the same length
    padList = zeroPadding(indexes_batch)
    # Convert the padded list into a PyTorch LongTensor
    padVar = torch.LongTensor(padList)
    return padVar, lengths

# Prepare output variables (padded sequences, mask matrix, and max sequence length) for the target sentences
def outputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    max_target_len = max(len(indexes) for indexes in indexes_batch)  # Find the length of the longest sequence
    padList = zeroPadding(indexes_batch)
    # Create a mask where PAD tokens are marked as 0
    mask = torch.BoolTensor(binaryMatrix(padList))
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len

# Combine input and target sentences into a single batch, properly padded and sorted by sentence length
def batch2TrainData(voc, pair_batch):
    # Sort sentence pairs in descending order of the length of the input sentence
    pair_batch.sort(key=lambda x: len(x[0].split()), reverse=True)
    # Separate input and target sentences into two separate lists
    input_batch = [pair[0] for pair in pair_batch]
    output_batch = [pair[1] for pair in pair_batch]
    # Get padded tensor and lengths for input sentences
    inp, lengths = inputVar(input_batch, voc)
    # Get padded tensor, mask, and max target sentence length for target sentences
    output, mask, max_target_len = outputVar(output_batch, voc)
    return inp, lengths, output, mask, max_target_len

# Create a small random batch to verify data preparation functionality
small_batch_size = 5
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len = batches

# Print out the prepared tensors to inspect the shapes and contents
print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)

## Section 5: Model Architecture

> **Text Box:**
> This section defines the architecture of the chatbot model:
>
> 1. **EncoderRNN:** A GRU-based encoder that processes the input embeddings and produces a hidden state for each token in the input.
>
> 2. **Attn Module:** Implements Luong-style attention. It supports three methods: 'dot', 'general', and 'concat'. The attention mechanism computes a weighted sum of encoder outputs based on the current decoder hidden state.
>
> 3. **LuongAttnDecoderRNN:** A decoder that uses the attention weights along with its GRU outputs to produce predictions for the next word in the sequence, one word at a time.

In [None]:
# EncoderRNN class using a bidirectional GRU
class EncoderRNN(nn.Module):
    def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        # Number of layers for the GRU
        self.n_layers = n_layers
        # Hidden size for the GRU
        self.hidden_size = hidden_size
        # Embedding layer to convert word indices into embeddings
        self.embedding = embedding

        # Define the bidirectional GRU. Note that input_size and hidden_size are both hidden_size
        # because the input is expected to be already embedded
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers,
                          dropout=(0 if n_layers == 1 else dropout), bidirectional=True)

    def forward(self, input_seq, input_lengths, hidden=None):
        # Convert input indices to embeddings
        embedded = self.embedding(input_seq)

        # Pack the embedded sequences for efficient processing in the GRU
        packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)

        # Pass the packed sequence through the GRU
        outputs, hidden = self.gru(packed, hidden)

        # Unpack the sequences back to padded format
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs)

        # Sum the outputs from both directions (forward and backward)
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:]
        return outputs, hidden

# Attention module implementing Luong attention
class Attn(nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        # Ensure the provided attention method is one of the allowed choices
        self.method = method
        if self.method not in ['dot', 'general', 'concat']:
            raise ValueError(f"{self.method} is not an appropriate attention method.")
        self.hidden_size = hidden_size

        # For 'general' attention, we use a linear layer to transform the encoder output
        if self.method == 'general':
            self.attn = nn.Linear(hidden_size, hidden_size)
        # For 'concat' attention, we concatenate and then transform; also define a parameter vector v
        elif self.method == 'concat':
            self.attn = nn.Linear(hidden_size * 2, hidden_size)
            self.v = nn.Parameter(torch.FloatTensor(hidden_size))

    def dot_score(self, hidden, encoder_output):
        # Element-wise multiplication and sum the result along the feature dimension
        return torch.sum(hidden * encoder_output, dim=2)

    def general_score(self, hidden, encoder_output):
        # Transform the encoder output and then compute the dot product
        energy = self.attn(encoder_output)
        return torch.sum(hidden * energy, dim=2)

    def concat_score(self, hidden, encoder_output):
        # Concatenate the hidden state and encoder output, then apply a nonlinear transformation
        energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), dim=2)).tanh()
        return torch.sum(self.v * energy, dim=2)

    def forward(self, hidden, encoder_outputs):
        # Choose the appropriate scoring function based on the specified method
        if self.method == 'general':
            attn_energies = self.general_score(hidden, encoder_outputs)
        elif self.method == 'concat':
            attn_energies = self.concat_score(hidden, encoder_outputs)
        elif self.method == 'dot':
            attn_energies = self.dot_score(hidden, encoder_outputs)

        # Transpose dimensions so that the batch is the first dimension
        attn_energies = attn_energies.t()

        # Normalize the attention scores to probabilities
        return F.softmax(attn_energies, dim=1).unsqueeze(1)

# Decoder RNN that integrates Luong attention
class LuongAttnDecoderRNN(nn.Module):
    def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
        super(LuongAttnDecoderRNN, self).__init__()
        # Save the model configuration
        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout

        # Define the embedding layer (shared with the encoder)
        self.embedding = embedding
        # Apply dropout to embeddings to prevent overfitting
        self.embedding_dropout = nn.Dropout(dropout)

        # Define the GRU layer
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))

        # Linear layer to combine GRU outputs with attention context
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        # Final output layer mapping to vocabulary size
        self.out = nn.Linear(hidden_size, output_size)

        # Attention mechanism
        self.attn = Attn(attn_model, hidden_size)

    def forward(self, input_step, last_hidden, encoder_outputs):
        # Convert current input token to its embedding
        embedded = self.embedding(input_step)
        # Apply dropout to the embedding
        embedded = self.embedding_dropout(embedded)

        # Forward pass through the GRU; note that the decoder is unidirectional
        rnn_output, hidden = self.gru(embedded, last_hidden)

        # Compute attention weights using the current GRU output and the encoder outputs
        attn_weights = self.attn(rnn_output, encoder_outputs)

        # Compute context vector as the weighted sum of encoder outputs
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))

        # Remove the time-step dimension from GRU output and context
        rnn_output = rnn_output.squeeze(0)
        context = context.squeeze(1)

        # Concatenate the GRU output and the context vector
        concat_input = torch.cat((rnn_output, context), 1)
        # Pass through a tanh activation after linear combination to produce the final representation
        concat_output = torch.tanh(self.concat(concat_input))

        # Predict the next word using a softmax layer
        output = self.out(concat_output)
        output = F.softmax(output, dim=1)
        return output, hidden

## Section 6: Training Functions

> **Text Box:**
> This section defines the training loop and associated functions:
>
> - **maskNLLLoss:** Computes the negative log-likelihood loss while ignoring padded elements.
> - **train:** Runs a single training iteration, including forward propagation, loss computation, backpropagation, and optimizer stepping.
> - **trainIters:** Manages multiple iterations of training and handles checkpoint saving and progress printing.
>
> Each function includes detailed comments explaining each step.

In [None]:
# Compute the loss while taking into account only non-PAD tokens
def maskNLLLoss(inp, target, mask):
    # Total number of non-PAD tokens in this batch
    nTotal = mask.sum()
    # Calculate the negative log-likelihood loss for the predicted probabilities corresponding to the target tokens
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    # Apply the mask to remove padded elements and compute the mean loss
    loss = crossEntropy.masked_select(mask).mean().to(device)
    return loss, nTotal.item()

# Single training iteration over one batch
def train(input_variable, lengths, target_variable, mask, max_target_len, encoder, decoder, embedding,
          encoder_optimizer, decoder_optimizer, batch_size, clip, max_length=MAX_LENGTH):
    # Zero the gradients for both encoder and decoder optimizers
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    # Move input, target, and mask tensors to the correct device (GPU or CPU)
    input_variable = input_variable.to(device)
    target_variable = target_variable.to(device)
    mask = mask.to(device)
    # Ensure lengths are on the CPU for packing
    lengths = lengths.to("cpu")

    loss = 0
    print_losses = []
    n_totals = 0

    # Run the encoder forward pass; it outputs features and hidden states
    encoder_outputs, encoder_hidden = encoder(input_variable, lengths)

    # Prepare the first input of the decoder which is the SOS token for each sentence in the batch
    decoder_input = torch.LongTensor([[SOS_token for _ in range(batch_size)]]).to(device)

    # Initialize decoder hidden state with encoder's final hidden state (for the first 'n_layers' layers)
    decoder_hidden = encoder_hidden[:decoder.n_layers]

    # Decide whether to use teacher forcing based on a random number
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    # Run decoder forward pass one time-step at a time
    if use_teacher_forcing:
        # For each time step, feed the correct token from the target as the next input
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
            # Next input is the actual target token
            decoder_input = target_variable[t].view(1, -1)
            # Compute the loss for this time step
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal
    else:
        # Without teacher forcing, use the decoder's own predictions as the next input
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
            # Get the most likely word token from the output
            _, topi = decoder_output.topk(1)
            # Prepare the token to be fed into the next time step
            decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]]).to(device)
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal

    # Backpropagate the loss
    loss.backward()

    # Clip gradients to prevent exploding gradients
    _ = nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = nn.utils.clip_grad_norm_(decoder.parameters(), clip)

    # Update model parameters
    encoder_optimizer.step()
    decoder_optimizer.step()

    return sum(print_losses) / n_totals

# Function to run multiple training iterations
def trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding,
               encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, print_every, save_every, clip,
               corpus_name, teacher_forcing_ratio, profile=False, save_model=False):
    # Create training batches for each iteration using randomly selected sentence pairs
    training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)]) for _ in range(n_iteration)]

    print('Initializing ...')
    start_iteration = 1  # Default starting iteration
    print_loss = 0

    # If resuming from a checkpoint, update start_iteration (this part can be extended)
    print("Training...")

    # Iterate over each training batch
    for iteration in range(start_iteration, n_iteration + 1):
        training_batch = training_batches[iteration - 1]
        input_variable, lengths, target_variable, mask, max_target_len = training_batch

        # Train over the current batch and get the loss
        loss = train(input_variable, lengths, target_variable, mask, max_target_len, encoder,
                     decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size, clip)
        print_loss += loss

        # Print progress every 'print_every' iterations
        if iteration % print_every == 0:
            print_loss_avg = print_loss / print_every
            print("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(
                iteration, iteration / n_iteration * 100, print_loss_avg))
            print_loss = 0

        # Save a checkpoint every 'save_every' iterations
        if iteration % save_every == 0:
            directory = os.path.join(save_dir, model_name, corpus_name, '{}-{}_{}'.format(
                encoder_n_layers, decoder_n_layers, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save({
                'iteration': iteration,
                'en': encoder.state_dict(),
                'de': decoder.state_dict(),
                'en_opt': encoder_optimizer.state_dict(),
                'de_opt': decoder_optimizer.state_dict(),
                'loss': loss,
                'voc_dict': voc.__dict__,
                'embedding': embedding.state_dict()
            }, os.path.join(directory, '{}_{}.tar'.format(iteration, 'checkpoint')))


## Section 7: Evaluation and Interaction

> **Text Box:**
> In this section we define functions for model inference:
>
> - **GreedySearchDecoder:** Uses a greedy approach to select the most likely token at each time step.
> - **evaluate:** Processes an input sentence and converts model outputs back to words.
> - **evaluateInput:** Provides an interactive loop for chatting with the bot.
> Extensive inline comments explain each step of the process.

In [None]:
# Greedy decoder that uses the encoder and decoder to generate responses one token at a time
class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder  # The pre-trained encoder model
        self.decoder = decoder  # The pre-trained decoder model

    def forward(self, input_seq, input_length, max_length):
        # Run the encoder to get outputs and hidden states
        encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)

        # Use the final hidden state of the encoder as the initial hidden state for the decoder
        decoder_hidden = encoder_hidden[:self.decoder.n_layers]

        # Initialize the decoder input with the SOS token
        decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * SOS_token

        # Containers for the output tokens and scores
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_scores = torch.zeros([0], device=device)

        # Decode one token at a time up to max_length
        for _ in range(max_length):
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
            # Choose the token with the highest probability
            decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
            # Append the chosen token and its score
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            # Prepare the token for the next iteration (add a time-step dimension)
            decoder_input = torch.unsqueeze(decoder_input, 0)
        return all_tokens, all_scores

# Evaluate an input sentence and return the decoded words
def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
    # Convert the sentence into indices using the vocabulary
    indexes_batch = [indexesFromSentence(voc, sentence)]
    # Calculate the lengths of the sentence(s)
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    # Create a tensor representing the input batch (transpose for correct dimensions)
    input_batch = torch.LongTensor(indexes_batch).transpose(0, 1).to(device)
    lengths = lengths.to("cpu")

    # Use the searcher (decoder) to generate output tokens and associated scores
    tokens, scores = searcher(input_batch, lengths, max_length)

    # Convert the output token indices back into words
    decoded_words = [voc.index2word[token.item()] for token in tokens]
    return decoded_words

# Interactive loop to chat with the bot
def evaluateInput(encoder, decoder, searcher, voc):
    while True:
        try:
            # Read user input
            input_sentence = input('> ')
            if input_sentence.lower() in ['q', 'quit']:
                break  # Exit the loop if the user types 'q' or 'quit'
            # Normalize the input sentence
            input_sentence = normalizeString(input_sentence)
            # Generate the output words from the model
            output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
            # Remove special tokens from the output
            output_words = [x for x in output_words if x not in ['EOS', 'PAD']]
            print('Bot:', ' '.join(output_words))
        except KeyError:
            # In case an unknown word is encountered, print an error message
            print("Error: Encountered unknown word.")


## Section 8: Model Initialization, Checkpoint Loading, and Training Setup

> **Text Box:**
> In this final section we configure model parameters (hidden sizes, number of layers, dropout rates, etc.),
> initialize the embedding layer, encoder, and decoder, and finally set up the training process.
> A wrapper function is provided to encapsulate training initialization (with optional checkpoint support and Wandb logging).

In [None]:
# Configuration parameters for the model
model_name = 'cb_model'        # Name or identifier for the model
attn_model = 'dot'             # Attention mechanism type to use ('dot', 'general', or 'concat')
hidden_size = 512              # Dimensionality of the hidden state and embeddings
encoder_n_layers = 2           # Number of layers in the encoder
decoder_n_layers = 2           # Number of layers in the decoder
dropout = 0.1                  # Dropout rate to avoid overfitting
batch_size = 64                # Size of each training batch

loadFilename = None            # Specify a checkpoint file to load (None to start from scratch)
checkpoint_iter = 4000         # Iteration number of the checkpoint (if applicable)

# Wrapper function to initialize and start training
def wrapper_train(config=None, profile=False, save_model=False):
    import wandb  # For logging training metrics (Weights & Biases)
    run = wandb.init(project="W&BProjectName")
    config = run.config if config is None else config

    # Set local configuration parameters; these could be overridden by 'config'
    model_name = 'cb_model'
    attn_model = 'dot'
    hidden_size = 500
    encoder_n_layers = 2
    decoder_n_layers = 2
    dropout = 0.1
    batch_size = 64

    print('Building encoder and decoder ...')

    # Initialize the embedding layer with the vocabulary size and the chosen hidden size
    embedding = nn.Embedding(voc.num_words, hidden_size)

    # Initialize the encoder using the embedding layer
    encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)

    # Initialize the decoder with attention, using the same embedding layer
    decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)

    # Move both models to the selected device (GPU or CPU)
    encoder = encoder.to(device)
    decoder = decoder.to(device)

    print('Models built and ready to go!')

    # Retrieve training hyperparameters from the configuration
    clip = config.clip
    teacher_forcing_ratio = config.tf_ratio
    learning_rate = config.lr
    decoder_learning_ratio = config.decoder_lrn_ratio
    n_iteration = 4000
    print_every = 1
    save_every = 500

    # Ensure models are in training mode (this activates dropout layers, etc.)
    encoder.train()
    decoder.train()

    print('Building optimizers ...')
    # Choose an optimizer (Adam or SGD) based on the configuration
    optimizer_fn = optim.Adam if config.optimizer == "adam" else optim.SGD
    encoder_optimizer = optimizer_fn(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optimizer_fn(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)

    # Move optimizer state tensors to the correct device
    for state in encoder_optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.to(device)

    for state in decoder_optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.to(device)

    print("Starting Training!")

    # Begin training iterations
    trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer,
               embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size,
               print_every, save_every, clip, corpus_name, config.tf_ratio, profile, save_model)
    return encoder, decoder

# Define a configuration dataclass for clearer parameter management
from dataclasses import dataclass

@dataclass
class Config:
    clip: float              # Gradient clipping value
    tf_ratio: float          # Teacher forcing ratio (probability of using true target as next input)
    lr: float                # Learning rate
    optimizer: str           # Optimizer type (e.g., 'adam' or 'sgd')
    decoder_lrn_ratio: float # Multiplier for decoder's learning rate relative to encoder's

# Start training using the configuration defined above
encoder, decoder = wrapper_train(Config(clip=0.0, tf_ratio=0.0, lr=0.0001, optimizer="adam", decoder_lrn_ratio=1.0), profile=True, save_model=True)

# Set additional training parameters (could also be part of the Config class)
clip = 100
teacher_forcing_ratio = 1.0
learning_rate = 0.001
decoder_learning_ratio = 5.0
n_iteration = 4000
print_every = 1
save_every = 500

# Ensure the models are still in training mode
encoder.train()
decoder.train()

print('Building optimizers ...')
encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)

# Move optimizer states to the selected device
for state in encoder_optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

for state in decoder_optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

print("Training iteration in progress...")
