In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from tqdm import tqdm
from dataclasses import dataclass
from typing import Optional, Tuple
from safetensors.torch import load_file

In [None]:
params = 7000000000
quant = 16 # float16

def get_mem_requirements(params, quant):
    """
    Memory requirements for inference:
    4bytes per parameter, 32/quant bits per parameter (e.g. 32/16 for float16), 1.2 additional memory overhead factor
    For training ~ (params * 4bytes) * 4
    """
    return str((((params * 4)/(32/quant)) * 1.2) * 1e-9) + " GB"

get_mem_requirements(params, quant)

In [None]:
DEVICE = "cpu"


@dataclass
class Args:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    n_kv_heads: Optional[int] = None
    vocab_size: int = 32000 # -1
    multiple_of: int = 256
    ffn_dim_multiplier: Optional[int] = None
    norm_eps: float = 1e-5
    batch_size: int = 32
    seq_len: int = 2048
    device: str = DEVICE


args = Args()

In [None]:
def precomputed_freqs(dim: int, seqlen: int, device: str, theta: float = 10000.0):
    assert dim % 2 == 0, "Dim must be even"

    I = torch.arange(0, dim, 2, dtype=torch.float32, device=device)
    theta = 1.0 / (theta ** (I / dim))
    m = torch.arange(seqlen, device=device).float()
    freqs = torch.outer(m, theta)
    freqs = torch.polar(torch.ones_like(freqs), freqs)
    return freqs


def rope(x: torch.Tensor, freqs: torch.Tensor):
    x_cmplx = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2))
    freqs = freqs.unsqueeze(0).unsqueeze(2)
    x_rotated = x_cmplx * freqs
    x_rotated = torch.view_as_real(x_rotated).reshape(*x.shape)
    return x_rotated.type_as(x).to(x.device)


def rep_tensor(x: torch.Tensor, n_rep: int):
    if n_rep == 1:
        return x
    else:
        return torch.repeat_interleave(x, dim=2, repeats=n_rep)


class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float=1e-6):
        super().__init__()
        self.w = nn.Parameter(torch.ones(dim))
        self.eps = eps
    def forward(self, x):
        return (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)) * self.w
        

class SelfAttention(nn.Module):
    def __init__(self, args: Args):
        super().__init__()
        self.dim = args.dim
        self.n_q_heads = args.n_heads
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        assert self.n_q_heads % self.n_kv_heads == 0, "n_q_heads must be divisible by n_kv_heads"
        self.n_rep = self.n_q_heads // self.n_kv_heads
        self.h_size = args.dim // args.n_heads

        self.wq = nn.Linear(args.dim, self.n_q_heads * self.h_size, bias=False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.h_size, bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.h_size, bias=False)
        self.wo = nn.Linear(args.dim, args.dim, bias=False)

        self.k_cache = torch.zeros((args.batch_size, args.seq_len, self.n_kv_heads, self.h_size), device=DEVICE)
        self.v_cache = torch.zeros((args.batch_size, args.seq_len, self.n_kv_heads, self.h_size), device=DEVICE)

    
    def forward(self, x: torch.Tensor, start_pos: int, freqs: torch.Tensor, mask: Optional[torch.Tensor]=None):
        B, T, D = x.shape
        assert D == self.dim, f"x.shape[2] --> dim mismatch: {D} != {self.dim}"
        assert D == self.n_q_heads * self.h_size, f"x.shape[2] --> dim mismatch: {D} != {self.n_q_heads} * {self.h_size}"

        q = self.wq(x) # B, T, n_q_head * hsize
        _k = self.wk(x) # B, T, n_kv_heads * hsize
        _v = self.wv(x) # B, T, n_kv_heads * hsize

        q = q.view(B, T, self.n_q_heads, self.h_size)
        _k = _k.view(B, T, self.n_kv_heads, self.h_size)
        _v = _v.view(B, T, self.n_kv_heads, self.h_size)
        
        # apply RoPE, same shape
        q = rope(q, freqs)
        _k = rope(_k, freqs)

        # kv - cache
        self.k_cache[:B, start_pos :start_pos + T] = _k 
        self.v_cache[:B, start_pos :start_pos + T] = _v

        # B, start_pos + T, n_kv_heads, hsize
        k = self.v_cache[:B, :start_pos + T]
        # B, start_pos + T, n_kv_heads, hsize
        v = self.v_cache[:B, :start_pos + T]

        # B, start_pos + T, n_kv_heads, hsize --> B, start_pos + T, n_rep * n_kv_heads, hsize
        # n_reps * n_kv_heads = n_q_heads
        k, v = rep_tensor(k, self.n_rep), rep_tensor(v, self.n_rep)
        
        q = q.transpose(1, 2) # B, n_q_heads, T, hsize
        k = k.transpose(1, 2) # B, n_q_heads, start_pos + T, hsize
        v = v.transpose(1, 2) # B, n_q_heads, start_pos + T, hsize

        # B, n_q_heads, [(T , hsize) @ (hsize, start_pos + T)] --> B, n_q_heads, T, start_pos + T
        attn = torch.matmul(q, k.transpose(-2, -1)) * math.sqrt(self.h_size)

        if mask is not None:
            attn += mask
        
        # B, n_q_heads, T, start_pos + T
        attn = F.softmax(attn, dim=-1)

        # B, n_q_heads [(T, start_pos + T) @ (start_pos + T, hsize)]  --> B, n_q_heads, T, hsize
        attn = torch.matmul(attn, v)

        # B, nheads, T, hsize --> B, T, nheads * hsize
        attn = attn.transpose(1, 2).contiguous().view(B, T, self.n_q_heads * self.h_size)

        # B, T, D
        attn = self.wo(attn)

        return attn


