In [None]:
import torch 
from torch import nn
from dataclasses import dataclass 
from pathlib import Path 
import fire 
import json
from typing import Optional, Tuple, List 
from sentencepiec import SentencePieceProcessor

In [None]:
@dataclass 
class ModelArgs: 
    dim: int
    n_layers: int
    head_dim: int
    hidden_dim: int
    n_heads: int 
    n_kv_hads: int 
    sliding_window: int 
    norm_eps: float 
    vocab_size: int 
    
    max_batch_size: int = 0

In [None]:
def repeat_kv(key: torch.Tensor, values: torch.Tensor, repeasts: int):
    keys = torch.repeat_interleave(keys, repeasts=repeasts, dim=2)
    values = torch.repeat_interleave(values, repeasts=repeasts, dim=2)
    return keys, values

## Unserstanding rotary embeddings

In [None]:
def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    """
    freqs_cis: complex - (seq_len, head_dim / 2)
    x: complex - (bsz, seq_len, head_dim / 2)
    """
    ndim = x.ndim
    assert 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1]), ( 
        freqs_cis.shape,
        (x.shape[1], x.shape[-1]),
    )
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

In [None]:
def apply_rotary_emb(
        xq: torch.Tensor,
        xk: torch.Tensor,
        freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xq.shape[:-1], -1, 2))

    freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)

    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

## Understanding attention

In [None]:
class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args 
        
        self.n_heads: int = args.n_heads
        self.n_kv_heads: int = args.n_kv_heads
        
        self.repeats =self.n_heads // self.n_kv_heads
        self.sliding_window = self.args.sliding_window
        
        self.scale = self.args.head_dim**-0.5
        
        # Instantiate the weight matrices for query, key, value and 
        self.wq = nn.Linear(
            args.dim, # 4096
            args.n_heads * args.head_dim, # 32 * 128 = 4.096
            bias=False
        )
        
        self.wk = nn.Linear(
            args.dim, # 4096
            args.n_heads * args.head_dim, # 32 * 128 = 4.096
            bias=False
        )
        
        self.wv = nn.Linear(
            args.dim, # 4096
            args.n_heads * args.head_dim, # 32 * 128 = 4.096
            bias=False
        )
        
        # weight output
        self.wo = nn.Linear(
            args.n_heads * args.head_dim,
            args.dim, 
            bias=False
        )
        
        self.cache_k = torch.emtpy(
            (
                args.max_batch_size,
                args.sliding_window,
                self.n_kv_heads,
                self.args.head_dim,
            ), dtype=torch.float16
        ).cuda()
        
        self.cache_v = torch.emtpy(
            (
                args.max_batch_size,
                args.sliding_window,
                self.n_kv_heads,
                self.args.head_dim,
            ), dtype=torch.float16
        ).cuda()
        
        
    def forward(
        self, x: torch.Tensor, freqs_cis: torch.Tensor, position: torch.Tensor, mask: Optional[torch.Tensor]
    ) -> torch.Tensor: 
        
        bsz, seqlen, _ = x.shape
        
        xq, xK, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bsz, seqlen, self.n_heads, self.args.head_dim)
        xk = xk.view(bsz, seqlen, self.n_kv_heads, self.args.head_dim)
        xv = xv.view(bsz, seqlen, self.n_kv_heads, self.args.head_dim)
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
        
        # The cache is a rotating buffer
        scatter_pos = (position[-self.sliding_window:] % self.sliding_window)[None, :, None, None]
        scatter_pos = scatter_pos.repeat(bsz, 1, self.n_kv_heads, self.args.head_dim)
        self.cache_k[:bsz].scatter_(dim=1, index=scatter_pos, src=xk[:, -self.sliding_window:])
        self.cache_v[:bsz].scatter_(dim=1, index=scatter_pos, src=xv[:, -self.sliding_window:])
        
        if position.shape[0] > 1:
            key, value = repeat_kv(xk, xv, self.repeats)
        else: 
            cur_pos = position[-1].item() + 1
            key, value = repeat_kv(self.cache_k[:bsz, :cur_pos, ...], self.cache_v[:bsz, :cur_pos, ...], self.repeats)
            
        query = xq.transpose(1, 2)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)
        
        scores = torch.matmul(query, key.transpose(2, 3)) * self.scale
        
        if mask is not None: 
            scores += mask[None, None, ...]
            
        scores = scores.float()
        scores = nn.functional.softmax(scores, dim=-1).type_as(query)
        output = torch.matmul(scores, value)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)