Natural Language Processing Tutorial
======

This is the tutorial of the 2024 [Mediterranean Machine Learning Summer School](https://www.m2lschool.org/) on Natural Language Processing!

This tutorial will explore the fundamental aspects of Natural Language Processing (NLP). Basic Python programming skills are expected.
Prior knowledge of standard NLP techniques (e.g. text tokenization and classification with ML) is beneficial but optional when working through the notebooks as they assume minimal prior knowledge.

This tutorial combines detailed analysis and development of essential NLP concepts via custom (i.e. from scratch) implementations. Other necessary NLP components will be developed using PyTorch's NLP library implementations. As a result, the tutorial offers deep understanding and facilitates easy usage in future applications.

## Outline

* Part I: Introduction to Text Tokenization and Classification
  *  Text Classification: Simple Classifier
  *  Text Classification: Encoder-only Transformer

* Part II: Introduction to Decoder-only Transformer and Sparse Mixture of Experts Architecture
  *  Text Generation: Decoder-only Transformer
  *  Text Generation: Decoder-only Transformer + MoE

* Part III: Introduction to Parameter Efficient Fine-tuning
  *  Fine-tuning the full Pre-trained Models
  *  Fine-tuning using Low-Rank Adaptation of Large Language Models (LoRA)

## Notation

* Sections marked as [📝] contain cells with missing code that you should complete.
* Sections marked with [📚] contain cells that you should read and modify to understand how your changes alter the obtained results.
* External resources are mentioned with [✨]. These provide valuable supplementary information for this tutorial and offer opportunities for further in-depth exploration of the topics covered.
* Sections that contain code that test the functionality of other sections are marked with [✍]. You are more that welcome to modify these sections so that you can understand code functionality.


## Libraries

This tutorial leverages [PyTorch](https://pytorch.org/) for neural network implementation and training, complemented by standard Python libraries for data processing and the [Hugging Face](https://huggingface.co/) datasets library for accessing NLP resources.

GPU access is recommended for optimal performance, particularly for model training and text generation. While all code can run on CPU, a CUDA-enabled environment will significantly speed up these processes.

## Credits

The tutorial is created by:

* [Georgios Peikos](https://www.linkedin.com/in/peikosgeorgios/)
* [Luca Herranz-Celotti](http://LuCeHe.github.io)

It is inspired by and synthesizes various online resources, which are cited throughout for reference and further reading.

## Note for Colab users

To grab a GPU (if available), make sure you go to `Edit -> Notebook settings` and choose a GPU under `Hardware accelerator`





---



# Part II: Introduction to the decoder-only Transformers architecture and Sparse Mixture of Expert

We create a decoder-only Transformer architecture from the bottom up, including a custom text tokenizer and an efficient dataset handler. We will explore all essential components of this architecture, train the model, and show its capabilities in text generation.

Then, we will enhance our base model by incorporating a gating function and implementing a sparse mixture of experts.



---



# Decoder-only Transformer Architecture

The decoder-only transformer architecture consists of multiple identical blocks stacked sequentially. Each block is composed of two main elements:
- A masked multi-head self-attention mechanism.
- A feed-forward neural network.

These components are typically encapsulated within residual connections and layer normalization. In this section, we will explore the internal structure of these blocks in greater depth and provide a practical PyTorch implementation.

![Decoder Only Architecture](https://drive.google.com/uc?id=1ksROxQxf3b7dlBUoIQggzyLeBaPO-AQn)




## Importing Libraries

In [26]:
!pip install datasets

import math
from collections import Counter
from typing import List, Tuple, Union
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

from datasets import load_dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## 📝 Text Tokenization from Scratch

Tokenization is a fundamental step in NLP that converts raw text into a format that systems can understand and process. It enables the transformation of variable-length text sequences into fixed-size numerical representations, which is crucial for input to neural network models.

Here, we create a simple text tokenizer for basic word-level tokenization tasks.
The tokenizer can be improved so that:

1.   Methods for handling very large vocabularies (e.g., frequency thresholding)
2.   Support for n-grams or phrase detection
3.   Handle punctuation. For instance now, tokens like "word." and "word" are being treated differently.

Also, we create a testing function to showcase the codes behavior.

**✨ Additional Resources:**

*   Overview of hugging Face tokenizers [Link-huggingface](https://huggingface.co/docs/transformers/en/tokenizer_summary)

In [2]:
class SimpleTokenizer:
    def __init__(self):
        """Initialize the tokenizer with special tokens and prepare vocabulary structures."""
        # Special tokens are used for various purposes in NLP tasks:
        # <PAD>: Used for padding sequences to a fixed length
        # <UNK>: Represents unknown words not in the vocabulary
        # <SOS>: Marks the start of a sequence
        # <EOS>: Marks the end of a sequence
        self.special_tokens = ["<PAD>", "<UNK>", "<SOS>", "<EOS>"]

        # word_to_idx: Maps words to unique integer indices
        # This is crucial for converting text into a format that neural networks can process
        self.word_to_idx = {token: idx for idx, token in enumerate(self.special_tokens)}

        # idx_to_word: The reverse mapping of word_to_idx
        # This is used for converting model outputs back into readable text
        self.idx_to_word = {idx: token for idx, token in enumerate(self.special_tokens)}

        # Counter object to keep track of word frequencies in the corpus
        self.word_count = Counter()

    def fit(self, texts: List[str]) -> None:
        """Build the vocabulary from a list of texts."""
        # Count the frequency of each word in the entire corpus
        for text in texts:
            self.word_count.update(text.split())

        # Add each unique word to the vocabulary
        # We assign a unique index to each word, which the model will use to represent words
        for word in self.word_count:
            if word not in self.word_to_idx:
                idx = len(self.word_to_idx)
                self.word_to_idx[word] = idx
                self.idx_to_word[idx] = word

    def encode(self, text: str) -> List[int]:
        """Convert a text string to a list of indices."""
        # This method is used to prepare input for the model
        # It converts each word to its corresponding index
        # If a word is not in the vocabulary, it uses the <UNK> token
        return [self.word_to_idx.get(word, self.word_to_idx["<UNK>"]) for word in text.split()]

    def decode(self, indices: List[int]) -> str:
        """Convert a list of indices back to a text string."""
        # This method is used to convert model output back into readable text
        # It maps each index back to its corresponding word
        return " ".join([self.idx_to_word.get(idx, "<UNK>") for idx in indices])

    def encode_batch(self, texts: List[str]) -> List[List[int]]:
        """Convert a batch of text strings to lists of indices."""
        return [self.encode(text) for text in texts]

    def decode_batch(self, batch_indices: List[List[int]]) -> List[str]:
        """Convert a batch of lists of indices back to text strings."""
        return [self.decode(indices) for indices in batch_indices]

    def show_vocab(self):
        """Display the vocabulary."""
        # Useful for debugging and understanding the tokenizer's state
        print("Vocabulary:")
        for word, idx in self.word_to_idx.items():
            print(f"{word}: {idx}")

    def __len__(self):
        """Return the size of the vocabulary."""
        # The vocabulary size is an important parameter for the model
        # It determines the dimensionality of the model's output layer
        return len(self.word_to_idx)

### ✍ Testing the Tokenizer

This testing function shows examples of text tokenization presenting also extreme use cases.

In [3]:
def test_tokenizer():
    print("\nTesting SimpleTokenizer")
    print("=" * 30)

    # Sample texts
    texts = [
        "The quick brown fox jumps over the lazy dog.",
        "Pack my box with five dozen liquor jugs!",
        "How vexingly quick daft zebras jump!",
        "This is a sentence with some punctuation, including commas.",
        "This text contains an unknown word: monkey",
        ""  # Empty string to test edge case
    ]

    # Initialize and fit the tokenizer
    tokenizer = SimpleTokenizer()
    tokenizer.fit(texts)

    # Display vocabulary
    tokenizer.show_vocab()
    print(f"\nVocabulary size: {len(tokenizer)}")

    # Test encoding and decoding
    print("\nEncoding and Decoding Test:")
    for text in texts:
        encoded = tokenizer.encode(text)
        decoded = tokenizer.decode(encoded)
        print(f"\nOriginal: {text}")
        print(f"Encoded : {encoded}")
        print(f"Decoded : {decoded}")
        print(f"Match   : {'✓' if text.strip().lower() == decoded.strip().lower() else '✗'}")

    # Test unknown word handling
    print("\nUnknown Word Handling Test:")
    unknown_text = "This text contains an unknown word: xylophone"
    encoded_unknown = tokenizer.encode(unknown_text)
    decoded_unknown = tokenizer.decode(encoded_unknown)
    print(f"Original: {unknown_text}")
    print(f"Encoded : {encoded_unknown}")
    print(f"Decoded : {decoded_unknown}")

    # Test special tokens
    print("\nSpecial Tokens Test:")
    special_text = "< SOS > This is a test sentence <EOS>"
    encoded_special = tokenizer.encode(special_text)
    decoded_special = tokenizer.decode(encoded_special)
    print(f"Original: {special_text}")
    print(f"Encoded : {encoded_special}")
    print(f"Decoded : {decoded_special}")

    # Test case sensitivity
    print("\nCase Sensitivity Test:")
    case_text = "The Quick Brown Fox"
    encoded_case = tokenizer.encode(case_text)
    decoded_case = tokenizer.decode(encoded_case)
    print(f"Original: {case_text}")
    print(f"Encoded : {encoded_case}")
    print(f"Decoded : {decoded_case}")

print("\nChecking the tokenizer's outputs")
test_tokenizer()


Checking the tokenizer's outputs

Testing SimpleTokenizer
Vocabulary:
<PAD>: 0
<UNK>: 1
<SOS>: 2
<EOS>: 3
The: 4
quick: 5
brown: 6
fox: 7
jumps: 8
over: 9
the: 10
lazy: 11
dog.: 12
Pack: 13
my: 14
box: 15
with: 16
five: 17
dozen: 18
liquor: 19
jugs!: 20
How: 21
vexingly: 22
daft: 23
zebras: 24
jump!: 25
This: 26
is: 27
a: 28
sentence: 29
some: 30
punctuation,: 31
including: 32
commas.: 33
text: 34
contains: 35
an: 36
unknown: 37
word:: 38
monkey: 39

Vocabulary size: 40

Encoding and Decoding Test:

Original: The quick brown fox jumps over the lazy dog.
Encoded : [4, 5, 6, 7, 8, 9, 10, 11, 12]
Decoded : The quick brown fox jumps over the lazy dog.
Match   : ✓

Original: Pack my box with five dozen liquor jugs!
Encoded : [13, 14, 15, 16, 17, 18, 19, 20]
Decoded : Pack my box with five dozen liquor jugs!
Match   : ✓

Original: How vexingly quick daft zebras jump!
Encoded : [21, 22, 5, 23, 24, 25]
Decoded : How vexingly quick daft zebras jump!
Match   : ✓

Original: This is a sentence wi

## 📚 TextDataset: Efficient Text Processing

The TextDataset class is a crucial component in preparing text data for deep learning models, implementing a sliding window approach that allows processing of variable-length texts while maintaining context.

This class bridges the gap between raw text data and the input requirements of neural networks, handling tasks such as tokenization, padding, and attention mask generation, which are essential for training effective sequence models like Transformers.

**✨ Additional Resources:**

*   Padding and truncation [Link-huggingface](https://huggingface.co/docs/transformers/en/pad_truncation)


In [4]:
class TextDataset(Dataset):
    def __init__(self, texts: List[str], tokenizer: SimpleTokenizer, max_length: int, overlap: int = 50):
        """
        Initialize the TextDataset with sliding window functionality.

        Args:
            texts (List[str]): List of input texts.
            tokenizer (SimpleTokenizer): Tokenizer object for encoding texts.
            max_length (int): Maximum length of encoded sequences.
            overlap (int): Number of overlapping tokens between windows.
        """
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.overlap = overlap
        self.data = []
        self.attention_masks = []
        self.document_map = []  # Maps each window to its original document
        self.original_texts = texts  # Store original texts

        for doc_idx, text in enumerate(texts):
            tokens = self.tokenizer.encode(text)
            windows = self.create_sliding_windows(tokens)

            for window in windows:
                attention_mask = [1] * len(window)  # 1 for real tokens

                # Pad if necessary
                if len(window) < max_length:
                    padding_length = max_length - len(window)
                    window = window + [self.tokenizer.word_to_idx["<PAD>"]] * padding_length
                    attention_mask = attention_mask + [0] * padding_length  # 0 for padding in attention mask

                self.data.append(window)
                self.attention_masks.append(attention_mask)
                self.document_map.append(doc_idx)

    def create_sliding_windows(self, tokens: List[int]) -> List[List[int]]:
        """
        Create sliding windows from a list of tokens.

        Args:
            tokens (List[int]): List of token ids.

        Returns:
            List[List[int]]: List of token windows.
        """
        windows = []
        # Calculate stride: how many tokens to move for each new window
        # -1 accounts for the added <SOS> token at the start of each window
        stride = self.max_length - self.overlap - 1

        for start in range(0, len(tokens), stride):
            # Create a window starting with <SOS> token
            window = [self.tokenizer.word_to_idx["<SOS>"]] + tokens[start:start + self.max_length - 1]
            if len(window) < self.max_length:
                # This is the last window, add <EOS> token
                window.append(self.tokenizer.word_to_idx["<EOS>"])
            windows.append(window)

        return windows

    def get_original_document(self, doc_idx: int) -> str:
        """Retrieve the original document text."""
        if 0 <= doc_idx < len(self.original_texts):
            return self.original_texts[doc_idx]
        else:
            raise IndexError(f"Document index {doc_idx} is out of range.")

    def get_document_length(self, doc_idx: int) -> int:
        """Get the number of tokens in the original document."""
        if 0 <= doc_idx < len(self.original_texts):
            return len(self.tokenizer.encode(self.original_texts[doc_idx]))
        else:
            raise IndexError(f"Document index {doc_idx} is out of range.")

    def window_to_document_position(self, window_idx: int, token_idx: int) -> Tuple[int, int]:
        """Map a position in a window back to its position in the original document."""
        if 0 <= window_idx < len(self.data):
            doc_idx = self.document_map[window_idx]
            doc_windows = self.get_document_windows(doc_idx)
            # Find which window of the document this is
            relative_window_idx = doc_windows.index(window_idx)
            # Calculate the start position of this window in the document
            window_start = relative_window_idx * (self.max_length - self.overlap - 1)
            # -1 to account for <SOS> token at the start of each window
            return doc_idx, window_start + token_idx - 1
        else:
            raise IndexError(f"Window index {window_idx} is out of range.")

    def __len__(self) -> int:
        """Get the number of windows in the dataset."""
        return len(self.data)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, int]:
        """
        Get a sample from the dataset.

        Args:
            idx (int): Index of the sample.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, int]:
                A tuple containing (token_ids, attention_mask, document_index).
        """
        if 0 <= idx < len(self.data):
            # Add an extra dimension to make it batch-first (batch_size=1)
            return (torch.tensor(self.data[idx]).unsqueeze(0),
                    torch.tensor(self.attention_masks[idx]).unsqueeze(0),
                    self.document_map[idx])
        else:
            raise IndexError(f"Index {idx} is out of range.")


    def get_document_windows(self, doc_idx: int) -> List[int]:
        """
        Get all window indices for a specific document.

        Args:
            doc_idx (int): Index of the document.

        Returns:
            List[int]: List of window indices belonging to the document.
        """
        return [i for i, doc in enumerate(self.document_map) if doc == doc_idx]

### ✍ Testing the Dataset Processing

This testing function shows how the TextDataset and SimpleTokenizer classes work together.

In [5]:
def test_sliding_window_dataset():
    print("\n--- Testing Sliding Window Dataset ---\n")

    texts = [
        "This is a short sentence.",
        "This is a much longer sentence that will be split into multiple windows to demonstrate the sliding window approach. It contains enough tokens to create at least two or three windows depending on the chosen maximum length and overlap.",
        "Another sentence of medium length that might create two windows.",
        "",  # Empty text to test edge case
        "Short."  # Very short text to test edge case
    ]

    try:
        tokenizer = SimpleTokenizer()
        tokenizer.fit(texts)

        max_length = 16
        overlap = 5
        dataset = TextDataset(texts, tokenizer, max_length, overlap)

        print(f"Dataset configuration:")
        print(f"  Max length: {max_length}")
        print(f"  Overlap: {overlap}")
        print(f"  Total windows: {len(dataset)}")
        print(f"  Vocabulary size: {len(tokenizer)}\n")

        for doc_idx, text in enumerate(texts):
            print(f"Document {doc_idx}:")
            print(f"  Original text: '{text}'")
            print(f"  Original length: {len(text.split())}")

            window_indices = dataset.get_document_windows(doc_idx)
            print(f"  Number of windows: {len(window_indices)}")

            for i, window_idx in enumerate(window_indices):
                tokens, attention_mask, _ = dataset[window_idx]
                # Remove the batch dimension for decoding
                decoded = tokenizer.decode(tokens.squeeze(0).tolist())
                print(f"\n    Window {i}:")
                print(f"    Tokens shape: {tokens.shape}")
                print(f"    Tokens: {tokens.squeeze(0).tolist()}")
                print(f"    Attention mask shape: {attention_mask.shape}")
                print(f"    Attention mask: {attention_mask.squeeze(0).tolist()}")
                print(f"    Decoded: '{decoded}'")
                print(f"    Window length: {tokens.size(1)}")  # Use size(1) for sequence length

            print("\n" + "-"*50)

        tokenizer.show_vocab()

    except Exception as e:
        print(f"An error occurred: {str(e)}")

    print("\n--- End of Test ---")

# Run the test
test_sliding_window_dataset()


--- Testing Sliding Window Dataset ---

Dataset configuration:
  Max length: 16
  Overlap: 5
  Total windows: 7
  Vocabulary size: 48

Document 0:
  Original text: 'This is a short sentence.'
  Original length: 5
  Number of windows: 1

    Window 0:
    Tokens shape: torch.Size([1, 16])
    Tokens: [2, 4, 5, 6, 7, 8, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    Attention mask shape: torch.Size([1, 16])
    Attention mask: [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    Decoded: '<SOS> This is a short sentence. <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>'
    Window length: 16

--------------------------------------------------
Document 1:
  Original text: 'This is a much longer sentence that will be split into multiple windows to demonstrate the sliding window approach. It contains enough tokens to create at least two or three windows depending on the chosen maximum length and overlap.'
  Original length: 39
  Number of windows: 4

    Window 0:
    Tokens shape: torch.Size

## 📝 Positional Encoding

Positional Encoding adds information about the position of each token in the sequence. This is necessary because the self-attention mechanism in Transformers doesn't inherently have a notion of token order.


\begin{equation}
PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{\frac{2i}{d_{model}}}}\right)
\end{equation}

\begin{equation}
PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{\frac{2i}{d_{model}}}}\right)
\end{equation}



**✨ Additional Resources:**

*   Transformer Architecture: The Positional Encoding [Link-kazemnejad](https://kazemnejad.com/blog/transformer_architecture_positional_encoding/)

*   Positional Encoding in Transformers [Link-geeksforgeeks](https://www.geeksforgeeks.org/positional-encoding-in-transformers/)




In [6]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=5000):
        """
        Inputs
            d_model - Hidden dimensionality of the input.
            max_len - Maximum length of a sequence to expect.
        """
        super().__init__()

        # Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)

        # register_buffer => Tensor which is not a parameter, but should be part of the modules state.
        # Used for tensors that need to be on the same device as the module.
        # persistent=False tells PyTorch to not add the buffer to the state dict (e.g. when we save the model)
        self.register_buffer('pe', pe, persistent=False)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x

### ✍ Testing the Positional Encoding for Text

In [7]:
def test_positional_encoding_with_dataset():
    print("\n--- Testing Positional Encoding with Dataset ---\n")

    # Set a fixed seed for reproducibility
    torch.manual_seed(42)

    # Sample texts
    texts = [
        "This is a short sentence.",
        "This is a much longer sentence that will be split into multiple windows. Observe the term overlapping?",
        "Another sentence of medium length."
    ]

    # Initialize tokenizer and fit it to the texts
    tokenizer = SimpleTokenizer()
    tokenizer.fit(texts)

    # Create dataset
    max_length = 10
    overlap = 2
    dataset = TextDataset(texts, tokenizer, max_length, overlap)

    # Initialize positional encoding
    d_model = 16  # Small dimension for demonstration
    pos_encoder = PositionalEncoding(d_model, max_length)

    print(f"Dataset configuration:")
    print(f"  Max length: {max_length}")
    print(f"  Overlap: {overlap}")
    print(f"  Total windows: {len(dataset)}")
    print(f"  Vocabulary size: {len(tokenizer)}")
    print(f"  Embedding dimension: {d_model}\n")

    # Process each window through the positional encoding
    for i in range(len(dataset)):
        tokens, attention_mask, doc_idx = dataset[i]

        # Convert tokens to "embeddings" (just for demonstration)
        pseudo_embeddings = torch.rand(1, tokens.size(1), d_model)  # (batch_size, seq_len, d_model)

        # Apply positional encoding
        encoded = pos_encoder(pseudo_embeddings)

        print(f"Window {i} (from document {doc_idx}):")
        print(f"  Original tokens: {tokens.squeeze(0).tolist()}")
        print(f"  Attention mask: {attention_mask.squeeze(0).tolist()}")
        print(f"  Decoded: '{tokenizer.decode(tokens.squeeze(0).tolist())}'")
        print(f"  Shape after positional encoding: {encoded.shape}")

        # Display the positional encoding effect for all tokens
        print(f"  Positional encoding effect:")
        for j in range(tokens.size(1)):
            if attention_mask[0, j] == 1:  # Only show for non-padding tokens
                print(f"    Token {j}:")
                print(f"      Before: {pseudo_embeddings[0, j, :].tolist()}")
                print(f"      After:  {encoded[0, j, :].tolist()}")

        print()

    print("--- End of Test ---")

# Run the test
test_positional_encoding_with_dataset()


--- Testing Positional Encoding with Dataset ---

Dataset configuration:
  Max length: 10
  Overlap: 2
  Total windows: 5
  Vocabulary size: 27
  Embedding dimension: 16

Window 0 (from document 0):
  Original tokens: [2, 4, 5, 6, 7, 8, 3, 0, 0, 0]
  Attention mask: [1, 1, 1, 1, 1, 1, 1, 0, 0, 0]
  Decoded: '<SOS> This is a short sentence. <EOS> <PAD> <PAD> <PAD>'
  Shape after positional encoding: torch.Size([1, 10, 16])
  Positional encoding effect:
    Token 0:
      Before: [0.8822692632675171, 0.9150039553642273, 0.38286375999450684, 0.9593056440353394, 0.3904482126235962, 0.600895345211029, 0.2565724849700928, 0.7936413288116455, 0.9407714605331421, 0.13318592309951782, 0.9345980882644653, 0.5935796499252319, 0.8694044351577759, 0.5677152872085571, 0.7410940527915955, 0.42940449714660645]
      After:  [0.8822692632675171, 1.915004014968872, 0.38286375999450684, 1.9593056440353394, 0.3904482126235962, 1.6008954048156738, 0.2565724849700928, 1.7936413288116455, 0.9407714605331421

## 📝 Masked Multihead Attention Mechanism

As you have implemented the Attention Mechanism in Part I, here you will have to use its ready implementation from Pytorch.

Masked Attention mechanism allows the transformer model to focus on relevant parts of the input sequence while preventing information leakage from future tokens during sequential processing (i.e. we use the term Masked).

Please, visit [Link-pytorch](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) for a detailed description of the MultiheadAttention function in PyTorch.

**✨ Additional Resources:**

*   Multi-head Attention, deep dive [Link-towardsdatascience](https://towardsdatascience.com/transformers-explained-visually-part-3-multi-head-attention-deep-dive-1c1ff1024853)
* Attention Is All You Need (original Transformer paper) [Link-ArXiv](https://arxiv.org/abs/1706.03762)

* A visual explanation of the attention mechanism [Link-youtube](https://www.youtube.com/watch?v=bCz4OMemCcA&t=1208s&ab_channel=UmarJamil)

In [8]:
class MaskedAttention(nn.Module):
    def __init__(self, d_model: int, nhead: int, dropout: float = 0.1):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.d_model = d_model
        self.nhead = nhead

    def generate_square_subsequent_mask(self, sz: int) -> torch.Tensor:
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        # The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). See: https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#Transformer
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        seq_len = x.size(1)
        attn_mask = self.generate_square_subsequent_mask(seq_len).to(x.device)
        output, _ = self.multihead_attn(x, x, x, attn_mask=attn_mask)
        return output, attn_mask

### ✍ Testing the Multihead Attention

In [9]:
def test_expanded_transformer_components():
    print("\n--- Testing Expanded Transformer Components ---\n")

    # Set a fixed seed for reproducibility
    torch.manual_seed(42)

    # Sample texts
    texts = [
        "This is a short sentence.",
        "This is a much longer sentence that will be split into multiple windows. Observe the term overlapping?",
        "Another sentence of medium length."
    ]

    # Initialize tokenizer and fit it to the texts
    tokenizer = SimpleTokenizer()
    tokenizer.fit(texts)

    # Create dataset
    max_length = 10
    overlap = 2
    dataset = TextDataset(texts, tokenizer, max_length, overlap)

    # Hyperparameters
    d_model = 16  # Small dimension for demonstration
    nhead = 2

    # Initialize components
    pos_encoder = PositionalEncoding(d_model, max_length)
    masked_self_attn = MaskedAttention(d_model, nhead)

    print(f"Dataset configuration:")
    print(f"  Max length: {max_length}")
    print(f"  Overlap: {overlap}")
    print(f"  Total windows: {len(dataset)}")
    print(f"  Vocabulary size: {len(tokenizer)}")
    print(f"  Embedding dimension: {d_model}")
    print(f"  Number of attention heads: {nhead}\n")

    # Process each window
    for i in range(len(dataset)):
        tokens, attention_mask, doc_idx = dataset[i]

        print(f"Window {i} (from document {doc_idx}):")
        print(f"  Original tokens: {tokens.squeeze(0).tolist()}")
        print(f"  Attention mask: {attention_mask.squeeze(0).tolist()}")
        print(f"  Decoded: '{tokenizer.decode(tokens.squeeze(0).tolist())}'")

        # Convert tokens to "embeddings" (just for demonstration)
        pseudo_embeddings = torch.rand(1, tokens.size(1), d_model)  # (batch_size, seq_len, d_model)
        print(f"  Shape of pseudo embeddings: {pseudo_embeddings.shape}")

        # Apply positional encoding
        pos_encoded = pos_encoder(pseudo_embeddings)
        print(f"  Shape after positional encoding: {pos_encoded.shape}")

        # Apply masked self-attention
        attn_output, attn_mask = masked_self_attn(pos_encoded)
        print(f"  Shape after masked self-attention: {attn_output.shape}")

        # Display the effect of positional encoding and attention for all tokens
        print(f"  Transformer effect on tokens:")
        for j in range(tokens.size(1)):
            if attention_mask[0, j] == 1:  # Only show for non-padding tokens
                print(f"    Token {j}:")
                print(f"      Initial:   {pseudo_embeddings[0, j, :5].tolist()}")
                print(f"      Positional:{pos_encoded[0, j, :5].tolist()}")
                print(f"      Attention Mask: {attn_mask[j, :5].tolist()}")  # Show first 5 values of attention mask
                print(f"      Attention: {attn_output[0, j, :5].tolist()}")
        print()

    print("--- End of Expanded Test ---")

# Run the expanded test
test_expanded_transformer_components()


--- Testing Expanded Transformer Components ---

Dataset configuration:
  Max length: 10
  Overlap: 2
  Total windows: 5
  Vocabulary size: 27
  Embedding dimension: 16
  Number of attention heads: 2

Window 0 (from document 0):
  Original tokens: [2, 4, 5, 6, 7, 8, 3, 0, 0, 0]
  Attention mask: [1, 1, 1, 1, 1, 1, 1, 0, 0, 0]
  Decoded: '<SOS> This is a short sentence. <EOS> <PAD> <PAD> <PAD>'
  Shape of pseudo embeddings: torch.Size([1, 10, 16])
  Shape after positional encoding: torch.Size([1, 10, 16])
  Shape after masked self-attention: torch.Size([1, 10, 16])
  Transformer effect on tokens:
    Token 0:
      Initial:   [0.09746289253234863, 0.8920455574989319, 0.5080603361129761, 0.6052985191345215, 0.2980855107307434]
      Positional:[0.09746289253234863, 1.892045497894287, 0.5080603361129761, 1.6052985191345215, 0.2980855107307434]
      Attention Mask: [0.0, -inf, -inf, -inf, -inf]
      Attention: [-0.30292782187461853, 0.058979008346796036, 0.34361016750335693, 0.133262112

## 📝 Feed Forward Netwrok

A feed-forward network is a multi-layered structure in which information moves in a single direction, from the input layer to the output layer.


**✨ Additional Resources:**

*   Transformer Feed-Forward Layers Are Key-Value Memories
 [Link-ArXiv](https://arxiv.org/abs/2012.14913)







In [10]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()

        # Define feed-forward network using nn.Sequential
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),  # First linear layer
            nn.ReLU(),                 # ReLU activation
            nn.Dropout(dropout),       # Dropout for regularization
            nn.Linear(d_ff, d_model)   # Second linear layer
        )

    def forward(self, x):
        return self.net(x)  # Apply the feed-forward network

### ✍ Testing the Feed Forward Layer

In [11]:
def test_all_transformer_components():
    print("\n--- Testing Transformer Components ---\n")

    # Set a fixed seed for reproducibility
    torch.manual_seed(42)

    # Sample texts
    texts = [
        "This is a short sentence.",
        "This is a much longer sentence that will be split into multiple windows. Observe the term overlapping?",
        "Another sentence of medium length."
    ]

    # Initialize tokenizer and fit it to the texts
    tokenizer = SimpleTokenizer()
    tokenizer.fit(texts)

    # Create dataset
    max_length = 10
    overlap = 2
    dataset = TextDataset(texts, tokenizer, max_length, overlap)

    # Hyperparameters
    d_model = 16  # Small dimension for demonstration
    nhead = 2
    d_ff = d_model  # Feed-forward dimension
    dropout = 0.1

    # Initialize components
    pos_encoder = PositionalEncoding(d_model, max_length)
    masked_self_attn = MaskedAttention(d_model, nhead)
    feed_forward = FeedForward(d_model, d_ff, dropout)

    print(f"Dataset configuration:")
    print(f"  Max length: {max_length}")
    print(f"  Overlap: {overlap}")
    print(f"  Total windows: {len(dataset)}")
    print(f"  Vocabulary size: {len(tokenizer)}")
    print(f"  Embedding dimension: {d_model}")
    print(f"  Number of attention heads: {nhead}")
    print(f"  Feed-forward dimension: {d_ff}\n")

    # Process each window
    for i in range(len(dataset)):
        tokens, attention_mask, doc_idx = dataset[i]

        print(f"Window {i} (from document {doc_idx}):")
        print(f"  Original tokens: {tokens.squeeze(0).tolist()}")
        print(f"  Attention mask: {attention_mask.squeeze(0).tolist()}")
        print(f"  Decoded: '{tokenizer.decode(tokens.squeeze(0).tolist())}'")

        # Convert tokens to "embeddings" (just for demonstration)
        pseudo_embeddings = torch.rand(1, tokens.size(1), d_model)  # (batch_size, seq_len, d_model)
        print(f"  Shape of pseudo embeddings: {pseudo_embeddings.shape}")

        # Apply positional encoding
        pos_encoded = pos_encoder(pseudo_embeddings)
        print(f"  Shape after positional encoding: {pos_encoded.shape}")

        # Apply masked self-attention
        attn_output, attn_mask = masked_self_attn(pos_encoded)
        print(f"  Shape after masked self-attention: {attn_output.shape}")

        # Apply feed-forward network
        ff_output = feed_forward(attn_output)
        print(f"  Shape after feed-forward: {ff_output.shape}")

        # Display the effect of positional encoding, attention, and feed-forward for all tokens
        print(f"  Transformer effect on tokens:")
        for j in range(tokens.size(1)):
            if attention_mask[0, j] == 1:  # Only show for non-padding tokens
                print(f"    Token {j}:")
                print(f"      Initial:   {pseudo_embeddings[0, j, :5].tolist()}")
                print(f"      Positional:{pos_encoded[0, j, :5].tolist()}")
                print(f"      Attention Mask: {attn_mask[j, :5].tolist()}")  # Show first 5 values
                print(f"      Attention: {attn_output[0, j, :5].tolist()}")
                print(f"      Feed-Forward: {ff_output[0, j, :5].tolist()}")
        print()

    print("--- End of Expanded Test ---")

# Run the expanded test
test_all_transformer_components()


--- Testing Transformer Components ---

Dataset configuration:
  Max length: 10
  Overlap: 2
  Total windows: 5
  Vocabulary size: 27
  Embedding dimension: 16
  Number of attention heads: 2
  Feed-forward dimension: 16

Window 0 (from document 0):
  Original tokens: [2, 4, 5, 6, 7, 8, 3, 0, 0, 0]
  Attention mask: [1, 1, 1, 1, 1, 1, 1, 0, 0, 0]
  Decoded: '<SOS> This is a short sentence. <EOS> <PAD> <PAD> <PAD>'
  Shape of pseudo embeddings: torch.Size([1, 10, 16])
  Shape after positional encoding: torch.Size([1, 10, 16])
  Shape after masked self-attention: torch.Size([1, 10, 16])
  Shape after feed-forward: torch.Size([1, 10, 16])
  Transformer effect on tokens:
    Token 0:
      Initial:   [0.50750333070755, 0.8034182190895081, 0.532285213470459, 0.5399761199951172, 0.6362065076828003]
      Positional:[0.50750333070755, 1.8034181594848633, 0.532285213470459, 1.5399761199951172, 0.6362065076828003]
      Attention Mask: [0.0, -inf, -inf, -inf, -inf]
      Attention: [-0.21086835

## 📝 Decoder Layer

Implementation of a Transformer Decoder Layer with Masked Multi-Head Attention and Feed Forward Network








![Decoder Only Architecture](https://drive.google.com/uc?id=1ksROxQxf3b7dlBUoIQggzyLeBaPO-AQn)


In [12]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model: int, nhead: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.masked_attention = MaskedAttention(d_model, nhead, dropout)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Masked Multi-Head Attention
        normed_x = self.norm1(x)
        attn_output, _ = self.masked_attention(normed_x) # _ because we returned also the mask in the previous demonstration
        x = x + self.dropout1(attn_output)

        # Feed Forward
        normed_x = self.norm2(x)
        ff_output = self.feed_forward(normed_x)
        x = x + self.dropout2(ff_output)

        return x

## 📝 Decoder-only Transformer

The components of a Decoder-Only Transformer include an embedding layer for token representation, positional encoding for sequential information, stacked decoder layers for hierarchical processing, layer normalization for stability, and an output projection layer for generating tokens.

In [13]:
class DecoderOnlyTransformer(nn.Module):
    def __init__(self, vocab_size: int, d_model: int, nhead: int, num_layers: int,
                 d_ff: int, max_seq_length: int, dropout: float = 0.1):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.max_seq_length = max_seq_length

        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, d_model)

        # Positional encoding
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

        # Decoder layers
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, nhead, d_ff, dropout)
            for _ in range(num_layers)
        ])

        # Final layer norm
        self.final_norm = nn.LayerNorm(d_model)

        # Output projection
        self.output_projection = nn.Linear(d_model, vocab_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x shape: (batch_size, seq_len)

        # Embed the input
        x = self.embedding(x) * math.sqrt(self.d_model)

        # Add positional encoding
        x = self.positional_encoding(x)

        # Apply decoder layers
        for layer in self.layers:
            x = layer(x)

        # Apply final layer norm
        x = self.final_norm(x)

        # Project to vocabulary size
        output = self.output_projection(x)

        return output

    def generate(self, start_tokens: torch.Tensor, max_length: int,
                 temperature: float = 1.0) -> torch.Tensor:
        self.eval()
        current_seq = start_tokens

        with torch.no_grad():
            for _ in range(max_length - start_tokens.size(1)):
                # Ensure we're not exceeding the maximum sequence length
                if current_seq.size(1) > self.max_seq_length:
                    current_seq = current_seq[:, -self.max_seq_length:]

                # Get model predictions
                logits = self(current_seq)
                next_token_logits = logits[:, -1, :] / temperature

                # Sample next token
                probs = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)

                # Append next token to sequence
                current_seq = torch.cat([current_seq, next_token], dim=1)

                # Check if we've generated an EOS token
                if next_token.item() == self.vocab_size - 1:  # Assuming EOS is the last token
                    break

        return current_seq

### ✍ Displaying the Decoder-only Transformer Architecture

In [14]:
!pip install torchinfo
from torchinfo import summary

# Initialize the model with some example parameters
vocab_size = 10000
d_model = 512
nhead = 2
num_layers = 1
d_ff = 2048
max_seq_length = 1024
dropout = 0.1

# Define your model
model = DecoderOnlyTransformer(
    vocab_size=vocab_size,
    d_model=d_model,
    nhead=nhead,
    num_layers=num_layers,
    d_ff=d_ff,
    max_seq_length=max_seq_length,
    dropout=dropout
)

# Print the model summary
summary(model, input_size=(1, max_seq_length), dtypes=[torch.int64])

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


Layer (type:depth-idx)                        Output Shape              Param #
DecoderOnlyTransformer                        [1, 1024, 10000]          --
├─Embedding: 1-1                              [1, 1024, 512]            5,120,000
├─PositionalEncoding: 1-2                     [1, 1024, 512]            --
├─ModuleList: 1-3                             --                        --
│    └─DecoderLayer: 2-1                      [1, 1024, 512]            --
│    │    └─LayerNorm: 3-1                    [1, 1024, 512]            1,024
│    │    └─MaskedAttention: 3-2              [1, 1024, 512]            1,050,624
│    │    └─Dropout: 3-3                      [1, 1024, 512]            --
│    │    └─LayerNorm: 3-4                    [1, 1024, 512]            1,024
│    │    └─FeedForward: 3-5                  [1, 1024, 512]            2,099,712
│    │    └─Dropout: 3-6                      [1, 1024, 512]            --
├─LayerNorm: 1-4                              [1, 1024, 512]        

## 📝 Training the Decoder-only Transformer

In [15]:
# Load the tiny_shakespeare dataset
dataset = load_dataset("tiny_shakespeare", split="train")
# Load the tiny_shakespeare dataset
# dataset = load_dataset("lyimo/shakespear", split="train")

# Extract the text from the dataset
texts = dataset["text"]

# Hyperparameters
d_model = 256
nhead = 2
num_layers = 2
d_ff = 256
max_seq_length = 128
batch_size = 64
num_epochs = 10
learning_rate = 0.0001
dropout = 0.2

# Tokenize and prepare data
tokenizer = SimpleTokenizer()
tokenizer.fit(texts)
vocab_size = len(tokenizer.word_to_idx)

dataset = TextDataset(texts, tokenizer, max_seq_length)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

print(f"Vocabulary size: {vocab_size}")

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create model and move to device
model = DecoderOnlyTransformer(vocab_size, d_model, nhead, num_layers, d_ff, max_seq_length, dropout).to(device)

# Create optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.word_to_idx["<PAD>"])

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch_idx, batch in enumerate(train_loader):
        optimizer.zero_grad()

        input_seq, _, _ = batch  # Unpack batch
        input_seq = input_seq.squeeze(1).to(device)  # Move input to device and remove extra dimension

        # Forward pass
        output = model(input_seq)


        # Reshape output tensor
        output = output[:, :-1, :].contiguous().view(-1, output.size(-1))  # Shift predictions to the left

        # Shift targets to the right (original targets)
        target_seq = input_seq[:, 1:].contiguous().view(-1)


        # Compute loss
        loss = criterion(output, target_seq)

        # Debugging prints
        print(f"Loss: {loss.item()}")

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if batch_idx == 0:
          # Debugging prints
          print(f"Epoch: {epoch+1}, Batch: {batch_idx+1}")
          print(f"Input sequence shape: {input_seq.shape}")
          print(f"Input sequence: {input_seq.unsqueeze(1)}")
          print(f"Output shape before reshape: {output.shape}")
          print(f"Output shape after reshape: {output.shape}")
          print(f"Target sequence shape: {target_seq.shape}")

    # Print epoch loss
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader):.4f}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading builder script:   0%|          | 0.00/3.73k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/6.10k [00:00<?, ?B/s]

The repository for tiny_shakespeare contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/tiny_shakespeare.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


Downloading data:   0%|          | 0.00/435k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1 [00:00<?, ? examples/s]

Vocabulary size: 23845
Loss: 10.230379104614258
Epoch: 1, Batch: 1
Input sequence shape: torch.Size([64, 128])
Input sequence: tensor([[[    2,   120,   604,  ...,    53,  5272,    44]],

        [[    2,  6008,   621,  ...,    53, 12935,   102]],

        [[    2,   221,   235,  ...,  1464,    68, 15352]],

        ...,

        [[    2,   124,   142,  ...,   122, 18124,    28]],

        [[    2,  6613,   529,  ...,    21,    46,  5532]],

        [[    2,  2028,  8209,  ...,   235,  4709,  8219]]], device='cuda:0')
Output shape before reshape: torch.Size([8128, 23845])
Output shape after reshape: torch.Size([8128, 23845])
Target sequence shape: torch.Size([8128])
Loss: 10.25078010559082
Loss: 10.250794410705566
Loss: 10.237319946289062
Loss: 10.231711387634277
Loss: 10.222508430480957
Loss: 10.237987518310547
Loss: 10.212271690368652
Loss: 10.215311050415039
Loss: 10.223275184631348
Loss: 10.2117280960083
Loss: 10.201935768127441
Loss: 10.206974029541016
Loss: 10.209392547607422
Los

### ✍ Testing the Decoder-only Transformer

In [16]:
texts = ["Better three hours too soon than", " I believe I can ", "My words fly up, my", "Brevity is ", "Love looks not with the eyes, but", "To be or "]

for quote in texts:
  start_tokens = torch.tensor(tokenizer.encode(quote)).unsqueeze(0).to(device)  # Add batch dimension and move to device

  generated_tokens = model.generate(start_tokens, max_length=20, temperature=.9)
  generated_text = tokenizer.decode(generated_tokens.squeeze().tolist())

  print(generated_text)

Better three hours too soon than no blame but a in Edward, By'r like young on bound their in fifteen
I believe I can QUEEN disgrace, Bridget you, on Hastings, have you and this he small: of First Elizabeth It
My words fly up, my prophesy Indeed, I. thus. Such wrap Menenius. amain, ever, pleader, Third fool, manner sugar, gin.
<UNK> is thou thou you Warwick. not And forces captain gentle-sleeping Olympian cheek? For good Citizen: scarfs Jupiter, 'lordship:' wisely.
Love looks not with the eyes, but are, well-warranted By to had nod; Marry, indeed. and SICINIUS: nice are such
To be or begin. SCROOP: Cobham, youth: say. applause cheaper brazen duke? on Sneak he's he, went answering baseness material.


# Decoder-only with MoE instead of FFN

In the Sparse Mixture of Experts (MoE) architecture, the self-attention mechanism within each transformer block stays the same.

However, a key modification is made to the structure **of each block**: the standard **feed-forward neural network** is replaced with **multiple sparsely activated feed-forward networks, known as experts.**

"Sparse activation" means that each token in the sequence is routed to only a small number of these experts—usually one or two—out of the entire pool.



**✨ Additional Resources:**

*   makeMoE: Implement a Sparse Mixture of Experts Language Model from Scratch [Link-huggingface](https://huggingface.co/blog/AviSoori1x/makemoe-from-scratch)




## 📝 Expert Layer

In [17]:
class Expert(nn.Module):
    """ An MLP with a single hidden layer and ReLU activation, serving as an expert in a Mixture of Experts. """
    def __init__(self, n_embd: int, dropout: float = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

## 📝 Gating in MoE Architectures

Types of gating in Mixture of Experts (MoE) systems include Top-k gating, Noisy Top-k gating (as implemented here), and other variants like Hierarchical gating or Soft gating.

Gating is essential in MoE systems because it determines which experts to use for each input, allowing the model to specialize different experts for different types of inputs or tasks.

Specifically, Noisy Top-k gating adds controlled randomness to the expert selection process, which can help balance expert utilization and potentially improve model performance by introducing exploration in the routing mechanism.

In [18]:
class NoisyTopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k_moe):
        super(NoisyTopkRouter, self).__init__()
        # Store the top_k_moe parameter which specifies the number of top experts to select
        self.top_k_moe = top_k_moe
        # Linear layer to compute logits for routing
        self.topkroute_linear = nn.Linear(n_embed, num_experts)
        # Linear layer to compute noise logits for added noise
        self.noise_linear = nn.Linear(n_embed, num_experts)

    def forward(self, mh_output):
        # Compute the logits for routing to experts
        logits = self.topkroute_linear(mh_output)
        # Compute the noise logits
        noise_logits = self.noise_linear(mh_output)
        # Generate noise with standard deviation determined by softplus of noise logits
        noise = torch.randn_like(logits) * F.softplus(noise_logits)
        # Add noise to the original logits to get noisy logits
        noisy_logits = logits + noise
        # Select the top k logits and their indices from the noisy logits
        top_k_moe_logits, indices = noisy_logits.topk(self.top_k_moe, dim=-1)
        # Create a tensor full of -inf values
        zeros = torch.full_like(noisy_logits, float('-inf'))
        # Scatter the top k logits into the zeros tensor to create a sparse logits tensor
        sparse_logits = zeros.scatter(-1, indices, top_k_moe_logits)
        # Apply softmax to the sparse logits to get the final router output
        router_output = F.softmax(sparse_logits, dim=-1)
        # Return the router output and the indices of the selected experts
        return router_output, indices

## 📚 Sparse MoE Layer

In [19]:
class SparseMoE(nn.Module):
    def __init__(self, n_embed, num_experts, top_k_moe):
        super(SparseMoE, self).__init__()
        # Initialize the NoisyTopkRouter to determine which experts to activate
        self.router = NoisyTopkRouter(n_embed, num_experts, top_k_moe)
        # Create a list of expert networks, each being a feed-forward network
        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
        # Store the number of top experts to activate
        self.top_k_moe = top_k_moe

    def forward(self, x):
        # Get the gating output and indices from the router
        gating_output, indices = self.router(x)
        # Initialize the final output tensor with zeros, having the same shape as x
        final_output = torch.zeros_like(x)
        # Flatten the input tensor to simplify processing
        flat_x = x.view(-1, x.size(-1))
        # Flatten the gating output tensor to align with the flattened input
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))

        # Iterate over each expert
        for i, expert in enumerate(self.experts):
            # Create a mask to identify where the current expert is used
            expert_mask = (indices == i).any(dim=-1)
            # Flatten the expert mask to match the flattened input
            flat_mask = expert_mask.view(-1)

            if flat_mask.any():  # Check if there are any positions using the current expert
                # Extract the inputs for the current expert based on the mask
                expert_input = flat_x[flat_mask]
                # Get the output from the current expert
                expert_output = expert(expert_input)
                # Get the gating scores for the current expert
                gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)
                # Compute the weighted output based on the gating scores
                weighted_output = expert_output * gating_scores
                # Add the weighted output to the final output tensor
                final_output[expert_mask] += weighted_output.view_as(final_output[expert_mask])

        # Return the final output tensor which combines the results from all activated experts
        return final_output


## 📝 Decoder with Sparse MoE

![Decoder Only Architecture](https://drive.google.com/uc?id=1ksROxQxf3b7dlBUoIQggzyLeBaPO-AQn)



In [20]:
class DecoderLayerMoE(nn.Module):
    def __init__(self, d_model, nhead, d_ff, num_experts, top_k_moe, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.sparse_moe = SparseMoE(d_model, num_experts, top_k_moe)

    def forward(self, x):
        x2 = self.norm1(x)
        attn_output, _ = self.self_attn(x2, x2, x2, attn_mask=self.generate_square_subsequent_mask(x.size(1)).to(x.device))
        x = x + self.dropout1(attn_output)
        x2 = self.norm2(x)
        moe_output = self.sparse_moe(x2)
        x = x + self.dropout2(moe_output)
        return x

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

### ✍ Showcase how Sparse MoE handles its inputs

-  Experiment: Change the number of experts in the SparseMoE_example model.
  -  Observation: Observe how increasing or decreasing the number of experts affects the routing, gating outputs, and final output.

- Experiment: Adjust the top_k_moe parameter to select more or fewer top experts.
  - Observation: See how the number of experts activated for each token changes and how it impacts the final output.



In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SparseMoE_example(nn.Module):
    def __init__(self, n_embed, num_experts, top_k_moe):
        super(SparseMoE_example, self).__init__()
        self.router = NoisyTopkRouter(n_embed, num_experts, top_k_moe)
        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
        self.top_k_moe = top_k_moe

    def forward(self, x):
        gating_output, indices = self.router(x)
        final_output = torch.zeros_like(x)
        flat_x = x.view(-1, x.size(-1))
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))

        for i, expert in enumerate(self.experts):
            expert_mask = (indices == i).any(dim=-1)
            flat_mask = expert_mask.view(-1)

            if flat_mask.any():
                expert_input = flat_x[flat_mask]
                expert_output = expert(expert_input)
                gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)
                weighted_output = expert_output * gating_scores
                final_output[expert_mask] += weighted_output.view_as(final_output[expert_mask])

        return final_output

    def forward_debug_example(self, x):
        # Forward pass with debug prints
        gating_output, indices = self.router(x)
        print("Gating Output Shape:", gating_output.shape)
        print("Gating Output:", gating_output)
        print("Expert Indices Shape:", indices.shape)
        print("Expert Indices:", indices)

        print("Input Shape:", x.shape)
        print("Input:", x)

        final_output = torch.zeros_like(x)
        flat_x = x.view(-1, x.size(-1))
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))

        for i, expert in enumerate(self.experts):
            print("\n" + "-"*50)
            expert_mask = (indices == i).any(dim=-1)
            flat_mask = expert_mask.view(-1)

            if flat_mask.any():
                expert_input = flat_x[flat_mask]
                print(f"Expert {i} Input Shape:", expert_input.shape)
                print(f"Expert {i} Input:", expert_input)

                expert_output = expert(expert_input)
                print(f"Expert {i} Output Shape:", expert_output.shape)
                print(f"Expert {i} Output:", expert_output)

                gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)
                print(f"Gating Scores for Expert {i}:")
                print(gating_scores.squeeze())

                print(f"Weighted Output for Expert {i}:")
                weighted_output = expert_output * gating_scores
                print(weighted_output)

                final_output[expert_mask] += weighted_output.view_as(final_output[expert_mask])
                print(f"Expert {i} final_output Shape:", final_output[expert_mask].shape)
                print(f"Expert {i} final_output:", final_output[expert_mask])

        print("Final MoE Output Shape:", final_output.shape)
        print("Final MoE Output:", final_output)
        print("-"*50)
        return final_output


