In [3]:
import dataclasses

import torch
from labml_helpers.module import Module
from torch import nn
from torch.utils.data import Dataset, DataLoader

from labml import experiment, lab, tracker, monit, logger
from labml.logger import Text
from labml.utils.download import download_file
from labml_nn.experiments.nlp_autoregression import transpose_batch
from labml_nn.optimizers.noam import Noam
from labml_nn.transformers import Encoder, MultiHeadAttention
from labml_nn.transformers.feed_forward import FeedForward
from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
from labml_nn.transformers.utils import subsequent_mask


In [None]:
class AutoregressiveModel(Module):
    """
    ## Auto regressive model
    This is a simple autoregressive model that uses a transformer encoder.
    """

    def __init__(self, src_embed: Module, encoder: Encoder, generator: Module):
        """
        * `src_embed` is the module that embeds the source tokens
        * `encoder` is the transformer encoder
        * `generator` is the module that generates the next token logits
        """
        super().__init__()
        self.src_embed = src_embed
        self.encoder = encoder
        self.generator = generator
        self.src_mask = None

    def forward(self, src: torch.Tensor):
        """
        * `src` is the source token ids
        """
        # Create subsequent mask, so that the transformer can only pay attention to past tokens.
        if self.src_mask is None or self.src_mask.size(0) != len(src):
            self.src_mask = subsequent_mask(len(src)).to(src.device)
        
        # Embed tokens and run through the encoder
        res = self.encoder(self.src_embed(src), self.src_mask)
        
        # Generate next token logits
        return self.generator(res)


In [None]:
@dataclasses.dataclass
class Configs:
    """
    ### Configurations for the Transformer model
    This class holds all the configurations for our model.
    """

    # Model dimensions
    d_model: int = 512  # Dimensionality of the token embeddings
    seq_len: int = 128  # Maximum sequence length
    batch_size: int = 32  # Batch size for training

    # Transformer specific configurations
    n_layers: int = 6  # Number of transformer layers
    n_heads: int = 8  # Number of attention heads
    dropout: float = 0.1  # Dropout rate

    # Feedforward network configuration
    d_ff: int = 2048  # Dimensionality of the feedforward network

    # GLU variant to use
    glu_variant: str = 'GLU'  # Can be 'GLU', 'Bilinear', 'ReGLU', 'GEGLU', 'SwiGLU', 'ReLU', 'GELU'

    # Training configurations
    epochs: int = 5  # Number of epochs to train for
    grad_norm_clip: float = 0.5  # Gradient clipping threshold

    # Add any additional configurations you might need


In [None]:
class TinyShakespeareDataset(Dataset):
    """
    ### Tiny Shakespeare Dataset
    This is a dataset class for the Tiny Shakespeare dataset.
    """

    def __init__(self, seq_len: int):
        """
        * `seq_len` is the length of the sequence of data taken for training
        """
        super().__init__()
        # Location of the text file
        path = lab.get_data_path() / 'tiny_shakespeare.txt'
        # Download the file if it's not present
        if not path.exists():
            download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)
        # Read the file content
        with open(str(path), 'r') as f:
            text = f.read()
        # Create a character set and map to indices
        self.chars = sorted(set(text))
        self.stoi = {ch: i for i, ch in enumerate(self.chars)}
        self.itos = {i: ch for i, ch in enumerate(self.chars)}
        # Encode the entire text
        self.data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
        # Define sequence length
        self.seq_len = seq_len

    def __len__(self):
        """
        Number of sequences in the dataset.
        """
        return (len(self.data) - self.seq_len)

    def __getitem__(self, index):
        """
        Get the input sequence (x) and the target sequence (y)
        Target sequence is the input sequence shifted by one character.
        """
        # Get the sequence of indices (offset by one for target sequence)
        x = self.data[index:index+self.seq_len]
        y = self.data[index+1:index+self.seq_len+1]
        return x, y
