# Table of Contents

1. [Transformer model components:](#Transformer-model-components:)  
   1.1 [FMM-Attention factorization](#FMM-Attention-factorization)  
   1.2 [Attention]()  
   1.3 [Multi-headed Attention]()  
   1.4 [Transformer Block]()  
   1.5 [Modality-specific components]()  
   1.6 [Transformer]()  
2. [Training](#training)  
3. [Data Prep]()  
4. [Running]()  

In [None]:
# Dependencies
import math
import torch
from torch import nn
from torch.nn import functional as F

# Transformer model components:
## FMM-Attention factorization

Input: Query, Key, and Value matrices (which are n by d matrices), along with n and d
Output: The result of Attention matrix * Value (an n by d matrix)
The Attention matrix in here is an estimate of softmax. This code is not optimized through Factorization, its just a test to help us take next steps
If the result is bad, but not too bad, increas p. It can go up to 4, but from 5 onward it would not be worth it in my opinion

In [None]:
# See score_function.py
def score_before_jit(Q, K, V, n, d, p=2):
    sqrtd = math.sqrt(d)
    fastmax = torch.zeros([n,n])
    div = torch.zeros(n)
    ans = torch.zeros([n,d])

    for i in range(n):
        for j in range(n):
            for k in range(p):
                fastmax[i][j] += torch.dot(Q[i], K[j])/sqrtd**k/math.factorial(k)
            
    for i in range(n):
        for j in range(n):
            div[i] += fastmax[i][j]

    for i in range(n):
        for j in range(n):
            fastmax[i][j] /= div[i]

    for i in range(n):
        for j in range(d):
            ans[i,j] = torch.dot(fastmax[i], V[:, j])

    return ans

score = torch.jit.script(score_before_jit)

## Attention Module

The core attention operation of a transformer. The work here is in learning
a square n x n attention matrix, A. Each row of A corresponds to a single
token and the row is a weighting of how much information should be accumulated
from each other token.

This weighting is ultimately determined by learning a good linear transformation
that maps each token vector to a "query" vector. The query vector for each token
gets compared with each other token by a "compatability function", and higher
compatability results in a higher weight in the resulting attention row.

Here we use "dot product attention" for the compatability function. We could
compute the compatability/closeness of the query vector directly with the
token vector, but it turns out that learning another linear transformation
from the token vector to a "key" vector space and then computing the
compatability between those produces better results.

Finally, we do a mat-mul of the attention matrix with the original matrix of
token vectors to get the updates from these weightings. And just as how it
is helpful to learn a transformation from the raw token vectors to key vectors,
it turns out to also be helpful to learn a transformation to a "value" vector
space and do the mat-mul update on this space instead. The final output is the
updated token vectors in the value vector space.

Note on masking/causal attention/auto-regressive/encoder-decoder:  
More often than not, we will apply a mask on the attention matrix so that each token only gets weightings from the tokens earlier in the sequence, and the weightings from tokens later in the sequence get zeroed out. A transformer that doesn't apply masking is often referred to as an "encoder-only" model (e.g. BERT) and a transformer that applies masking is "decoder-only" (e.g. GPT3). Because masking makes it so that tokens can only be updated by information from earlier tokens, masked-attention is sometimes called "causal attention", and a model with this forumlation is called "auto-regressive".

Classic Q, K, V explanation:  
Q: what a token "wants"   
K: what a token "is"  
V: what a token "shares"  

In [None]:
# See model.py
class AttentionHead(nn.Module):
    """The core attention operation of a transformer."""

    def __init__(self, d_model, d_head, dropout_rate=0.2, use_masking=False, fastmax=False, fmm_p=2):
        super().__init__()
        self.d_head = d_head
        self.dropout_rate = dropout_rate
        self.use_masking = use_masking
        self.fastmax = fastmax
        self.fmm_p = fmm_p
        self.query_transform = nn.Linear(d_model, d_head, bias=False)
        self.key_transform = nn.Linear(d_model, d_head, bias=False)
        self.value_transform = nn.Linear(d_model, d_head, bias=False)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, X):
        # X shape: (b, n, d_model)
        # b: batch size
        # n: num tokens/features
        # d_model: dim of vector representation of each token
        Q = self.query_transform(X)  # (b, n, d_head)
        K = self.key_transform(X)  # (b, n, d_head)
        V = self.value_transform(X)  # (b, n, d_head)
        if self.fastmax:
            n = X.shape[1]
            results = []
            for i in range(Q.shape[0]):
                QQ = Q[i]  # (n, d_head)
                KK = K[i]  # (n, d_head)
                VV = V[i]  # (n, d_head)
                VV_hat = score(QQ, KK, VV, n, self.d_head, self.fmm_p)  # (n, d_head)
                results.append(VV_hat)  # (n, d_head)
            V_hat = torch.stack(results)  # (b, n, d_head)
        elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
            # If using PyTorch 2.0+, we have access to built-in attention that implements Flash Attention CUDA kernels
            V_hat = F.scaled_dot_product_attention(
                Q, K, V,
                dropout_p=self.dropout_rate if self.training else 0,
                is_causal=self.use_masking,
            )
        else:
            # manual implementation of scaled dot product attention
            A = Q @ K.mT  # (b, n, d_head) @ (b, d_head, n) -> (b, n, n)
            A = A / math.sqrt(self.d_head)
            if self.use_masking:
                # Mask out upper triangular of A so that weightings only apply to
                # previous tokens in the sequence. Filling with -inf before softmax
                # will cause softmax to give weight 0 to these tokens.
                n = X.shape[1]
                upper_tri_mask = torch.triu(
                    torch.ones((n, n), dtype=bool, device=X.device), diagonal=1
                )  # Upper triangular matrix
                A = A.masked_fill(
                    upper_tri_mask, float("-inf")
                )  # Fill upper triangular with -inf
            A = F.softmax(A, dim=-1)
            A = self.dropout(A)
            V_hat = A @ V  # (b, n, n) @ (b, n, d_head) -> (b, n, d_head)
        return V_hat

## Multi-head Attention

A simple module runs some number of attention modules in parallel
and concatenates the resulting vectors into a single, longer output vector.

Each parallel attention head works with vectors of size d_model / num_heads.

There is also a simple linear transformation after the concatenation.
I'd be interested in running some ablations to see if this is necessary.


In [None]:
class MultiHeadAttention(nn.Module):
    """Runs multiple attention modules in parallel and concatenates the results."""

    def __init__(self, num_heads, d_model, dropout_rate=0.2, use_masking=False, fastmax=False, fmm_p=2):
        super().__init__()
        d_head = d_model // num_heads
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            self.heads.append(AttentionHead(d_model, d_head, dropout_rate, use_masking, fastmax, fmm_p))
        self.linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, X):
        # X shape: (b, n, d_model)
        heads_outputs = []
        for head in self.heads:
            heads_outputs.append(head(X))  # (b, n, d_head)
        outputs_concated = torch.cat(heads_outputs, dim=-1)  # (b, n, d_model)
        output = self.linear(outputs_concated)  # (b, n, d_model)
        output = self.dropout(output)
        return output

## Transformer Block

The basic element of a transformer, which is repeated some number of times
in the larger model. Each of these blocks does a round of attention
followed by an MLP with one hidden layer. The
attention in each block is split into multiple "heads" which allow for
different attention queries in parallel at each step.

In the MPP sub-block, there are no connections between tokens/features -
these are just functions mapping each token vector to an updated vector of
the same dimensions.

Following best practices for training deep networks, layer norm is included
after both the attention sub-block and the MLP sub-block. A
residual connection is included at each of these points as well.

In [None]:
class TransformerBlock(nn.Module):
    """Basic element of a Transformer: Attention + MLP."""

    def __init__(self, num_heads, d_model, d_mlp, dropout_rate=0.2, use_masking=False, fastmax=False, fmm_p=2):
        super().__init__()
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.attention_subblock = MultiHeadAttention(
            num_heads, d_model, dropout_rate, use_masking, fastmax, fmm_p
        )
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.mlp_subblock = nn.Sequential(
            nn.Linear(d_model, d_mlp),
            nn.ReLU(),
            nn.Linear(d_mlp, d_model),
            nn.Dropout(dropout_rate),
        )

    def forward(self, X):
        # X shape: (b, n, d_model)
        attention_output = self.layer_norm1(X)
        attention_output = self.attention_subblock(attention_output)  # (b, n, d_model)
        attention_output = attention_output + X  # residual connection
        mlp_output = self.layer_norm2(attention_output)
        mlp_output = self.mlp_subblock(mlp_output)  # (b, n, d_model)
        mlp_output = mlp_output + X  # residual connection
        return mlp_output

## Modality-specific embeddings and classifiers

In [None]:
# See modules.py
class PerTokenClassifier(nn.Module):
    """Linear classifier that outputs one class prediction per token."""
    def __init__(self, d_model, n_classes):
        super().__init__()
        self.layer_norm = nn.LayerNorm(d_model)
        self.linear = nn.Linear(d_model, n_classes)
    def forward(self, X):
        
        X = self.layer_norm(X)  # (b, n, d_model)
        logits = self.linear(X)  # (b, n, n_classes)
        return logits
    
    
class AvgPoolClassifier(nn.Module):
    """Linear classifier that outputs one class prediction by averaging tokens."""
    def __init__(self, d_model, n_classes):
        super().__init__()
        self.layer_norm = nn.LayerNorm(d_model)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.linear = nn.Linear(d_model, n_classes)
    def forward(self, X):
        X = self.layer_norm(X)  # (b, n, d_model)
        X = X.permute(0, 2, 1)  # (b, d_model, n) to fit AdaptiveMaxPool1d input format
        X = self.pool(X).squeeze(-1)  # (b, d_model)
        logits = self.linear(X)  # (b, n_classes)
        return logits
    

class ConvEmbedding(nn.Module):
    """Embedding for images where each token is a patch of pixels convolved, n_kernels=d_model."""
    def __init__(self, channels, patch_shape, d_model):
        super().__init__()
        self.conv = nn.Conv2d(
            channels, d_model, kernel_size=patch_shape, stride=patch_shape
        )
    def forward(self, images):
        # images shape: (b, c, h, w)
        X = self.conv(images)  # (b, d_model, h/patch_h, w/patch_w);
        X = X.flatten(2)  # (b, d_model, n)
        return X.transpose(1,2)  # (b, n, d_model)

## Transformer Model

Simple Transformer model based on the paper "Attention is all you need".
Consists of some number of repeated, identical blocks. Each block does a
round of attention followed by an MLP with one
hidden layer. The attention in each block is split into multiple "heads"
which allow for different attention queries in parallel at each step.

There is a simple linear embedding layer at the beginning that maps from
the discrete vocabulary of the tokens/features to d_model-dimension vectors of real
numbers. A positional encoding is included to explicitly add spatial
information to the embedding vector.

Finally there is a classification layer that maps from d_model
vectors to whatever is needed for the task at hand (e.g. a logit vector
for classification over a huge set of vocab_size possible classes).

A single transformer model is called an "decoder" if it uses a mask in
training to only learn attention connections from earlier tokens to later
tokens. An "encoder" omits this mask and learns the fullest possible
contextual representation of every token in the input. The decoder
restriction is useful for learning to make predictions on input sequences
shorter than the n_tokens the model was trained on.

In [None]:
# See model.py
class Transformer(nn.Module):
    """
    Simple model consisting of repeated blocks of Attention + MLP.
    in_feature_dim: vector length of each individual token or feature. If using
        discrete features such as words, this will be vocab_size because each
        word can be represented an integer or by a one-hot vector of that length.
        If using image patches, simply the number of pixels in each flattened
        patch.
    out_dim: again, the desired vector length. If classification, just
        the number of classes. If word prediction, the vocab_size so we can
        output a probability over each word in the vocab.
    max_features: the maximum number of features/tokens/patches in the input. 
        This is only used to define a positional encoding that learns to 
        represent spatial positions up to that number. Can be set to a longer
        lenth than the actual intended number of tokens in the input.
    """
        
    def __init__(
        self,
        out_dim,
        max_features,
        d_model=512,
        d_mlp=2048,
        heads_per_block=8,
        num_blocks=6,
        dropout_rate=0.2,
        use_masking=False,
        embedding="linear",
        classifier="per_token",
        vocab_size=None,
        patch_shape=None,
        patch_channels=None,
        fastmax=False,
        fmm_p=2
    ):
        super().__init__()

        # Embedding layers: feature space -> d_model space (X)
        if embedding == "discrete_set":
            self.token_embedding = nn.Embedding(vocab_size, d_model)
        elif embedding == "patch_conv":
            self.token_embedding = ConvEmbedding(patch_channels, patch_shape, d_model)
        else:
            raise ValueError(f"Unknown embedding type: {embedding}")
        self.pos_embedding = nn.Embedding(max_features, d_model)
        
        # Transformer core: X -> X_hat
        self.blocks = nn.Sequential()
        for i in range(num_blocks):
            self.blocks.append(
                TransformerBlock(
                    heads_per_block, d_model, d_mlp, dropout_rate, use_masking, fastmax, fmm_p
                )
            )
        
        # Classifier head: X_hat -> logits
        if classifier == "per_token":
            self.classifier = PerTokenClassifier(d_model, out_dim)
        elif classifier == "avg_pool":
            self.classifier = AvgPoolClassifier(d_model, out_dim)
        else:
            raise ValueError(f"Unknown classifier type: {classifier}")
        
        # Weight initialization
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, tokens):
        # tokens shape: (b, n, d_feature) or (b, n) if discrete vocabulary
        # b: batch size
        # n: num tokens/patches
        # d_feature: dim of each token/patch prior to embedding. Accepts integers for discrete feature sets.
        # d_model: dim of vector representation of each token
        tok_emb = self.token_embedding(tokens)  # (b, n, d_model)
        position_indices = torch.arange(tokens.shape[1], device=tokens.device)  # (n,)
        pos_emb = self.pos_embedding(position_indices)  # (n, d_model)
        X = tok_emb + pos_emb  # (b, n, d_model)
        X = self.blocks(X)  # (b, n, d_model)
        logits = self.classifier(X)  # (b, n, out_dim)
        return logits

