In [2]:
import torch
import torch.nn as nn 
import torch.nn.functional as F 

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch.set_default_tensor_type(torch.BFloat16Tensor)

Llama Architecture Parameters

In [3]:
d_model = 4096
ffnn_dim = 14336
n_encoders = 32
n_heads = 32
n_kv_heads = 8
vocab_size = 128256
norm_eps = 1e-05
rope_theta = 500000.0
max_batch_size = 4
max_seq_length = 128
n_kv_heads_rep = n_heads // n_kv_heads
head_dim = d_model // n_heads

RMS Norm and Rotary Positional Embedding

In [7]:
class RMSNorm(nn.Module):
    def __init__(self, d_model, norm_eps):
        super().__init__()
        self.norm_esp = norm_eps
        self.weight = nn.Parameter(torch.ones(d_model))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim = True) + self.norm_esp)
    
    def forward(self, x):
        out = self._norm(x.float()).type_as(x)
        return out * self.weight

norm = RMSNorm(d_model, norm_eps)
print(norm(torch.randn(2, 8, d_model)).shape)


def percompute_freqs_cis(d_model, end, theta = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, d_model, 2)[: (d_model // 2)].float() / d_model))
    t = torch.arange(end, device=freqs.device, dtype = torch.float32)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis


def reshape_for_broadcast(freqs_cis, x):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def apply_rotary_emb(xq, xk, freqs_cis):
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)


ip1 = torch.randn(2, 8, n_heads, head_dim)
ip2 = torch.randn(2, 8, n_kv_heads, head_dim)

freqs_cis1 = percompute_freqs_cis(head_dim, max_seq_length*2, rope_theta)
freqs_cis1 = freqs_cis1[0:0+8]
out1, out2 = apply_rotary_emb(ip1, ip2, freqs_cis1)
print("-"*30)
print(out1.shape)
print(out2.shape)

torch.Size([2, 8, 4096])
------------------------------
torch.Size([2, 8, 32, 128])
torch.Size([2, 8, 8, 128])


SwiGLU Feed Forward Neural Network

In [9]:
class FFNN(nn.Module):
    def __init__(self, d_model, ffnn_dim):
        super().__init__()
        self.w1 = nn.Linear(d_model, ffnn_dim, bias=False)
        self.w3 = nn.Linear(d_model, ffnn_dim, bias=False)
        self.w2 = nn.Linear(ffnn_dim, d_model, bias=False)

    def forward(self, x):
        x = self.w2(F.silu(self.w1(x)) * self.w3(x))
        return x
    
feedfor =FFNN(d_model, ffnn_dim)
feedfor(torch.randn(2, 8, d_model)).shape

torch.Size([2, 8, 4096])

Group Query Attention

In [None]:
torch.einsum()