# Example usage and debugging prints
def test_sparse_moe():
    # Parameters
    batch_size = 2   #
    seq_length = 1   # number of tokens, if 1 then it is easier to see which experts are activated and how each embedding is calculated
    n_embed = 5
    num_experts = 6  # Increased number of experts
    top_k_moe = 2   # If you modify, more or less experts will be activated for each input token

    # Random input tensor (simulating token embeddings)
    random_input = torch.randn(batch_size, seq_length, n_embed)

    # Initialize SparseMoE
    sparse_moe = SparseMoE_example(n_embed, num_experts, top_k_moe)

    # Forward pass with debugging example
    final_output = sparse_moe.forward_debug_example(random_input)

    print("\nRandom Input Tensor:")
    print(random_input)
    print("\nFinal Output Tensor (after MoE processing):")
    print(final_output)

# Run the test function
test_sparse_moe()

Gating Output Shape: torch.Size([2, 1, 6])
Gating Output: tensor([[[0.6733, 0.0000, 0.0000, 0.0000, 0.3267, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.3676, 0.0000, 0.6324]]],
       grad_fn=<SoftmaxBackward0>)
Expert Indices Shape: torch.Size([2, 1, 2])
Expert Indices: tensor([[[0, 4]],

        [[5, 3]]])
Input Shape: torch.Size([2, 1, 5])
Input: tensor([[[ 0.3745,  0.9364,  1.1533,  1.2270,  0.7363]],

        [[-0.7316,  0.3322, -0.4338,  0.0275,  1.8612]]])