## Training loop

Use the Pytorch Lightning library to automate logging and checkpointing.

In [None]:
# See train.py
import time
from datetime import timedelta

import lightning as L
from lightning.pytorch import callbacks

import torch
from torch.utils import data


class LightingWrapper(L.LightningModule):
    """Wraps a nn.Module with an associated loss function and optimizer."""

    def __init__(self, model, loss_fn, optim, lr):
        super().__init__()
        self.model = model
        self.loss_fn = loss_fn
        self.optim = optim
        self.lr = lr

    def forward(self, X):
        return self.model(X)
    
    def configure_optimizers(self):
        return self.optim(self.model.parameters(), lr=self.lr)

    def training_step(self, batch, batch_idx):
        x, y = batch
        # Text modality:
        # x: (b, n) where each token is an integer in [0, n_classes) (vocab_size)
        # y: (b, n) There n labels for n tokens (the predicted next token). Each 
        #    successive label gets a longer preceeding series of tokens to use as its
        #    features due to masking in attention.
        # Image modality:
        # x: (b, n, d_feature) where each token is a flattened image patch
        # y: (b,) There is a single class label for each image.
        # The model handles both input shapes properly
        logits = self.model(x)  # (b, n, n_classes) or (b, n_classes) depending on classification head
        if isinstance(self.model.classifier, PerTokenClassifier):
            logits = logits.view(-1, logits.shape[-1])  # (b*n, n_classes)
        # This handles both output shapes
        y = y.view(-1)  # (b*n) or (b)
        loss = self.loss_fn(logits, y)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        # Identical to training step
        x, y = batch
        logits = self.model(x)
        if isinstance(self.model.classifier, PerTokenClassifier):
            logits = logits.view(-1, logits.shape[-1])
        y = y.view(-1)
        loss = self.loss_fn(logits, y)

        # Compute accuracy
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        return loss
    

