## Implementing Llama2 3B


In [3]:
import torch
import torch.nn as nn
from dataclasses import dataclass

### **Normalization**

One of the main divergences from the original GPT transformer architecture is the normalization technique. The purpose of normalization is to recenter the model weights and inputs to help stablize training and boost model convergence.

Traditional **LayerNorm** normalizes the layers using the mean of variance across the feature dimension. There are three hyperparameters $\epsilon$, setting a floor for the fraction denominator, $\gamma$ a learnable scaling parameter and $\beta$ applying a learnable shift parameter.
$$y=\frac{x-E[x]}{\sqrt{Var(x) + \epsilon}} * \gamma + \beta$$

In the LLama2 architecture, **RMSNorm** is used (root mean square normalization). The main benefit of using RMSNorm is that it is more efficient than LayerNorm and it's performance decreases are neglible in practice. Notably, the hyperparameter $\beta$ is not used in RMSNorm

$$y_i=\frac{x_i}{\sqrt{\epsilon +\frac{1}{n}\sum{x_i^2}}}*\gamma_i$$


In [21]:
class RMSNorm(nn.Module):
    def __init__(self, embd_dim:int, eps=1e-5):
        super().__init__()
        self.eps = eps # epsilon
        self.embd_dim = embd_dim
        self.weight = nn.Parameter(torch.ones(embd_dim)).float()
    
    def forward(self, x:torch.Tensor):
        means = x.pow(2).mean(dim=-1, keepdim=True)
        x_norm = x * torch.rsqrt(means + self.eps)
        return (x_norm * self.weight).to(dtype=x.dtype)

##### Forward Pass Step Through


In [22]:
inputs = torch.rand((4, 8))
rmsn = RMSNorm(embd_dim=8)
inputs

tensor([[0.0189, 0.5897, 0.3435, 0.6396, 0.5835, 0.6628, 0.8831, 0.2713],
        [0.7535, 0.1432, 0.0761, 0.6367, 0.8376, 0.4416, 0.8005, 0.3437],
        [0.9414, 0.9058, 0.2968, 0.6923, 0.4357, 0.1979, 0.1784, 0.7046],
        [0.1627, 0.2033, 0.6560, 0.0777, 0.2667, 0.8596, 0.5893, 0.6754]])

In [23]:
# Current unnormalized sum
inputs.sum(dim=-1)

tensor([3.9924, 4.0329, 4.3529, 3.4907])

In [24]:
# Take the mean along the feature dimension
inputs_mean = inputs.pow(2).mean(dim=1, keepdim=True)
inputs_mean

tensor([[0.3135],
        [0.3319],
        [0.3789],
        [0.2647]])

In [25]:
# Add epsilon and take sqrt
torch.sqrt(inputs_mean + rmsn.eps)

tensor([[0.5600],
        [0.5761],
        [0.6156],
        [0.5145]])

In [26]:
# But since we have x / RMS(x) lets take the reciprocal root
torch.rsqrt(inputs_mean + rmsn.eps)

tensor([[1.7858],
        [1.7359],
        [1.6245],
        [1.9436]])

In [27]:
# now multiply by the numerator x (inputs in this case) 
inputs_norm = inputs * torch.rsqrt(inputs_mean + rmsn.eps)
inputs_norm

tensor([[0.0337, 1.0531, 0.6135, 1.1422, 1.0420, 1.1836, 1.5770, 0.4846],
        [1.3081, 0.2486, 0.1322, 1.1052, 1.4539, 0.7665, 1.3895, 0.5966],
        [1.5292, 1.4715, 0.4822, 1.1246, 0.7078, 0.3215, 0.2898, 1.1447],
        [0.3161, 0.3952, 1.2751, 0.1509, 0.5183, 1.6706, 1.1454, 1.3128]])

In [28]:
# Finally, multiply by gamma_i, which are the learnable weights
inputs_norm * rmsn.weight

tensor([[0.0337, 1.0531, 0.6135, 1.1422, 1.0420, 1.1836, 1.5770, 0.4846],
        [1.3081, 0.2486, 0.1322, 1.1052, 1.4539, 0.7665, 1.3895, 0.5966],
        [1.5292, 1.4715, 0.4822, 1.1246, 0.7078, 0.3215, 0.2898, 1.1447],
        [0.3161, 0.3952, 1.2751, 0.1509, 0.5183, 1.6706, 1.1454, 1.3128]],
       grad_fn=<MulBackward0>)

