In [None]:
import torch 
from torch import nn
from dataclasses import dataclass 
from pathlib import Path 
import fire 
import json
from typing import Optional, Tuple, List 
from sentencepiec import SentencePieceProcessor

In [None]:
@dataclass 
class ModelArgs: 
    dim: int
    n_layers: int
    head_dim: int
    hidden_dim: int
    n_heads: int 
    n_kv_hads: int 
    sliding_window: int 
    norm_eps: float 
    vocab_size: int 
    
    max_batch_size: int = 0

In [None]:
def repeat_kv(key: torch.Tensor, values: torch.Tensor, repeasts: int):
    keys = torch.repeat_interleave(keys, repeasts=repeasts, dim=2)
    values = torch.repeat_interleave(values, repeasts=repeasts, dim=2)
    return keys, values

## Unserstanding rotary embeddings

In [None]:
def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    """
    freqs_cis: complex - (seq_len, head_dim / 2)
    x: complex - (bsz, seq_len, head_dim / 2)
    """
    ndim = x.ndim
    assert 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1]), ( 
        freqs_cis.shape,
        (x.shape[1], x.shape[-1]),
    )
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

In [None]:
def apply_rotary_emb(
        xq: torch.Tensor,
        xk: torch.Tensor,
        freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xq.shape[:-1], -1, 2))

    freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)

    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

## Understanding attention