def train(
    model,
    train_loader,
    val_loader=None,
    epochs=1,
    loss_fn=torch.nn.CrossEntropyLoss,
    optim=torch.optim.AdamW,
    lr=1e-3,
    name="",
):
    """
    Given a model, dataset, loss function and optimizer, train the model.
    Uses Pytorch Lighting to handle logging, checkpointing, and GPU accelleration.
    """

    wrapped_model = LightingWrapper(model, loss_fn, optim, lr)
    ckpt = callbacks.ModelCheckpoint()
    trainer = L.Trainer(
        max_epochs=epochs,
        val_check_interval=1 / 4,
        accelerator="auto",
        callbacks=[ckpt],
        default_root_dir="logs/" + name,
    )

    # Run the training loop
    start_time = time.time()
    trainer.fit(wrapped_model, train_loader, val_loader)
    end_time = time.time()
    elapsed_time = str(timedelta(seconds=end_time - start_time))
    print(f"Training time: {elapsed_time}")

    return wrapped_model.model

# Data prep

In [None]:
# See data.py
import json
import os
import requests
from torch.utils import data
from torchvision import datasets, transforms

def download_shakespear_data():
    url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    response = requests.get(url)
    text = response.text
    unique_chars = sorted(list(set(text)))
    token_to_int = {ch: i for i, ch in enumerate(unique_chars)}
    os.makedirs("data/shakespeare", exist_ok=True)
    with open("data/shakespeare/input.txt", "w") as f:
        f.write(text)
    with open("data/shakespeare/vocab.json", "w") as f:
        json.dump(token_to_int, f)


