## Implementing Llama2 3B


In [3]:
import torch
import torch.nn as nn

### **Normalization**

One of the main divergences from the original GPT transformer architecture is the normalization technique. The purpose of normalization is to recenter the model weights and inputs to help stablize training and boost model convergence.

Traditional **LayerNorm** normalizes the layers using the mean of variance across the feature dimension. There are three hyperparameters $\epsilon$, setting a floor for the fraction denominator, $\gamma$ a learnable scaling parameter and $\beta$ applying a learnable shift parameter.
$$y=\frac{x-E[x]}{\sqrt{Var(x) + \epsilon}} * \gamma + \beta$$

In the LLama2 architecture, **RMSNorm** is used (root mean square normalization). The main benefit of using RMSNorm is that it is more efficient than LayerNorm and it's performance decreases are neglible in practice. Notably, the hyperparameter $\beta$ is not used in RMSNorm

$$y_i=\frac{x_i}{\sqrt{\epsilon +\frac{1}{n}\sum{x_i^2}}}*\gamma_i$$


In [21]:
class RMSNorm(nn.Module):
    def __init__(self, embd_dim:int, eps=1e-5):
        super().__init__()
        self.eps = eps # epsilon
        self.embd_dim = embd_dim
        self.weight = nn.Parameter(torch.ones(embd_dim)).float()
    
    def forward(self, x:torch.Tensor):
        means = x.pow(2).mean(dim=-1, keepdim=True)
        x_norm = x * torch.rsqrt(means + self.eps)
        return (x_norm * self.weight).to(dtype=x.dtype)

##### Forward Pass Step Through


In [22]:
inputs = torch.rand((4, 8))
rmsn = RMSNorm(embd_dim=8)
inputs

tensor([[0.0189, 0.5897, 0.3435, 0.6396, 0.5835, 0.6628, 0.8831, 0.2713],
        [0.7535, 0.1432, 0.0761, 0.6367, 0.8376, 0.4416, 0.8005, 0.3437],
        [0.9414, 0.9058, 0.2968, 0.6923, 0.4357, 0.1979, 0.1784, 0.7046],
        [0.1627, 0.2033, 0.6560, 0.0777, 0.2667, 0.8596, 0.5893, 0.6754]])

In [23]:
# Current unnormalized sum
inputs.sum(dim=-1)

tensor([3.9924, 4.0329, 4.3529, 3.4907])

In [24]:
# Take the mean along the feature dimension
inputs_mean = inputs.pow(2).mean(dim=1, keepdim=True)
inputs_mean

tensor([[0.3135],
        [0.3319],
        [0.3789],
        [0.2647]])

In [25]:
# Add epsilon and take sqrt
torch.sqrt(inputs_mean + rmsn.eps)

tensor([[0.5600],
        [0.5761],
        [0.6156],
        [0.5145]])

In [26]:
# But since we have x / RMS(x) lets take the reciprocal root
torch.rsqrt(inputs_mean + rmsn.eps)

tensor([[1.7858],
        [1.7359],
        [1.6245],
        [1.9436]])

In [27]:
# now multiply by the numerator x (inputs in this case) 
inputs_norm = inputs * torch.rsqrt(inputs_mean + rmsn.eps)
inputs_norm

tensor([[0.0337, 1.0531, 0.6135, 1.1422, 1.0420, 1.1836, 1.5770, 0.4846],
        [1.3081, 0.2486, 0.1322, 1.1052, 1.4539, 0.7665, 1.3895, 0.5966],
        [1.5292, 1.4715, 0.4822, 1.1246, 0.7078, 0.3215, 0.2898, 1.1447],
        [0.3161, 0.3952, 1.2751, 0.1509, 0.5183, 1.6706, 1.1454, 1.3128]])

In [28]:
# Finally, multiply by gamma_i, which are the learnable weights
inputs_norm * rmsn.weight

tensor([[0.0337, 1.0531, 0.6135, 1.1422, 1.0420, 1.1836, 1.5770, 0.4846],
        [1.3081, 0.2486, 0.1322, 1.1052, 1.4539, 0.7665, 1.3895, 0.5966],
        [1.5292, 1.4715, 0.4822, 1.1246, 0.7078, 0.3215, 0.2898, 1.1447],
        [0.3161, 0.3952, 1.2751, 0.1509, 0.5183, 1.6706, 1.1454, 1.3128]],
       grad_fn=<MulBackward0>)

In [29]:
rms = RMSNorm(inputs.size(-1))
rms.forward(inputs)

tensor([[0.0337, 1.0531, 0.6135, 1.1422, 1.0420, 1.1836, 1.5770, 0.4846],
        [1.3081, 0.2486, 0.1322, 1.1052, 1.4539, 0.7665, 1.3895, 0.5966],
        [1.5292, 1.4715, 0.4822, 1.1246, 0.7078, 0.3215, 0.2898, 1.1447],
        [0.3161, 0.3952, 1.2751, 0.1509, 0.5183, 1.6706, 1.1454, 1.3128]],
       grad_fn=<MulBackward0>)

### **Activation Functions**

