<a href="https://colab.research.google.com/github/Shrsht/LLaMA3-From-Scratch/blob/main/LLaMA_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

Annotate with images and explanations:

In [None]:
# vocabulary length. Llama's real vocab size is 128256. Here let's just use an absurdly small number


v = 10 ## total number of words in our entire vocabulary

# Llama's maximum sequence length is 8192, but for inference they cache 3/4 of it and only use an effective length of 2048. more on that later

seq_len = 5  ##number of words in a given sentence/phrase/etc.

# we'll use a batch size of 1 for simplicity when visualizing our tensors
b = 1 ## training input size

# now let's make ourselves a list of token indices. Each represents somewhere between a letter and a word
tokens = torch.randint(v, (b, seq_len))
tokens.shape, tokens

(torch.Size([1, 5]), tensor([[1, 0, 6, 0, 2]]))

## Embedding our Vocabulary:

***Purpose of Embeddings:***

- Need to find a way to represent our input words in a vector-space.
- Ideally words with a similar meaning are closer to each other in the vector space

1b. Initializing the first residual state

In [None]:
# our embedding dimension. Llama 3 8b's is 4096
d = 16


# initializing our token embedding matrix
embedding = nn.Embedding(v, d)


## for each of our v = 10 words, we embedded the vectors of the tokens into a 16-dimensional space

### So there are now 10 vectors, w/ 16-dimensions that are a vector representation of our entire vocabulary

embedding.weight.shape, embedding.weight
# each row in this embedding is a high dimensional repersentation of its corresponding token