In [29]:
rms = RMSNorm(inputs.size(-1))
rms.forward(inputs)

tensor([[0.0337, 1.0531, 0.6135, 1.1422, 1.0420, 1.1836, 1.5770, 0.4846],
        [1.3081, 0.2486, 0.1322, 1.1052, 1.4539, 0.7665, 1.3895, 0.5966],
        [1.5292, 1.4715, 0.4822, 1.1246, 0.7078, 0.3215, 0.2898, 1.1447],
        [0.3161, 0.3952, 1.2751, 0.1509, 0.5183, 1.6706, 1.1454, 1.3128]],
       grad_fn=<MulBackward0>)

### **Activation Functions**

Activation functions are non-linear functions connection linear layers in a neural network. Without the non-linear connection neural networks would only learn linear relationships. For example, if I have two matrices - without the activation function of course - and multiply them together, the output matrix would be equivalent to the first two layers, meaning the two layers would collapse into one.

There are a variety of activation functions, a simple one being ReLU which essentially clips any negative number to 0, leaving the positive numbers be. This became popular as it helps to circumvent the _vanishing gradients_ issue in NN weights. Recently modern activation functions like GeLU and SiLU, which have smoother approximations, perform better than ReLU. The smoothness in these modern activtation functions allow for more nuanced learning since any negative number isn't immediately cuttoff to 0, like ReLU.

Llama 2 uses **SwiLU** or a Gate linear unit variant of sigmoid-weighted linear units.

**SiLU**

$$silu(x) = x * \sigma(x)$$

**SwiGLU**

$$SwiGLU(x) = SiLU(Linear_1(x)) * Linear_2(x)$$

Using PyTorch, SiLU is simply implemented below


In [19]:
nn.SiLU()

SiLU()

### **Feed Forward Layer**

The feed forward layer is where the network gets to "think about" each token individually.

In the feed forward network, Llama 2 uses **SwiLU** or a Gate linear unit variant of sigmoid-weighted linear units.

**Recall SiLU is**:

$$silu(x) = x * \sigma(x)$$

**SwiGLU**:

$$SwiGLU(x) = SiLU(Linear_1(x)) * Linear_2(x)$$


In [199]:
class FeedForward(nn.Module):
    def __init__(self, config:dataclass):
        super().__init__()
        # 3 Linear layers, one silu in between the two
        self.fc1 = nn.Linear(config.emb_dim, config.hidden_dim, dtype=config.dtype, bias=config.ff_bias)
        self.fc2 = nn.Linear(config.emb_dim, config.hidden_dim, dtype=config.dtype, bias=config.ff_bias)
        self.fc3 = nn.Linear(config.hidden_dim, config.emb_dim, dtype=config.dtype, bias=config.ff_bias)
        self.silu = nn.SiLU()

    def forward(self, x:torch.Tensor):
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)
        x = self.silu(x_fc1) * x_fc2
        return self.fc3(x)

##### Forward Pass Step Through


In [200]:
from dataclasses import dataclass

@dataclass
class ExConfig:
    emb_dim = 8
    hidden_dim = 4
    ff_bias = False
    dtype = torch.float32

In [201]:
inputs = torch.stack([torch.rand(4, 8), torch.rand(4, 8)], dim=0)
inputs

tensor([[[0.8562, 0.8171, 0.5510, 0.7538, 0.5183, 0.2659, 0.6102, 0.2956],
         [0.1618, 0.9196, 0.5412, 0.0672, 0.5936, 0.3335, 0.2901, 0.7050],
         [0.3917, 0.9414, 0.8767, 0.2810, 0.1935, 0.4370, 0.1713, 0.2047],
         [0.7780, 0.0842, 0.0210, 0.3345, 0.3533, 0.9742, 0.6587, 0.1125]],

        [[0.2863, 0.3911, 0.6510, 0.6795, 0.3936, 0.4174, 0.6731, 0.4670],
         [0.1124, 0.7038, 0.6871, 0.1183, 0.8075, 0.6196, 0.8748, 0.1718],
         [0.8527, 0.1414, 0.2472, 0.8804, 0.4532, 0.4600, 0.6645, 0.3190],
         [0.2205, 0.9183, 0.0903, 0.6148, 0.2728, 0.5989, 0.1556, 0.6848]]])

In [202]:
ff = FeedForward(ExConfig())