def encode(text):
    """create a mapping from unique vocab tokens to integers"""
    # Get vocab data if not already downloaded
    if not os.path.exists("data/input.txt"):
        download_shakespear_data()
    with open("data/vocab.json", "r") as f:
        token_to_int = json.load(f)
    encoded_text = [token_to_int[char] for char in text]
    return torch.tensor(encoded_text, dtype=torch.long)


def decode(tokens):
    """create a mapping from integers back to unique vocab tokens"""
    # Get vocab data if not already downloaded
    if not os.path.exists("data/shakespeare/input.txt"):
        download_shakespear_data()
    with open("data/shakespeare/vocab.json", "r") as f:
        token_to_int = json.load(f)
    int_to_token = {v: k for k, v in token_to_int.items()}
    token_list = [int_to_token[token] for token in tokens]
    return "".join(token_list)


class ShakespeareDataset(data.Dataset):
    def __init__(self, tokens_per_chunk):
        super().__init__()
        # Download data if not already downloaded
        if not os.path.exists("data/shakespeare/input.txt"):
            download_shakespear_data()
        with open("data/shakespeare/input.txt", "r") as f:
            text = f.read()
        with open("data/shakespeare/vocab.json", "r") as f:
            token_to_int = json.load(f)
        self.data = encode(text)
        self.vocab_size = len(token_to_int)
        self.block_size = tokens_per_chunk

    def __len__(self):
        # A single example of this text set is a chunk of characters with length block_size
        return len(self.data) // self.block_size

    def __getitem__(self, i):
        # The corresponding label for each example is a chunk of characters of the same size,
        # but shifted one character to the right. Thus the task is to predict the next character
        # given all of the previous characters in a block. In this sense, the model learns to
        # generate predictions based on with varying amounts of preceding characters, ranging
        # from just a single character to the entire block.
        x = self.data[i : i + self.block_size]
        y = self.data[i + 1 : i + self.block_size + 1]
        return x, y


