# The goal is to access and prepare text data to train a model to act as a chatbot regarding any desired category
##### Disclaimer: This is being done on a single machine so the model is not going to be well trained

## Import the necessary libraries

In [1]:
import re
import nltk
import math
import time
import torch
import pynvml
import logging
import requests
import seaborn as sns
from tqdm import tqdm
import torch.nn as nn
from pathlib import Path
from bs4 import BeautifulSoup
import matplotlib.pyplot as plt
from wordcloud import WordCloud
from datasets import load_dataset
from nltk.corpus import stopwords
from tokenizers.models import WordLevel
from typing import Union, Optional, Callable
from torch.cuda.amp import GradScaler, autocast
from tokenizers.trainers import WordLevelTrainer
from tokenizers import Tokenizer, pre_tokenizers
from torch.utils.data.dataloader import default_collate
from generic_transformer import build_transformer, Transformer
from torch.utils.data import Dataset, DataLoader, random_split, dataset
from torch.optim.lr_scheduler import StepLR, CyclicLR, CosineAnnealingLR
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AdamW, logging as hf_logging, get_linear_schedule_with_warmup

## Configure Notebook Settings

In [2]:
sns.set()
hf_logging.set_verbosity_error()
nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Amram\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

# Create a logger

In [3]:
def create_logger(name: str, file: str, log_name: str):
    # Create the logger object
    logger = logging.getLogger(name)

    # Set the level of the logger object to DEBUG
    logger.setLevel(logging.DEBUG)

    # Create a formatter object for the logger object
    formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(name)s:%(message)s')

    # Ensure that parent directory exists
    if not Path(file).parent.parent.joinpath('Logs').exists():
        Path(file).parent.parent.joinpath('Logs').mkdir(parents=True, exist_ok=True)

    # Create error and info level logging file handlers, as well as a stream (console) handler and apply the formatter
    error_logging_file = Path(file).parent.joinpath(f'Logs\{log_name}_Errors.log')
    info_logging_file = Path(file).parent.joinpath(f'Logs\{log_name}_Info.log')
    error_file_handler = logging.FileHandler(error_logging_file)
    info_file_handler = logging.FileHandler(info_logging_file)
    error_file_handler.setLevel(logging.ERROR)
    info_file_handler.setLevel(logging.INFO)
    error_file_handler.setFormatter(formatter)
    info_file_handler.setFormatter(formatter)
    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(logging.DEBUG)
    stream_handler.setFormatter(formatter)
    logger.addHandler(error_file_handler)
    logger.addHandler(info_file_handler)
    logger.addHandler(stream_handler)
    return logger

In [4]:
logger = create_logger('TransformerChatbot', './TransformerChatbot.py', 'TransformerChatbot')

## Automated text data retrieval from gutenberg

In [5]:
def get_subcategory_link(url: str, subcategory: str, logger: logging.Logger):
    # Send a GET request to the genre page
    response = requests.get(url)

    # Check if the request was successful
    if response.status_code == 200:
        # Parse the HTML content of the genre page
        soup = BeautifulSoup(response.content, 'html.parser')

        # Find the subcategory link that matches the query
        subcategory_link_tag = soup.find('a', href=True, string=lambda text: text and subcategory.lower() in text.lower())

        # If the subcategory link was found, return the corresponding url, else log a list of available subcategories
        if subcategory_link_tag:
            subcategory_url = 'https://www.gutenberg.org' + subcategory_link_tag['href']
            return subcategory_url
        else:
            subcategory_links = soup.find_all('a', href=True)
            for link in subcategory_links:
                href = link['href']
                text = link.get_text(strip=True)
                if href and text:
                    logger.info(f"Available Options: {text}, Href: https://www.gutenberg.org{href}")

            raise ValueError(f"No subcategory link found for query: {subcategory}")
    else:
        raise ValueError(f"No subcategory link found for query: {subcategory}")


def get_book_links(subcategory_url: str, logger: logging.Logger):
    # When sending a GET request to the subcategory_url, there will be 25 books per page starting with an index of 1
    start_index = 1

    # This dictionary will keep track of all the download links and which books they are corresponding to
    txt_file_links = {}

    # I want to keep going to the next page of 25 books until there is no more books
    while True:
        # If this is the first page, the url will be the subcategory_url.
        if start_index == 1:
            paginated_url = f"{subcategory_url}"

        # Otherwise, the start_index query string must be included
        else:
            paginated_url = f"{subcategory_url}?start_index={start_index}"

        # Send a GET request to the paginated url
        response = requests.get(paginated_url)

        if response.status_code == 200:
            logger.info(f'Accessed {paginated_url}')

            # parse the html of the webpage in order to find the book links
            soup = BeautifulSoup(response.content, 'html.parser')
            book_links = soup.find_all('li', class_='booklink')

            # If no more book links are found, break the loop
            if not book_links:
                break

            # Loop through the book links and extract the title and author of the book, as well as the download link
            for i, book in enumerate(book_links):
                title_tag = book.find('span', class_='title')
                subtitle_tag = book.find('span', class_='subtitle')

                if title_tag and subtitle_tag:
                    title = title_tag.text
                    author = subtitle_tag.text

                    book_page_link = 'https://www.gutenberg.org' + book.find('a')['href']
                    book_response = requests.get(book_page_link)
                    if book_response.status_code == 200:
                        book_soup = BeautifulSoup(book_response.content, 'html.parser')
                        txt_link_tag = book_soup.find('a', href=True, string='Plain Text UTF-8')
                        if txt_link_tag:
                            txt_file_link = txt_link_tag['href']
                            txt_file_links[f'{title}_{author}'] = txt_file_link
                        else:
                            logger.info(f'No .txt link found for book {i + 1}')
                    else:
                        logger.info(f"Failed to retrieve the book page for book {i + 1}. Status code: {book_response.status_code}")
                else:
                    logger.info(f'Title or author not found for book {i + 1}')

            if len(book_links) < 25:
                break

            start_index += 25
        else:
            raise ValueError(f"No subcategory link found for query: {subcategory} - Status code: {response.status_code}")

    return txt_file_links


def get_sentences(txt_links: dict, logger: logging.Logger, intro_pct: float = 0.02):
    # A list of strings will be returned for each book found in the subcategory library in project gutenberg
    all_sentences = []

    for link in txt_links.values():
        # Construct the full URL
        full_url = 'https://www.gutenberg.org' + link

        # Fetch the .txt file content
        response = requests.get(full_url)
        if response.status_code == 200:
            content = response.text

            # Tokenize the text into words
            words = nltk.word_tokenize(content)

            # Remove the first and last 2% of the words as untext
            remaining_words = words[int(intro_pct * len(words)):-int(intro_pct * len(words))]

            # Join the remaining words back into a string
            text = ' '.join(remaining_words)

            # Remove project gutenberg out of the text since it occurs so often
            remaining_text = re.sub('project gutenberg', '', text, flags=re.IGNORECASE)
            remaining_text = re.sub('gutenberg', '', text, flags=re.IGNORECASE)

            # Split the remaining text into sentences
            sentences = nltk.sent_tokenize(remaining_text)

            # Add the sentences to the list of all sentences
            all_sentences.extend(sentences)
        else:
            print(f"Failed to retrieve the .txt file from {full_url}. Status code: {response.status_code}")

    return all_sentences