--------------------------------------------------
Expert 0 Input Shape: torch.Size([1, 5])
Expert 0 Input: tensor([[0.3745, 0.9364, 1.1533, 1.2270, 0.7363]])
Expert 0 Output Shape: torch.Size([1, 5])
Expert 0 Output: tensor([[-0.1023, -0.5338,  0.2251,  0.2022, -0.5254]], grad_fn=<MulBackward0>)
Gating Scores for Expert 0:
tensor(0.6733, grad_fn=<SqueezeBackward0>)
Weighted Output for Expert 0:
tensor([[-0.0688, -0.3594,  0.1515,  0.1361, -0.3538]], grad_fn=<MulBackward0>)
Expert 0 final_output Shape: torch.Size([1, 5])
E

## 📝 Decoder-only Transformer with MoE

In [23]:
class DecoderOnlyTransformerMoE(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers, d_ff, max_seq_length, dropout, num_experts, top_k_moe):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_seq_length)
        self.layers = nn.ModuleList([DecoderLayerMoE(d_model, nhead, d_ff, num_experts, top_k_moe, dropout) for _ in range(num_layers)])
        self.norm = nn.LayerNorm(d_model)
        self.output = nn.Linear(d_model, vocab_size)
        self.max_seq_length = max_seq_length
        self.num_experts = num_experts
        self.top_k_moe = top_k_moe
        self.vocab_size = vocab_size

    def forward(self, x):
        x = self.embedding(x)
        x = self.pos_encoder(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)
        return self.output(x)

    def generate(self, start_tokens: torch.Tensor, max_length: int, temperature: float = 1.0) -> torch.Tensor:
        self.eval()
        current_seq = start_tokens

        with torch.no_grad():  # Disable gradient computation
            # Generate tokens until max_length is reached or end token is generated
            for _ in range(max_length - start_tokens.size(1)):
                # Ensure the sequence length does not exceed max_seq_length
                if current_seq.size(1) > self.max_seq_length:
                    current_seq = current_seq[:, -self.max_seq_length:]

                # Get logits from the model
                logits = self(current_seq)

                # Extract logits for the next token and scale by temperature
                next_token_logits = logits[:, -1, :] / temperature

                # Compute probabilities using softmax
                probs = F.softmax(next_token_logits, dim=-1)

                # Sample the next token from the probability distribution
                next_token = torch.multinomial(probs, num_samples=1)

                # Append the next token to the current sequence
                current_seq = torch.cat([current_seq, next_token], dim=1)

                # Stop if the end token is generated (vocab_size - 1 assumed to be the end token)
                if next_token.item() == self.vocab_size - 1:
                    break

        # Return the generated sequence
        return current_seq