In [203]:
fc1 = ff.fc1(inputs)
fc1

tensor([[[ 0.4497,  0.1524, -0.1136, -0.2573],
         [ 0.3189,  0.1773, -0.2373, -0.0403],
         [ 0.2861,  0.2566,  0.0712, -0.0957],
         [ 0.6609, -0.0991, -0.0587,  0.0728]],

        [[ 0.1569,  0.3293,  0.0407, -0.0748],
         [ 0.7395,  0.1013, -0.0921,  0.2568],
         [ 0.3133,  0.1137, -0.0527, -0.2240],
         [ 0.1214,  0.1897, -0.1192, -0.0503]]], grad_fn=<UnsafeViewBackward0>)

In [204]:
fc2 = ff.fc2(inputs)
fc2

tensor([[[-0.0040, -0.2185,  0.1382,  0.0865],
         [-0.2930, -0.0662,  0.6157,  0.1834],
         [-0.2536,  0.1811,  0.4009,  0.2920],
         [ 0.4699, -0.2878,  0.2240, -0.1464]],

        [[ 0.1701, -0.2044,  0.3307, -0.0213],
         [ 0.0190, -0.2882,  0.6598, -0.1842],
         [ 0.3800, -0.3569, -0.0411, -0.0663],
         [ 0.0429, -0.1022,  0.4397,  0.0784]]], grad_fn=<UnsafeViewBackward0>)

In [205]:
# notice how it balanced out the weights.
silu = ff.silu(fc1)
silu

tensor([[[ 0.2746,  0.0820, -0.0536, -0.1122],
         [ 0.1847,  0.0965, -0.1046, -0.0197],
         [ 0.1634,  0.1447,  0.0369, -0.0455],
         [ 0.4359, -0.0471, -0.0285,  0.0377]],

        [[ 0.0846,  0.1915,  0.0207, -0.0360],
         [ 0.5005,  0.0532, -0.0439,  0.1448],
         [ 0.1810,  0.0601, -0.0257, -0.0995],
         [ 0.0644,  0.1038, -0.0561, -0.0245]]], grad_fn=<SiluBackward0>)

In [206]:
x = silu * fc2
out = ff.fc3(x)
out