def get_text_data(subcategory: str, logger: logging.Logger):
    # Start off with project gutenberg's bookshelf
    url = 'https://www.gutenberg.org/ebooks/bookshelf/'

    # Get the link to the desired subcategory page
    subcategory_url = get_subcategory_link(url=url, subcategory=subcategory, logger=logger)

    if subcategory_url:
        # Get the list of book .txt file links from the subcategory page
        txt_links = get_book_links(subcategory_url=subcategory_url, logger=logger)

        # Get all sentences from the list of .txt file links
        all_sentences = get_sentences(txt_links=txt_links, logger=logger)

        # Print the number of sentences and the first few sentences as a sample
        logger.info(f"Total number of sentences: {len(all_sentences)}")
        return txt_links, all_sentences
    else:
        logger.warning("Failed to find subcategory link.")

In [6]:
load_data_time = time.time()
txt_links, all_sentences = get_text_data(subcategory='World War II', logger=logger)
print(f'Loading data took {time.time() - load_data_time:.2f} seconds')

2024-05-27 09:07:18,478:INFO:TransformerChatbot:Accessed https://www.gutenberg.org/ebooks/bookshelf/325
2024-05-27 09:07:21,353:INFO:TransformerChatbot:Title or author not found for book 8
2024-05-27 09:07:27,467:INFO:TransformerChatbot:Title or author not found for book 24
2024-05-27 09:07:28,225:INFO:TransformerChatbot:Accessed https://www.gutenberg.org/ebooks/bookshelf/325?start_index=26
2024-05-27 09:07:38,390:INFO:TransformerChatbot:Accessed https://www.gutenberg.org/ebooks/bookshelf/325?start_index=51
2024-05-27 09:07:48,356:INFO:TransformerChatbot:Accessed https://www.gutenberg.org/ebooks/bookshelf/325?start_index=76
2024-05-27 09:09:04,757:INFO:TransformerChatbot:Total number of sentences: 308954


Loading data took 107.14 seconds


In [7]:
txt_links

{'Integration of the Armed Forces, 1940-1965_Morris J. MacGregor': '/ebooks/20587.txt.utf-8',
 'The Homing Pigeon_United States. War Department and United States. Army. Signal Corps': '/ebooks/55084.txt.utf-8',
 'Closing In: Marines in the Seizure of Iwo Jima_Joseph H. Alexander': '/ebooks/49080.txt.utf-8',
 'Motorcycle, Solo (Harley-Davidson Model WLA)_United States. War Department': '/ebooks/51058.txt.utf-8',
 'Across the Reef: The Marine Assault of Tarawa_Joseph H. Alexander': '/ebooks/48836.txt.utf-8',
 'Portable Flame Thrower M2-2_United States. War Department': '/ebooks/53669.txt.utf-8',
 'Blesky nad Beskydami (Czech)_František Omelka': '/ebooks/47754.txt.utf-8',
 'Leyte: The Return to the Philippines_M. Hamlin Cannon': '/ebooks/48991.txt.utf-8',
 'Day of Infamy Speech: Given before the US Congress December 8 1941_Franklin D. Roosevelt': '/ebooks/21805.txt.utf-8',
 'Forward, Children!_Paul Alexander Bartlett': '/ebooks/44717.txt.utf-8',
 'Bloody Beaches: The Marines at Peleliu_Go

In [8]:
all_sentences

['existence of four black Regular Army regiments also institutionalized segregation , granting federal recognition to a system racially separate and theoretically equal in treatment and opportunity a generation before the Supreme Court sanctioned such a distinction in _Plessy_ v .',
 '_Ferguson_ .',
 '[ 1-6 ] So important to many in the black community was this guaranteed existence of the four regiments that had served with distinction against the frontier Indians that few complained about segregation .',
 'In fact , as historian Jack Foner has pointed out , black leaders sometimes interpreted demands for integration as attempts to eliminate black soldiers altogether .',
 '[ 1-7 ] [ Footnote 1-6 : 163 U.S. 537 ( 1896 ) .',
 'In this 1896 case concerning segregated seating on a Louisiana railroad , the Supreme Court ruled that so long as equality of accommodation existed , segregation could not in itself be considered discriminatory and therefore did not violate the equal rights provisi

## Run preliminary analytics on the unprocessed text data

In [21]:
def word_analytics(sentences: list, width: int = 800, height: int = 400, background_color: str = 'white'):
    # Load all the stop words from the ntlk corpus and remove them from the dataset temporarily for analysis
    remove_time = time.time()
    stop_words = set(stopwords.words('english'))
    text = [(word, len(word)) for sentence in sentences for word in sentence.split(' ') if word not in stop_words]
    print(f'removal of stop words took {time.time() - remove_time} seconds')
    
    # Build a wordcloud to visualize the most common words
    tokenize_time = time.time()
    tokens = nltk.word_tokenize(' '.join([word for word, _ in text]))
    print(f'tokenization took {time.time() - tokenize_time} seconds')
    
    processed_text = ' '.join(tokens)

    generation_time = time.time()
    wordcloud = WordCloud(width=width, height=height, background_color=background_color).generate(processed_text)
    print(f'generation took {time.time() - generation_time} seconds')

    
    plt.figure(figsize=(10, 5))
    plt.imshow(wordcloud, interpolation='bilinear')
    plt.axis('off')
    plt.show()

    # Build a word distribution in order to determine a good sequence length
    print('starting sequencing')
    sequence_time = time.time()
    sequence_lengths = [length for _, length in text]
    print(f'sequencing took {time.time() - sequence_time} seconds')
    
    fig, ax = plt.subplots()
    ax.hist(sequence_lengths, bins=50)
    ax.set_xlabel('Sequence Length')
    ax.set_ylabel('Frequency')
    plt.show()

In [None]:
word_analytics(sentences=all_sentences)

removal of stop words took 0.6951816082000732 seconds
tokenization took 16.1490535736084 seconds
generation took 9.42536449432373 seconds


## Build the Transformer Model

### Define activation functions

#### Relu (Rectified Linear Unit):
$Relu(x) = max(0, x)$


#### Softmax: 
$softmax(z_i) = \frac{e^{z_i}}{\sum^n_{j=1}e^{z_j}}$

In [None]:
def relu(x: torch.Tensor):
    return torch.max(x, torch.tensor(0.0))


def softmax(x: torch.Tensor, dim: int = -1):
    max_val, _ = torch.max(x, dim=dim, keepdim=True)
    x_exp = torch.exp(x - max_val)
    sum_x_exp = torch.sum(x_exp, dim=dim, keepdim=True)
    return x_exp / sum_x_exp

### Define the layers of the Transformer model
Input -> Tokenization -> Encoder -> Decoder -> Projection -> Output

N Encoder Blocks: Source Embedding -> Positional Encoding -> Residual Connection -> Layer Normalization -> h Multihead Attention Blocks -> Residual Connection -> Layer Normalization -> Feed Forward -> repeat N times -> Encoder Output

N Decoder Blocks: Target Embedding -> Positional Encoding -> Residual Connection -> Layer Normalization -> h Multihead Attention Blocks -> Residual Connection -> Layer Normalization -> h Cross Attention Blocks -> Feed Forward -> repeat N times -> Decoder Output

In [None]:
class InputEmbeddings(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x: torch.Tensor):
        return self.embedding(x) * math.sqrt(self.d_model)


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_len: int, dropout: float):
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(seq_len, d_model)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -math.log(10000) / d_model)

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)

        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor):
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)
        return self.dropout(x)