### ✍ Displaying the Decoder-only Transformer with MoE Architecture

In [24]:
# Initialize the model with some example parameters
vocab_size = 10000
d_model = 512
nhead = 2
num_layers = 1
d_ff = 2048
max_seq_length = 1024
dropout = 0.1
num_experts = 4
top_k_moe = 2

# Define your model
model = DecoderOnlyTransformer(
    vocab_size=vocab_size,
    d_model=d_model,
    nhead=nhead,
    num_layers=num_layers,
    d_ff=d_ff,
    max_seq_length=max_seq_length,
    dropout=dropout,
)

# Define your model
model_moe = DecoderOnlyTransformerMoE(
    vocab_size=vocab_size,
    d_model=d_model,
    nhead=nhead,
    num_layers=num_layers,
    d_ff=d_ff,
    max_seq_length=max_seq_length,
    dropout=dropout,
    num_experts=num_experts,
    top_k_moe=top_k_moe
)

# Print the model summary
print(50*"-")
print(summary(model, input_size=(1, max_seq_length), dtypes=[torch.int64]))
print(50*"-")
print(summary(model_moe, input_size=(1, max_seq_length), dtypes=[torch.int64]))

--------------------------------------------------
Layer (type:depth-idx)                        Output Shape              Param #
DecoderOnlyTransformer                        [1, 1024, 10000]          --
├─Embedding: 1-1                              [1, 1024, 512]            5,120,000
├─PositionalEncoding: 1-2                     [1, 1024, 512]            --
├─ModuleList: 1-3                             --                        --
│    └─DecoderLayer: 2-1                      [1, 1024, 512]            --
│    │    └─LayerNorm: 3-1                    [1, 1024, 512]            1,024
│    │    └─MaskedAttention: 3-2              [1, 1024, 512]            1,050,624
│    │    └─Dropout: 3-3                      [1, 1024, 512]            --
│    │    └─LayerNorm: 3-4                    [1, 1024, 512]            1,024
│    │    └─FeedForward: 3-5                  [1, 1024, 512]            2,099,712
│    │    └─Dropout: 3-6                      [1, 1024, 512]            --
├─LayerNorm: 1-4 