class FFN(nn.Module):
    def __init__(self, dim: int, multiple_of: int, ffn_dim_multiplier: Optional[int]=None):
        super().__init__()

        hidden_dim = 4 * dim
        hidden_dim = int(2 * hidden_dim / 3)
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        
        # Round the hidden_dim to the nearest multiple of the multiple_of parameter
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        # print(f"hidden_dim: {hidden_dim}")

        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)
    
    def forward(self, x: torch.tensor):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class TransformerBlock(nn.Module):
    def __init__(self, args: Args):
        super().__init__()
        
        self.attention = SelfAttention(args)
        self.attn_norm = RMSNorm(args.dim)
        self.ffn = FFN(args.dim, args.multiple_of)
        self.ffn_norm = RMSNorm(args.dim)
    
    def forward(self, x: torch.Tensor, start_pos: int, freqs: torch.Tensor, mask: Optional[torch.Tensor]=None):
        x_ = x + self.attention(self.attn_norm(x), start_pos, mask, freqs)
        out = x_ + self.ffn(self.ffn_norm(x_))
        return out


class Transformer(nn.Module):
    def __init__(self, args: Args):
        super().__init__()
        self.args = args
        self.n_layers = args.n_layers
        self.vocab_size = args.vocab_size
        self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)

        self.layers = nn.ModuleList()
        for layer in range(self.n_layers):
            self.layers.append(TransformerBlock(args))
        
        self.norm = RMSNorm(args.dim, args.norm_eps)
        self.out = nn.Linear(args.dim, self.vocab_size, bias=False)

        self.freqs = precomputed_freqs(args.dim // args.n_heads, args.seq_len * 2, args.device)

    def forward(self, x: torch.Tensor, start_pos: int):
        B, T = x.shape
        x_embd = self.tok_embeddings(x) # B, T --> B, T, D

        #rope freqs
        freqs = self.freqs[start_pos:start_pos + T]
        mask = None
        if T > 1:
            mask = torch.full((T, T), float("-inf"), device=x.device)
            mask = torch.triu(mask, diagonal=1)
            mask = torch.hstack([torch.zeros((T, start_pos), device=x.device),mask]).type_as(x_embd)
        
        for layer in self.layers:
            x_embd = layer(x_embd, start_pos, mask, freqs)

        x_embd = self.norm(x_embd)

        return self.out(x_embd)

In [None]:
model = Transformer(args).half().to(args.device)
# model_dict = model.state_dict()

In [None]:
model.layers[0].attention.wq.weight.dtype