class LayerNormalization(nn.Module):
    def __init__(self, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(1))
        self.bias = nn.Parameter(torch.zeros(1))

    def forward(self, x: torch.Tensor):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.alpha * ((x - mean) / torch.sqrt(std + self.eps)) + self.bias


class FeedForwardBlock(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float):
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)

    def forward(self, x: torch.Tensor):
        return self.linear_2(self.dropout(relu(self.linear_1(x))))


class MultiHeadAttentionBlock(nn.Module):
    def __init__(self, d_model: int, h: int, dropout: float):
        super().__init__()
        self.d_model = d_model
        self.h = h
        assert d_model % h == 0, 'd_model is not divisible by h'
        self.d_k = d_model // h
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self._attention_scores = None

    @property
    def attention_scores(self):
        return self._attention_scores

    @attention_scores.setter
    def attention_scores(self, attention_scores: torch.Tensor):
        self._attention_scores = attention_scores

    def attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor]):
        d_k = query.shape[-1]
        attention_scores = ((query @ key.transpose(-2, -1)) / (math.sqrt(d_k)))

        if mask is not None:
            mask = mask.to(torch.bool)  # Ensure mask is boolean
            mask = mask.expand_as(attention_scores)  # Expand mask dimensions
            attention_scores = attention_scores.float()  # Ensure attention scores are float32
            attention_scores.masked_fill_(mask == 0, -1e9)

        attention_scores = softmax(attention_scores, dim=-1)

        if self.dropout is not None:
            attention_scores = self.dropout(attention_scores)

        return attention_scores @ value, attention_scores

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor]):
        query = self.w_q(q)
        key = self.w_k(k)
        value = self.w_v(v)
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
        x, self.attention_scores = self.attention(query, key, value, mask)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
        return self.w_o(x)


class ResidualConnection(nn.Module):
    def __init__(self, dropout: float):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization()

    def forward(self, x: torch.Tensor, sublayer: Callable):
        return x + self.dropout(sublayer(self.norm(x)))