Activation functions are non-linear functions connection linear layers in a neural network. Without the non-linear connection neural networks would only learn linear relationships. For example, if I have two matrices - without the activation function of course - and multiply them together, the output matrix would be equivalent to the first two layers, meaning the two layers would collapse into one.

There are a variety of activation functions, a simple one being ReLU which essentially clips any negative number to 0, leaving the positive numbers be. This became popular as it helps to circumvent the _vanishing gradients_ issue in NN weights. Recently modern activation functions like GeLU and SiLU, which have smoother approximations, perform better than ReLU. The smoothness in these modern activtation functions allow for more nuanced learning since any negative number isn't immediately cuttoff to 0, like ReLU.

Llama 2 uses **SwiLU** or a Gate linear unit variant of sigmoid-weighted linear units.

**SiLU**

$$silu(x) = x * \sigma(x)$$

**SwiGLU**

$$SwiGLU(x) = SiLU(Linear_1(x)) * Linear_2(x)$$

Using PyTorch, SiLU is simply implemented below


In [19]:
nn.SiLU()

SiLU()

### **Feed Forward Layer**

The feed forward layer is where the network gets to "think about" each token individually.

In the feed forward network, Llama 2 uses **SwiLU** or a Gate linear unit variant of sigmoid-weighted linear units.

**Recall SiLU is**:

$$silu(x) = x * \sigma(x)$$

**SwiGLU**:

$$SwiGLU(x) = SiLU(Linear_1(x)) * Linear_2(x)$$


In [50]:
class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 3 Linear layers, one silu in between the two
        self.fc1 = nn.Linear(config.emb_dim, config.hidden_dim, dtype=config.dtype, bias=config.ff_bias)
        self.fc2 = nn.Linear(config.emb_dim, config.hidden_dim, dtype=config.dtype, bias=config.ff_bias)
        self.fc3 = nn.Linear(config.hidden_dim, config.emb_dim, dtype=config.dtype, bias=config.ff_bias)
        self.silu = nn.SiLU()

    def forward(self, x:torch.Tensor):
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)
        x = self.silu(x_fc1) * x_fc2
        return self.fc3(x)

##### Forward Pass Step Through


In [51]:
from dataclasses import dataclass

@dataclass
class ExConfig:
    emb_dim = 8
    hidden_dim = 4
    ff_bias = False
    dtype = torch.float32

In [52]:
inputs = torch.stack([torch.rand(4, 8), torch.rand(4, 8)], dim=0)
inputs

tensor([[[0.3603, 0.4199, 0.4171, 0.0011, 0.1867, 0.8628, 0.3320, 0.7426],
         [0.8094, 0.2329, 0.5079, 0.1834, 0.5775, 0.2740, 0.1080, 0.4881],
         [0.5460, 0.8762, 0.4189, 0.0390, 0.4255, 0.4692, 0.5556, 0.6423],
         [0.6320, 0.7746, 0.2760, 0.3436, 0.3357, 0.7936, 0.7033, 0.2979]],

        [[0.1314, 0.5537, 0.2424, 0.6685, 0.4951, 0.2792, 0.5484, 0.1634],
         [0.8764, 0.6410, 0.5488, 0.4683, 0.3687, 0.1157, 0.7247, 0.8691],
         [0.4152, 0.2310, 0.1379, 0.5891, 0.0307, 0.7268, 0.1730, 0.1222],
         [0.8189, 0.2870, 0.1445, 0.3491, 0.0457, 0.6918, 0.5878, 0.2607]]])

In [53]:
ff = FeedForward(ExConfig())

In [54]:
fc1 = ff.fc1(inputs)
fc1

tensor([[[ 0.1568,  0.1593, -0.3864, -0.3850],
         [-0.0262,  0.1807, -0.5945, -0.0604],
         [ 0.1912,  0.4057, -0.4512, -0.3500],
         [ 0.2077,  0.2110, -0.4929, -0.3798]],

        [[ 0.1745,  0.1748, -0.3208, -0.1518],
         [-0.1126,  0.3203, -0.5189, -0.1164],
         [ 0.0739, -0.1579, -0.3674, -0.1471],
         [-0.0506, -0.0101, -0.4048, -0.2046]]], grad_fn=<UnsafeViewBackward0>)

In [55]:
fc2 = ff.fc2(inputs)
fc2

tensor([[[ 0.3318,  0.2992,  0.3246,  0.3024],
         [ 0.2211,  0.5143,  0.1726, -0.0024],
         [ 0.3801,  0.1504,  0.2583,  0.3293],
         [ 0.3755,  0.0780,  0.2880,  0.3753]],

        [[ 0.2400,  0.1438,  0.0485,  0.0684],
         [ 0.1901,  0.3467,  0.3630,  0.0854],
         [ 0.0382,  0.2343,  0.2841,  0.0519],
         [ 0.1691,  0.1246,  0.3673,  0.2421]]], grad_fn=<UnsafeViewBackward0>)

In [56]:
# notice how it balanced out the weights.
silu = ff.silu(fc1)
silu