(torch.Size([10, 16]),
 Parameter containing:
 tensor([[ 0.5023,  0.1898, -0.0842, -0.4255, -0.3528, -1.0773, -0.6780,  0.9639,
          -0.8250, -0.4544, -1.9748,  1.2844,  0.2053, -0.7056, -0.3937,  0.5133],
         [ 0.7922, -1.0242,  0.1843, -1.0991, -0.9402, -1.9237,  1.4471, -1.6211,
          -0.4627,  0.0724, -2.5458,  0.3172,  0.7139,  0.8306, -0.3575, -0.5453],
         [ 1.0906, -0.5341,  1.5008, -2.1458, -0.2829,  0.4449, -2.3409, -0.6489,
          -0.3310, -0.3682, -0.3702,  1.2048, -0.4190, -0.5210, -0.8396, -1.5450],
         [-1.5210, -0.1012, -1.4199, -1.2587,  0.0276,  0.1067, -0.4456, -0.5917,
           0.6417, -0.7544,  0.8427, -0.7505,  0.7146, -1.1200, -0.0906,  0.0935],
         [-0.3263, -0.5366, -0.1466,  0.9948,  0.7581, -0.3410, -0.7700,  0.3221,
           0.7490,  1.5158, -0.8681, -1.2408,  1.0331,  0.0739,  0.3373, -0.6929],
         [ 0.0199,  0.6430, -0.6010,  1.3610, -0.8156,  0.6825,  0.5961, -0.9829,
          -0.6699, -1.5987,  0.2250,  0.3872,  

In [None]:
## We want the embedings of our 5 specific words - so we feed the tokens of these words into the embeddings() object.

x = embedding(tokens)
x.shape, x

# at this points many models would multiply the embeddings by the square root of the embedding dimension, but Llama 3 foregoes that strategy

(torch.Size([1, 5, 16]),
 tensor([[[ 0.7922, -1.0242,  0.1843, -1.0991, -0.9402, -1.9237,  1.4471,
           -1.6211, -0.4627,  0.0724, -2.5458,  0.3172,  0.7139,  0.8306,
           -0.3575, -0.5453],
          [ 0.5023,  0.1898, -0.0842, -0.4255, -0.3528, -1.0773, -0.6780,
            0.9639, -0.8250, -0.4544, -1.9748,  1.2844,  0.2053, -0.7056,
           -0.3937,  0.5133],
          [ 0.4887, -0.1694,  0.4623,  1.1920,  0.2188, -2.1144, -0.6303,
           -1.2609, -0.4754, -0.4203, -0.3314,  0.0310,  0.8525,  0.0888,
            0.4840, -0.7512],
          [ 0.5023,  0.1898, -0.0842, -0.4255, -0.3528, -1.0773, -0.6780,
            0.9639, -0.8250, -0.4544, -1.9748,  1.2844,  0.2053, -0.7056,
           -0.3937,  0.5133],
          [ 1.0906, -0.5341,  1.5008, -2.1458, -0.2829,  0.4449, -2.3409,
           -0.6489, -0.3310, -0.3682, -0.3702,  1.2048, -0.4190, -0.5210,
           -0.8396, -1.5450]]], grad_fn=<EmbeddingBackward0>))

## Positional Encoding:


by default the attention mechanism is blind to the ordering of tokens - so we need to find a way to represent order of tokens.


### RoPE Encoding
<a id='c'></a>

Rotary Positional Encoding (RoPE) is a method [originally proposed in 2019](https://arxiv.org/abs/2104.09864) that quickly became the defacto standard for enabling transformers to understand positional information.

The method utilizes trigonometry to "rotate" the entries in two matrices before they are multiplied together. A small amount of rotation indicates that two tokens are close together, while a large amount of rotation corresponds to being far apart.


We will precompute these positional encodings because we want to reuse them throughout the model as opposed to creating them from scratch every time we need them.

In [None]:
theta = 10000 # 10,000 is the most common value but Llama 3 uses 50,000. In theory smaller models should use a smaller value
num_heads = 4 # Llama 3 8b has 32 total attention heads
head_dim = d // num_heads # Llama 3 ties its head dimension to the embedding dimension. This value comes out to 128 in Llama 3, which is purposeful to

# go watch the video to get a better explanation of what's happening here
freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))
print(f'freqs: {freqs.shape}\n{freqs}\n')

t = torch.arange(seq_len * 2, device=freqs.device, dtype=torch.float32)
print(f't: {t.shape}\n{t}\n')

freqs = torch.outer(t, freqs)
print(f'freqs: {freqs.shape}\n{freqs}\n')

freqs_cis = torch.polar(torch.ones_like(freqs), freqs)[:seq_len]  # complex64
print(f'freqs_cis: {freqs_cis.shape}\n{freqs_cis}')

freqs: torch.Size([2])
tensor([1.0000, 0.0100])

t: torch.Size([10])
tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])

freqs: torch.Size([10, 2])
tensor([[0.0000, 0.0000],
        [1.0000, 0.0100],
        [2.0000, 0.0200],
        [3.0000, 0.0300],
        [4.0000, 0.0400],
        [5.0000, 0.0500],
        [6.0000, 0.0600],
        [7.0000, 0.0700],
        [8.0000, 0.0800],
        [9.0000, 0.0900]])

freqs_cis: torch.Size([5, 2])
tensor([[ 1.0000+0.0000j,  1.0000+0.0000j],
        [ 0.5403+0.8415j,  0.9999+0.0100j],
        [-0.4161+0.9093j,  0.9998+0.0200j],
        [-0.9900+0.1411j,  0.9996+0.0300j],
        [-0.6536-0.7568j,  0.9992+0.0400j]])


### Precomputing the Causal Mask
<a id='d'></a>


The basic idea of a causal mask is that by default, attention mechanisms allow every single token to pay attention to every single other token.

This is okay or even preferable for some model types, but Llama is auto-regressive, meaning it would be bad if a given token to be predicted was able to see itself and future tokens during training but not during inference.


The negative infinity's in the upper-triangle prevent the model from attending to the corresponding token; how this works will be more clear later when we do the attention softmax

In [None]:
mask = torch.full(
    (seq_len, seq_len),
    float("-inf") ) ##This code creates a square tensor of size seq_len x seq_len filled entirely with negative infinity values (float("-inf")).)


mask = torch.triu(mask, diagonal=1)
mask

### creates -inf in the upper portion of the matrix as a max to prevent future words from being seen before the token is produced:

tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])

## RMS Normalization:


We want a kind of normalization that does not refactor the mean.

Root Mean Square Normalization has also been the norm for quite awhile. Like its predecessor LayerNorm, RMSNorm restricts the variability of the entries in each embedding vector such that the vector lies on a hypersphere with radius $\sqrt{d}$. However unlike LayerNorm which centers that hypersphere with a mean of zero, RMSNorm does not mess with the mean, which is an important source of data for networks that utilize residual connections.


In [None]:
x

tensor([[[ 0.7922, -1.0242,  0.1843, -1.0991, -0.9402, -1.9237,  1.4471,
          -1.6211, -0.4627,  0.0724, -2.5458,  0.3172,  0.7139,  0.8306,
          -0.3575, -0.5453],
         [ 0.5023,  0.1898, -0.0842, -0.4255, -0.3528, -1.0773, -0.6780,
           0.9639, -0.8250, -0.4544, -1.9748,  1.2844,  0.2053, -0.7056,
          -0.3937,  0.5133],
         [ 0.4887, -0.1694,  0.4623,  1.1920,  0.2188, -2.1144, -0.6303,
          -1.2609, -0.4754, -0.4203, -0.3314,  0.0310,  0.8525,  0.0888,
           0.4840, -0.7512],
         [ 0.5023,  0.1898, -0.0842, -0.4255, -0.3528, -1.0773, -0.6780,
           0.9639, -0.8250, -0.4544, -1.9748,  1.2844,  0.2053, -0.7056,
          -0.3937,  0.5133],
         [ 1.0906, -0.5341,  1.5008, -2.1458, -0.2829,  0.4449, -2.3409,
          -0.6489, -0.3310, -0.3682, -0.3702,  1.2048, -0.4190, -0.5210,
          -0.8396, -1.5450]]], grad_fn=<EmbeddingBackward0>)