def get_data(dataset_name, batch_size, n_features=None, train_ratio=0.9):
    """
    Get dataloaders for a given dataset.
    Returns:
        train_loader: a DataLoader for the training set
        val_loader: a DataLoader for the validation set
        feature_dim: the dimensionality of each feature one example from the dataset
        n_classes: the number of classes in the dataset
    Feature_dim and n_classes are required for initializing the Transformer model.
    """
    if dataset_name == "shakespeare":
        n_features = n_features or 32
        dataset = ShakespeareDataset(n_features)
        n = len(dataset)
        train_size = int(train_ratio * n)
        test_size = n - train_size
        train_data, val_data = data.random_split(dataset, [train_size, test_size])
        train_loader = data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
        val_loader = data.DataLoader(val_data, batch_size=batch_size)
    
    elif dataset_name == "mnist":
        train_data = datasets.MNIST(root="data/mnist", train=True, download=True, transform=transforms.ToTensor())
        val_data = datasets.MNIST(root="data/mnist", train=False, download=True, transform=transforms.ToTensor())
        train_loader = data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
        val_loader = data.DataLoader(val_data, batch_size=batch_size)

    else:
        raise ValueError(f"Dataset '{dataset_name}' not supported. Try 'shakespeare' or 'mnist'.")
    return train_loader, val_loader