In [27]:
def print_model_summary(model, input_size):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    dummy_input = torch.zeros(input_size, dtype=torch.int64).to(device)

    def register_hook(module):
        def hook(module, input, output):
            class_name = module.__class__.__name__
            module_idx = len(summary)
            m_key = f"{module_idx:03d} {class_name}"
            summary[m_key] = {}
            summary[m_key]["input_shape"] = list(input[0].size())
            if isinstance(output, torch.Tensor):
                summary[m_key]["output_shape"] = list(output.size())
            elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
                summary[m_key]["output_shape"] = [list(out.size()) for out in output]
            else:
                summary[m_key]["output_shape"] = "multiple outputs"
            params = sum(p.numel() for p in module.parameters())
            summary[m_key]["num_params"] = params

        if not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList):
            hooks.append(module.register_forward_hook(hook))

    summary = {}
    hooks = []
    model.apply(register_hook)
    model(dummy_input)
    for h in hooks:
        h.remove()

    print("----------------------------------------------------------------")
    print("{:>20}  {:>25} {:>15}".format("Layer (type)", "Input Shape", "Param #"))
    print("================================================================")
    total_params = 0
    total_output = 0
    for layer in summary:
        line_new = "{:>20}  {:>25} {:>15}".format(
            layer,
            str(summary[layer]["input_shape"]),
            "{0:,}".format(summary[layer]["num_params"]),
        )
        total_params += summary[layer]["num_params"]
        if isinstance(summary[layer]["output_shape"], list) and all(isinstance(i, int) for i in summary[layer]["output_shape"]):
            total_output += np.prod(summary[layer]["output_shape"])
        print(line_new)
    print("================================================================")
    print(f"Total params: {total_params:,}")
    print("----------------------------------------------------------------")

