In [None]:
'''
Key concepts:
* RMS Norm
* KV-cache
* Rotary Positional embedding
* Grouped query attention
* SwiGLU

Notes:
- Layer normalization is done primarily to deal with `internal co-variate shift`, which is to avoid excessive changes in the distribution of the neuron's values due to drastic adjustments made by SGD. This slows down training.
- LayerNorm: a unique (mean, variance) pair for each sample.
- BatchNorm: a unique (mean, variance) pair for each feature.
- RMSNorm: Hypothesizes that the scaling is mostly responsible for the success of the normalization. The re-centering is thus not needed and the mean doesn't have to be calculated.
- Rotary positional embedding: Uses a slightly different representation (relative) for the analysis.
    - They are parametrically efficient to compute.
    - They show better invariance to permutations.
    - Better generalization.
    - Easy to implement.
- Rotary positional embedding is only applied to the query and keys.
- It is somewhere in between the absolute and relative positional encodings
- Rotary positional embedding is only applied after Q and K are multiplied by W.
- KV-cache is an important concept that can be applied to all Transformer models (only during INFERENCE)
- KV-cache caches the K and V values from previous steps. but our Q is only one vector. 
  Instead of doing a (N,F)x(N,F).T self-attention computation each time, we only do a 
  (1,F)x(N,F).T each time where (1,F) is the dimension of Q. This allows us to skip the re-calculation of the entire self-attention matrix each time.
- In GPUs, memory transfer is really expensive and about 20x slower than matrix multiplication operations.
- Total operations -> O(bnd**2)
- Total memory accesses -> O(bnd + bhn**2 + d**2)
- Memory access is not the botleneck here since (Total memory)/(Total operations) <<< 1
- multi-query attention: When using group query attention, we only calculate the multi-heads on the query. N-heads per query for each key and value.
- GROUPED multi-query attention: When using GMQA, we reduce the no. of heads for the K and V values but don't remove them completely. A good compromise between speed and quality.
- SwiGLU = x * (1/(1+exp(-beta * x))) -> Works due to divine benevolence
'''

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import dataclasses as dataclass
from typing import Optional

In [8]:
#@dataclass # TODO: fix error
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    n_kv_heads: Optional[int] = None
    vocab_size: int = 256
    multiple_of: int = 256
    ffm_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5
    
    # Needed for KV cache
    max_batch_size: int = 32
    max_seq_len: int = 2048
    
    device: str = 'cuda'
    
# Instantiate the ModelArgs class
args = ModelArgs()