In [None]:
# first let's setup the residual connection that we'll use later
h = x ## our embedding vectors that have NOT yet been positionally encodded
print(f'h: {h.shape}\n{h}')



# now we'll perform our first normalization
# first we square each entry in x and then take the mean of those values across each embedding vector
mean_squared = x.pow(2).mean(dim=-1, keepdim=True)
mean_squared


# then we multiply x by the reciprocal of the square roots of mean_squared
# 1e-6 is a very small number added for stability just in case an entry happens to be equal to 0 (since you can't divide by 0)
x_normed = x * torch.rsqrt(mean_squared + 1e-6)
print(f'x_normed: {x_normed.shape}\n{x_normed}')

# and finally, we multiply by a learnable scale parameter
# This scale is initialized to 1's but if we were to train then those values would change
rms_scale = torch.ones(d)
print(f'rms_scale: {rms_scale.shape}\n{rms_scale}\n')

x_normed *= rms_scale
print(f'x_normed: {x_normed.shape}\n{x_normed}')

h: torch.Size([1, 5, 16])
tensor([[[ 0.7922, -1.0242,  0.1843, -1.0991, -0.9402, -1.9237,  1.4471,
          -1.6211, -0.4627,  0.0724, -2.5458,  0.3172,  0.7139,  0.8306,
          -0.3575, -0.5453],
         [ 0.5023,  0.1898, -0.0842, -0.4255, -0.3528, -1.0773, -0.6780,
           0.9639, -0.8250, -0.4544, -1.9748,  1.2844,  0.2053, -0.7056,
          -0.3937,  0.5133],
         [ 0.4887, -0.1694,  0.4623,  1.1920,  0.2188, -2.1144, -0.6303,
          -1.2609, -0.4754, -0.4203, -0.3314,  0.0310,  0.8525,  0.0888,
           0.4840, -0.7512],
         [ 0.5023,  0.1898, -0.0842, -0.4255, -0.3528, -1.0773, -0.6780,
           0.9639, -0.8250, -0.4544, -1.9748,  1.2844,  0.2053, -0.7056,
          -0.3937,  0.5133],
         [ 1.0906, -0.5341,  1.5008, -2.1458, -0.2829,  0.4449, -2.3409,
          -0.6489, -0.3310, -0.3682, -0.3702,  1.2048, -0.4190, -0.5210,
          -0.8396, -1.5450]]], grad_fn=<EmbeddingBackward0>)