vocab_size = 10000
d_model = 512
nhead = 8
num_layers = 1
d_ff = 2048
max_seq_length = 1024
dropout = 0.1
num_experts = 2
top_k_moe = 1

# Ensure you have the correct definitions for these classes
# from your_model_definitions import DecoderOnlyTransformer, DecoderOnlyTransformerMoE

model = DecoderOnlyTransformer(
    vocab_size=vocab_size,
    d_model=d_model,
    nhead=nhead,
    num_layers=num_layers,
    d_ff=d_ff,
    max_seq_length=max_seq_length,
    dropout=dropout,
)

model_moe = DecoderOnlyTransformerMoE(
    vocab_size=vocab_size,
    d_model=d_model,
    nhead=nhead,
    num_layers=num_layers,
    d_ff=d_ff,
    max_seq_length=max_seq_length,
    dropout=dropout,
    num_experts=num_experts,
    top_k_moe=top_k_moe
)

print("Summary for DecoderOnlyTransformer")
print_model_summary(model, (1, max_seq_length))

print("\nSummary for DecoderOnlyTransformerMoE")
print_model_summary(model_moe, (1, max_seq_length))


Summary for DecoderOnlyTransformer
----------------------------------------------------------------
        Layer (type)                Input Shape         Param #
       000 Embedding                  [1, 1024]       5,120,000