tensor([[[-5.8990e-04, -4.8279e-03, -2.4181e-03, -7.2122e-03,  7.1769e-03,
           7.0435e-03,  2.7941e-03,  3.6446e-04],
         [-5.1965e-03, -7.8260e-03,  3.6779e-02,  8.3208e-03,  2.6413e-02,
          -1.2290e-02, -1.1195e-03,  1.5196e-02],
         [ 1.2715e-02,  3.9423e-02,  1.0529e-02,  2.8812e-02, -4.8914e-04,
          -1.6098e-02, -4.5580e-03, -4.0795e-03],
         [-2.4614e-02, -7.8802e-02, -6.5532e-02, -7.9430e-02, -7.8179e-02,
           4.4514e-02, -2.0117e-02,  4.5216e-03]],

        [[-3.5914e-03, -1.7336e-02, -1.3322e-02, -1.8574e-02,  1.1708e-02,
           1.6750e-02,  1.0326e-02, -3.2872e-03],
         [-2.2824e-03, -9.9773e-03, -2.4735e-03, -1.4360e-02,  1.1376e-03,
           1.0567e-02, -3.2712e-03,  4.9115e-03],
         [-1.1371e-02, -3.8468e-02, -2.5447e-02, -3.5402e-02, -1.3636e-02,
           2.2543e-02,  1.8344e-03,  1.2647e-04],
         [-5.4201e-03, -1.5272e-02,  4.7404e-03, -9.2529e-03,  5.3722e-03,
           3.6943e-03, -1.5666e-05,  5.7114e-03]

In [207]:
torch.allclose(ff.forward(inputs), out)

True

### **Posistional Encodings**

Another divergence from the original GPT architecture is in the positional encoing approach.
There are two main token positions to contend with, _absolute_ and _relative_.

- _Absolute_ position is the token of interest's position within the whole sequence of tokens

- _Relative_ position is the token of interest's postition compared to another token within the sequence.

**Sinusoidal Positional Encoding**

- The original tranformer architecture used sinusoidal posistional encoding, which creates frequencies vector, using sine and cosine, to communicate a token's position.
- The downside of this approach is encodings start to break down when the model encounters sequence lengths longer than those it was trained on.

**Rotary Positional Embeddings (RoPE)**

- RoPE is a modern approach to positional encoding. RoPE applies rotations to represent a tokens positional embedding. Unlike sinusoidal positional encoding, RoPE can better extend to longer contexts not seen in training because of it's rotational properties.

_The equation is:_

$$f_{\{q,k\}}(x_m, m) = R^d_{\Theta,m}W_{\{q, k\}}x_m$$


In [62]:
def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096):
    assert head_dim % 2 == 0, "Embedding dimension must be even"

    # Compute the inverse frequencies
    inv_freq = 1.0 / ( theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))

    # Generate position indices
    positions = torch.arange(context_length)

    # Compute the angles
    angles = positions[:, None] * inv_freq[None, :]  # Shape: (context_length, head_dim // 2)

    # Expand angles to match the head_dim
    angles = torch.cat([angles, angles], dim=1)  # Shape: (context_length, head_dim)

    # Precompute sine and cosine
    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos, sin

def compute_rope(x, cos, sin):
    # x: (batch_size, num_heads, seq_len, head_dim)
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0, "Head dimension must be even"

    # Split x into first half and second half
    x1 = x[..., : head_dim // 2]  # First half
    x2 = x[..., head_dim // 2 :]  # Second half

    # Apply the rotary transformation
    rotated = torch.cat((-x2, x1), dim=-1)

    # Adjust sin and cos shapes
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq_len, head_dim)
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)

    x_rotated = (x * cos) + (rotated * sin)
    return x_rotated.to(dtype=x.dtype)

##### Steping Through the Functions


In [64]:
# Define the settings
batch_size, context_len, num_heads, head_dim, theta_base = 2, 5, 4, 16, 10_000

Let's explore `precompute_rope_params()`


In [65]:
# Lets first make sure that head dim is an event number
assert head_dim % 2 == 0, "Head dimension needs to be even"
# Create a tensor from 0 -> head_dim counting by 2
torch.arange(0, head_dim, 2)

tensor([ 0,  2,  4,  6,  8, 10, 12, 14])

In [75]:
# convert values into floats, and divide by head_dim to normalize frequencies
torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim

tensor([0.0000, 0.1250, 0.2500, 0.3750, 0.5000, 0.6250, 0.7500, 0.8750])

In [78]:
# Multiply by theta_base to control how quickly the frequencies decay
theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim)

tensor([1.0000e+00, 3.1623e+00, 1.0000e+01, 3.1623e+01, 1.0000e+02, 3.1623e+02,
        1.0000e+03, 3.1623e+03])

In [80]:
# The whole thing put together
inv_freq = 1.0 / ( theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))
inv_freq

tensor([1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
        1.0000e-03, 3.1623e-04])

Once we have the inverse frequencies, lets generate position indicies. This is a simple list of numbers from 0 - `context_length`


In [83]:
positions = torch.arange(context_len)
positions

tensor([0, 1, 2, 3, 4])

In [90]:
# adding another dimension to the positions and frequencies
print(f"Previous Position Shape: {positions.shape}\nNew Position Shape: {positions[:, None].shape}")
print("="*5)
print(f"Previous Inverse Frequency Shape: {inv_freq.shape}\nNew Inverse Frequency Shape: {inv_freq[None, :].shape}") 

Previous Position Shape: torch.Size([5])
New Position Shape: torch.Size([5, 1])
=====
Previous Inverse Frequency Shape: torch.Size([8])
New Inverse Frequency Shape: torch.Size([1, 8])


In [92]:
# We can add a new dimension another way by using unsqueeze
positions.unsqueeze(dim=-1), inv_freq.unsqueeze(dim=0)

(tensor([[0],
         [1],
         [2],
         [3],
         [4]]),
 tensor([[1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
          1.0000e-03, 3.1623e-04]]))

In [109]:
# Either way you choose to insert a new dimension is up to you.
# Once you insert a new dimension we do element wize multiplication
angles = (positions[:, None] * inv_freq[None, :])
angles

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
         1.0000e-03, 3.1623e-04],
        [2.0000e+00, 6.3246e-01, 2.0000e-01, 6.3246e-02, 2.0000e-02, 6.3246e-03,
         2.0000e-03, 6.3246e-04],
        [3.0000e+00, 9.4868e-01, 3.0000e-01, 9.4868e-02, 3.0000e-02, 9.4868e-03,
         3.0000e-03, 9.4868e-04],
        [4.0000e+00, 1.2649e+00, 4.0000e-01, 1.2649e-01, 4.0000e-02, 1.2649e-02,
         4.0000e-03, 1.2649e-03]])