x_normed: torch.Size([1, 5, 16])
tensor([[[ 0.6973, -0.9015,  0.162

In [None]:
# let's turn that RMSNorm into a function that we'll be able to reuse repeatedly later
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

## Multi Query Attention:

<a id='f'></a>
[multi-query attention](https://arxiv.org/abs/1911.02150) is the de facto standard for saving on parameter counts in order to get a bigger model.


GPU's are computationally efficient, but not as good at storage. Using multi-Query attention we remove heads for V's and K's and only use multi-heads for Q's.

 The idea is that the model can make multiple queries to the residual state and have those many queries be answered by shared keys & values.

** Why do we need normalization??**

In [None]:
# first up, remember we're currently working with two separate objects
# x is for the residual connection and x_normed will go into our Attention calculation
h, x_normed

(tensor([[[ 0.7922, -1.0242,  0.1843, -1.0991, -0.9402, -1.9237,  1.4471,
           -1.6211, -0.4627,  0.0724, -2.5458,  0.3172,  0.7139,  0.8306,
           -0.3575, -0.5453],
          [ 0.5023,  0.1898, -0.0842, -0.4255, -0.3528, -1.0773, -0.6780,
            0.9639, -0.8250, -0.4544, -1.9748,  1.2844,  0.2053, -0.7056,
           -0.3937,  0.5133],
          [ 0.4887, -0.1694,  0.4623,  1.1920,  0.2188, -2.1144, -0.6303,
           -1.2609, -0.4754, -0.4203, -0.3314,  0.0310,  0.8525,  0.0888,
            0.4840, -0.7512],
          [ 0.5023,  0.1898, -0.0842, -0.4255, -0.3528, -1.0773, -0.6780,
            0.9639, -0.8250, -0.4544, -1.9748,  1.2844,  0.2053, -0.7056,
           -0.3937,  0.5133],
          [ 1.0906, -0.5341,  1.5008, -2.1458, -0.2829,  0.4449, -2.3409,
           -0.6489, -0.3310, -0.3682, -0.3702,  1.2048, -0.4190, -0.5210,
           -0.8396, -1.5450]]], grad_fn=<EmbeddingBackward0>),
 tensor([[[ 0.6973, -0.9015,  0.1623, -0.9675, -0.8276, -1.6933,  1.2738,
   

In [None]:
# let's define the hyperparameters of MQA

num_kv_heads = 2 # Llama uses 8 key and value heads per layer
assert num_heads % num_kv_heads == 0 # each query needs to match up to a kv so checking for perfect divisibility
print(f"as a reminder: num_heads = {num_heads}, head_dim = {head_dim}")


## nn.Linear(): This creates a linear layer, which applies a linear transformation to the input data.
## d: This is the input dimension of the linear layer, meaning the layer expects input tensors to have d = 16 features.
## (num_heads * head_dim) : This is the output dimension of the linear layer. It calculates the total number of output features by multiplying the number of attention heads (num_heads) by the dimension of each head (head_dim).
## bias=False: This argument indicates that the linear layer should not have a bias term.

# now we'll initialize our self-attention Weight Matrices
wq = nn.Linear(d, num_heads * head_dim, bias=False)
wk = nn.Linear(d, num_kv_heads * head_dim, bias=False)
wv = nn.Linear(d, num_kv_heads * head_dim, bias=False)
print("Attention weights: ", wq.weight.shape, wk.weight.shape, wv.weight.shape)

# and project x_normed out to get our queries, keys and values
xq = wq(x_normed)
xk = wk(x_normed)
xv = wv(x_normed)
print("Attention projections: ", xq.shape, xk.shape, xv.shape)

# then reshape them to separate out by head
xq = xq.view(b, seq_len, num_heads, head_dim)
xk = xk.view(b, seq_len, num_kv_heads, head_dim)
xv = xv.view(b, seq_len, num_kv_heads, head_dim)
print("Reshaped: ", xq.shape, xk.shape, xv.shape)

as a reminder: num_heads = 4, head_dim = 4
Attention weights:  torch.Size([16, 16]) torch.Size([8, 16]) torch.Size([8, 16])
Attention projections:  torch.Size([1, 5, 16]) torch.Size([1, 5, 8]) torch.Size([1, 5, 8])
Reshaped:  torch.Size([1, 5, 4, 4]) torch.Size([1, 5, 2, 4]) torch.Size([1, 5, 2, 4])


### 1g. Apply Pre-Computed RoPE Positional Encodings:
<a id='g'></a>

Earlier we pre-computed the frequencies for rotation. Now we'll actually apply our rotary embeddings.

In [None]:
# first we reshape and then view our queries and keys as complex values, the type of number that works well with rotation
xq = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
print(f'xq: {xq.shape}\n{xq}\n')
print(f'xk: {xk.shape}\n{xk}')


ndim = xq.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (xq.shape[1], xq.shape[-1]), f'freqs_cis.shape {freqs_cis.shape} != xq.shape[1], xq.shape[-1] {(xq.shape[1], xq.shape[-1])}'

# reshape our queries
shape = [d if i == 1 or i == xq.ndim - 1 else 1 for i, d in enumerate(xq.shape)]
print(f'shape: {shape}\n')

freqs_cis = freqs_cis.view(*shape)
print(f'freqs_cis: {freqs_cis.shape}\n{freqs_cis}')

# now multiply the data by the frequencies, turn them back into real numbers, revert the shape and make sure they're of the right type
xq = torch.view_as_real(xq * freqs_cis).flatten(3).type_as(xv)
xk = torch.view_as_real(xk * freqs_cis).flatten(3).type_as(xv)
print(f'xq: {xq.shape}\n{xq}\n')
print(f'xk: {xk.shape}\n{xk}')



xq: torch.Size([1, 5, 4, 2])
tensor([[[[-1.0814-1.2398j, -0.8482-0.3291j],
          [ 0.0121-1.7533j,  0.2948-0.1679j],
          [ 0.9378+0.1311j,  0.4642-0.0390j],
          [-0.2767-0.9907j,  0.1970+0.5204j]],

         [[ 0.1150+0.2663j, -0.2363-0.5226j],
          [ 0.8432-0.4822j,  1.5463-0.9210j],
          [ 0.6538-0.0643j,  0.5585-0.4385j],
          [-0.0605-0.4828j,  0.1684+0.1511j]],

         [[-0.1797-0.2311j, -0.6211-0.4716j],
          [ 0.8583-0.6888j, -0.1591-1.0148j],
          [ 0.7652-0.3836j,  0.2293-0.0237j],
          [-1.1040-1.0531j,  0.2276-0.1342j]],

         [[ 0.1150+0.2663j, -0.2363-0.5226j],
          [ 0.8432-0.4822j,  1.5463-0.9210j],
          [ 0.6538-0.0643j,  0.5585-0.4385j],
          [-0.0605-0.4828j,  0.1684+0.1511j]],

         [[ 0.4101-0.3572j,  0.4347-0.5412j],
          [ 0.8725-1.0928j,  0.0247-0.2400j],
          [-0.4210+0.0369j,  0.6348-0.8082j],
          [-0.3813-1.1089j, -0.1655+0.0956j]]]],
       grad_fn=<ViewAsComplexBackward0>)

** Calculating Self-Attention:



In [None]:
# If the number of K & V heads is different from the number of query heads, adjusts keys and values to match the query heads count.
if num_kv_heads != num_heads:
  num_queries_per_kv = num_heads // num_kv_heads
  xk = torch.repeat_interleave(xk, num_queries_per_kv, dim=2)
  xv = torch.repeat_interleave(xv, num_queries_per_kv, dim=2)

xq.shape, xk.shape, xv.shape


# Transposes Q, K, and V tensors to align them for the batch matrix multiplication in attention calculation.
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)

xq.shape, xk.shape, xv.shape


# Calculates attention logits by performing a batch matrix multiplication between queries and keys
scores = torch.matmul(xq, xk.transpose(2, 3))

# then we scale the logits by the reciprocal of the square root of the head dimension
scores = scores / math.sqrt(head_dim)

scores.shape, scores

(torch.Size([1, 4, 5, 5]),
 tensor([[[[-1.1030, -0.5156, -0.1514, -0.5465, -0.0455],
           [-0.1411, -0.2138, -0.0174, -0.1974,  0.0170],
           [-0.2549, -0.3824, -0.4898, -0.4011, -0.2347],
           [-0.2690, -0.2003, -0.1295, -0.2138,  0.0284],
           [-0.0351,  0.1346,  0.4226,  0.1451,  0.3840]],
 
          [[-0.3657,  0.1571, -0.2842,  0.0486,  0.1757],
           [ 1.0138,  0.6935,  0.2995,  0.6707,  0.5885],
           [ 0.2175, -0.2403, -0.0526, -0.1974,  0.0255],
           [ 0.5534,  0.6405,  1.1798,  0.6935,  0.9333],
           [-0.4328, -0.0763,  0.7236, -0.0185,  0.3464]],
 
          [[ 0.0534, -0.0956, -0.0147,  0.1148, -0.0319],
           [-0.2813, -0.0266,  0.0997, -0.0041, -0.1081],
           [-0.3534,  0.0218,  0.1409, -0.0963,  0.0506],
           [-0.1076,  0.1552,  0.2795, -0.0266, -0.0063],
           [-0.3034, -0.0219,  0.0904,  0.0060, -0.2260]],
 
          [[ 0.7585,  0.1254,  0.1158,  0.2035,  0.2392],
           [ 0.2652, -0.0244, -0.018

## Using Mask we created earlier:

In [None]:
# now we get to use the mask that we precomputed earlier
scores = scores + mask

scores.shape, scores

(torch.Size([1, 4, 5, 5]),
 tensor([[[[-1.1030,    -inf,    -inf,    -inf,    -inf],
           [-0.1411, -0.2138,    -inf,    -inf,    -inf],
           [-0.2549, -0.3824, -0.4898,    -inf,    -inf],
           [-0.2690, -0.2003, -0.1295, -0.2138,    -inf],
           [-0.0351,  0.1346,  0.4226,  0.1451,  0.3840]],
 
          [[-0.3657,    -inf,    -inf,    -inf,    -inf],
           [ 1.0138,  0.6935,    -inf,    -inf,    -inf],
           [ 0.2175, -0.2403, -0.0526,    -inf,    -inf],
           [ 0.5534,  0.6405,  1.1798,  0.6935,    -inf],
           [-0.4328, -0.0763,  0.7236, -0.0185,  0.3464]],
 
          [[ 0.0534,    -inf,    -inf,    -inf,    -inf],
           [-0.2813, -0.0266,    -inf,    -inf,    -inf],
           [-0.3534,  0.0218,  0.1409,    -inf,    -inf],
           [-0.1076,  0.1552,  0.2795, -0.0266,    -inf],
           [-0.3034, -0.0219,  0.0904,  0.0060, -0.2260]],
 
          [[ 0.7585,    -inf,    -inf,    -inf,    -inf],
           [ 0.2652, -0.0244,    -in

In [None]:
# now we perform the softmax operation to get our actual probabilities
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores
# notice that thanks to the causal mask, 0 probability is placed on future tokens

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5182, 0.4818, 0.0000, 0.0000, 0.0000],
          [0.3744, 0.3296, 0.2960, 0.0000, 0.0000],
          [0.2338, 0.2504, 0.2688, 0.2471, 0.0000],
          [0.1542, 0.1828, 0.2438, 0.1847, 0.2345]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5794, 0.4206, 0.0000, 0.0000, 0.0000],
          [0.4174, 0.2640, 0.3186, 0.0000, 0.0000],
          [0.1956, 0.2134, 0.3660, 0.2250, 0.0000],
          [0.1075, 0.1536, 0.3418, 0.1627, 0.2344]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.4367, 0.5633, 0.0000, 0.0000, 0.0000],
          [0.2442, 0.3554, 0.4004, 0.0000, 0.0000],
          [0.2059, 0.2677, 0.3032, 0.2232, 0.0000],
          [0.1600, 0.2120, 0.2372, 0.2180, 0.1729]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5719, 0.4281, 0.0000, 0.0000, 0.0000],
          [0.4569, 0.2776, 0.2655, 0.0000, 0.0000],
          [0.2253, 0.2543, 0.2753, 0.2451, 0.0000],
      

In [None]:
# then matmul by our values projection
output = torch.matmul(scores, xv)
output.shape, output

(torch.Size([1, 4, 5, 4]),
 tensor([[[[-0.3411, -0.2834,  0.7208,  0.2801],
           [-0.6501, -0.7035,  0.7213,  0.7681],
           [-0.6010, -0.7171,  0.5341,  0.6377],
           [-0.7042, -0.8500,  0.5515,  0.8055],
           [-0.5821, -0.7557,  0.4780,  0.8053]],
 
          [[-0.3411, -0.2834,  0.7208,  0.2801],
           [-0.6108, -0.6501,  0.7213,  0.7061],
           [-0.5627, -0.6711,  0.5198,  0.5731],
           [-0.6823, -0.8466,  0.4901,  0.7535],
           [-0.5654, -0.7596,  0.4161,  0.7613]],
 
          [[ 0.8611,  0.4087,  1.4048,  0.7438],
           [ 0.6380,  0.3079,  0.6681,  0.8952],
           [ 0.2859,  0.0231,  0.5375,  0.9409],
           [ 0.3377,  0.0770,  0.4579,  0.9527],
           [ 0.2379, -0.0508,  0.4032,  0.7980]],
 
          [[ 0.8611,  0.4087,  1.4048,  0.7438],
           [ 0.6916,  0.3321,  0.8449,  0.8588],
           [ 0.4630,  0.1454,  0.7748,  0.8858],
           [ 0.3646,  0.0979,  0.4749,  0.9479],
           [ 0.1775, -0.1153,  0.

In [None]:
# and reshape to put the sequence length back into place and the outputs of our heads lined up
output = output.transpose(1, 2).contiguous().view(b, seq_len, -1)
output.shape, output

(torch.Size([1, 5, 16]),
 tensor([[[-0.3411, -0.2834,  0.7208,  0.2801, -0.3411, -0.2834,  0.7208,
            0.2801,  0.8611,  0.4087,  1.4048,  0.7438,  0.8611,  0.4087,
            1.4048,  0.7438],
          [-0.6501, -0.7035,  0.7213,  0.7681, -0.6108, -0.6501,  0.7213,
            0.7061,  0.6380,  0.3079,  0.6681,  0.8952,  0.6916,  0.3321,
            0.8449,  0.8588],
          [-0.6010, -0.7171,  0.5341,  0.6377, -0.5627, -0.6711,  0.5198,
            0.5731,  0.2859,  0.0231,  0.5375,  0.9409,  0.4630,  0.1454,
            0.7748,  0.8858],
          [-0.7042, -0.8500,  0.5515,  0.8055, -0.6823, -0.8466,  0.4901,
            0.7535,  0.3377,  0.0770,  0.4579,  0.9527,  0.3646,  0.0979,
            0.4749,  0.9479],
          [-0.5821, -0.7557,  0.4780,  0.8053, -0.5654, -0.7596,  0.4161,
            0.7613,  0.2379, -0.0508,  0.4032,  0.7980,  0.1775, -0.1153,
            0.3722,  0.7496]]], grad_fn=<ViewBackward0>))

In [None]:
# finally we can initialize and apply our output projection that mixes the information from the heads together
wo = nn.Linear(num_heads * head_dim, d, bias=False)
Xout = wo(output)
Xout.shape, Xout

(torch.Size([1, 5, 16]),
 tensor([[[ 0.1081,  0.1188, -0.5131,  0.0919,  0.1291,  0.3527,  0.2897,
            0.0015,  0.4972,  0.4286, -0.1191,  0.3706, -0.2346, -0.0786,
           -0.2442, -0.5780],
          [ 0.3537,  0.4496, -0.4476,  0.2291,  0.0298,  0.0955, -0.1814,
           -0.1036,  0.5948,  0.2303, -0.1553,  0.7117, -0.0258, -0.0470,
           -0.2564, -1.0106],
          [ 0.3688,  0.5131, -0.4693,  0.1576, -0.0749, -0.0713, -0.1770,
           -0.0439,  0.4747,  0.1959, -0.1184,  0.7113,  0.0079,  0.0852,
           -0.1823, -0.9736],
          [ 0.4203,  0.5579, -0.4546,  0.2267, -0.0476, -0.0901, -0.2981,
           -0.0890,  0.5570,  0.1949, -0.1651,  0.7934,  0.0639,  0.1002,
           -0.1195, -1.2026],
          [ 0.3947,  0.4520, -0.4219,  0.1850, -0.0317, -0.1611, -0.2762,
           -0.0674,  0.4483,  0.1784, -0.1560,  0.6925,  0.0200,  0.0808,
           -0.0115, -1.0943]]], grad_fn=<UnsafeViewBackward0>))

### 1i. Our first residual connection
<a id='i'></a>
Here we'll normalize the output of our attention mechanism and then add it to our residual state

# What is residual state?

In [None]:
h += Xout
h.shape, h

(torch.Size([1, 5, 16]),
 tensor([[[ 9.0023e-01, -9.0545e-01, -3.2878e-01, -1.0072e+00, -8.1104e-01,
           -1.5710e+00,  1.7369e+00, -1.6196e+00,  3.4522e-02,  5.0097e-01,
           -2.6649e+00,  6.8783e-01,  4.7931e-01,  7.5197e-01, -6.0171e-01,
           -1.1232e+00],
          [ 8.5605e-01,  6.3933e-01, -5.3178e-01, -1.9641e-01, -3.2296e-01,
           -9.8177e-01, -8.5948e-01,  8.6021e-01, -2.3014e-01, -2.2408e-01,
           -2.1301e+00,  1.9960e+00,  1.7949e-01, -7.5255e-01, -6.5012e-01,
           -4.9730e-01],
          [ 8.5751e-01,  3.4367e-01, -7.0630e-03,  1.3496e+00,  1.4391e-01,
           -2.1857e+00, -8.0728e-01, -1.3049e+00, -7.5397e-04, -2.2443e-01,
           -4.4976e-01,  7.4226e-01,  8.6041e-01,  1.7402e-01,  3.0177e-01,
           -1.7248e+00],
          [ 9.2261e-01,  7.4765e-01, -5.3878e-01, -1.9884e-01, -4.0034e-01,
           -1.1674e+00, -9.7616e-01,  8.7485e-01, -2.6800e-01, -2.5943e-01,
           -2.1399e+00,  2.0778e+00,  2.6921e-01, -6.0532e-01, -

In [None]:
# then we'll normalize the current state of our residual for use in our MoE later
pre_ffwd_norm = RMSNorm(d)
h_normed = pre_ffwd_norm(h)
# so now we're working with x, which we'll use later for our next residual conenction, and x_normed which is used by our MoE MLP

### 1j. The SwiGLU Feedforward Network
<a id='j'></a>

Llama models have surprisingly not opted for a mixture of experts strategy which i was assuming they'd go for by now. Their feedforward networks use the SwiGLU activation which basically uses the activation function as a gate that dynamically determines what information gets through

In [None]:
# first we need to define our actual hidden dimension, which Llama's code does in an unnecessarily complicated manner
hidden_dim = 4 * d # usually i would designate a hyperparameter for this 4, but in llama's code it was just there
print(hidden_dim)
hidden_dim = int(2 * hidden_dim / 3)
print(hidden_dim)
multiple_of = 256 # their description of this was "make SwiGLU hidden layer size multiple of large power of 2"
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
print(hidden_dim)
# so basically this overly convoluted setup is designed to ensure that hidden_dim is a multiple of 256, likely for hardware efficiency reasons

64
42
256


In [None]:
up = nn.Linear(d, hidden_dim, bias=False)
gate = nn.Linear(d, hidden_dim, bias=False)
down = nn.Linear(hidden_dim, d, bias=False)

up_proj = up(h_normed)
print(up_proj.shape, up_proj)

gate_proj = F.silu(gate(h_normed))
print(gate_proj.shape, gate_proj)

ffwd_output = down(up_proj * gate_proj)
print(ffwd_output.shape, ffwd_output)

# and then do our final residual connection of this layer
out = h + ffwd_output
print(out.shape, out)

torch.Size([1, 5, 256]) tensor([[[-0.0257, -0.4757,  0.2491,  ..., -0.3339, -0.1644, -0.6306],
         [-0.1191,  0.1224, -0.4527,  ..., -0.6046, -0.8389,  0.7007],
         [ 0.0750, -0.1991,  0.2303,  ..., -0.3422,  0.0081,  0.2349],
         [-0.1029,  0.1211, -0.4618,  ..., -0.5956, -0.7935,  0.6934],
         [-0.1519, -0.3932, -0.9028,  ..., -0.0702, -0.6066, -0.2438]]],
       grad_fn=<UnsafeViewBackward0>)
torch.Size([1, 5, 256]) tensor([[[-0.1094,  0.3486, -0.2633,  ..., -0.0676,  0.1021, -0.1896],
         [ 0.0761, -0.0480, -0.2195,  ...,  0.0683, -0.1498, -0.1238],
         [-0.1262, -0.1389, -0.2595,  ..., -0.1223, -0.0722,  0.2790],
         [ 0.0541, -0.0701, -0.2385,  ...,  0.0567, -0.1545, -0.1101],
         [ 1.1919, -0.2752, -0.2087,  ..., -0.1679, -0.2504,  0.2954]]],
       grad_fn=<SiluBackward0>)
torch.Size([1, 5, 16]) tensor([[[ 0.1607,  0.0911, -0.1415,  0.0118,  0.0690,  0.1104,  0.0337,
           0.0371,  0.0367,  0.0695,  0.0220, -0.1304,  0.0840,  0.0495,

## Output:

So usually we'd run it back on steps 1e through 1j for however many layers our model has (Llama 3 8b uses 32) using different weight matrices but you get the point.


Since our current `out` is of the same shape that it would be if we were to do more layers, let's go ahead and just see what Llama's output mechanism looks like. It's nothing interesting though, just a linear layer. Notably they chose to use a separate linear layer rather than re-using the embedding layer as is relatively common

In [None]:
# first we norm the residual state
final_norm = RMSNorm(d)
out_normed = final_norm(out)

In [None]:
# then multiply by the linear layer to get our final output logits
final_output = nn.Linear(d, v, bias=False)
logits = final_output(out_normed).float()
logits.shape, logits

(torch.Size([1, 5, 10]),
 tensor([[[ 0.3093,  0.6462, -0.8326, -0.5472,  0.2947, -0.7590,  0.6102,
           -1.3271,  0.0114, -0.1131],
          [-0.7030,  0.2899, -0.0378,  0.1744,  0.7107, -0.4401,  0.2596,
           -0.6487, -0.0036,  1.3435],
          [-0.8501, -0.3519,  0.0763, -0.6963,  0.4976, -0.6761,  0.5076,
           -1.2458,  0.7464,  0.8849],
          [-0.7211,  0.2250, -0.0505,  0.1692,  0.7335, -0.4843,  0.2953,
           -0.7043,  0.0515,  1.3766],
          [-0.5662, -0.1354, -0.0931,  0.6919,  0.0322, -0.4636,  0.1797,
           -0.1944, -0.2474,  0.5938]]], grad_fn=<UnsafeViewBackward0>))

In [None]:
# softmax the logits to get the probability for each token's prediction across every token in the sequence
probs = F.softmax(logits, dim=-1)
probs

tensor([[[0.1348, 0.1888, 0.0430, 0.0573, 0.1329, 0.0463, 0.1822, 0.0262,
          0.1001, 0.0884],
         [0.0372, 0.1004, 0.0723, 0.0894, 0.1529, 0.0484, 0.0974, 0.0393,
          0.0749, 0.2879],
         [0.0377, 0.0620, 0.0951, 0.0439, 0.1450, 0.0448, 0.1465, 0.0254,
          0.1860, 0.2136],
         [0.0362, 0.0933, 0.0708, 0.0882, 0.1551, 0.0459, 0.1001, 0.0368,
          0.0784, 0.2951],
         [0.0534, 0.0822, 0.0858, 0.1880, 0.0972, 0.0592, 0.1127, 0.0775,
          0.0735, 0.1705]]], grad_fn=<SoftmaxBackward0>)

In [None]:
# Greedily decode the probabilities to get our final predicted indices
greedy_indices = torch.argmax(probs, dim=-1)
greedy_indices
# if we were performing inference rather than training, that final token in the list would be the one to show the user

tensor([[1, 9, 9, 9, 3]])