001 PositionalEncoding             [1, 1024, 512]               0
       002 LayerNorm             [1, 1024, 512]           1,024
003 MultiheadAttention             [1, 1024, 512]       1,050,624
 004 MaskedAttention             [1, 1024, 512]       1,050,624
         005 Dropout             [1, 1024, 512]               0
       006 LayerNorm             [1, 1024, 512]           1,024
          007 Linear             [1, 1024, 512]       1,050,624
            008 ReLU            [1, 1024, 2048]               0
         009 Dropout            [1, 1024, 2048]               0
          010 Linear            [1, 1024, 2048]       1,049,088
     011 FeedForward             [1, 1024, 512]       2,099,712
         012 Dropout             [1, 1024, 512]               0


## 📝 Training the Decoder-only Transformer with MoE

In [28]:
# Load the tiny_shakespeare dataset
dataset = load_dataset("tiny_shakespeare", split="train")

# Extract the text from the dataset
texts = dataset["text"]

# Hyperparameters
d_model = 128
nhead = 2
num_layers = 2
d_ff = 256
max_seq_length = 64
batch_size = 32
num_epochs = 1
learning_rate = 0.0001
dropout = 0.2
num_experts=4
top_k_moe=2

# Tokenize and prepare data
tokenizer = SimpleTokenizer()
tokenizer.fit(texts)
vocab_size = len(tokenizer.word_to_idx)