tensor([[[ 0.0846,  0.0860, -0.1563, -0.1559],
         [-0.0129,  0.0985, -0.2114, -0.0293],
         [ 0.1047,  0.2435, -0.1756, -0.1447],
         [ 0.1146,  0.1166, -0.1869, -0.1543]],

        [[ 0.0949,  0.0950, -0.1349, -0.0702],
         [-0.0531,  0.1856, -0.1936, -0.0548],
         [ 0.0383, -0.0727, -0.1503, -0.0681],
         [-0.0246, -0.0050, -0.1620, -0.0919]]], grad_fn=<SiluBackward0>)

In [59]:
x = silu * fc2
out = ff.fc3(x)
out

tensor([[[ 0.0045,  0.0264, -0.0309,  0.0294,  0.0157,  0.0001,  0.0221,
           0.0325],
         [ 0.0133,  0.0217,  0.0090,  0.0138, -0.0172, -0.0161, -0.0130,
           0.0224],
         [ 0.0051,  0.0298, -0.0281,  0.0213,  0.0102, -0.0064,  0.0184,
           0.0349],
         [ 0.0036,  0.0275, -0.0467,  0.0308,  0.0291,  0.0086,  0.0343,
           0.0306]],

        [[ 0.0051,  0.0108, -0.0048, -0.0053, -0.0021, -0.0055, -0.0004,
           0.0061],
         [ 0.0217,  0.0338,  0.0034,  0.0335, -0.0156, -0.0151, -0.0114,
           0.0341],
         [ 0.0126,  0.0117, -0.0234,  0.0231,  0.0189,  0.0168,  0.0133,
           0.0017],
         [ 0.0102,  0.0176, -0.0271,  0.0392,  0.0202,  0.0138,  0.0187,
           0.0171]]], grad_fn=<UnsafeViewBackward0>)

In [61]:
torch.allclose(ff.forward(inputs), out)

True

### **Posistional Encodings**

Another divergence from the original GPT architecture is in the positional encoing approach.
There are two main token positions to contend with, _absolute_ and _relative_.

- _Absolute_ position is the token of interest's position within the whole sequence of tokens

- _Relative_ position is the token of interest's postition compared to another token within the sequence.

**Sinusoidal Positional Encoding**

- The original tranformer architecture used sinusoidal posistional encoding, which creates frequencies vector, using sine and cosine, to communicate a token's position.
- The downside of this approach is encodings start to break down when the model encounters sequence lengths longer than those it was trained on.

**Rotary Positional Embeddings (RoPE)**

- RoPE is a modern approach to positional encoding. RoPE applies rotations to represent a tokens positional embedding. Unlike sinusoidal positional encoding, RoPE can better extend to longer contexts not seen in training because of it's rotational properties.


In [62]:
def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096):
    assert head_dim % 2 == 0, "Embedding dimension must be even"

    # Compute the inverse frequencies
    inv_freq = 1.0 / ( theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))

    # Generate position indices
    positions = torch.arange(context_length)

    # Compute the angles
    angles = positions[:, None] * inv_freq[None, :]  # Shape: (context_length, head_dim // 2)

    # Expand angles to match the head_dim
    angles = torch.cat([angles, angles], dim=1)  # Shape: (context_length, head_dim)

    # Precompute sine and cosine
    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos, sin

def compute_rope(x, cos, sin):
    # x: (batch_size, num_heads, seq_len, head_dim)
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0, "Head dimension must be even"

    # Split x into first half and second half
    x1 = x[..., : head_dim // 2]  # First half
    x2 = x[..., head_dim // 2 :]  # Second half

    # Adjust sin and cos shapes
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq_len, head_dim)
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)

    # Apply the rotary transformation
    rotated = torch.cat((-x2, x1), dim=-1)
    x_rotated = (x * cos) + (rotated * sin)

    return x_rotated.to(dtype=x.dtype)

##### Steping Through the Functions


In [64]:
# Define the settings
batch_size, context_len, num_heads, head_dim, theta_base = 2, 5, 4, 16, 10_000

Let's explore inverse frequencies


In [65]:
# Lets first make sure that head dim is an event number
assert head_dim % 2 == 0, "Head dimension needs to be even"
# Create a tensor from 0 -> head_dim counting by 2
torch.arange(0, head_dim, 2)

tensor([ 0,  2,  4,  6,  8, 10, 12, 14])

In [75]:
# convert values into floats, and divide by head_dim to normalize frequencies
torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim

tensor([0.0000, 0.1250, 0.2500, 0.3750, 0.5000, 0.6250, 0.7500, 0.8750])

In [78]:
# Multiply by theta_base to control how quickly the frequencies decay
theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim)

tensor([1.0000e+00, 3.1623e+00, 1.0000e+01, 3.1623e+01, 1.0000e+02, 3.1623e+02,
        1.0000e+03, 3.1623e+03])

In [80]:
# The whole thing put together
inv_freq = 1.0 / ( theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))
inv_freq

tensor([1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
        1.0000e-03, 3.1623e-04])

Once we have the inverse frequencies, lets generate position indicies. This is a simple list of numbers from 0 - `context_length`


In [82]:
torch.arange(context_len)

tensor([0, 1, 2, 3, 4])

### **Multi-Headed Attention**
