In this notebook, we'll explore a "functional" transformer implementation that can be efficiently executed using Thunder. This approach offers several advantages in terms of optimization and code clarity.

Python's expressivity and flexibility make it an excellent choice for developing for deep learning models. However, these same qualities can sometimes hinder performance optimization and make it challenging to understand or modify existing codebases.
Recently, there has been a growing trend towards developing models in a simpler, more transparent style that facilitates optimization and comprehension. Projects like [LitGPT](https://github.com/Lightning-AI/litgpt) and [nanoGPT](https://github.com/karpathy/nanoGPT) are examples of this trend.
One such style is the "functional" programming style, which is free of side effects and can be easily understood and optimized by both developers and compilers.

We'll cover the following key points:
* The structure and implementation of a functional transformer
* Advantages of this approach compared to traditional implementations
* How Thunder can be applied to optimize and execute this functional transformer

By the end of this notebook, you'll have a clear understanding of how functional programming principles can be leveraged to create more efficient and compiler-friendly transformer models.

**Credit**: The code used in this notebook is adapted from https://gist.github.com/nreHieW/a4ae05d216c5326c9fb9a70fcdda3274 

To transform a PyTorch module into a "functional" Python implementation, we need to restructure how neural network parameters are handled. Instead of relying on class members and `nn.Module`, we'll pass the module's parameters explicitly as function inputs. This approach enhances transparency and makes the code more amenable to optimization.
To maintain clean and organized code, we'll use named tuples to group related parameters together.
Let's examine the helper named tuples we'll use to organize our transformer's parameters:

In [2]:
import torch
from typing import NamedTuple

# Helper classes that group the parameters together
class LayerWeights(NamedTuple):
    input_norm: torch.Tensor  # (hidden)
    post_attn_norm: torch.Tensor  # (hidden)
    q_proj: torch.Tensor  # (hidden, q_intermediate)
    k_proj: torch.Tensor  # (hidden, kv_intermediate)
    v_proj: torch.Tensor  # (hidden, kv_intermediate)
    o_proj: torch.Tensor  # (q_intermediate, hidden)
    gate_proj: torch.Tensor  # (hidden, intermediate)
    up_proj: torch.Tensor  # (hidden, intermediate)
    down_proj: torch.Tensor  # (intermediate, hidden)


class TransformerWeights(NamedTuple):
    layers: list[LayerWeights]
    token_emb: torch.Tensor  # (vocab_size, hidden)
    final_norm: torch.Tensor  # (hidden)
    lm_head: torch.Tensor  # (hidden, vocab_size)

Next, we'll write all the layers used in the transformer using functional implementations. This process involves converting each layer from a PyTorch module into a pure function (a function with no side effects) that takes both the input data and the relevant parameters as explicit arguments. 
In the following sections, we'll walk through the functional implementations of key transformer components, including:
* Layer normalization
* Feed-forward network
* Embedding layer
* Attention mechanism

Each function will clearly define its inputs, including both the data to be processed and the necessary parameters. This approach will provide a comprehensive view of how data flows through the transformer architecture.

In [3]:
from torch.nn.functional import silu, softmax

NUM_Q_HEADS = 32  # Llama numbers
NUM_KV_HEADS = 8  # Llama numbers
SLIDING_WINDOW_SIZE = 4096

# Layer normalization
def norm(x: torch.Tensor, weight: torch.Tensor):
    in_dtype = x.dtype
    compute_dtype = torch.float32
    x = x.to(compute_dtype)
    eps = 1e-5  # eps might change depending on the model
    out = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
    return weight * out.to(in_dtype)


# Feed-forward network
def ffn(x: torch.Tensor, weights: LayerWeights):
    gate = silu(x @ weights.gate_proj)
    fused = gate * (x @ weights.up_proj)
    return fused @ weights.down_proj


# Rotary Positional Encoding
def rope(x: torch.Tensor, freqs_cis: torch.Tensor):
    def rotate(x):
        """
        rotate_half(torch.arange(4))
        > tensor([-2, -3,  0,  1])
        """
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    cos, sin = freqs_cis
    cos = cos.to(x.dtype)
    sin = sin.to(x.dtype)
    right = rotate(x.reshape(*x.shape[:-1], -1, 2)).reshape(x.shape)
    out = x * cos + right * sin
    return out.to(x.dtype)


# Attention
def attn(
    x: torch.Tensor,
    weights: LayerWeights,
    freqs_cis: tuple,
    sliding_window_size=None,
):
    bs, seq_len, d_model = x.shape
    xq, xk, xv = x @ weights.q_proj, x @ weights.k_proj, x @ weights.v_proj
    xq = xq.view(bs, seq_len, NUM_Q_HEADS, -1).transpose(1, 2)  # (bs, NUM_Q_HEADS, seq_len, q_intermediate)
    xk = xk.view(bs, seq_len, NUM_KV_HEADS, -1).transpose(1, 2)  # (bs, NUM_KV_HEADS, seq_len, kv_intermediate)
    xv = xv.view(bs, seq_len, NUM_KV_HEADS, -1).transpose(1, 2)  # (bs, NUM_KV_HEADS, seq_len, kv_intermediate)
    head_dim = xq.shape[-1]

    # Treat GQA as MHA and just repeat along the head dimension
    xk = torch.repeat_interleave(xk, NUM_Q_HEADS // NUM_KV_HEADS, dim=1)
    xv = torch.repeat_interleave(xv, NUM_Q_HEADS // NUM_KV_HEADS, dim=1)
    xq = rope(xq, freqs_cis)
    xk = rope(xk, freqs_cis)

    attn_scores = (xq @ xk.transpose(2, 3)) * (head_dim**-0.5)
    mask = torch.triu(torch.full((bs, seq_len, seq_len), -2.3819763e38), diagonal=1)  # This number is taken from Gemma
    if sliding_window_size is not None:  # Sliding window attention
        all_ones = torch.ones((seq_len, seq_len))
        sliding_mask = torch.triu(all_ones, -1 * sliding_window_size + 1) * torch.tril(all_ones, sliding_window_size - 1)
        mask = torch.where(sliding_mask == 1, mask, -2.3819763e38)
    mask = mask.to(x.device, x.dtype)
    attn_scores = attn_scores + mask
    attn_probs = softmax(attn_scores, dim=-1)
    attn_out = attn_probs @ xv
    attn_out = attn_out.transpose(1, 2).contiguous().view(bs, seq_len, -1)
    return attn_out @ weights.o_proj


# for efficiency, should precompute for 0..max_length * 2 then select [:curr_length]
def precompute_freqs_cis(head_dim: int, seq_len: int, base_theta: float = 500000.0):
    inv_freqs = 1.0 / (base_theta ** (torch.arange(0, head_dim, 2).float() / head_dim))  # Eq 15: theta_{1} ... theta_{dim/2}. Shape: (dim/2)
    m = torch.arange(seq_len)  # all possible position indices
    freqs = torch.outer(m, inv_freqs).float()  # [m_i * theta_j] for all i (positions) and j (frequencies). Shape: (seq_len, dim/2) | freqs[i][j] == m[i] * inv_freqs[j]
    cos = torch.cos(freqs)  # Shape: (seq_len, dim/2)
    cos = torch.repeat_interleave(cos, 2, dim=-1)  # Shape: (seq_len, dim)
    sin = torch.sin(freqs)  # Shape: (seq_len, dim/2)
    sin = torch.repeat_interleave(sin, 2, dim=-1)  # Shape: (seq_len, dim)
    return (cos, sin)


def transformer(in_tokens: torch.Tensor, weights: TransformerWeights):
    x = torch.nn.functional.embedding(in_tokens, weights.token_emb)
    b, t, d = x.shape
    q_intermediate = weights.layers[0].q_proj.shape[1]
    freqs_cis = precompute_freqs_cis(q_intermediate // NUM_Q_HEADS, t)  # (cos, sin)
    for i, layer in enumerate(weights.layers):
        residual = x
        hidden = norm(x, layer.input_norm)
        hidden = attn(hidden, layer, freqs_cis, sliding_window_size=SLIDING_WINDOW_SIZE if i % 6 != 0 else None)  # Follows https://research.character.ai/optimizing-inference/
        hidden = residual + hidden

        residual = hidden
        hidden = norm(hidden, layer.post_attn_norm)
        hidden = ffn(hidden, layer)
        hidden = residual + hidden
        x = hidden

    x = norm(x, weights.final_norm)
    x = x @ weights.lm_head
    return x

Once the functional versions of the layers and transformer are ready, we'll load the weights for each layer into our `LayerWeights` and `TransformerWeights` container classes. These classes will store the parameters so that they can be easily passed as inputs to the functional transformer.

**NOTE**: To run the cells below, you'll need access to the Hugging Face Meta-Llama-3-8B model. Be sure to download the model weights and place them in the "Meta-Llama-3-8B/consolidated.00.pth". See [here](https://huggingface.co/meta-llama/Meta-Llama-3-8B) to learn more about Hugging Face Meta-Llama-3-8B.

In [4]:
from transformers import AutoTokenizer
import thunder

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
# Download the official repo weights
state_dict = torch.load("Meta-Llama-3-8B/consolidated.00.pth", map_location="cuda")
layers = []
n_layers = 32
for i in range(n_layers):
    layer = LayerWeights(
        input_norm=state_dict[f"layers.{i}.attention_norm.weight"],
        post_attn_norm=state_dict[f"layers.{i}.ffn_norm.weight"],
        q_proj=state_dict[f"layers.{i}.attention.wq.weight"].t(),
        k_proj=state_dict[f"layers.{i}.attention.wk.weight"].t(),
        v_proj=state_dict[f"layers.{i}.attention.wv.weight"].t(),
        o_proj=state_dict[f"layers.{i}.attention.wo.weight"].t(),
        gate_proj=state_dict[f"layers.{i}.feed_forward.w1.weight"].t(),
        up_proj=state_dict[f"layers.{i}.feed_forward.w3.weight"].t(),
        down_proj=state_dict[f"layers.{i}.feed_forward.w2.weight"].t(),
    )
    layers.append(layer)

weights = TransformerWeights(
    layers=layers,
    token_emb=state_dict["tok_embeddings.weight"],
    final_norm=state_dict["norm.weight"],
    lm_head=state_dict["output.weight"].t(),
)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
  state_dict = torch.load("Meta-Llama-3-8B/consolidated.00.pth", map_location="cuda")


Finally, we’ll use Thunder to execute the "functional" transformer and observe the results. 

In [5]:
prompt = "the answer to the ultimate question of life "
in_tokens = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda")

# Use thunder on the "functional" transformer
jitted_transformer = thunder.jit(transformer)
for _ in range(10):
    out = jitted_transformer(in_tokens, weights)
    next_token = torch.argmax(out[:, -1, :])
    in_tokens = torch.cat((in_tokens, next_token.unsqueeze(0).unsqueeze(0)), dim=1)

del weights
del state_dict

print("Ours:", tokenizer.decode(in_tokens[0].tolist()))

Ours: <|begin_of_text|>the answer to the ultimate question of life 42
the answer to the ultimate question of life


And that's it! Converting a PyTorch module into a functional Python function is easier than you might think. Moreover, Thunder seamlessly operates with this functional version of the transformer.