## Run the experiment

In [None]:
# See run_image_example.py

# Modality-specific parameters (images)
patch_shape = (5, 5) # 6x6 grid of patches 
n_patches_per_image = 25
# Based on 5x5 patches and 28x28 images, we can fit 25 patches per image.
# The core transformer is agnostic of this number, but it learns a conv embedding based on the patch size.
# So the patch shape at training must remain the same at inference, but different size images or different
# conv strides can be used at inference time. Model is O(n^2) with this number.
embedding = "patch_conv" # Learnable conv kernels that map patch pixels to d_model vectors.
classifier = "avg_pool" # Average all n patch vectors into a single vector and do linear classification on that vector.

# Model hyperparameters
max_features = n_patches_per_image # Used for defining pos_encoding. Must be at least n_patches, but could be more.
d_model = 15
d_mlp = 24
heads_per_block = 5
num_blocks = 3
dropout_rate = 0.0

# Training parameters
lr = 1e-3
batch_size = 32
epochs = 2
loss_fn = F.cross_entropy
optim = torch.optim.AdamW

# FMM test
fastmax = True
p = 2


if __name__ == "__main__":
    # Create dataloaders
    dataset_name = 'mnist'
    train_loader, test_loader = get_data(dataset_name, batch_size)

    # MNIST info
    n_classes = 10
    channels = 1

    # Create model
    model = Transformer(
        out_dim=n_classes,
        max_features=max_features,
        d_model=d_model,
        d_mlp=d_mlp,
        heads_per_block=heads_per_block,
        num_blocks=num_blocks,
        dropout_rate=dropout_rate,
        use_masking=False,
        embedding=embedding,
        classifier=classifier,
        patch_shape=patch_shape,
        patch_channels=channels,
        fastmax=fastmax,
        fmm_p=p
    )

    # Train
    model = train(
        model,
        train_loader,
        test_loader,
        epochs=epochs,
        loss_fn=loss_fn,
        optim=optim,
        lr=lr,
        name=dataset_name,
    )