dataset = TextDataset(texts, tokenizer, max_seq_length)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

print(f"Vocabulary size: {vocab_size}")

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create model and move to device
model_moe = DecoderOnlyTransformerMoE(vocab_size, d_model, nhead, num_layers, d_ff, max_seq_length, dropout, num_experts, top_k_moe).to(device)

# Create optimizer and loss function
optimizer = torch.optim.AdamW(model_moe.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.word_to_idx["<PAD>"])

# Training loop
for epoch in range(num_epochs):
    model_moe.train()
    total_loss = 0
    for batch_idx, batch in enumerate(train_loader):
        optimizer.zero_grad()

        input_seq, _, _ = batch  # Unpack batch
        input_seq = input_seq.squeeze(1).to(device)  # Move input to device and remove extra dimension

        # Forward pass
        output = model_moe(input_seq)


        # Reshape output tensor
        output = output[:, :-1, :].contiguous().view(-1, output.size(-1))  # Shift predictions to the left

        # Shift targets to the right (original targets)
        target_seq = input_seq[:, 1:].contiguous().view(-1)


        # Compute loss
        loss = criterion(output, target_seq)

        # Debugging prints
        print(f"Loss: {loss.item()}")

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if batch_idx == 0:
          # Debugging prints
          print(f"Epoch: {epoch+1}, Batch: {batch_idx+1}")
          print(f"Input sequence shape: {input_seq.shape}")
          print(f"Input sequence: {input_seq.unsqueeze(1)}")
          print(f"Output shape before reshape: {output.shape}")
          print(f"Output shape after reshape: {output.shape}")
          print(f"Target sequence shape: {target_seq.shape}")

    # Print epoch loss
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader):.4f}")


Vocabulary size: 23845
Loss: 10.261877059936523
Epoch: 1, Batch: 1
Input sequence shape: torch.Size([32, 64])
Input sequence: tensor([[[    2,  1721,   102,  ...,    35,  5782,    61]],

        [[    2,    32,  4437,  ...,  4432,   102,  5270]],

        [[    2, 10984,  1131,  ..., 14708,   122,   424]],

        ...,

        [[    2,   122,    46,  ..., 13476, 10791,   675]],

        [[    2,  7823,  1873,  ..., 17657,   547,   144]],

        [[    2,   235,  8362,  ...,  7171,   501,    35]]], device='cuda:0')
Output shape before reshape: torch.Size([2016, 23845])
Output shape after reshape: torch.Size([2016, 23845])
Target sequence shape: torch.Size([2016])
Loss: 10.270655632019043
Loss: 10.238564491271973
Loss: 10.24071979522705
Loss: 10.202763557434082
Loss: 10.206976890563965
Loss: 10.17464542388916
Loss: 10.15494441986084
Loss: 10.17193603515625
Loss: 10.165684700012207
Loss: 10.136601448059082
Loss: 10.122502326965332
Loss: 10.115411758422852
Loss: 10.085820198059082
Loss:

### ✍ Testing the Decoder-only Transformer with MoE

In [29]:
texts = ["Better three hours too soon than", " I believe I can ", "My words fly up, my", "Brevity is ", "Love looks not with the eyes, but", "To be or "]

for quote in texts:
  start_tokens = torch.tensor(tokenizer.encode(quote)).unsqueeze(0).to(device)  # Add batch dimension and move to device

  generated_tokens = model_moe.generate(start_tokens, max_length=20, temperature=.9)
  generated_text = tokenizer.decode(generated_tokens.squeeze().tolist())

  print(generated_text)

Better three hours too soon than you humour he A to deep it Teaching Gaunt, me the use QUEEN their
I believe I can begin great for the I Marcius, these And Of too Lancaster; than And the from of
My words fly up, my from O, here? in wed prorogue whom do sirrah, Now, my a the would we
<UNK> is to of should they covenant their The and to you, by of then fall a function, I is
Love looks not with the eyes, but creeping A first him I I King have agony. and be not pardoning
To be or verier his makes Aufidius I and royal is of, hotly you along one without knowledge, In But
