In [2]:
import torch
import torch.nn as nn
import math
from typing import Tuple, Optional

In [3]:
class LayerNorm(nn.Module):
    def __init__(self, features: int, eps = 1e-5):
        super().__init__()
        self.features = features
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(features))
        self.bias = nn.Parameters(torch.zeros(features))
    def forward(self, x):
        mean = x.mean(dim = -1, keepdim = True)
        std = x.std(dim = -1, keepdim = True) 
        return self.alpha * (x - mean) / (std + self.eps) + self.bias

In [None]:

def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str, theta: float = 10000.0):
    assert head_dim % 2 == 0, "Dimension must be divisible by 2"

    # build theta parameters
    # theta: Head_Dim / 2
    theta_numerator = torch.arange(0, head_dim, 2).float()
    theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device)
    # construct the positions "m" of shape seq_len
    m = torch.arange(seq_len, device = device)
    # multiply each theta by each position using outer product (seq_len, head_dim / 2)
    freqs = torch.outer(m, theta).float()
    # compute complex numbers in polar form 
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_complex

def apply_rotary_embeddings(x:Tensor, freqs_complex: torch.Tensor, device):
    """
    x: 
        the token to which we apply rotary embeddings
    freqs_complex: 
        has all thetas for all positions, but we only need m*theta for this x token
    """
    # take 2 consequtive dimensions and group them
    # (batch, seq_len, H, Head_Dim) -> (batch, seq_len, H, Head_Dim/2)
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    # (seq_len, Head_Dim/2) ->
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
    # (batch, seq_len, H, Head_Dim/2)
    x_rotated = x_complex * freqs_complex
    # (batch, seq_len, H, Head_Dim/2) -> (batch, seq_len, H, Head_Dim/2, 2)
    x_out = x_complex.view_as_real(x_rotated)
    x_out = x_out.reshape(*x.shape)
    return x_out.type_as(x).to(device)

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        # gamma
        self.weight = nn.Parameter(torch.ones(dim))
    
    def _norm(self, x: torch.Tensor):
        # x: (batch, seq_len, Dim)
        return x * torch.rsqrt(x.powr(2).mean(-1, keepdim=True) + self.eps)
    def forward(self, x: torch.Tensor):
        return self.weight * self._norm(x.float()).type_as(x)
        


# represents the entire model besides softmax
class Transformer(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        assert args.vocab_size != -1, "Need to set vocab size"
        self.args = args
        self.vocab_size = args.vocab_size
        self.n_layers = args.n_layers # the Nx
        self.tok_embeddings = nn.Embedding(self,vocab_size, args.dim)
        self.layers = nn.ModuleList()
        for _ in range(args.n_layers):
            self.layers.append(EncoderBlock(args))
        
        self.norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.output = nn.Linear(args.dim, self.vocab_size, bias = False)

        # recompute frequencies of rotary PE
        self.freqs_complex = precompute_theta_pos_frequencies(self.args.dim // self.args.n_heads, self.args.max_seq_len * 2, device = self.args.device)
    
    def forward(self, token: torch.Tensor, start_pos: int):
        # batch, seq_len
        batch_size, seq_len = token.shape
        assert seq_len == 1, "Only 1 token at a time can be processed"
        
        # batch, seq_len -> batch, seq_len, dim
        h = self.tok_embeddings(tokens)

        # retrieves m, theta pairs corresponding to the positions (start_pos, start_pos + seq_len)
        freqs_complex = self.freqs_complex[start_pos: start_pos + seq_len]

        # Consecutively apply all the encoder layers
        for layer in self.layers:
            h = layer(h, start_pos, freqs_complex)
        h = self.norm(h)
        output = self.output(h).float
        return output