class EncoderBlock(nn.Module):
    def __init__(self, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])

    def forward(self, x: torch.Tensor, src_mask: Optional[torch.Tensor]):
        x = self.residual_connections[0](x, lambda x_i: self.self_attention_block(x_i, x_i, x_i, src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x


class Encoder(nn.Module):
    def __init__(self, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization()

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
        for layer in self.layers:
            x = layer(x, mask)

        return self.norm(x)


class DecoderBlock(nn.Module):
    def __init__(self, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])

    def forward(self, x: torch.Tensor, encoder_output: torch.Tensor, src_mask: Optional[torch.Tensor], tgt_mask: Optional[torch.Tensor]):
        x = self.residual_connections[0](x, lambda x_i: self.self_attention_block(x_i, x_i, x_i, tgt_mask))
        x = self.residual_connections[1](x, lambda x_j: self.cross_attention_block(x_j, encoder_output, encoder_output, src_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        return x


class Decoder(nn.Module):
    def __init__(self, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization()

    def forward(self, x: torch.Tensor, encoder_output: torch.Tensor, src_mask: Optional[torch.Tensor], tgt_mask: Optional[torch.Tensor]):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)

        return self.norm(x)


class ProjectionLayer(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x: torch.Tensor):
        epsilon = 1e-8
        return torch.log(softmax(self.proj(x), dim=-1) + epsilon)


class Transformer(nn.Module):
    def __init__(
            self,
            encoder: Encoder,
            decoder: Decoder,
            src_embed: InputEmbeddings,
            tgt_embed: InputEmbeddings,
            src_pos: PositionalEncoding,
            tgt_pos: PositionalEncoding,
            projection_layer: ProjectionLayer
    ):

        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer

    def encode(self, src: torch.Tensor, src_mask: Optional[torch.Tensor]):
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)

    def decode(self, encoder_output: torch.Tensor, src_mask: Optional[torch.Tensor], tgt: torch.Tensor, tgt_mask: Optional[torch.Tensor]):
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)

    def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: Optional[torch.Tensor] = None, tgt_mask: Optional[torch.Tensor] = None):
        # Encode the source sequence
        encoder_output = self.encode(src, src_mask)

        # Decode the target sequence
        decoder_output = self.decode(encoder_output, src_mask, tgt, tgt_mask)

        # Project the decoder output to the vocabulary space
        output = self.projection_layer(decoder_output)

        return output

In [None]:
def build_transformer(vocab_size: int, seq_len: int, d_model: int, N: int, h: int, dropout: float, d_ff: int) -> Transformer:
    # Build the first layer of the encoder and decoder blocks, which is the input embedding
    src_embed = InputEmbeddings(d_model, vocab_size)
    tgt_embed = InputEmbeddings(d_model, vocab_size)

    # Add the second layer of positional encoding to add spatial relationships between tokens
    src_pos = PositionalEncoding(d_model, seq_len, dropout)
    tgt_pos = PositionalEncoding(d_model, seq_len, dropout)

    # Build the encoder blocks using multi head self attention and feed forward blocks
    encoder_blocks = []
    for _ in range(N):
        encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        encoder_block = EncoderBlock(encoder_self_attention_block, feed_forward_block, dropout)
        encoder_blocks.append(encoder_block)

    # Build the decoder blocks using multi head self attention, multi head cross attention, and feed forward blocks
    decoder_blocks = []
    for _ in range(N):
        decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        decoder_block = DecoderBlock(decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
        decoder_blocks.append(decoder_block)

    # Build the transformer model using the encoder, decoder, and projection layer
    encoder = Encoder(nn.ModuleList(encoder_blocks))
    decoder = Decoder(nn.ModuleList(decoder_blocks))
    projection_layer = ProjectionLayer(d_model, vocab_size)
    transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)

    # Initialize paramters using the xavier glorot uniform initialization
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return transformer

## Train Model Using either the transformer model above of use the GPT2 model for transfer learning

In [None]:
class ChatbotDataset(Dataset):
    special_tokens = []

    def __init__(self, dataset: dataset.Subset, tokenizer: Union[Tokenizer, GPT2Tokenizer], seq_len: int, logger: logging.Logger, transfer_learning: bool):
        # Call the constructor for the parent torch Dataset class
        super().__init__()

        # Store variables that will be called in the __getitem__ method
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.logger = logger
        self.transfer_learning = transfer_learning

        # Depending on the tokenizer, different sos, eos, pad, and unk tokens are assigned
        if not transfer_learning:
            self.sos_token = torch.tensor([tokenizer.token_to_id('[SOS]')], dtype=torch.int64)
            self.eos_token = torch.tensor([tokenizer.token_to_id('[EOS]')], dtype=torch.int64)
            self.pad_token = torch.tensor([tokenizer.token_to_id('[PAD]')], dtype=torch.int64)
            self.unk_token = torch.tensor([tokenizer.token_to_id('[UNK]')], dtype=torch.int64)
        else:
            self.sos_token = torch.tensor([tokenizer.convert_tokens_to_ids('<|sos|>')])
            self.eos_token = torch.tensor([tokenizer.eos_token_id])
            self.pad_token = torch.tensor([tokenizer.convert_tokens_to_ids('<|pad|>')])
            self.unk_token = torch.tensor([tokenizer.convert_tokens_to_ids('<|unk|>')])

        # Store special tokens as class objects in order to access them freely later
        self.__class__.special_tokens = list({*self.__class__.special_tokens, *[self.sos_token, self.eos_token, self.pad_token, self.unk_token]})

    def __len__(self):
        return len(self.dataset.dataset.keys())

    def __getitem__(self, index: int):
        # Extract the context and response using the DataLoader during training and validation
        context = self.dataset.dataset[index]['context']
        response = self.dataset.dataset[index]['response']

        # Tokenize the context and response using the tokenizer
        if self.transfer_learning:
            context_tokens = self.tokenizer.encode(context)
            response_tokens = self.tokenizer.encode(response)
        else:
            context_tokens = self.tokenizer.encode(context).ids
            response_tokens = self.tokenizer.encode(response).ids

        # If the length of the context or response is greater than the sequence length, truncate the context or response
        full_context_length = len(context_tokens) + 2
        full_response_length = len(response_tokens) + 2

        if full_context_length > self.seq_len:
            context_tokens = context_tokens[:self.seq_len - 2]
            full_context_length = len(context_tokens) + 2

        if full_response_length > self.seq_len:
            response_tokens = response_tokens[:self.seq_len - 2]
            full_response_length = len(response_tokens) + 2

        # Add pad tokens after sos tokens if the context or response is smaller than the sequence length
        num_enc_padding_token = self.seq_len - full_context_length
        num_dec_padding_token = self.seq_len - full_response_length

        if num_enc_padding_token < 0 or num_dec_padding_token < 0:
            self.logger.error(f'Sentence is too long - Context Token Length: {len(context_tokens)}, Response Token Length: {len(response_tokens)}')
            raise ValueError(f'Sentence is too long - Context Token Length: {len(context_tokens)}, Response Token Length: {len(response_tokens)}')

        # Create the encoder input, decoder input and the label to input into the transformer model
        encoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(context_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * num_enc_padding_token, dtype=torch.int64)
            ]
        )

        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(response_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * num_dec_padding_token, dtype=torch.int64)
            ]
        )

        label = torch.cat(
            [
                torch.tensor(response_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * (num_dec_padding_token + 1), dtype=torch.int64),
            ]
        )

        # Ensure that the data sizes are correct
        encoder_input_size = encoder_input.size(0)
        decoder_input_size = decoder_input.size(0)
        label_size = label.size(0)
        assert encoder_input_size == self.seq_len, f'sequence length must be {self.seq_len}, encoder input length is {encoder_input_size}'
        assert decoder_input_size == self.seq_len, f'sequence length must be {self.seq_len}, decoder input length is {decoder_input_size}'
        assert label_size == self.seq_len, f'sequence length must be {self.seq_len}, label length is {label_size}'

        # Encoder mask: [batch_size, 1, seq_len]
        encoder_mask = (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(1).type(torch.bool)

        # Decoder mask: [1, seq_len, seq_len]
        subsequent_mask = torch.triu(torch.ones((self.seq_len, self.seq_len), dtype=torch.uint8), diagonal=1).to(torch.bool)
        decoder_mask = (decoder_input != self.pad_token).unsqueeze(0).unsqueeze(1) & ~subsequent_mask

        return {
            'encoder_input': encoder_input,
            'decoder_input': decoder_input,
            'encoder_mask': encoder_mask,
            'decoder_mask': decoder_mask,
            'label': label,
            'context': context,
            'response': response
        }


def remove_stutter(text: str) -> str:
    # Remove any existing stutters defined as repeating letters separated by hyphens
    stutter_pattern = re.compile(r'\b(\w)-\1(\w+)\b', re.IGNORECASE)
    corrected_text = stutter_pattern.sub(r'\1\2', text)
    return corrected_text


def remove_most_common_words(sentences: list, top_n: int):
    stop_words = set(stopwords.words('english'))
    processed_sentences = []
    for sentence in sentences:
        words = sentence.split(' ')
        filtered_words = [word for word in words if word not in stop_words]
        processed_sentence = ' '.join(filtered_words)
        processed_sentences.append(processed_sentence)

    combined_text = ' '.join(processed_sentences)

    # Remove punctuation and convert to lowercase
    translator = str.maketrans('', '', string.punctuation)
    normalized_text = combined_text.translate(translator).lower()

    # Tokenize the text
    words = normalized_text.split(' ')

    # Count the frequency of each word
    word_counts = Counter(words)

    # Get the most common words
    most_common_words = word_counts.most_common(top_n)

    print(f'Top {str(top_n) + " words" if top_n > 1 else "word"} being removed from the text data:\n{most_common_words}')

    # Remove the top n common words from the sentences
    final_sentences = []
    for preprocessed_sentence in processed_sentences:
        preprocessed_words = preprocessed_sentence.split(' ')
        filtered_preprocessed_words = [preprocessed_word for preprocessed_word in preprocessed_words if preprocessed_word not in most_common_words]
        final_sentence = ' '.join(filtered_preprocessed_words)
        final_sentences.append(final_sentence)

    return final_sentences


def preprocess_text_data(dataset: list, seq_len: int, logger: logging.Logger, top_n: int = 1) -> dict:
    # The max sentence length should be at most 2 words less than the sequence length to adjust for sos and eos tokens.
    max_sentence_length = seq_len - 2

    # The text might have stutters, weird symbols or characters, or some extra white spaces so that will all be removed.
    normalized_text = []
    for t in dataset:
        sentence = remove_stutter(' '.join(t.replace('\\N', ' ').split())).lower()
        sentence = re.sub(r'[^0-9A-Za-z\s]', '', sentence)
        sentence = sentence.strip()
        sentence = re.sub(r'\s+', ' ', sentence)
        if len(sentence.split(' ')) > 1:
            normalized_text.append(sentence)

    # texts that are longer than max_sentence_length will be split into valid sentences
    text = []
    for t in normalized_text:
        sentences = []
        words = t.split(' ')
        while len(words) > max_sentence_length:
            sentences.append(' '.join(words[:max_sentence_length]))
            words = words[max_sentence_length:]

        if len(words) > 0:
            sentences.append(' '.join(words))

        text.extend(sentences)

    # Remove the top n most common words from the text data to avoid overfitting
    text = remove_most_common_words(sentences=text, top_n=top_n)

    # Prepare the text data into a dictionary that can be used to build a torch Dataset object
    text_dict = {}
    max_input_length = -1
    for i, t in enumerate(text[:-1]):
        if len(t.split(' ')) > max_sentence_length:
            logger.error(len(t))

        n_words = len(t.split(' '))
        if n_words > max_input_length:
            max_input_length = n_words

        text_dict[i] = {'context': t, 'response': text[i + 1]}

    return text_dict


def get_or_build_tokenizer(config: dict, dataset: Optional[dict] = None) -> Union[GPT2Tokenizer, Tokenizer]:
    # If tranfer learning is enabled, the GPT2 tokenizer will be loaded into memory and configured
    if config['transfer_learning']:
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

        # Define the special tokens
        sos_token = "<|sos|>"
        pad_token = "<|pad|>"
        unk_token = "<|unk|>"

        # Add the special tokens to the tokenizer
        tokenizer.add_special_tokens({
            'pad_token': pad_token,
            'unk_token': unk_token,
            'additional_special_tokens': [sos_token]
        })

        # Resize the model's token embeddings to accommodate the new tokens
        return tokenizer

    # Otherwise, if the tokenizer doesn't exist in storage, it will be trained on the text dataset
    tokenizer_path = Path(config['tokenizer_path'])
    if not tokenizer_path.exists():
        tokenizer = Tokenizer(WordLevel(unk_token='[UNK]'))

        tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
        trainer = WordLevelTrainer(
            special_tokens=['[UNK]', '[PAD]', '[SOS]', '[EOS]'],
            min_frequency=2
        )

        lexicon = [dataset[i]['context'] for i in dataset.keys()] + [dataset[len(dataset) - 1]['response']]
        tokenizer.train_from_iterator(lexicon, trainer=trainer)
        tokenizer.save(str(tokenizer_path))

    # If the tokenizer exists in storage, it will be loaded into memory
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))

    return tokenizer