In [110]:
# Lets expand the we need to join angles with itself since we only have head_dim // 2
angles = torch.cat([angles, angles], dim=1)
angles

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
         1.0000e-03, 3.1623e-04, 1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02,
         1.0000e-02, 3.1623e-03, 1.0000e-03, 3.1623e-04],
        [2.0000e+00, 6.3246e-01, 2.0000e-01, 6.3246e-02, 2.0000e-02, 6.3246e-03,
         2.0000e-03, 6.3246e-04, 2.0000e+00, 6.3246e-01, 2.0000e-01, 6.3246e-02,
         2.0000e-02, 6.3246e-03, 2.0000e-03, 6.3246e-04],
        [3.0000e+00, 9.4868e-01, 3.0000e-01, 9.4868e-02, 3.0000e-02, 9.4868e-03,
         3.0000e-03, 9.4868e-04, 3.0000e+00, 9.4868e-01, 3.0000e-01, 9.4868e-02,
         3.0000e-02, 9.4868e-03, 3.0000e-03, 9.4868e-04],
        [4.0000e+00, 1.2649e+00, 4.0000e-01, 1.2649e-01, 4.0000e-02, 1.2649e-02,
         4.0000e-03, 1.2649e-03, 4.0000

In [111]:
# Compute sine and cosine. This is fairly straight forward
cos = torch.cos(angles)
sin = torch.sin(angles)

torch.Size([5, 16])

Now that we have the corresponding sin and cos tensors, we can compute RoPE from the `compute_rope()` function


In [116]:
# RECALL: batch_size, context_len, num_heads, head_dim, theta_base = 2, 5, 4, 16, 10_000
# Create dummy query and key tensors
queries = torch.randn(batch_size, num_heads, context_len, head_dim)
keys = torch.randn(batch_size, num_heads, context_len, head_dim)

In [160]:
queries[..., :head_dim//2].shape

torch.Size([2, 4, 5, 8])

We have the queries and keys of dimension 2x4x5x16. Recall from the attention [notebook](../Attention/play.ipynb), we are dealing with weight splitting. So what does this mean?

- The first dimension is the batch dimension or the number of token sequences we are passing through the model.
- The second dimension is the number of attention heads we have. This is where weight splitting comes into play. Since we have 4 attention heads, we actually package all of the query and key vectors in one tensor. By splitting along this dimension we get 4 separate _context_len x head_dim_ matricies.
- The third dimension is the context length or the number of tokens in the sequence that is passed in to the transformer.
- The fourth dimension is the dimension of each attention head. This also represents the number of features that are in a Q, K, and V vector. Remember there is one Q, K, V vector to represent each token.


In [119]:
queries.shape 

torch.Size([2, 4, 5, 16])

In [120]:
# lets get the dimensions of our input x. x is either the query or key vectors
batch_size, num_heads, seq_len, head_dim = queries.shape
# ensure the head_dim is even
assert head_dim % 2 == 0, "Head dimension must be even"

Take a look at the code below. What this code is doing is taking the original query vector, splitting the embeddings for each token in half, negating the original second half and then concatonating the negated, negative second half with the orignal first half. This is a mathematical trick that emulates multiplication by the original rotary embedding matrix.

lets have a look at the 2D rotational matrix.

\begin{pmatrix}

cos(m\theta) & -sin(m\theta) \\

cos(m\theta) & sin(m\theta)

\end{pmatrix}

Now lets take a look at what the code below is doing with the following example:

$$
\begin{pmatrix}

1 & 2 & 3 & 4 \\

5 & 6 & 7 & 8 \\

1 & 4 & 5 & 9

\end{pmatrix}
=>
\begin{pmatrix}

1 & 2 \\

5 & 6 \\

1 & 4

\end{pmatrix}

and

\begin{pmatrix}

-3 & -4 \\

-7 & -8 \\

-5 & -9

\end{pmatrix}

=>

\begin{pmatrix}

-3 & -4 & 1 & 2 \\

-7 & -8 & 5 & 6 \\

-5 & -9 & 1 & 4

\end{pmatrix}


$$

The reasoning may be confusing at the moment, but I will show how this is advantageous shortly. Remeber the number of rows are the sequence legnth, and the columns are the embeddings. This means each row represents query embedding.


In [166]:
# split the query into its first and second halfs
q_first = queries[..., : head_dim // 2]
q_second = queries[..., head_dim // 2 : ]
rotated = torch.cat([q_first, q_second], dim=-1)
print(f"Full Query Shape: {queries.shape}")
print(f"First half Query Shape: {q_first.shape}")
print(f"Second half Query Shape: {q_second.shape}")

Full Query Shape: torch.Size([2, 4, 5, 16])
First half Query Shape: torch.Size([2, 4, 5, 8])
Second half Query Shape: torch.Size([2, 4, 5, 8])


We spoke about broadcasting, a very important PyTorch concept, in the attention [notebook]("../Attention/play.ipynb"). In order to properly apply the cos and sin transformations we need to create two new dimensions of size 1 for proper broadcasting.


In [133]:
# Adjust sin and cos shapes
# Option 1: cos[:seq_len, :][None, None, :, :].shape
# Option 2:
print(f"Previous Cos shape: {cos.shape}")
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
print(f"New Cos shape: {cos.shape}")
# apply the same for sin
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)

Previous Cos shape: torch.Size([5, 16])
New Cos shape: torch.Size([1, 1, 5, 16])


Now to explain what is going on in the next code block, lets recall what we did with the previous embedding split and negation:

$$
\begin{pmatrix}

1 & 2 & 3 & 4 \\

5 & 6 & 7 & 8 \\

1 & 4 & 5 & 9

\end{pmatrix}
=(split)=>
\begin{pmatrix}

1 & 2 \\

5 & 6 \\

1 & 4

\end{pmatrix}

and

\begin{pmatrix}

-3 & -4 \\

-7 & -8 \\

-5 & -9

\end{pmatrix}

=(rotate)=>
\begin{pmatrix}

-3 & -4 & 1 & 2 \\

-7 & -8 & 5 & 6 \\

-5 & -9 & 1 & 4

\end{pmatrix}


$$

Now we take the queries, do component wise multiplication using the RoPE cos frequencies and sin frequencies

$$
\begin{pmatrix}

1 & 2 & 3 & 4 \\

5 & 6 & 7 & 8 \\

1 & 4 & 5 & 9

\end{pmatrix}
* cos
+
\begin{pmatrix}

-3 & -4 & 1 & 2 \\

-7 & -8 & 5 & 6 \\

-5 & -9 & 1 & 4

\end{pmatrix}
* sin
$$

Although this does not look exactly the same as the RoPE equation below, it is a mathematical trick that acheives the same geometric rotation more efficiently:

$$f_{\{q,k\}}(x_m, m) = R^d_{\Theta,m}W_{\{q, k\}}x_m$$


In [141]:
queries_rotated = ( queries * cos ) + ( rotated * sin )

### **Multi-Headed Attention**

Multiheaded attention was covered extensively in this notebook [here]("../Attention/play.ipynb")

But let's implement Attention applying RoPE


In [197]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, num_heads, dtype=None):
        super().__init__()
        assert d_out % num_heads == 0, "The output dimension (d_out) mus be divisible by n_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads 

        # Set q, k, v, vectors and output projection
        self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
        self.W_key = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
        self.W_value = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
        self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)

        # Create token masking 
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

        # Create RoPE parameters
        cos, sin = precompute_rope_params(head_dim=self.head_dim, context_length=context_length)
        self.register_buffer("cos", cos)
        self.register_buffer("sin", sin)
    
    def forward(self, x:torch.Tensor):
        b, num_tokens, d_in = x.shape

        # Calculate the q, k, v vectors
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        # Remember the output dim is all of the q, k, v vectors concatonated. Need to split
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        # Now we need to swap the head dimension and the context dim
        queries:torch.Tensor = queries.transpose(1,2)
        keys:torch.Tensor = keys.transpose(1,2)
        values:torch.Tensor = values.transpose(1,2)

        # compute the positional encodings for the key and the values
        keys = compute_rope(keys, self.cos, self.sin)
        queries = compute_rope(queries, self.cos, self.sin)

        # Attention scores via scaled dot-product
        attn_scores = queries @ keys.transpose(2,3)

        # Masking
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use mask to fill upper triangle of attention to -inf for norm
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        # normalizing the weights
        attn_weights = torch.softmax(attn_scores / keys.size(-1)**0.5, dim=-1)

        context_vec = (attn_weights @ values).transpose(1,2)

        # combine the heads again self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.reshape(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # projection
        return context_vec

In [198]:
# Settings
batch_size = 1
context_len = 100
max_context_len = 4096
embed_dim = 128
num_heads = 4

example_batch = torch.randn((batch_size, context_len, embed_dim))

mha = MultiHeadedAttention(d_in=embed_dim, d_out=embed_dim, context_length=max_context_len, num_heads=num_heads)
mha.forward(example_batch)

tensor([[[ 0.6056,  0.4086, -0.0366,  ..., -0.2326, -0.1381,  0.3727],
         [ 0.5075,  0.2772,  0.0285,  ..., -0.2222, -0.1858,  0.4793],
         [ 0.3986,  0.2096,  0.0269,  ..., -0.3572, -0.2304,  0.3711],
         ...,
         [-0.0310, -0.0103,  0.0176,  ...,  0.0535, -0.0766, -0.0107],
         [-0.0382, -0.0397,  0.0171,  ...,  0.0416, -0.0566, -0.0347],
         [-0.0237, -0.0306,  0.0328,  ...,  0.0404, -0.0638, -0.0395]]],
       grad_fn=<UnsafeViewBackward0>)

### Transformer Block


In [219]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg:dataclass):
        super().__init__()
        self.att = MultiHeadedAttention(
            d_in=cfg.emb_dim,
            d_out=cfg.emb_dim,
            context_length=cfg.context_length,
            num_heads=cfg.n_heads,
            dtype=cfg.dtype
        )

        self.ff = FeedForward(cfg)

        # RMSNorms
        self.norm1 = RMSNorm(cfg.emb_dim)
        self.norm2 = RMSNorm(cfg.emb_dim)
    
    def forward(self, x:torch.Tensor):
        # save x for the residual connections 
        residual = x
        x = self.norm1(x)
        x = self.att(x)
        x = x + residual

        # reset residual connections
        residual = x
        x = self.norm2(x)
        x = self.ff(x)
        x = x + residual
        
        return x

