In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from dataclasses import dataclass
from typing import Optional

In [4]:
# Class to represent the model parameters

@dataclass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32                 # For the query
    n_kv_heads: Optional[int] = None  # For the keys and values
    vocab_size: int = -1              # Will be set when we load the tokenizer
    multiple_of: int = 256            # These two will be used in feed forward network
    ffn_dim_multiplyer: Optional[float] = None
    norm_eps: float = 1e-5
    max_batch_size: int = 32
    max_seq_len: int = 2048
    device: str = None

In [8]:
# torch.arange(0, head_dim, 2).float()
# The above code will generate numbers as [0, 2, 4, 6, 8, ..., head_dim - 2]
# That list wont contain head dim

# Consider the formuala
# 2(i-1) for i = [1, 2, 3, 4, 5, ..., dim/2]
# The above formula would generate numbers as [((1-1)*2), ((2-1)*2), ((3-1)*2), ..., ((dim/2 - 1)*2)]
# Which is equivalent to [0, 2, 4, ..., dim-2]
# This can be achieved from the code torch.arange(0, dim, 2).float()

In [9]:
def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str, theta: float = 10000.0):

    assert head_dim % 2 == 0, "We can apply rotary position embeddings only to even dim token embeddings"

    #Now we build the theta parameters
    # According to the formula theta_i = 10000 ^ (-2(i-1)/dim), i = [1, 2, ..., dim/2] 
    # shape: (head_dim / 2)
    theta_numerator = torch.arange(0, head_dim, 2).float()
    # shape -> (head_dim/2)
    # The two in the above line is given by 2 the code for theta_numerator
    theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device)

    #Construct the positions ( the "m" parameter of the formula to convert the token embeddings into some form)
    # shape -> (Seq_Len)
    m = torch.arange(seq_len, device=device)

    # Multiply each theta by each position using the outer product
    # shape -> (Seq_Len) outer_product (Head_dim/2) -> (Seq_Len, Head_dim/2)

    freqs = torch.outer(m, theta)
    
    # We can compute complex numbers in the polar form = R * exp(i * m * theta), where R = 1 as follows
    # Shape -> (Seq_Len, head_dim/2) -> (Seq_Len, head_dim/2)
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs)

    return freqs_complex

In [None]:
# The class below would represent the transformer block which consists of all the parts of llama model except the softmax block

class Transformer(nn.Module):

    def __init__(self, args: ModelArgs):
        super().__init__()

        assert args.vocab_size != -1, 'we confirm that the vocab size is set'

        self.args = args
        self.vocab_size = args.vocab_size
        self.n_layers = args.n_layers
        self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)
        
        self.layers = nn.ModuleList()
        for _ in range(self.n_layers):
            self.layers.append(EncoderBlock(args))
        
        self.norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.output = nn.Linear(args.dim, self.vocab_size, bias=False)

        self.freqs_complex = precompute_theta_pos_frequencies(self.args.dim // self.args.n_heads, self.args.max_seq_len*2, device=self.args.device)

    def forward(self, tokens: torch.Tensor, start_pos: int):
        # (B, Seq_Len)
        batch_size, seq_len = tokens.shape
    
        assert seq_len == 1, "Only one token is passes to the model as we implement the KV cache"

        # (B, Seq_Len) -> (B, Seq_Len, dim)
        h = self.tok_embeddings(tokens)

        # Here we precompute some information related to positions that we then give to successive layers
        #Retrieve the pairs (m, theta) corresponding to the positions [start_pos, start_pos]
        freqs_complex = self.freqs_complex[start_pos:start_pos+seq_len]

        # Consecutively apply all the encoder layers
        for layer in self.layers:
            h = layer(h, start_pos, freqs_complex)
        
        # Then we apply RMSNorm to the output that is obtained from all the encoder blocks
        h = self.norm(h)

        output = self.output(h).float()

        return output