def split_dataset(
        dataset: dict,
        tokenizer: Union[Tokenizer, GPT2Tokenizer],
        config: dict,
        logger: logging.Logger
) -> (ChatbotDataset, ChatbotDataset):

    # Split the dataset into training and validation sets
    train_dataset_size = int(0.9 * len(dataset))
    val_dataset_size = len(dataset) - train_dataset_size
    train_dataset_raw, val_dataset_raw = random_split(dataset, (train_dataset_size, val_dataset_size))

    # Create the torch Dataset objects
    train_dataset = ChatbotDataset(dataset=train_dataset_raw, tokenizer=tokenizer, seq_len=config['seq_len'], logger=logger, transfer_learning=config['transfer_learning'])
    val_dataset = ChatbotDataset(dataset=val_dataset_raw, tokenizer=tokenizer, seq_len=config['seq_len'], logger=logger, transfer_learning=config['transfer_learning'])

    return train_dataset, val_dataset


def custom_collate_fn(batch: list):
    # To avoid any incorrect data from loading into the training device, this function will filter out any invalid data
    filtered_batch = []
    for idx, item in enumerate(batch):
        try:
            # Ensure item has all required keys
            if not all(key in item for key in ['encoder_input', 'decoder_input', 'encoder_mask', 'decoder_mask', 'context', 'response', 'label']):
                print(f"Missing keys at index {idx} - {item}")
                continue

            encoder_input = item['encoder_input']
            decoder_input = item['decoder_input']

            if (encoder_input is not None and decoder_input is not None and
                    not torch.isnan(encoder_input).any() and not torch.isnan(decoder_input).any() and
                    not torch.isinf(encoder_input).any() and not torch.isinf(decoder_input).any()):

                filtered_batch.append(item)
            else:
                print(f"Invalid data (None, NaN, Inf) found at index {idx} - encoder_input: {encoder_input}, decoder_input: {decoder_input}")
        except BaseException as e:
            print(f"Error at index {idx} - {item}: {e}")
            continue

    # Use default_collate to collate the filtered batch
    if len(filtered_batch) == 0:
        return {
            'encoder_input': torch.tensor([]),
            'decoder_input': torch.tensor([]),
            'encoder_mask': torch.tensor([]),
            'decoder_mask': torch.tensor([]),
            'label': torch.tensor([]),
            'context': [],
            'response': [],
        }

    else:
        batch = default_collate(filtered_batch)

        # Ensure no NaN or Inf values in the batch
        for key in ['encoder_input', 'decoder_input']:
            batch[key] = torch.where(
                torch.isfinite(batch[key]), batch[key], torch.zeros_like(batch[key])
            )

        return batch


def print_gpu_utilization(logger: logging.Logger, scale: str = 'MB'):
    # Initialize the NVML library and get the handle to the first GPU device
    pynvml.nvmlInit()
    handle = pynvml.nvmlDeviceGetHandleByIndex(0)
    info = pynvml.nvmlDeviceGetMemoryInfo(handle)

    # Access the used memory by the GPU to provide console output
    if scale == 'MB':
        logger.info(f"GPU memory occupied: {info.used // 1024 ** 2:.2f} MB.")
    else:
        logger.info(f"GPU memory occupied: {info.used // 1024 ** 3:.2f} GB.")


def beam_search_decode(model: Transformer, input_tokens: torch.Tensor, seq_len: int, tokenizer: Tokenizer, device: torch.device, num_beams: int = 3):
    model.eval()
    # Initialize the decoding with the start-of-sequence (SOS) token
    sos_token = torch.tensor([tokenizer.token_to_id('[SOS]')], dtype=torch.long, device=device).unsqueeze(0)
    
    # Beam initialization: each beam is a tuple of (decoded tokens, log-probability)
    beams = [(sos_token, 0.0)]

    for _ in range(seq_len):
        new_beams = []
        for decoded_tokens, log_prob in beams:
            with torch.no_grad():
                # Forward pass through the model
                output = model.forward(input_tokens, decoded_tokens)
                
                # Get the log probabilities of the next tokens
                next_token_log_probs = torch.log_softmax(output.logits[:, -1, :], dim=-1)
                
                # Get the top `num_beams` tokens and their log probabilities
                top_next_tokens = next_token_log_probs.topk(num_beams, dim=-1)
                for i in range(num_beams):
                    next_token = top_next_tokens.indices[:, i].unsqueeze(0)
                    next_log_prob = top_next_tokens.values[:, i].item()
                    
                    # Append the predicted token to the decoded sequence
                    new_decoded_tokens = torch.cat((decoded_tokens, next_token), dim=1)
                    new_log_prob = log_prob + next_log_prob
                    new_beams.append((new_decoded_tokens, new_log_prob))
                    
        # Keep only the top `num_beams` sequences
        beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:num_beams]

        # Check if any beam has generated the EOS token
        if any(tokenizer.token_to_id('[EOS]') in beam[0].squeeze().tolist() for beam in beams):
            break

    # Return the best sequence
    best_sequence = beams[0][0]
    return best_sequence


def nucleus_sampling_decode(model: Transformer, input_tokens: torch.Tensor, seq_len: int, tokenizer: Tokenizer, device: torch.device, top_p: float = 0.9):
    model.eval()
    # Initialize the decoding with the start-of-sequence (SOS) token
    sos_token = torch.tensor([tokenizer.token_to_id('[SOS]')], dtype=torch.long, device=device).unsqueeze(0)
    decoded_tokens = sos_token

    for _ in range(seq_len):
        with torch.no_grad():
            # Forward pass through the model
            output = model(input_ids=input_tokens, decoder_input_ids=decoded_tokens)
            
            # Get the probabilities of the next tokens
            next_token_probs = torch.softmax(output.logits[:, -1, :], dim=-1)
            
            # Sort the probabilities to get cumulative probabilities
            sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True)
            cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
            
            # Find the smallest set of tokens with cumulative probability > top_p
            nucleus_mask = cumulative_probs <= top_p
            nucleus_mask[..., 1:] = nucleus_mask[..., :-1].clone()
            nucleus_mask[..., 0] = True
            
            # Filter out tokens outside the nucleus
            next_token_probs = next_token_probs * nucleus_mask
            next_token_probs /= next_token_probs.sum(dim=-1, keepdim=True)
            
            # Sample the next token from the nucleus
            next_token = torch.multinomial(next_token_probs, num_samples=1)
            
            # Append the sampled token to the decoded sequence
            decoded_tokens = torch.cat((decoded_tokens, next_token), dim=1)
            
            # Stop decoding if the end-of-sequence token is produced
            if next_token.item() == tokenizer.token_to_id('[EOS]'):
                break

    return decoded_tokens