In [13]:
def precompute_theta_positional_frequencies(head_dim: int, seq_len: int, device: str, theta: float=1000):
    # build_the_theta_parameters: According to the formula: theta_i = 10000^(-2*(i-1)/dim) for i = [1,2,3,4,dim//2]
    theta_numerator = torch.arange(0, head_dim//2 - 1)
    theta = 1 / (10000 ** ((2 * theta_numerator)//head_dim)).to(device)
    m = torch.arange(seq_len, device=device)
    freqs = torch.outer(m, theta).float()
    
    # According to the efficient formula in the paper
    # We can compute the complex numbers in the polar form c = R * exp(i * m * theta)
    # (batch_size, seq_len, head_dim/2) -> (batch_size, seq_len, head_dim/2)
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
    
    return freqs_complex

def apply_rotary_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor, device: str):
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
    # Need to convert x to complex form to be able to do the analysis
    x_rotated = x_complex * freqs_complex
    x_out = torch.view_as_real(x_rotated)
    x_out = x_out.reshape(*x.shape)
    
    return x_out.type_as(x).to(device)
    
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
        
    def _norm(self, x: torch.Tensor) -> torch.Tensor:
        return x * torch.rsqrt( x.pow(2).mean(axis=-1, keep_dim=True) + self.eps)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.weight * self._norm(x.float())
        
class SelfAttention:
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        self.device = args.device
        self.n_heads = self.args.n_heads
        self.n_kv_heads = self.n_heads if self.args.n_kv_heads is None else self.args.n_kv_heads
        self.n_q_heads = self.n_heads
        self.n_rep = self.n_q_heads//self.n_kv_heads
        self.head_dim = self.n_heads//self.n_q_heads
        
        self.to_q = nn.Linear(self.args.dim, self.args.dim)
        self.to_k = nn.Linear(self.args.dim, self.n_kv_heads * self.head_dim)
        self.to_v = nn.Linear(self.args.dim, self.n_kv_heads * self.head_dim)
        self.to_o = nn.Linear(self.args.dim, self.args.dim)
        
        self.cache_k = torch.zeros((self.args.max_batch_size, self.args.max_seq_len, self.args.n_kv_heads, self.args.dim))
        self.cache_v = torch.zeros((self.args.max_batch_size, self.args.max_seq_len, self.args.n_kv_heads, self.args.dim))
        
    def repeat_heads(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
        b, s, n_kv_heads, dim = x.shape
        return x.unsqueeze(2).repeat(n_rep).reshape(b, s, n_rep * n_kv_heads, dim)
        
    def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor):
        
        bs, s, d = x.shape
        
        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)
        
        q = q.view(bs, s, self.n_q_heads, self.head_dim)
        k = k.view(bs, s, self.n_kv_heads, self.head_dim)
        v = v.view(bs, s, self.n_kv_heads, self.head_dim)
        
        # Apply the rotary embeddings
        q = apply_rotary_embeddings(q, freqs_complex, self.args.device) 
        k = apply_rotary_embeddings(k, freqs_complex, self.args.device)
        
        # Keep track of the cache
        self.cache_k[:bs, start_pos:start_pos+s] = k
        self.cache_v[:bs, start_pos:start_pos+s] = v
        
        # This is the grouped query attention implementation.
        keys = repeat_heads(keys, self.n_rep)
        values = repeat_heads(values, self.n_rep)
        
        query = q.transpose(1,2)
        keys = keys.transpose(1,2)
        values = values.transpose(1,2)
        
        attention = F.softmax(torch.matmul(q, keys.transpose(2,3)), dim=-1)
        output = torch.matmul(attention, values)
        
        output = output.transpose(1,2).view(b, s, -1)
        
        return self.to_o(output) # Note: Only outputting one token!
        
        
        
class EncoderBlock:
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // self.n_heads
        
        self.attention = SelfAttention(args)
        self.feed_forward = FeedForward(args)
        
        # Normalization BEFORE self-attention
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
        
    def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor):
        h = self.attention(self.attention_norm(args.dim, args.norm_eps), start_pos, freqs_complex)
    
    
    

In [ ]:
class Transformer(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()
    
        self.args = args
        self.vocab_size = args.vocab_size
        self.n_layers = args.n_layers
        self.tok_embeddings = nn.Embedding(self.vocab_size, self.args.dim)
        
        self.layers = nn.ModuleList()
        for _ in range(self.n_layers):
            self.layers.append(EncoderBlock(self.args))
            
        self.norm = RMSNorm(self.args.dim, eps=self.args.norm_eps)
        self.output = nn.Linear(args.dim, self.vocab_size, bias=False)
        
        self.freqs_complex = precompute_theta_positional_frequencies(self.args.dim//self.args.n_heads, self.args.max_seq_len * 2, device=self.args.device)
        
    def forward(self, tokens: torch.Tensor, start_pos: int) -> torch.Tensor:
        
        # [batch_size, seq_len]
        batch_size, seq_len = tokens.shape
        assert seq_len == 1
        
        # [batch_size, seq_len, embed_dim]
        h = self.tok_embeddings(tokens)
        
        # Retrieve the [m, theta] pairs
        freq_complex = self.freqs_complex[start_pos:start_pos+seq_len]
        
        for layer in self.layers:
            h = layer(h, start_pos, freq_complex)
            
        h = self.norm(h)
        h = self.output(h).float()
        
        return h # Return the resulting tensor
                   
        

In [12]:
a = torch.Tensor([1,2])
b = torch.Tensor([1,2,3,4])

torch.outer(a,b)

tensor([[1., 2., 3., 4.],
        [2., 4., 6., 8.]])