### Define Llama2


In [229]:
class Llama2(nn.Module):
    def __init__(self, cfg:dataclass):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.emb_dim, dtype=cfg.dtype)

        # Transformer block
        self.trf_blocks = nn.Sequential(*[TransformerBlock(cfg) for _ in range(cfg.n_layers)])

        # RMSNorm
        self.final_norm = RMSNorm(cfg.emb_dim)

        # Final output layer
        self.out_head = nn.Linear(cfg.emb_dim, cfg.vocab_size, bias=False, dtype=cfg.dtype)
    
    def forward(self,  in_idx:torch.Tensor):
        tok_embeds = self.tok_emb(in_idx)
        x = tok_embeds 
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits


In [230]:
@dataclass
class Llama2Config7B:
    vocab_size = 32000      # Vocabulary size
    context_length = 4096   # Context length
    emb_dim = 4096          # Embedding dimension
    n_heads = 32            # Number of attention heads
    ff_bias = False         # Feed forward bias 
    n_layers = 32           # Number of layers
    hidden_dim = 11008      # NEW = Size of the intermediate dimension in FeedForward
    dtype = torch.bfloat16  # NEW: Lower-precision dtype to reduce memory usage


In [231]:
model = Llama2(Llama2Config7B)

In [232]:
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params:,}")

Total number of parameters: 6,738,415,616


In [233]:
def model_memory_size(model, input_dtype=torch.float32):
    total_params = 0
    total_grads = 0
    for param in model.parameters():
        # Calculate total number of elements per parameter
        param_size = param.numel()
        total_params += param_size
        # Check if gradients are stored for this parameter
        if param.requires_grad:
            total_grads += param_size

    # Calculate buffer size (non-parameters that require memory)
    total_buffers = sum(buf.numel() for buf in model.buffers())

    # Size in bytes = (Number of elements) * (Size of each element in bytes)
    # We assume parameters and gradients are stored in the same type as input dtype
    element_size = torch.tensor(0, dtype=input_dtype).element_size()
    total_memory_bytes = (total_params + total_grads + total_buffers) * element_size

    # Convert bytes to gigabytes
    total_memory_gb = total_memory_bytes / (1024**3)

    return total_memory_gb

print(f"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB")
print(f"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB")

float32 (PyTorch default): 52.33 GB
bfloat16: 26.17 GB


### Load in the Tokenizer