def autoregressive_decode(model: Transformer, input_tokens: torch.Tensor, seq_len: int, tokenizer: Tokenizer, device: torch.device):
    sos_token = torch.tensor([tokenizer.token_to_id('[SOS]')], dtype=torch.long, device=device).unsqueeze(0)
    batch_size = input_tokens.size(0)
    decoded_tokens = sos_token.repeat(batch_size, 1)

    for _ in range(seq_len - 1):  # -1 because we already have the sos_token
        with torch.no_grad():
            output = model.forward(src=input_tokens, tgt=decoded_tokens)
            next_token = output[:, -1, :].argmax(dim=-1, keepdim=True)
            decoded_tokens = torch.cat((decoded_tokens, next_token), dim=1)

            # Stop decoding if the end-of-sequence token is produced
            if (next_token == tokenizer.token_to_id('[EOS]')).all().item():
                break

    # Pad the sequence if it's shorter than seq_len
    if decoded_tokens.size(1) < seq_len:
        padding = torch.full((batch_size, seq_len - decoded_tokens.size(1)), tokenizer.token_to_id('[PAD]'), device=device, dtype=torch.long)
        decoded_tokens = torch.cat((decoded_tokens, padding), dim=1)

    # Truncate the sequence if it's longer than seq_len
    decoded_tokens = decoded_tokens[:, :seq_len]

    # Adjust the size of decoded_tokens to match the input batch size
    if decoded_tokens.size(0) != batch_size:
        decoded_tokens = decoded_tokens[:batch_size]

    return decoded_tokens


def train_chatbot(raw_dataset: list, config: dict, train_info: dict, logger: logging.Logger):
    # Create a Path object from the model path in the config dict to determine if the model has already been trained
    model_path = Path(config['model_path'])
    if not model_path.exists():
        # Extract the number of epochs from the config dictionary as it will be used multiple times
        num_epochs = config['num_epochs']

        # Preprocess the text data so that it is clean and prepared for training a chatbot
        dataset = preprocess_text_data(dataset=raw_dataset, seq_len=config['seq_len'], logger=logger)

        # Store the cleaned text dataset for later analysis
        train_info['dataset'] = dataset

        # Assign the deep learning to a computing device (GPU, TPU, or CPU)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        logger.info(f'The neural network will be running on {device}')

        # Build a tokenizer using the cleaned text dataset if it doesn't exist
        tokenizer = get_or_build_tokenizer(config=config, dataset=dataset)

        # Split the dataset into training and validation sets and convert them into torch Dataset objects
        train_dataset, val_dataset = split_dataset(dataset=dataset, tokenizer=tokenizer, config=config, logger=logger)

        # Create Dataloader objects for the training and validation sets
        train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=custom_collate_fn, pin_memory=True)
        val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_collate_fn, pin_memory=True)

        # Get parameters for model construction
        vocab_size = len(tokenizer.get_vocab())
        total_batches = len(train_loader)

        # Different models will be built depending on if transfer learning is enabled or not
        # Different schedulers and optimizers can be used to find how fast a global minimum is found.
        if config['transfer_learning']:
            model = GPT2LMHeadModel.from_pretrained('gpt2')
            model.to(device)
            model.resize_token_embeddings(len(tokenizer))
            criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.convert_tokens_to_ids('<|pad|>'))
            optimizer = AdamW(model.parameters(), lr=config['learning_rate'])
            scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0,
                                                        num_training_steps=total_batches * num_epochs)

        else:
            model = build_transformer(
                vocab_size=vocab_size,
                seq_len=config['seq_len'],
                d_model=config['d_model'],
                N=config['N'],
                h=config['h'],
                dropout=config['dropout'],
                d_ff=config['d_ff']
            )
            model.to(device)
            criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.token_to_id('[PAD]'))
            optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'])

            # scheduler = CosineAnnealingLR(optimizer, T_max=total_batches * config['num_epochs'])
            # scheduler = StepLR(optimizer, step_size=5, gamma=0.1)
            scheduler = CyclicLR(
                optimizer,
                base_lr=config['learning_rate'],
                max_lr=1e-3,
                step_size_up=int(total_batches / 2),
                mode='triangular2'
            )

        # Instantiate the gradient scaler to avoid vanishing gradients
        scaler = GradScaler()

        # Initialize an interator for the epochs to generate a cool loop monitor
        epoch_iterator = tqdm(range(num_epochs), desc=f'Epoch counter')

        # Start training
        for epoch in epoch_iterator:
            # Keep track of progress and time
            epoch_time = time.time()
            train_loss = 0
            last_printed_progress = 0

            # Set model to train mode and start training loop
            model.train()
            for idx, batch in enumerate(train_loader):
                # Keep track of how long it takes to go through one batch and which batch is currently in memory
                train_batch_start_time = time.time()
                progress = (idx + 1) / total_batches * 100

                if int(progress) // 10 > last_printed_progress:
                    last_printed_progress = int(progress) // 10
                    logger.info(f'Epoch {epoch + 1}/{config["num_epochs"]} Training Loop is {progress:.2f}% completed')

                # Extract all the tensors from the DataLoader object and send them to the training device
                encoder_input = batch['encoder_input'].to(device, non_blocking=True)
                decoder_input = batch['decoder_input'].to(device, non_blocking=True)
                encoder_mask = batch['encoder_mask'].to(device, non_blocking=True)
                decoder_mask = batch['decoder_mask'].to(device, non_blocking=True)
                labels = batch['label'].to(device, non_blocking=True)

                # Zero out the gradients from the previous batch
                optimizer.zero_grad()

                # Implement mixed precision by using autocast
                with autocast():
                    # If transfer learning is enabled, the model expects concatenated inputs and masks
                    if config['transfer_learning']:
                        # Concatenate encoder and decoder inputs
                        combined_input = torch.cat((encoder_input, decoder_input), dim=1)  # (batch_size, seq_len * 2)

                        # Create attention mask for the combined input
                        encoder_mask = (encoder_input != tokenizer.pad_token_id).long()  # (batch_size, seq_len)
                        decoder_mask = (decoder_input != tokenizer.pad_token_id).long()  # (batch_size, seq_len)

                        # Combined mask: encoder mask followed by decoder mask
                        combined_mask = torch.cat((encoder_mask, decoder_mask), dim=1)  # (batch_size, seq_len * 2)

                        # Ensure the combined_mask has the same number of dimensions as combined_input
                        combined_mask = combined_mask.unsqueeze(1)  # (batch_size, 1, seq_len * 2)

                        # Forward pass
                        outputs = model(input_ids=combined_input, attention_mask=combined_mask)
                        decoder_output_logits = outputs.logits

                        # Since target labels correspond only to the decoder part, slice the output logits accordingly
                        output = decoder_output_logits[:, encoder_input.size(1):, :]

                        # Calculate loss
                        loss = criterion(output.reshape(-1, vocab_size), labels.view(-1))

                    # Otherwise, the transformer model built in generic_transformer.py will be used
                    else:
                        # Generate output using the transformer model
                        output = model.forward(encoder_input, decoder_input, encoder_mask, decoder_mask)

                        # Calculate loss
                        loss = criterion(output.view(-1, vocab_size), labels.view(-1))

                # If nans are encountered, skip that batch to avoid runtime exceptions
                if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
                    logger.info(f"NaN or Inf detected in loss at batch {idx + 1} - {batch['context']}, {batch['response']}")
                    continue

                # Commence backpropagation
                scaler.scale(loss).backward()

                # Clip gradients to avoid exploding gradients
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)

                # Update optimizer and scaler
                scaler.step(optimizer)
                scaler.update()

                # Update training loss
                train_loss += loss.item()

                # Update scheduler
                scheduler.step()

                # Keep track of how long it takes to go through one batch
                train_batch_end_time = time.time()
                if train_batch_end_time - train_batch_start_time > 3:
                    logger.info(f'batch {idx + 1} took {train_batch_end_time - train_batch_start_time} seconds')


            # Calculate average training loss
            train_loss /= len(train_loader)

            # Store training metrics in the train_info dictionary
            train_info['train_losses'].append(train_loss)
            train_info['learning_rates'].append(optimizer.param_groups[0]["lr"])

            # Set the model to evaluation mode
            model.eval()

            # Keep track of progress
            val_loss = 0

            # Disable gradient calculation during validation
            with torch.no_grad():
                for idx, batch in enumerate(val_loader):
                    # Keep track of how long it takes to go through one batch and which batch is currently in memory
                    val_batch_start_time = time.time()
                    progress = (idx + 1) / total_batches * 100

                    if int(progress) // 10 > last_printed_progress:
                        last_printed_progress = int(progress) // 10
                        logger.info(f'Epoch {epoch + 1}/{config["num_epochs"]} Training Loop is {progress:.2f}% completed')

                    # Extract all the tensors from the DataLoader object and send them to the training device
                    encoder_input = batch['encoder_input'].to(device, non_blocking=True)
                    encoder_mask = batch['encoder_mask'].to(device, non_blocking=True)
                    decoder_input = batch['decoder_input'].to(device, non_blocking=True)
                    labels = batch['label'].to(device, non_blocking=True)

                    # If transfer learning is enabled, the model expects concatenated inputs and masks
                    if config['transfer_learning']:
                        # Concatenate encoder and decoder inputs
                        combined_input = torch.cat((encoder_input, decoder_input), dim=1)  # (batch_size, seq_len * 2)

                        # Create attention mask for the combined input
                        encoder_mask = (encoder_input != tokenizer.pad_token_id).long()  # (batch_size, seq_len)
                        decoder_mask = (decoder_input != tokenizer.pad_token_id).long()  # (batch_size, seq_len)

                        # Combined mask: encoder mask followed by decoder mask
                        combined_mask = torch.cat((encoder_mask, decoder_mask), dim=1)  # (batch_size, seq_len * 2)

                        # Ensure the combined_mask has the same number of dimensions as combined_input
                        combined_mask = combined_mask.unsqueeze(1)  # (batch_size, 1, seq_len * 2)

                        # Forward pass
                        outputs = model(input_ids=combined_input, attention_mask=combined_mask)
                        decoder_output_logits = outputs.logits

                        # Since target labels correspond only to the decoder part, slice the output logits accordingly
                        output = decoder_output_logits[:, encoder_input.size(1):, :]
                        loss = criterion(output.reshape(-1, vocab_size), labels.view(-1))

                    # Otherwise, the transformer model built in generic_transformer.py will be used
                    else:
                        # Inference validation
                        output_tokens = autoregressive_decode(model, encoder_input, config['seq_len'], tokenizer, device)
                        output = model.forward(encoder_input, output_tokens, encoder_mask, None)
                        loss = criterion(output.view(-1, vocab_size), batch['label'].view(-1).to(device))

                    # If nans are encountered, skip that batch to avoid runtime exceptions
                    if torch.isnan(loss) or torch.isinf(loss):
                        logger.info(f"NaN or Inf detected in loss at batch {idx + 1} - {batch['context']}, {batch['response']}")
                        continue

                    # Update validation loss
                    val_loss += loss.item()

                    # Keep track of how long it takes to go through one batch
                    val_batch_end_time = time.time()
                    if val_batch_end_time - val_batch_start_time > 3:
                        logger.info(f'batch {idx + 1} took {val_batch_end_time - val_batch_start_time} seconds')

            # Calculate average validation loss
            val_loss /= len(val_loader)

            # Store validation metrics in the train_info dictionary
            train_info['val_losses'].append(val_loss)
            train_info['epochs'].append(epoch + 1)

            # Calculate the current patience being experienced by the gradient
            if val_loss < train_info['best_val_loss']:
                train_info['best_val_loss'] = val_loss
                train_info['patience_counter'] = 0
            else:
                train_info['patience_counter'] += 1

            # Performance console output
            monitoring_string = f'Epoch {epoch + 1}/{config["num_epochs"]} took {time.time() - epoch_time:.2f} seconds, Loss: {train_loss},  Validation Loss: {val_loss}, Learning rate: {optimizer.param_groups[0]["lr"]}'
            logger.info(monitoring_string)
            epoch_iterator.set_postfix({f'Epoch {epoch + 1}/{config["num_epochs"]}': monitoring_string})
            print_gpu_utilization(logger=logger)

            # Determine if early stopping is necessary
            if train_info['patience_counter'] >= config['patience']:
                logger.info("Early stopping triggered.")
                break

        torch.save(model.state_dict(), config['model_path'])
        return train_info

## Create Chatbot

In [None]:
def load_chatbot(device: torch.device, config: dict, vocab_size: int, tokenizer: Optional[Union[Tokenizer, GPT2Tokenizer]]):
    # If transfer learning is enabled, the GPT2 model will be loaded onto the device
    if config['transfer_learning']:
        model = GPT2LMHeadModel.from_pretrained('gpt2')
        model.resize_token_embeddings(len(tokenizer))
        model.load_state_dict(torch.load(config['model_path']))
        model.to(device)
        model.eval()

    # Otherwise, the generic transformer model will be loaded into memory
    else:
        model = build_transformer(
            vocab_size=vocab_size,
            seq_len=config['seq_len'],
            d_model=config['d_model'],
            N=config['N'],
            h=config['h'],
            dropout=config['dropout'],
            d_ff=config['d_ff']
        )

        model.load_state_dict(torch.load(config['model_path']))
        model.to(device)
        model.eval()

    return model


def preprocess_sentence(sentence: str, tokenizer: Union[Tokenizer, GPT2Tokenizer], config: dict, special_tokens: list, device: torch.device):
    seq_len = config['seq_len']
    if config['transfer_learning']:
        tokens = tokenizer.encode(sentence)
    else:
        tokens = tokenizer.encode(sentence).ids

    tokens = [special_tokens[0]] + tokens + [special_tokens[1]]
    tokens = tokens[:seq_len] + [special_tokens[2]] * (seq_len - len(tokens))
    return torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)


def chatbot_predict(
        model: Union[Transformer, GPT2LMHeadModel],
        sentence: str,
        tokenizer: Union[Tokenizer, GPT2Tokenizer],
        config: dict,
        device: torch.device,
):
    # Set the model to evaluation mode
    model.eval()

    # Extract parameters from the config dictionary
    seq_len = config['seq_len']
    transfer_learning = config['transfer_learning']

    # Create a list of special tokens for the model to avoid outputting
    if transfer_learning:
        special_tokens = [tokenizer.convert_tokens_to_ids('<|sos|>'), tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids('<|pad|>'), tokenizer.convert_tokens_to_ids('<|unk|>')]
    else:
        special_tokens = [tokenizer.token_to_id('[SOS]'), tokenizer.token_to_id('[EOS]'), tokenizer.token_to_id('[PAD]'), tokenizer.token_to_id('[UNK]')]

    # Preprocess the input sentence and convert it to a tensor
    input_tokens = preprocess_sentence(sentence=sentence, tokenizer=tokenizer, config=config, special_tokens=special_tokens, device=device)

    # Start decoding with the start-of-sequence token
    sos_token = torch.tensor([special_tokens[0]], dtype=torch.long, device=device).unsqueeze(0)
    decoded_tokens = sos_token
    last_token_id = None

    while len(decoded_tokens[0]) < seq_len:
        with torch.no_grad():
            if transfer_learning:
                # Concatenate encoder and decoder inputs
                combined_input = torch.cat((input_tokens, decoded_tokens), dim=1)  # (batch_size, seq_len * 2)

                # Create attention mask for the combined input
                encoder_mask = (input_tokens != tokenizer.pad_token_id).long()  # (batch_size, seq_len)
                decoder_mask = (decoded_tokens != tokenizer.pad_token_id).long()  # (batch_size, seq_len)

                # Combined mask: encoder mask followed by decoder mask
                combined_mask = torch.cat((encoder_mask, decoder_mask), dim=1)  # (batch_size, seq_len * 2)

                # Ensure the combined_mask has the same number of dimensions as combined_input
                combined_mask = combined_mask.unsqueeze(1)  # (batch_size, 1, seq_len * 2)

                # Forward pass
                outputs = model(input_ids=combined_input, attention_mask=combined_mask)
                decoder_output_logits = outputs.logits

                # Since target labels correspond only to the decoder part, slice the output logits accordingly
                output_tokens = decoder_output_logits[:, input_tokens.size(1):, :]
            else:
                output_tokens = model.forward(input_tokens, decoded_tokens)

            next_token_logits = output_tokens[:, -1, :]

            # Mask out the PAD token logits and the last token logits
            next_token_logits[:, special_tokens] = float('-inf')
            if last_token_id is not None:
                next_token_logits[:, last_token_id] = float('-inf')

            # Sample from the top k tokens
            top_k = 10
            top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
            probabilities = torch.nn.functional.softmax(top_k_logits, dim=-1)
            next_token_idx = torch.multinomial(probabilities, 1).item()

            # Get the actual token index
            next_token = top_k_indices[:, next_token_idx].unsqueeze(0)

            # Ensure next_token is properly shaped before concatenation
            next_token = next_token.view(1, -1)

            while next_token.item() in special_tokens:
                next_token_idx = torch.multinomial(probabilities, 1).item()
                next_token = top_k_indices[:, next_token_idx].unsqueeze(0)
                next_token = next_token.view(1, -1)

            # print(next_token.item())
            if next_token.item() not in special_tokens:
                decoded_tokens = torch.cat((decoded_tokens, next_token), dim=1)
                last_token_id = next_token.item()

            if next_token.item() == special_tokens[1]:
                if len(decoded_tokens[0]) >= seq_len:
                    break

    output_ids = decoded_tokens.squeeze().tolist()

    if isinstance(output_ids, int):
        output_ids = [output_ids]

    output_sentence = tokenizer.decode(output_ids, skip_special_tokens=True)
    return output_sentence


def chat(config: dict):
    # Assign the device that the model will use to perform calculations
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # get or build the tokenizer corresponding to the text dataset the chatbot was trained on
    tokenizer = get_or_build_tokenizer(config=config)

    # get the vocab size to help build the transformer model and load the chatbot into memory
    vocab_size = len(tokenizer.get_vocab())
    chatbot = load_chatbot(device=device, config=config, vocab_size=vocab_size, tokenizer=tokenizer)

    # Start chatting with the chatbot!
    name = 'Goku'
    print('Type `quit` to end chatting')
    while True:
        sentence = input('You: ')
        if 'quit' == sentence:
            break

        output = chatbot_predict(chatbot, sentence, tokenizer, config, device)
        print(f'{name}: {output}')

## Train the model and run analytics

In [None]:
# Create a config dictionary in order to change deep learning hyperparameters
config = {
    'subcategory': 'World War II',
    'dropout': 0.1,
    'seq_len': 32,
    'batch_size': 128,
    'd_model': 512,
    'h': 8,
    'N': 6,
    'd_ff': 2048,
    'num_epochs': 100,
    'learning_rate': 1e-5,
    'patience': 5,
    'transfer_learning': False,
}

# Create automated naming
suffix = '_TL' if config['transfer_learning'] else ''
subcategory = config["subcategory"]
subcategory = subcategory.replace(' ', '_')
config['model_path'] = f'./{subcategory}_LLM_{config["num_epochs"]}{suffix}.pt'
config['tokenizer_path'] = f'./{subcategory}_tokenizer.json'

# Create a logger object to track training progress
logger = create_logger(__name__, __file__, f'{subcategory}_LLM_{config["num_epochs"]}{suffix}_Chatbot')

# Load data into memory using project gutenberg's library
load_data_time = time.time()
txt_links, all_sentences = get_text_data(subcategory=config["subcategory"], logger=logger)
logger.info(f'Loading data took {time.time() - load_data_time:.2f} seconds')

# Run preliminary analytics on the text data
word_analytics(sentences=all_sentences)

# Instantiate a dictionary to track training metrics
train_info = {
    'train_losses': [],
    'val_losses': [],
    'learning_rates': [],
    'epochs': [],
    'best_val_loss': float('inf'),
    'patience_counter': 0,
}

# Train the chatbot
train_info = train_chatbot(raw_dataset=all_sentences, config=config, train_info=train_info, logger=logger)

# Run performance analytics
if len(train_info['train_losses']) == len(train_info['val_losses']) == len(train_info['learning_rates']) == len(train_info['epochs']):
    n_strings = len(train_info['dataset'].keys())
    preprocessed_sentences = [train_info['dataset'][i]['context'] for i in range(n_strings)] + [train_info['dataset'][n_strings - 1]['response']]
    word_analytics(sentences=preprocessed_sentences)
    plot_losses(train_info=train_info)

In [None]:
# Communicate with the chatbot
chat(config=config)