<a id="Intro"></a>
## LLama3 
    
    1. This is purely to understand the innards of Llama3
    2. Build the components for Llama3 with only torch tensors
        Implement forward, backward and update for the params
    3. Compare the loss, logits of this custom model, with a gpt2 model params

In [1]:
import torch
import torch.nn as nn
import numpy as np
from dataclasses import dataclass
import math
from typing import List
import train_llama3

  from torch.distributed.optim import ZeroRedundancyOptimizer


In [84]:
torch.set_printoptions(precision=7)
TEST = True
DEBUG = False

In [87]:
def decorate_print(func):
    def wrapper(*args, **kwargs):
        if 'debug' in kwargs and kwargs['debug']:
            print("\033[94m" + "*" * 50 + "\033[0m")
            func(*args, **kwargs)
            print("\033[94m" + "*" * 50 + "\033[0m")
    return wrapper

@decorate_print
def cprint(*args, debug=False):
    print("\t, ".join(str(arg) for arg in args))



#### Small examples to understand requires_grad functionality

In [88]:
"""
    x, y,z
    No path tracking
"""

x = torch.ones((2, 3)).float()
y = x**2
z = y.sum()

do = torch.ones_like(z)
try:
    z.backward(do)
except Exception as e:
    print("There is no graph to back prop !!!!")
    print(e)

"""
    x, y(with grad) -> z
"""

y.requires_grad = True
z = y.sum()
try:
    z.backward()
    print("\nThis will work. y.grad will have a value")
    print(y.grad)
except:
    pass

try:
    y.backward(torch.ones_like(y), inputs=[x])
except Exception as e:
    print("\nCannot backprop into a node that does not require grad")

"""
    x(with grad) -> y(with grad : does not retain grad since it is an intermediate node) -> z
"""
x = torch.ones((2, 3), requires_grad=True).float()
y = x**2
z = y.sum()
do = torch.ones_like(z) + 0.5
z.backward(do, inputs=[x])
print(x.grad, y.grad is None)

"""
    x(with grad) -> y(with grad : does not retain grad since it is an intermediate node. Must explicitl retain grad) -> z
"""
x = torch.ones((2, 3), requires_grad=True).float()
y = x**2
z = y.sum()
y.retain_grad()
do = torch.ones_like(z) + 0.5
z.backward(do)
x.grad, y.grad

There is no graph to back prop !!!!
element 0 of tensors does not require grad and does not have a grad_fn

This will work. y.grad will have a value
tensor([[1., 1., 1.],
        [1., 1., 1.]])

Cannot backprop into a node that does not require grad
tensor([[3., 3., 3.],
        [3., 3., 3.]]) True


  print(x.grad, y.grad is None)


(tensor([[3., 3., 3.],
         [3., 3., 3.]]),
 tensor([[1.5000000, 1.5000000, 1.5000000],
         [1.5000000, 1.5000000, 1.5000000]]))

## LlamaConfig

In [89]:
@dataclass
class LlamaConfig:
    version: str = "3.1"
    block_size: int = 10
    vocab_size: int = 20
    n_layer: int = 8
    n_head: int = 8
    n_kv_head: int = 2
    n_embd: int = 32
    ffn_dim_multiplier: float = 1.3
    multiple_of: int = 32
    norm_eps: float = 1e-5
    rope_theta: float = 500000.0
    use_scaled_rope: bool = False
    max_gen_batch_size: int = 4
    use_kv: bool = True
    n_rep = None
    flash: bool = False  # use flashattention?
    T: int = 5  # number of tokens in the forward pass
    B: int = 4

    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            if hasattr(self, k):
                setattr(self, k, v)
        assert self.n_kv_head <= self.n_head
        assert self.n_head % self.n_kv_head == 0
        assert self.n_embd % self.n_head == 0
        self.n_rep = self.n_head // self.n_kv_head

In [9]:
config0 = LlamaConfig(
    block_size=64,
    vocab_size=20,
    n_layer=0,
    n_head=8,
    n_kv_head=4,
    n_embd=256,
    ffn_dim_multiplier=1.3,
    multiple_of=64,
    T=32,  # number of tokens in the forward pass
    B=4,
)

config1 = LlamaConfig(
    block_size=64,
    vocab_size=20,
    n_layer=1,
    n_head=8,
    n_kv_head=4,
    n_embd=256,
    ffn_dim_multiplier=1.3,
    multiple_of=64,
    T=32,  # number of tokens in the forward pass
    B=4,
)

config2 = LlamaConfig(
    block_size=64,
    vocab_size=20,
    n_layer=2,
    n_head=8,
    n_kv_head=4,
    n_embd=256,
    ffn_dim_multiplier=1.3,
    multiple_of=64,
    T=32,  # number of tokens in the forward pass
    B=4,
)

B, T, C, nh, n_kv_head, OC = 2, 4, 8, 4, 2, 3
config3 = LlamaConfig(B=B, T=T, n_embd=C, n_head=nh, n_kv_head=n_kv_head, OC=OC)

large_config = LlamaConfig(
    block_size=64,
    vocab_size=1000,
    n_layer=2,
    n_head=8,
    n_kv_head=4,
    n_embd=512,
    ffn_dim_multiplier=1.3,
    multiple_of=64,
    T=32,  # number of tokens in the forward pass
    B=4,
)

configs = [config0, config1, config2, config3, large_config]
print(len(configs))

5


## forward + backward functions
    1. forward and backward function for the different modules of llama3
        state stores the variables needed for the backward pass
    2. Toggle TEST to enable/disable the tests for the components

### encoder

In [98]:
def encoder_forward(wte: torch.Tensor, tokens: torch.Tensor, state=[]):
    # state: store the vars needed for the backward pass
    # wte : V,C
    # tokens : B,T (tokens ids in range(V)) -> index into wte
    # out : B,T,C
    state.extend([tokens, wte.shape])
    return wte[tokens]


def encoder_back(dout: torch.Tensor, state=[]):
    # dout : B,T,C
    # dwte : V,C
    tokens = state[0].view(-1)  # B,T
    dout = dout.view(-1, dout.shape[-1])
    dwte = torch.zeros(state[1])

    # loop version --- is there a faster way to do this ????
    for i, t in enumerate(tokens):
        dwte[t] += dout[i]
    return dwte


if TEST:  # enable/disable running TEST for this component
    B, T, C, V = 2, 5, 5, 90
    enc = nn.Embedding(V, C)

    tokens = torch.randint(V, (B, T))
    assert tokens.shape == (B, T)
    enco1 = enc(tokens)
    state = []
    enco2 = encoder_forward(enc.weight, tokens, state)
    assert enco1.shape == (B, T, C)
    assert torch.allclose(enco1, enco2)
    assert torch.allclose(state[0], tokens)

    # backward
    dout = torch.randn_like(enco2)
    enco1.backward(dout, inputs=[enc.weight])

    dwte = encoder_back(dout, state=state)
    assert torch.allclose(dwte, enc.weight.grad)
    print("PASS")

PASS


In [9]:
dwte = torch.zeros(V, C).view(V, C)
# torch.index_select(dwte,0, tokens.view(-1)) = dout
torch.index_select(dwte, 0, tokens.view(-1)) + dout.view(-1, C)
# dwte

NameError: name 'V' is not defined

### rms

In [99]:
def rms_forward(norm_eps, inp: torch.Tensor, weight: torch.Tensor, state=[]):
    # inp (B, T, dim)
    # weight (dim)
    norm = inp / torch.sqrt(inp.pow(2).mean(-1, keepdim=True) + norm_eps)  # B,T,dim
    state.extend([inp, norm, norm_eps])
    return weight * norm


def rms_backward(dout: torch.Tensor, weight: torch.Tensor, state=[]):
    # dout : B,T,C
    # dweight : C
    # dinp : B,T,C
    B, T, C = dout.shape
    inp, norm, norm_eps = state  # B,T,C
    rms = torch.sqrt(inp.pow(2).mean(-1, keepdim=True) + norm_eps)
    dweight = torch.sum(dout * norm, [0, 1])  # sum over the B,T dims
    dinp = torch.zeros_like(inp)

    # is there a better way of doing this ???? this is ugly
    dinp += dout * weight / rms * (1.0 - norm * norm / C)
    inp = inp.view(-1, C)
    dinp = dinp.view(-1, C)
    dout = dout.view(-1, C)
    norm = norm.view(-1, C)
    rms = rms.view(-1)
    for i in range(C):
        for j in range(C):
            if j == i:
                continue
            dinp[:, i] += (
                -dout[:, j] * weight[j] * inp[:, i] * inp[:, j] / (rms * rms * rms * C)
            )
    return dinp.view(B, T, C), dweight


if TEST:
    B, T, C, V = 2, 5, 6, 9
    inp = torch.rand(B, T, C)
    inp.requires_grad = True
    weight = torch.ones(C)
    rms = train_llama3.RMSNorm(C)
    o1 = rms(inp)
    dout = torch.randn_like(o1)
    o1.backward(dout)
    inp.detach_()

    state = []
    o2 = rms_forward(config1.norm_eps, inp, weight, state)
    assert torch.allclose(o1, o2, atol=1e-4), (o1, o2)

    # backward
    dinp, dw1 = rms_backward(dout, weight, state)
    assert torch.allclose(
        dw1, rms.weight.grad, atol=1e-3
    ), f"{dw1[:5]},{rms.weight.grad[:5]}"
    assert torch.allclose(
        dinp, inp.grad, atol=1e-3
    ), f"{dinp.view(-1)[:10]}, {inp.grad.view(-1)[:10]}"
    print("PASS")

PASS


### matmul

In [100]:
def matmul(inp: torch.Tensor, weight: torch.Tensor, state=[]):
    assert inp.shape[-1] == weight.shape[-2], (inp.shape, weight.shape)
    state.append(inp)
    return inp @ weight


def matmul_backward(dout: torch.Tensor, weight: torch.Tensor, state=[]):
    inp = state[0]
    if len(weight.shape) == 2:
        B, T, OC = dout.shape
        _, _, C = inp.shape
        assert weight.shape == (C, OC), (weight.shape, (C, OC))
        dweight = (inp.unsqueeze(-1) * dout.unsqueeze(-2)).sum([0, 1])
        dinp = dout @ weight.transpose(-2, -1)
        return dinp, dweight
    elif len(inp.shape) == 4 and len(weight.shape) == 4:
        B, nh, T, C = dout.shape
        _, _, _, hs = inp.shape
        dinp = dout @ weight.transpose(-2, -1)
        dweight = inp.view(B * nh, T, hs).transpose(-2, -1) @ dout.view(B * nh, T, C)
        return dinp, dweight.view(weight.shape)


if TEST:
    B, T, C, OC = 12, 5, 6, 900
    inp = torch.rand(B, T, C)
    inp.requires_grad = True
    mmw = nn.Linear(in_features=C, out_features=OC, bias=False)
    mm1 = mmw(inp)
    mmdout = torch.randn_like(mm1)
    mm1.backward(mmdout)

    inp1 = inp.detach()
    state = []
    mm2 = matmul(inp=inp1, weight=mmw.weight.T, state=state)
    assert torch.allclose(mm1, mm2, atol=1e-4)

    # backward
    mmdinp, mmdw = matmul_backward(mmdout, mmw.weight.T, state)

    assert torch.allclose(mmdinp, inp.grad, atol=1e-4), (
        mmdinp.view(-1)[:5],
        inp.grad.view(-1)[:5],
    )
    assert torch.allclose(mmdw, mmw.weight.grad.T, atol=1e-4), (
        mmdw.view(-1)[:5],
        mmw.weight.grad.T.reshape(-1)[:5],
    )

    # backward q.KT
    B, T, nh, hs = 2, 3, 4, 5
    q = torch.randn((B, nh, T, hs))
    q.requires_grad = True
    kt = torch.randn((B, nh, hs, T))
    kt.requires_grad = True
    mmo1 = q @ kt

    dout = torch.randn_like(mmo1)
    mmo1.backward(dout, inputs=[q, kt])
    q.detach_(), kt.detach_()

    state = []
    mmo2 = matmul(q, kt, state=state)  # state =[q]
    assert torch.allclose(mmo1, mmo2)
    mmdq, mmdkt = matmul_backward(dout, kt, state)

    assert torch.allclose(q.grad, mmdq), (q.grad.shape, mmdq.shape)
    assert torch.allclose(kt.grad, mmdkt)
    print("PASS")

PASS


### silu

In [101]:
def silu(inp: torch.Tensor, state=[]):
    state.append(inp)
    return inp / (1.0 + torch.exp(-inp))


def silu_backward(dout: torch.Tensor, state=[]):
    x = state[0]
    expx = torch.exp(-x)
    den = 1.0 / (1 + expx)
    return dout * den * (1 + x * expx * den)


if TEST:
    state = []
    inp = torch.rand(B, T, C)
    inp.requires_grad = True
    sin = nn.functional.silu(inp)
    sdout = torch.ones_like(inp)
    sin.backward(sdout, inputs=[inp])
    assert torch.allclose(sin, silu(inp, state=state), atol=1e-4)
    inp.detach_()

    # backward
    sdin = silu_backward(sdout, state=state)
    assert torch.allclose(sdin, inp.grad)
    print("PASS")

PASS


### feed forward

In [102]:
def ff(
    inp: torch.Tensor, w1: torch.Tensor, w2: torch.tensor, w3: torch.Tensor, state=[]
):
    h1 = matmul(inp, w1)
    h3 = matmul(inp, w3)
    h1silu = silu(h1)
    h1h3 = h1silu * h3
    state.extend([inp, h1, h3, h1silu, h1h3])
    return matmul(h1h3, w2)


# backward
def ff_backward(
    dout: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, w3: torch.Tensor, state=[]
):
    inp, h1, h3, h1silu, h1h3 = state
    dh1h3, dw2 = matmul_backward(dout=dout, weight=w2, state=[h1h3])
    dh1silu, dh3 = dh1h3 * h3, dh1h3 * h1silu
    dh1 = silu_backward(dout=dh1silu, state=[h1])
    dinp1, dw1 = matmul_backward(dout=dh1, weight=w1, state=[inp])
    dinp2, dw3 = matmul_backward(dout=dh3, weight=w3, state=[inp])
    return dinp1 + dinp2, dw1, dw2, dw3


if TEST:
    print(len(configs))
    for cnfg in configs:
        print("\nChecking: ", cnfg)
        B, T, C = cnfg.B, cnfg.T, cnfg.n_embd
        inp = torch.rand(B, T, C)
        inp.requires_grad = True
        ffn = train_llama3.MLP(cnfg)
        ffo1 = ffn.forward(inp)
        dout = torch.randn_like(ffo1)
        ffo1.backward(dout)
        inp1 = inp.detach()

        ffw1 = ffn.c_fc2.weight.transpose(0, 1)
        ffw2 = ffn.c_proj.weight.transpose(0, 1)
        ffw3 = ffn.c_fc.weight.transpose(0, 1)
        state = []
        ffo2 = ff(inp1, ffw1, ffw2, ffw3, state)
        assert torch.allclose(ffo1, ffo2, atol=1e-4)

        ffdin, ffdw1, ffdw2, ffdw3 = ff_backward(
            dout=dout, w1=ffw1, w2=ffw2, w3=ffw3, state=state
        )
        assert torch.allclose(ffdin, inp.grad, atol=1e-4)
        assert torch.allclose(ffdw1, ffn.c_fc2.weight.grad.T, atol=1e-4)
        assert torch.allclose(ffdw3, ffn.c_fc.weight.grad.T, atol=1e-4)
        assert torch.allclose(ffdw2, ffn.c_proj.weight.grad.T, atol=1e-4)
print("PASS")

5

Checking:  LlamaConfig(version='3.1', block_size=64, vocab_size=20, n_layer=0, n_head=8, n_kv_head=4, n_embd=256, ffn_dim_multiplier=1.3, multiple_of=64, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=32, B=4)

Checking:  LlamaConfig(version='3.1', block_size=64, vocab_size=20, n_layer=1, n_head=8, n_kv_head=4, n_embd=256, ffn_dim_multiplier=1.3, multiple_of=64, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=32, B=4)

Checking:  LlamaConfig(version='3.1', block_size=64, vocab_size=20, n_layer=2, n_head=8, n_kv_head=4, n_embd=256, ffn_dim_multiplier=1.3, multiple_of=64, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=32, B=4)

Checking:  LlamaConfig(version='3.1', block_size=10, vocab_size=20, n_layer=5, n_head=4, n_kv_head=2, n_embd=8, ffn_dim_multiplier=1.3, multiple_of=32, norm_eps=1e-05, rope_th

### repeat_kv

In [103]:
def repeat_kv(inp: torch.Tensor, n_rep: int, state=[]):
    # inp : B,T, n_kv_heads, dim
    state.append(inp)
    B, T, n_kv_head, dim = inp.shape
    if n_rep == 1:
        return inp
    # return inp[:,:,:, None,:].repeat(1,1,1,n_rep,1).view(B,T,n_kv_head*n_rep,dim)

    # use expand instead of repeat. we shud not be allocation new memory
    return (
        inp[:, :, :, None, :]
        .expand(B, T, n_kv_head, n_rep, dim)
        .reshape(B, T, n_kv_head * n_rep, dim)
    )


def repeat_kv_backward(dout: torch.Tensor, state=[]):
    # dout : (B,T,n_head, hs)
    # dint : (B,T, n_kv_head, hs)
    inp = state[0]
    B, T, n_kv_head, hs = inp.shape
    _, _, n_head, _ = dout.shape
    n_rep = n_head // n_kv_head
    return dout.view(B, T, n_kv_head, n_rep, hs).sum(-2)


if TEST:
    for cnfg in configs:
        B, T, n_kv_head, n_head, C = (
            cnfg.B,
            cnfg.T,
            cnfg.n_kv_head,
            cnfg.n_head,
            cnfg.n_embd,
        )
        n_rep = n_head // n_kv_head
        hs = C // n_head
        rkvi = torch.randn((B, T, n_kv_head, hs))
        rkvi.requires_grad = True
        rkvo1 = train_llama3.repeat_kv(rkvi, n_rep)
        dout = torch.randn_like(rkvo1)
        rkvo1.backward(dout, inputs=[rkvi])
        rkvi1 = rkvi.detach()

        state = []
        rkvo2 = repeat_kv(rkvi1, n_rep, state=state)
        assert torch.allclose(rkvo1, rkvo2, atol=1e-5)

        # backward
        rkvdo2 = repeat_kv_backward(dout, state)
        assert  torch.allclose(rkvi.grad, rkvdo2, atol=1e-5)
print("PASS")

PASS


### Rotary Positional Embeddings (RoPE)

In [122]:
def compute_freqs_cis(
    dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False
):
    # end is generally the block size >= T (input seq length)
    theta = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    pos = torch.arange(end)[:, None]
    rotations = pos * theta
    assert rotations.shape == (end, dim // 2)
    cos_theta = torch.cos(rotations)
    sin_theta = torch.sin(rotations)
    return cos_theta, sin_theta


def rpe_forward(inp: torch.Tensor, cos_theta, sin_theta):
    if len(inp.shape) == 3:
        B, T, C = inp.shape
    else:
        B, T, n, s = inp.shape
        cos_theta, sin_theta = cos_theta.unsqueeze(-2), sin_theta.unsqueeze(-2)

    cos_theta, sin_theta = cos_theta[:T], sin_theta[:T]
    cprint("cos: ", cos_theta.view(-1)[-15:], debug=DEBUG)
    cprint("sin: ", sin_theta.view(-1)[-15:], debug=DEBUG)
    x, y = inp[..., ::2], inp[..., 1::2]
    final_x = x * cos_theta - y * sin_theta
    final_y = x * sin_theta + y * cos_theta
    return torch.concat([final_x[..., None], final_y[..., None]], dim=-1).view(
        inp.shape
    )


def rpe_backward(dout: torch.Tensor, cos_theta, sin_theta):
    # dinp has the same shape as dout
    if len(dout.shape) == 3:
        B, T, C = dout.shape
    else:
        B, T, n, s = dout.shape
        cos_theta, sin_theta = cos_theta.unsqueeze(-2), sin_theta.unsqueeze(-2)
    cos_theta, sin_theta = cos_theta[:T], sin_theta[:T]
    dout_x, dout_y = dout[..., ::2], dout[..., 1::2]
    dx = dout_x * cos_theta + dout_y * sin_theta
    dy = dout_x * (-sin_theta) + dout_y * cos_theta
    return torch.concat([dx[..., None], dy[..., None]], dim=-1).view(dout.shape)


if  TEST:
    for cnfg in configs:
        B, T, C, maxT = cnfg.B, cnfg.T, cnfg.n_embd, cnfg.block_size
        cos, sin = compute_freqs_cis(C, maxT, cnfg.rope_theta)
        freqs_cis = train_llama3.precompute_freqs_cis(
            dim=C, end=T, theta=cnfg.rope_theta, use_scaled=False
        )

        # --- the 3d case
        x = torch.randn(B, T, C).float()
        x1 = torch.randn(B, T, C).float()
        x.requires_grad = True
        xq, _ = train_llama3.apply_rotary_emb(x, x1, freqs_cis)
        dout = torch.ones_like(x)
        xq.view(B, T, -1).backward(dout, inputs=[x])

        x1 = x.detach()
        fx = rpe_forward(x1, cos, sin)
        dx = rpe_backward(dout, cos, sin)
        assert torch.allclose(
            xq.view(B, T, -1), fx, atol=1e-4
        ), f"{xq.view(-1)[-10:], fx.view(-1)[-10:]}"
        assert torch.allclose(dx, x.grad, atol=1e-4)

        # --- the 4d case
        xq = torch.randn(B, T, 14, C)
        xq.requires_grad = True
        xqr1 = train_llama3.apply_rotary_emb(xq, xq, freqs_cis)[0]
        dout = torch.ones_like(xq)
        xqr1.backward(dout, inputs=[xq])

        xq1 = xq.detach()
        xqr2 = rpe_forward(xq1, cos, sin)
        dx = rpe_backward(dout, cos, sin)
        assert torch.allclose(xqr1, xqr2, atol=1e-4)
        assert torch.allclose(dx, xq.grad, atol=1e-4)
    print("PASS")

PASS


### softmax

In [105]:
def softmax_forward(inp: torch.Tensor, state=[]):
    inp = inp.detach().clone().requires_grad_(True)
    out = torch.nn.functional.softmax(inp, dim=-1)
    assert inp.requires_grad
    assert out.requires_grad
    state.extend([inp, out])  # inp, out form a graph
    return out.detach()


def softmax_backward(dout: torch.Tensor, state=[]):
    inp, out = state
    assert inp.requires_grad
    assert out.requires_grad
    out.backward(dout, inputs=[inp])
    grad = inp.grad.clone()
    inp.detach_()
    out.detach_()
    assert not inp.requires_grad
    assert not out.requires_grad
    return grad


def cross_entropy_forward(logits: torch.tensor, targets: torch.tensor, state=[]):

    B, T, V = logits.shape
    log_sm = torch.log(softmax_forward(logits, state))
    loss_ce = -torch.gather(log_sm, -1, targets.view(B, T, 1)).sum() / (B * T)
    state.append(targets)
    return loss_ce


def cross_entropy_backward(loss, state):
    logits, sm, targets = state
    B, T, V = logits.shape
    d_log_sm = torch.zeros_like(sm)
    d_log_sm.view(B * T, -1)[torch.arange(B * T), targets.view(-1)] = -1.0 / (B * T)
    d_log_sm /= state[1]
    return softmax_backward(d_log_sm, state=state[:2])


if TEST:
    for cnfg in configs:
        print("\nChecking: ", cnfg)
        B, T, C = cnfg.B, cnfg.T, cnfg.n_embd
        inp = torch.randn((B, T, C))
        state = []
        out = softmax_forward(inp, state)
        for tnsr in state:
            assert tnsr.requires_grad

        dout = torch.randn_like(out)
        dinp = softmax_backward(dout, state=state)
        assert dinp.shape == inp.shape

        B, T, C = cnfg.B, cnfg.T, cnfg.n_embd
        torch.manual_seed(1)
        logits_ce = torch.randn(B, T, C).requires_grad_(True)
        targets = torch.randint(C, (B, T))
        lossx2 = torch.nn.functional.cross_entropy(
            logits_ce.view(-1, logits_ce.size(-1)), targets.view(-1), ignore_index=-1
        )
        lossx2.backward()

        logits1 = logits_ce.detach()
        state = []
        lossx1 = cross_entropy_forward(logits1, targets, state)
        cprint(lossx1, lossx2, debug=DEBUG)
        dlogits1 = cross_entropy_backward(lossx1, state)
        assert torch.allclose(dlogits1, logits_ce.grad), (
            dlogits1.view(-1)[:10],
            logits_ce.grad.view(-1)[:10],
        )
print("PASSED")


Checking:  LlamaConfig(version='3.1', block_size=64, vocab_size=20, n_layer=0, n_head=8, n_kv_head=4, n_embd=256, ffn_dim_multiplier=1.3, multiple_of=64, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=32, B=4)

Checking:  LlamaConfig(version='3.1', block_size=64, vocab_size=20, n_layer=1, n_head=8, n_kv_head=4, n_embd=256, ffn_dim_multiplier=1.3, multiple_of=64, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=32, B=4)

Checking:  LlamaConfig(version='3.1', block_size=64, vocab_size=20, n_layer=2, n_head=8, n_kv_head=4, n_embd=256, ffn_dim_multiplier=1.3, multiple_of=64, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=32, B=4)

Checking:  LlamaConfig(version='3.1', block_size=10, vocab_size=20, n_layer=5, n_head=4, n_kv_head=2, n_embd=8, ffn_dim_multiplier=1.3, multiple_of=32, norm_eps=1e-05, rope_thet

### attention

In [106]:
def attention(
    config: LlamaConfig,
    inp: torch.Tensor,
    wq: torch.Tensor,
    wk: torch.Tensor,
    wv: torch.Tensor,
    wo: torch.Tensor,
    cos_theta,
    sin_theta,
    state=[],
):
    # inp : (B,T,C)
    # wk,wv : (c,kv_dim)
    # wq : c,c
    # wo : c,c
    state.append(inp)
    cprint("inp_: ", inp.reshape(-1)[-15:], debug=DEBUG)
    B, T, nh, n_kv_head, C = (
        config.B,
        config.T,
        config.n_head,
        config.n_kv_head,
        config.n_embd,
    )
    hs = C // nh

    kv_dim = hs * n_kv_head
    n_rep = nh // n_kv_head
    q = matmul(inp, wq).view(B, T, nh, hs)
    k = matmul(inp, wk).view(B, T, n_kv_head, hs)
    v = matmul(inp, wv).view(B, T, n_kv_head, hs)
    state.extend([q, k, v])

    cprint("q_: ", q.reshape(-1)[-15:], debug=DEBUG)
    cprint("k_: ", k.reshape(-1)[-15:], debug=DEBUG)
    cprint("v_: ", v.reshape(-1)[-15:], debug=DEBUG)
    # rope
    q = rpe_forward(q, cos_theta, sin_theta)
    k = rpe_forward(k, cos_theta, sin_theta)
    state.extend([q, k])  # this is after rope
    cprint("qrope_: ", q.reshape(-1)[-15:], debug=DEBUG)
    cprint("krope_: ", k.reshape(-1)[-15:], debug=DEBUG)
    k = repeat_kv(inp=k, n_rep=n_rep)  # b,t, n_head, hs
    v = repeat_kv(inp=v, n_rep=n_rep)  # b,t, n_head, hs
    cprint("krep_: ", k.reshape(-1)[-15:], debug=DEBUG)

    state.extend([k, v])  # this is the repeated version

    # (b,nh,t,c) @ (b,nh, c, t) -> (b, nh, t,t)
    qp = q.permute(0, 2, 1, 3).contiguous()  # b,nh, t,hs
    kT = k.permute(0, 2, 3, 1).contiguous()  # b,nh,hs,t -> 0,3,1,2
    qkT = matmul(qp, kT) / (math.sqrt(hs))

    mask = torch.tril(qkT)
    qkT = torch.where(mask == 0.0, -torch.inf, qkT)

    state.extend([qp, kT])
    scores = softmax_forward(qkT, state)  # torch.softmax(qkT, dim=-1) # b, nh, t,t
    cprint("scores_: ", scores.reshape(-1)[-15:], debug=DEBUG)
    assert not scores.requires_grad

    vp = v.permute(0, 2, 1, 3).contiguous()
    o = matmul(
        scores, vp
    )  # b, nh, t, hs -> matmul SHOUlD GENERALIZE for any dim, not just for nn.Linear ???
    cprint("o_: ", o.reshape(-1)[-15:], debug=DEBUG)
    assert not o.requires_grad
    o = o.permute(0, 2, 1, 3).reshape(B, T, C)  # b, t, c
    state.extend([o, vp])
    fo = o @ wo
    assert not fo.requires_grad
    cprint("y_: ", fo.reshape(-1)[-15:], debug=DEBUG)
    return fo


def attention_backward(
    dout: torch.Tensor,
    wq: torch.Tensor,
    wk: torch.Tensor,
    wv: torch.Tensor,
    wo: torch.Tensor,
    cos_theta,
    sin_theta,
    state=[],
):

    # dout: (B,T,C)
    inp, q, k, v, qr, kr, k_e, v_e, qp, kT, qkT, scores, o, vp = state
    for tnsr in [inp, q, k, v, qr, kr, k_e, v_e, qp, kT, o, vp]:
        assert not tnsr.requires_grad
    assert k.shape == v.shape
    B, T, nh, hs = q.shape
    _, _, n_kv_head, khs = k.shape
    assert hs == khs
    assert qkT.shape == (B, nh, T, T)
    n_rep = nh // n_kv_head

    do, dwo = matmul_backward(dout=dout, weight=wo, state=[o])  # do : b,t,c
    assert not do.requires_grad
    assert not dwo.requires_grad
    do = (
        do.view(B, T, nh, hs).permute(0, 2, 1, 3).contiguous().view(B, nh, T, hs)
    )  # do : b,nh, t, hs
    dscores, dvp = matmul_backward(
        dout=do, weight=vp, state=[scores.detach()]
    )  # do : b,t,c
    assert not dscores.requires_grad
    assert not dvp.requires_grad
    dqkT = softmax_backward(dscores, state=[qkT, scores])
    assert not dqkT.requires_grad
    for tnsr in [qkT, scores]:
        assert not tnsr.requires_grad

    # do we need this ?????? (think)
    # mask = torch.tril(qkT)
    # dqkT = torch.where(mask==0.0, 0.0, dqkT)
    dqp, dkT = matmul_backward(dqkT, kT, state=[qp])
    assert not dqp.requires_grad
    assert not dkT.requires_grad
    dq = dqp.permute(0, 2, 1, 3).contiguous() / (math.sqrt(hs))
    dke = dkT.permute(0, 3, 1, 2).contiguous() / (math.sqrt(hs))
    dve = dvp.permute(0, 2, 1, 3).contiguous()
    dk = repeat_kv_backward(dke, state=[k])
    dv = repeat_kv_backward(dve, state=[v])

    # rope backward
    dq = rpe_backward(dq, cos_theta, sin_theta)
    dk = rpe_backward(dk, cos_theta, sin_theta)
    dq = dq.view(B, T, hs * nh)
    dk = dk.view(B, T, hs * n_kv_head)
    dv = dv.view(B, T, hs * n_kv_head)
    for tnsr in [dq, dk, dv]:
        assert not tnsr.requires_grad
    dinq, dwq = matmul_backward(dout=dq, weight=wq, state=[inp])  # do : b,t,c
    dink, dwk = matmul_backward(dout=dk, weight=wk, state=[inp])  # do : b,t,c
    dinv, dwv = matmul_backward(dout=dv, weight=wv, state=[inp])  # do : b,t,c
    for tnsr in [dwq, dwk, dwv, dinq, dink, dinv]:
        assert not tnsr.requires_grad
    return dinq + dink + dinv, dwq, dwk, dwv, dwo


def init_attention(cnfg: LlamaConfig, cattn: train_llama3.CausalSelfAttention):
    # need to init wq, wk, wv, wo from cattn weights
    B, C, T, nh, n_kv_head, maxT = (
        cnfg.B,
        cnfg.n_embd,
        cnfg.T,
        cnfg.n_head,
        cnfg.n_kv_head,
        cnfg.block_size,
    )
    hs = C // nh
    kv_dim = hs * n_kv_head
    n_rep = nh // n_kv_head
    weight = cattn.c_attn.weight.detach().clone()  # has only the weights, no grads
    awq, awk, awv = torch.split(weight, [C, kv_dim, kv_dim], dim=0)
    awo = cattn.c_proj.weight.detach().clone()
    assert awo._cdata != cattn.c_proj.weight._cdata
    for tnsr in [awq, awk, awv, awo]:
        assert not tnsr.requires_grad
    return awq, awk, awv, awo


if TEST:
    for cnfg in configs:
        print("Checking: ", cnfg)

        B, C, T, nh, n_kv_head, maxT = (
            cnfg.B,
            cnfg.n_embd,
            cnfg.T,
            cnfg.n_head,
            cnfg.n_kv_head,
            cnfg.block_size,
        )
        hs = C // nh

        cos, sin = compute_freqs_cis(hs, maxT, theta=cnfg.rope_theta)
        assert not cos.requires_grad
        assert not sin.requires_grad

        freqs_cis = train_llama3.precompute_freqs_cis(
            dim=hs, end=T, theta=cnfg.rope_theta, use_scaled=False
        )
        assert not freqs_cis.requires_grad

        inp = torch.randn((B, T, C))
        inp.requires_grad = True
        cattn = train_llama3.CausalSelfAttention(cnfg)
        mask = torch.triu(
            torch.ones((T, T), device="cpu", dtype=torch.bool),
            diagonal=1,
        )
        cattno1 = cattn.forward(x=inp, freqs_cis=freqs_cis, mask=mask)

        ### backward
        dout = torch.randn_like(cattno1)
        cattno1.backward(dout)
        kv_dim = hs * n_kv_head
        n_rep = nh // n_kv_head

        inp1 = inp.detach()
        assert not inp1.requires_grad
        awq, awk, awv, awo = init_attention(cnfg, cattn)

        state = []
        cattno2 = attention(cnfg, inp1, awq.T, awk.T, awv.T, awo.T, cos, sin, state)
        assert len(state) == 14
        assert cattno1.shape == cattno2.shape
        assert torch.allclose(
            cattno1, cattno2, atol=1e-5
        ), f"""cattno1: {cattno1.view(-1)[:10]}
                                                            cattno2: {cattno2.view(-1)[:10]}"""
        print("\tForward pass matches")

        bawo = cattn.c_proj.weight.grad.T
        bawq, bawk, bawv = torch.split(
            cattn.c_attn.weight.grad, [C, kv_dim, kv_dim], dim=0
        )
        bawq, bawk, bawv = bawq.T, bawk.T, bawv.T

        dainp, dawq, dawk, dawv, dawo = attention_backward(
            dout, awq.T, awk.T, awv.T, awo.T, cos, sin, state
        )

        assert torch.allclose(dainp, inp.grad, atol=1e-4)
        assert torch.allclose(dawq, bawq, atol=1e-4)
        assert torch.allclose(dawk, bawk, atol=1e-4)
        assert torch.allclose(dawv, bawv, atol=1e-4)
        assert torch.allclose(dawo, bawo, atol=1e-4)
        print("        Backward pass matches\n")
        print("--------")
    print("PASS")

Checking:  LlamaConfig(version='3.1', block_size=64, vocab_size=20, n_layer=0, n_head=8, n_kv_head=4, n_embd=256, ffn_dim_multiplier=1.3, multiple_of=64, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=32, B=4)
	Forward pass matches
        Backward pass matches

--------
Checking:  LlamaConfig(version='3.1', block_size=64, vocab_size=20, n_layer=1, n_head=8, n_kv_head=4, n_embd=256, ffn_dim_multiplier=1.3, multiple_of=64, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=32, B=4)
	Forward pass matches
        Backward pass matches

--------
Checking:  LlamaConfig(version='3.1', block_size=64, vocab_size=20, n_layer=2, n_head=8, n_kv_head=4, n_embd=256, ffn_dim_multiplier=1.3, multiple_of=64, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=32, B=4)
	Forward pass matches
        Backward pass matches

----

### block

In [108]:
def block(
    config: LlamaConfig,
    inp: torch.Tensor,
    attn_norm_w: torch.Tensor,
    wq: torch.Tensor,
    wk: torch.Tensor,
    wv: torch.Tensor,
    wo: torch.Tensor,
    ff_norm_w: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w3: torch.Tensor,
    cos_theta,
    sin_theta,
    state,
):
    ln1x = rms_forward(config.norm_eps, inp, attn_norm_w, state)
    cprint("ln_1_: ", ln1x.view(-1)[-15:], debug=DEBUG)
    attn = attention(config, ln1x, wq, wk, wv, wo, cos_theta, sin_theta, state)
    cprint("attn_: ", attn.view(-1)[-15:], debug=DEBUG)
    h = inp + attn
    ln2x = rms_forward(config.norm_eps, h, ff_norm_w, state)
    cprint("ln_2_: ", ln2x.view(-1)[-15:], debug=DEBUG)
    o1 = ff(ln2x, w1, w2, w3, state)
    cprint("mlp_: ", o1.view(-1)[-15:], debug=DEBUG)
    h = h + o1
    cprint("f_res_: ", h.view(-1)[-15:], debug=DEBUG)
    return h


def block_backward(
    config: LlamaConfig,
    dout: torch.Tensor,
    attn_norm_w: torch.Tensor,
    wq: torch.Tensor,
    wk: torch.Tensor,
    wv: torch.Tensor,
    wo: torch.Tensor,
    ff_norm_w: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w3: torch.Tensor,
    cos_theta,
    sin_theta,
    state=[],
):
    (
        inp,
        attn_norm,
        norm_eps,
        a_inp,
        q,
        k,
        v,
        qr,
        kr,
        k_e,
        v_e,
        qp,
        kT,
        qkT,
        scores,
        o,
        vp,
        ff_inp,
        ff_norm,
        _,
        mlp_inp,
        h1,
        h3,
        h1silu,
        h1h3,
    ) = state
    dh1, do1 = dout, dout
    d_mlp_inp, dw1, dw2, dw3 = ff_backward(
        do1, w1, w2, w3, state=[mlp_inp, h1, h3, h1silu, h1h3]
    )
    d_ff_inp, d_ff_norm_w = rms_backward(
        d_mlp_inp, ff_norm_w, state=[ff_inp, ff_norm, norm_eps]
    )
    dh = dh1 + d_ff_inp
    dinp1, dattn = dh, dh
    d_a_inp, dwq, dwk, dwv, dwo = attention_backward(
        dh,
        wq,
        wk,
        wv,
        wo,
        cos_theta,
        sin_theta,
        state=[a_inp, q, k, v, qr, kr, k_e, v_e, qp, kT, qkT, scores, o, vp],
    )
    d_attn_norm_inp, d_attn_norm_w = rms_backward(
        d_a_inp, attn_norm_w, state=[inp, attn_norm, norm_eps]
    )
    return (
        dinp1 + d_attn_norm_inp,
        d_attn_norm_w,
        dwq,
        dwk,
        dwv,
        dwo,
        d_ff_norm_w,
        dw1,
        dw2,
        dw3,
    )

# init custom block from the llama3 block
def init_from_block(config: LlamaConfig, block: train_llama3.Block):

    B, T, nh, n_kv_head, C = (
        config.B,
        config.T,
        config.n_head,
        config.n_kv_head,
        config.n_embd,
    )
    hs = C // nh

    # attn norm
    ln_1 = block.ln_1.weight.detach().clone()
    assert not ln_1.requires_grad
    awq, awk, awv, awo = init_attention(config, block.attn)
    ln_2 = block.ln_2.weight.detach().clone()
    assert not ln_2.requires_grad
    w3 = block.mlp.c_fc.weight.detach().clone()
    w1 = block.mlp.c_fc2.weight.detach().clone()
    w2 = block.mlp.c_proj.weight.detach().clone()
    for tnsr in [w1, w2, w3]:
        assert not tnsr.requires_grad
    return ln_1, awq, awk, awv, awo, ln_2, w1, w2, w3


if  TEST:
    for cnfg in configs:
        print("\nChecking: ", cnfg)
        B, T, C, nh, n_kv_head, maxT = (
            cnfg.B,
            cnfg.T,
            cnfg.n_embd,
            cnfg.n_head,
            cnfg.n_kv_head,
            cnfg.block_size,
        )
        blk = train_llama3.Block(cnfg)
        ln_1, awq, awk, awv, awo, ln_2, w1, w2, w3 = init_from_block(cnfg, blk)
        # attn norm
        attn_norm_w = blk.ln_1.weight

        cattn = blk.attn
        mask = torch.triu(
            torch.ones((T, T), device="cpu", dtype=torch.bool),
            diagonal=1,
        )
        hs = C // nh
        cos, sin = compute_freqs_cis(hs, maxT, cnfg.rope_theta)
        freqs_cis = train_llama3.precompute_freqs_cis(
            dim=hs, end=T, theta=cnfg.rope_theta, use_scaled=False
        )

        kv_dim = hs * n_kv_head
        n_rep = nh // n_kv_head

        # ff norm
        ff_norm_w = blk.ln_2.weight

        # feed forward
        ffn = blk.mlp
        ffw1 = ffn.c_fc2.weight
        ffw2 = ffn.c_proj.weight
        ffw3 = ffn.c_fc.weight

        # output
        inp = torch.randn((B, T, C))
        inp.requires_grad = True
        blko1 = blk(x=inp, freqs_cis=freqs_cis, mask=mask)
        dout = torch.randn_like(inp)
        blko1.backward(dout)
        inp1 = inp.detach()
        inp1.requires_grad = False
        state = []
        
        blko2 = block(
            cnfg,
            inp1,
            ln_1,
            awq.T,
            awk.T,
            awv.T,
            awo.T,
            ln_2,
            w1.T,
            w2.T,
            w3.T,
            cos,
            sin,
            state,
        )
        assert torch.allclose(blko1, blko2, atol=1e-5)

        # backward
        dawq, dawk, dawv = torch.split(
            cattn.c_attn.weight.grad, [C, kv_dim, kv_dim], dim=0
        )
        dawq, dawk, dawv = dawq.T, dawk.T, dawv.T
        dawo = cattn.c_proj.weight.grad.T
        d_binp, d_attn_norm_w, dwq, dwk, dwv, dwo, d_ff_norm_w, dw1, dw2, dw3 = (
            block_backward(
                config3,
                dout,
                ln_1,
                awq.T,
                awk.T,
                awv.T,
                awo.T,
                ln_2,
                w1.T,
                w2.T,
                w3.T,
                cos,
                sin,
                state,
            )
        )
        assert torch.allclose(d_binp, inp.grad, atol=1e-4)
        assert torch.allclose(d_attn_norm_w, attn_norm_w.grad, atol=1e-4)
        assert torch.allclose(dwq, dawq, atol=1e-4)
        assert torch.allclose(dwk, dawk, atol=1e-4)
        assert torch.allclose(dwv, dawv, atol=1e-4)
        assert torch.allclose(dwo, dawo, atol=1e-4)
        assert torch.allclose(d_ff_norm_w, ff_norm_w.grad, atol=1e-4)
        assert torch.allclose(dw1, ffw1.grad.T, atol=1e-4)
        assert torch.allclose(dw2, ffw2.grad.T, atol=1e-4)
        assert torch.allclose(dw3, ffw3.grad.T, atol=1e-4)
    print("PASS")


Checking:  LlamaConfig(version='3.1', block_size=64, vocab_size=20, n_layer=0, n_head=8, n_kv_head=4, n_embd=256, ffn_dim_multiplier=1.3, multiple_of=64, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=32, B=4)

Checking:  LlamaConfig(version='3.1', block_size=64, vocab_size=20, n_layer=1, n_head=8, n_kv_head=4, n_embd=256, ffn_dim_multiplier=1.3, multiple_of=64, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=32, B=4)

Checking:  LlamaConfig(version='3.1', block_size=64, vocab_size=20, n_layer=2, n_head=8, n_kv_head=4, n_embd=256, ffn_dim_multiplier=1.3, multiple_of=64, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=32, B=4)

Checking:  LlamaConfig(version='3.1', block_size=10, vocab_size=20, n_layer=5, n_head=4, n_kv_head=2, n_embd=8, ffn_dim_multiplier=1.3, multiple_of=32, norm_eps=1e-05, rope_thet

## Adam Update

In [109]:
def adam_update(
    w: torch.tensor,
    dw: torch.tensor,
    m_dw: torch.tensor,
    var_dw: torch.tensor,
    learning_rate: float,
    beta1: float,
    beta2: float,
    eps: float,
    weight_decay: float,
    t: int,
):
    bias_correction1 = 1 - beta1**t
    bias_correction2 = 1 - beta2**t

    m_dw.mul_(beta1).add_(1 - beta1, dw)
    var_dw.mul_(beta2).addcmul_(1 - beta2, dw, dw)
    denom = (var_dw.sqrt() / math.sqrt(bias_correction2)).add_(eps)
    step_size = learning_rate / bias_correction1
    w.addcdiv_(-step_size, m_dw, denom)

## Classes
    Wraps the forward, backward functions into a single class
    Also update, to change the param values

In [111]:
class Encoder:
    def __init__(self, V: int, C: int, wte: torch.Tensor = None):
        self.wte = wte if wte is not None else torch.zeros(V, C)
        self.dwte = None
        self.m_dw, self.v_dw = None, None
        self.state = []

    def forward(self, tokens):
        return encoder_forward(self.wte, tokens, self.state)

    def backward(self, dout: torch.Tensor):
        self.dwte = encoder_back(dout=dout, state=self.state)
        self.state = []   # reset for the next forward pass

        return self.dwte

    def update(
        self,
        learning_rate: float,
        beta1: float,
        beta2: float,
        eps: float,
        weight_decay: float,
        iter: int,
    ):
        if self.v_dw is None or self.m_dw is None:
            self.m_dw = torch.zeros_like(self.dwte)
            self.v_dw = torch.zeros_like(self.dwte)
        adam_update(
            self.wte,
            self.dwte,
            self.m_dw,
            self.v_dw,
            learning_rate,
            beta1,
            beta2,
            eps,
            weight_decay,
            iter,
        )
        self.dwte = None


class RMS:
    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.config = config
        self.weight = torch.ones(config.n_embd)
        self.dweight = None
        self.state = []
        self.m_dw, self.v_dw = None, None

    def forward(self, inp: torch.Tensor):
        return rms_forward(self.config.norm_eps, inp, self.weight, self.state)

    def backward(self, dout: torch.Tensor):
        dinp, self.dweight = rms_backward(dout, self.weight, self.state)
        self.state = []
        return dinp

    def update(
        self,
        learning_rate: float,
        beta1: float,
        beta2: float,
        eps: float,
        weight_decay: float,
        iter: int,
    ):
        if self.m_dw is None or self.v_dw is None:
            self.m_dw = torch.zeros_like(self.weight)
            self.v_dw = torch.zeros_like(self.weight)
        adam_update(
            self.weight,
            self.dweight,
            self.m_dw,
            self.v_dw,
            learning_rate,
            beta1,
            beta2,
            eps,
            weight_decay,
            iter,
        )
        self.dweight = None


class Matmul:
    def __init__(self, isize: int, osize: int):
        super().__init__()
        self.weight = torch.zeros(isize, osize)
        self.state = []
        self.dweight = None
        self.m_dw, self.v_dw = None, None

    def forward(self, inp: torch.Tensor):
        return matmul(inp, self.weight, self.state)

    def backward(self, dout: torch.Tensor):
        dinp, self.dweight = matmul_backward(dout, self.weight, self.state)
        self.state = []
        return dinp

    def update(
        self,
        learning_rate: float,
        beta1: float,
        beta2: float,
        eps: float,
        weight_decay: float,
        iter: int,
    ):
        if self.m_dw is None or self.v_dw is None:
            self.m_dw = torch.zeros_like(self.weight)
            self.v_dw = torch.zeros_like(self.weight)
        adam_update(
            self.weight,
            self.dweight,
            self.m_dw,
            self.v_dw,
            learning_rate,
            beta1,
            beta2,
            eps,
            weight_decay,
            iter,
        )
        self.dweight = None


class Silu:
    def __init__(self):
        self.state = []

    def forward(self, inp: torch.Tensor):
        return silu(inp, self.state)

    def backward(self, dout: torch.Tensor):
        dinp = silu_backward(dout, self.state)
        self.state = []
        return dinp


class FeedForward:

    def __init__(self, config: LlamaConfig):
        self.state = []
        hidden_dim = 4 * config.n_embd
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if config.ffn_dim_multiplier is not None:
            hidden_dim = int(config.ffn_dim_multiplier * hidden_dim)
        hidden_dim = config.multiple_of * (
            (hidden_dim + config.multiple_of - 1) // config.multiple_of
        )
        OC = hidden_dim
        self.config = config
        C, hidden_dim = self.config.n_embd, OC
        self.w3 = torch.zeros(C, hidden_dim)  # c_fc
        self.w1 = torch.zeros(C, hidden_dim)  # c_fc2
        self.w2 = torch.zeros(hidden_dim, C)  # c_proj
        (
            self.dw1,
            self.m_dw1,
            self.v_dw1,
            self.dw2,
            self.m_dw2,
            self.v_dw2,
            self.dw3,
            self.m_dw3,
            self.v_dw3,
        ) = [None] * 9

    def forward(self, inp: torch.Tensor):
        return ff(inp, self.w1, self.w2, self.w3, self.state)

    def backward(self, dout: torch.Tensor):
        dinp, self.dw1, self.dw2, self.dw3 = ff_backward(
            dout, self.w1, self.w2, self.w3, self.state
        )
        self.state = []
        return dinp

    def update(
        self,
        learning_rate: float,
        beta1: float,
        beta2: float,
        eps: float,
        weight_decay: float,
        iter: int,
    ):
        if any(
            [
                t is None
                for t in [
                    self.m_dw1,
                    self.v_dw1,
                    self.m_dw2,
                    self.v_dw2,
                    self.m_dw3,
                    self.v_dw3,
                ]
            ]
        ):
            self.m_dw1 = torch.zeros_like(self.dw1)
            self.v_dw1 = torch.zeros_like(self.dw1)
            self.m_dw2 = torch.zeros_like(self.dw2)
            self.v_dw2 = torch.zeros_like(self.dw2)
            self.m_dw3 = torch.zeros_like(self.dw3)
            self.v_dw3 = torch.zeros_like(self.dw3)
        for w, dw, m_dw, v_dw in [
            (self.w1, self.dw1, self.m_dw1, self.v_dw1),
            (self.w2, self.dw2, self.m_dw2, self.v_dw2),
            (self.w3, self.dw3, self.m_dw3, self.v_dw3),
        ]:

            adam_update(
                w,
                dw,
                m_dw,
                v_dw,
                learning_rate,
                beta1,
                beta2,
                eps,
                weight_decay,
                iter,
            )
        self.dw1 = None
        self.dw2 = None
        self.dw3 = None


class Repeat:
    def __init__(self, config: LlamaConfig):
        self.state = []
        self.n_rep = config.n_rep

    def forward(self, inp: torch.Tensor):
        return repeat_kv(inp, self.n_rep, self.state)

    def backward(self, dout):
        dinp = repeat_kv_backward(dout, self.state)
        self.state = []
        return dinp


class Attention:
    
    def __init__(self, config, cos, sin):
        C, nh, n_kv_head = config.n_embd, config.n_head, config.n_kv_head
        kv_dim = (C // nh) * n_kv_head
        self.wq = torch.zeros(C, C)
        self.wk = torch.zeros(C, kv_dim)
        self.wv = torch.zeros(C, kv_dim)
        self.wo = torch.zeros(C, C)
        self.config = config
        self.state = []
        self.cos, self.sin = cos, sin

        (
            self.dwq,
            self.m_dwq,
            self.v_dwq,
            self.dwk,
            self.m_dwk,
            self.v_dwk,
            self.dwv,
            self.m_dwv,
            self.v_dwv,
            self.dwo,
            self.m_dwo,
            self.v_dwo,
        ) = [None] * 12

    def forward(self, inp: torch.Tensor):
        return attention(
            self.config,
            inp,
            self.wq,
            self.wk,
            self.wv,
            self.wo,
            self.cos,
            self.sin,
            self.state,
        )

    def backward(self, dout):
        dinp, self.dwq, self.dwk, self.dwv, self.dwo = attention_backward(
            dout, self.wq, self.wk, self.wv, self.wo, self.cos, self.sin, self.state
        )
        self.state = []
        return dinp

    def update(
        self,
        learning_rate: float,
        beta1: float,
        beta2: float,
        eps: float,
        weight_decay: float,
        iter: int,
    ):
        if any(
            [
                t is None
                for t in [
                    self.m_dwq,
                    self.v_dwq,
                    self.m_dwk,
                    self.v_dwk,
                    self.m_dwv,
                    self.v_dwv,
                    self.m_dwo,
                    self.v_dwo,
                ]
            ]
        ):
            self.m_dwq = torch.zeros_like(self.dwq)
            self.v_dwq = torch.zeros_like(self.dwq)
            self.m_dwk = torch.zeros_like(self.dwk)
            self.v_dwk = torch.zeros_like(self.dwk)
            self.m_dwv = torch.zeros_like(self.dwv)
            self.v_dwv = torch.zeros_like(self.dwv)
            self.m_dwo = torch.zeros_like(self.dwo)
            self.v_dwo = torch.zeros_like(self.dwo)
        for w, dw, m_dw, v_dw in [
            (self.wq, self.dwq, self.m_dwq, self.v_dwq),
            (self.wk, self.dwk, self.m_dwk, self.v_dwk),
            (self.wv, self.dwv, self.m_dwv, self.v_dwv),
            (self.wo, self.dwo, self.m_dwo, self.v_dwo),
        ]:
            adam_update(
                w,
                dw,
                m_dw,
                v_dw,
                learning_rate,
                beta1,
                beta2,
                eps,
                weight_decay,
                iter,
            )
        self.dwq = None
        self.dwk = None
        self.dwv = None
        self.dwo = None


class MyBlock:

    def __init__(self, config: LlamaConfig, cos, sin):
        self.ln_1 = RMS(config)
        self.attention = Attention(config, cos, sin)
        self.ln_2 = RMS(config)
        self.ff = FeedForward(config)
        self.config = config
        self.state = []

    def forward_(self, inp: torch.Tensor):
        h = inp + self.attention.forward(self.ln_1.forward(inp))
        return h + self.ff.forward(self.ln_2.forward(h))

    def forward(self, inp: torch.Tensor):
        ln1x = self.ln_1.forward(inp)
        self.state.extend(self.ln_1.state)
        cprint("ln_1_: ", ln1x.view(-1)[-15:], debug=DEBUG)
        attn = self.attention.forward(ln1x)
        self.state.extend(self.attention.state)
        cprint("attn_: ", attn.view(-1)[-15:], debug=DEBUG)
        h = inp + attn
        ln2x = self.ln_2.forward(h)
        self.state.extend(self.ln_2.state)
        cprint("ln_2_: ", ln2x.view(-1)[-15:], debug=DEBUG)
        mlp = self.ff.forward(ln2x)
        self.state.extend(self.ff.state)
        h = h + mlp
        cprint("f_res_: ", h.view(-1)[-15:], debug=DEBUG)
        return h

    def backward(self, dout):

        dh1, do1 = dout, dout
        d_mlp_inp = self.ff.backward(do1)
        d_ff_inp = self.ln_2.backward(d_mlp_inp)
        dh = dh1 + d_ff_inp
        dinp1, dattn = dh, dh
        d_a_inp = self.attention.backward(dh)
        d_attn_norm_inp = self.ln_1.backward(d_a_inp)
        return dinp1 + d_attn_norm_inp

    def update(
        self,
        learning_rate: float,
        beta1: float,
        beta2: float,
        eps: float,
        weight_decay: float,
        iter: int,
    ):
        self.ln_1.update(learning_rate, beta1, beta2, eps, weight_decay, iter)
        self.attention.update(learning_rate, beta1, beta2, eps, weight_decay, iter)
        self.ln_2.update(learning_rate, beta1, beta2, eps, weight_decay, iter)
        self.ff.update(learning_rate, beta1, beta2, eps, weight_decay, iter)


class Transformer:
    def __init__(self, config: LlamaConfig):
        self.config = config
        self.encoder = Encoder(V=config.vocab_size, C=config.n_embd)
        self.output = Matmul(isize=config.n_embd, osize=config.vocab_size)
        self.loss_state = []

        hs = config.n_embd // config.n_head
        self.cos, self.sin = compute_freqs_cis(hs, config.block_size, config.rope_theta)
        self.blocks = [
            MyBlock(config, self.cos, self.sin) for l in range(config.n_layer)
        ]
        self.ln_f = RMS(config)

    def forward(self, inp: torch.Tensor, targets: torch.Tensor = None):
        B, T = targets.shape
        h = self.encoder.forward(inp)
        cprint(f"encoder_: ", h.view(-1)[-15:], debug=DEBUG)

        for l in range(self.config.n_layer):
            cprint(l, debug=DEBUG)
            h = self.blocks[l].forward(h)

        h = self.ln_f.forward(h)
        cprint("ln_f_: ", h.view(-1)[-15:], debug=DEBUG)
        loss = None
        if targets is not None:
            h = self.output.forward(h)  # B,T,V
            # sm_h = torch.log(torch.softmax(h, -1))
            # loss = -torch.gather(sm_h, -1, targets.view(B, T, 1)).sum() / (B * T)
            loss = cross_entropy_forward(h, targets, self.loss_state)
            assert len(self.loss_state) == 3
        cprint("logits_: ", h.view(-1)[-15:], debug=DEBUG)
        if targets is None:
            h = h[:, [-1], :]

        return h, loss

    def backward(self, dout: torch.Tensor, is_loss=False):
        if not is_loss:
            # dout : b,1,v (logits)
            dh = self.output.backward(dout)
            drms = self.ln_f.backward(dh)
            dbo = drms
            for i in range(self.config.n_layer - 1, -1, -1):
                dbo = self.blocks[i].backward(dbo)

            self.encoder.backward(dbo)
        else:
            # dout is from the scalar loss
            dh = cross_entropy_backward(None, self.loss_state)
            self.backward(dh, False)
        self.loss_state = []

    def update(
        self,
        learning_rate: float,
        beta1: float,
        beta2: float,
        eps: float,
        weight_decay: float,
        iter: int,
    ):
        self.encoder.update(learning_rate, beta1, beta2, eps, weight_decay, iter)
        self.ln_f.update(learning_rate, beta1, beta2, eps, weight_decay, iter)
        self.output.update(learning_rate, beta1, beta2, eps, weight_decay, iter)

        for block in self.blocks:
            block.update(learning_rate, beta1, beta2, eps, weight_decay, iter)

## Utils
    -- Help to load llama3 components (in train_llama3.py) into our custom component
    -- Help with testing

In [112]:
def init_myblock_from_block(
    block: train_llama3.Block, config: LlamaConfig, assert_shape=True
):

    B, T, nh, n_kv_head, C = (
        config.B,
        config.T,
        config.n_head,
        config.n_kv_head,
        config.n_embd,
    )
    hs = C // nh
    cos, sin = compute_freqs_cis(hs, config.block_size, config.rope_theta)
    myBlock = MyBlock(config, cos=cos, sin=sin)

    ln_1, awq, awk, awv, awo, ln_2, w1, w2, w3 = init_from_block(config, block) # clone the params in llama3 block
    for tnsr in [ln_1, awq, awk, awv, awo, ln_2, w1, w2, w3]:
        assert not tnsr.requires_grad

    # attn norm
    if assert_shape:
        assert myBlock.ln_1.weight.shape == block.ln_1.weight.shape
    myBlock.ln_1.weight[...] = ln_1  # block.ln_1.weight.detach()
    assert not myBlock.ln_1.weight.requires_grad

    # attn
    cattn = block.attn
    mask = torch.triu(
        torch.ones((T, T), device="cpu", dtype=torch.bool),
        diagonal=1,
    )
    hs = C // nh
    kv_dim = hs * n_kv_head
    n_rep = nh // n_kv_head

    if assert_shape:
        assert myBlock.attention.wq.shape == awq.T.shape
    myBlock.attention.wq = awq.T  # awq.detach().T

    assert not myBlock.attention.wq.requires_grad

    if assert_shape:
        assert myBlock.attention.wk.shape == awk.T.shape
    myBlock.attention.wk = awk.T  # awk.detach().T
    assert not myBlock.attention.wk.requires_grad

    if assert_shape:
        assert myBlock.attention.wv.shape == awv.T.shape
    myBlock.attention.wv = awv.T  # awv.detach().T
    assert not myBlock.attention.wv.requires_grad

    if assert_shape:
        assert myBlock.attention.wo.shape == awo.T.shape
    myBlock.attention.wo = awo.T  # awo.detach().T
    assert not myBlock.attention.wo.requires_grad

    if assert_shape:
        assert myBlock.ln_2.weight.shape == block.ln_2.weight.shape
    oldptr = myBlock.ln_2.weight._cdata
    myBlock.ln_2.weight = ln_2  # block.ln_2.weight.detach()
    assert not myBlock.ln_2.weight.requires_grad

    if assert_shape:
        assert myBlock.ff.w3.shape == block.mlp.c_fc.weight.transpose(0, 1).shape, (
            myBlock.ff.w3.shape,
            block.mlp.c_fc.weight.transpose(0, 1).shape,
        )
    myBlock.ff.w3 = w3.T  # block.mlp.c_fc.weight.detach().transpose(0,1)
    assert not myBlock.ff.w3.requires_grad

    if assert_shape:
        assert myBlock.ff.w1.shape == block.mlp.c_fc2.weight.transpose(0, 1).shape
    myBlock.ff.w1 = w1.T  # block.mlp.c_fc2.weight.detach().transpose(0,1)
    assert not myBlock.ff.w1.requires_grad

    if assert_shape:
        assert myBlock.ff.w2.shape == block.mlp.c_proj.weight.transpose(0, 1).shape
    myBlock.ff.w2 = w2.T  # block.mlp.c_proj.weight.detach().transpose(0,1)

    assert not myBlock.ff.w2.requires_grad
    return myBlock

### test_block_forward, test_block_backward

In [113]:
def test_block_forward(blk: train_llama3.Block, myblk: MyBlock, bdinp=None, inp=None):
    # ff
    assert torch.allclose(myblk.ff.w1, blk.mlp.c_fc2.weight.T, atol=1e-4), (
        myblk.ff.w1.reshape(-1)[:10],
        blk.mlp.c_fc2.weight.grad.reshape(-1)[:10],
    )
    assert torch.allclose(myblk.ff.w2, blk.mlp.c_proj.weight.T, atol=1e-4), (
        myblk.ff.w2.reshape(-1)[:10],
        blk.mlp.c_proj.weight.T.reshape(-1)[:10],
    )
    assert torch.allclose(myblk.ff.w3, blk.mlp.c_fc.weight.T, atol=1e-4), (
        myblk.ff.w3.reshape(-1)[:10],
        blk.mlp.c_fc.weight.T.reshape(-1)[:10],
    )

    ## ff norm
    assert torch.allclose(myblk.ln_2.weight, blk.ln_2.weight, atol=1e-4), (
        myblk.ln_2.weight.reshape(-1)[:10],
        blk.ln_2.weight.reshape(-1)[:10],
    )

    # attention
    wq, wk, wv, wo = (
        myblk.attention.wq.T,
        myblk.attention.wk.T,
        myblk.attention.wv.T,
        myblk.attention.wo.T,
    )
    attention_weight = torch.concat([wq, wk, wv], dim=0)
    # assert attention_weight.shape == (C+2*kv_dim, C)
    assert torch.allclose(attention_weight, blk.attn.c_attn.weight, atol=1e-4), (
        attention_weight.reshape(-1)[:10],
        blk.attn.c_attn.weight.reshape(-1)[:10],
    )
    assert torch.allclose(wo, blk.attn.c_proj.weight, atol=1e-4), (
        wo.reshape(-1)[:10],
        blk.attn.c_proj.weight.reshape(-1)[:10],
    )

    # attn norm
    assert torch.allclose(myblk.ln_1.weight, blk.ln_1.weight, atol=1e-4), (
        myblk.ln_1.weight.view(-1)[:10],
        blk.ln_1.weight.view(-1)[:10],
    )


def test_block_backwards(blk, myblk, bdinp=None, inp=None):
    # ff
    assert torch.allclose(myblk.ff.dw1, blk.mlp.c_fc2.weight.grad.T, atol=1e-4), (
        myblk.ff.dw1.view(-1)[:10],
        blk.mlp.c_fc2.weight.grad.T.reshape(-1)[:10],
    )
    assert torch.allclose(myblk.ff.dw2, blk.mlp.c_proj.weight.grad.T, atol=1e-4), (
        myblk.ff.dw2.view(-1)[:10],
        blk.mlp.c_proj.weight.grad.T.reshape(-1)[:10],
    )
    assert torch.allclose(myblk.ff.dw3, blk.mlp.c_fc.weight.grad.T, atol=1e-4), (
        myblk.ff.dw3.view(-1)[:10],
        blk.mlp.c_fc.weight.grad.T.reshape(-1)[:10],
    )

    ## ff norm
    assert torch.allclose(myblk.ln_2.dweight, blk.ln_2.weight.grad, atol=1e-4), (
        myblk.ln_2.dweight.view(-1)[:10],
        blk.ln_2.weight.grad.view(-1)[:10],
    )

    # attention
    dwq, dwk, dwv, dwo = (
        myblk.attention.dwq.T,
        myblk.attention.dwk.T,
        myblk.attention.dwv.T,
        myblk.attention.dwo.T,
    )
    attention_weight = torch.concat([dwq, dwk, dwv], dim=0)
    # assert attention_weight.shape == (C+2*kv_dim, C)
    assert torch.allclose(attention_weight, blk.attn.c_attn.weight.grad, atol=1e-4), (
        attention_weight.view(-1)[:10],
        blk.attn.c_attn.weight.grad.reshape(-1)[:10],
    )
    assert torch.allclose(dwo, blk.attn.c_proj.weight.grad, atol=1e-4), (
        dwo.reshape(-1)[:10],
        blk.attn.c_proj.weight.grad.reshape(-1)[:10],
    )

    # attn norm
    assert torch.allclose(myblk.ln_1.dweight, blk.ln_1.weight.grad, atol=1e-4), (
        myblk.ln_1.dweight.view(-1)[:10],
        blk.ln_1.weight.grad.view(-1)[:10],
    )

    # inp
    if bdinp is not None and inp.grad is not None:
        assert torch.allclose(bdinp, inp.grad, atol=1e-4)
        print("inp grads also match")


if TEST:
    for cnfg in configs:
        print(
            "\nChecking: ",
            cnfg,
        )
        inp = torch.randn((cnfg.B, cnfg.T, cnfg.n_embd))
        mask = torch.triu(
            torch.ones((cnfg.T, cnfg.T), device="cpu", dtype=torch.bool),
            diagonal=1,
        )
        inp.requires_grad = True
        hs = cnfg.n_embd // cnfg.n_head
        freqs_cis = train_llama3.precompute_freqs_cis(
            dim=hs, end=cnfg.T, theta=cnfg.rope_theta, use_scaled=False
        )

        blk = train_llama3.Block(cnfg)
        blko1 = blk(x=inp, freqs_cis=freqs_cis, mask=mask)
        dout = torch.randn_like(blko1)
        blko1.backward(dout)
        inp1 = inp.detach()
        myblk = init_myblock_from_block(blk, cnfg)
        assert torch.allclose(myblk.forward(inp1), blko1, atol=1e-6)
        for tnsr in myblk.state:
            if hasattr(tnsr, "detach"):
                tnsr = tnsr.detach()
        bdinp = myblk.backward(dout)
        test_block_backwards(blk, myblk, bdinp, inp)
    print("PASS")


Checking:  LlamaConfig(version='3.1', block_size=64, vocab_size=20, n_layer=0, n_head=8, n_kv_head=4, n_embd=256, ffn_dim_multiplier=1.3, multiple_of=64, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=32, B=4)
inp grads also match

Checking:  LlamaConfig(version='3.1', block_size=64, vocab_size=20, n_layer=1, n_head=8, n_kv_head=4, n_embd=256, ffn_dim_multiplier=1.3, multiple_of=64, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=32, B=4)
inp grads also match

Checking:  LlamaConfig(version='3.1', block_size=64, vocab_size=20, n_layer=2, n_head=8, n_kv_head=4, n_embd=256, ffn_dim_multiplier=1.3, multiple_of=64, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=32, B=4)
inp grads also match

Checking:  LlamaConfig(version='3.1', block_size=10, vocab_size=20, n_layer=5, n_head=4, n_kv_head=2, n_embd=8, ff

### init_mymodel_from_LLama
    -- use the llama3 model to initilise our custom model

In [114]:
def init_mymodel_from_LLama(config: LlamaConfig, model=None, assert_shape=True):
    mymodel = Transformer(config)
    if model is None:
        model = train_llama3.LLaMA(config)

    transformer = model.transformer

    # wte
    assert mymodel.encoder.wte.shape == transformer["wte"].weight.shape
    mymodel.encoder.wte = transformer["wte"].weight.detach().clone()
    config.vocab_size, config.n_embd = mymodel.encoder.wte.shape

    assert mymodel.output.weight.shape == model.lm_head.weight.T.shape, (
        mymodel.output.weight.shape,
        model.lm_head.weight.T.shape,
    )
    mymodel.output.weight = model.lm_head.weight.detach().clone().T  # (C,V)

    for l in range(config.n_layer):
        mymodel.blocks[l] = init_myblock_from_block(
            transformer["h"][l], config, assert_shape
        )

    mymodel.ln_f.weight = transformer["ln_f"].weight.detach().clone()
    return mymodel, model

### test Update
    - -compare the params after a forward and update

In [115]:
config3

LlamaConfig(version='3.1', block_size=10, vocab_size=20, n_layer=5, n_head=4, n_kv_head=2, n_embd=8, ffn_dim_multiplier=1.3, multiple_of=32, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=4, B=2)

In [116]:
cnfg = config3
cnfg.n_layer = 5
betas = (0.9, 0.99)
lr = 0.005124

# test_block_forward(blk, myblk)
torch.random.manual_seed(1)
mymodel, model = init_mymodel_from_LLama(cnfg)
B, T, C = cnfg.B, cnfg.T, cnfg.n_embd
tokens = torch.randint(cnfg.vocab_size, (B, T))
targets = tokens.clone()
optimiser = torch.optim.Adam(model.parameters(), lr, betas, cnfg.norm_eps, 0.0)

for n in range(3):
    optimiser.zero_grad()
    ot1, loss1 = model(tokens, targets, return_logits=True)
    dout = torch.randn_like(ot1)
    ot1.backward(dout)

    tokens1 = tokens.detach()
    targets1 = targets.detach()
    ot2, loss2 = mymodel.forward(tokens1, targets1)

    assert torch.allclose(
        mymodel.ln_f.weight, model.transformer["ln_f"].weight, atol=1e-3
    ), (
        n,
        mymodel.ln_f.weight.view(-1)[:10],
        model.transformer["ln_f"].weight.view(-1)[:10],
    )
    assert torch.allclose(mymodel.output.weight, model.lm_head.weight.T, atol=1e-3), (
        n,
        mymodel.output.weight.view(-1)[:10],
        model.lm_head.weight.T.reshape(-1)[:10],
    )

    for i, blk in enumerate(model.transformer["h"]):
        myblk = mymodel.blocks[i]
        test_block_forward(blk, myblk)
    assert torch.allclose(ot1, ot2, atol=1e-3), (
        ot1.view(-1)[:10],
        ot2.view(-1)[:10],
        (n, ot1.view(-1)[-10:], ot2.view(-1)[-10:]),
    )
    assert torch.allclose(loss1, loss2, atol=1e-3)
    # backward
    mymodel.backward(dout, False)
    optimiser.step()
    mymodel.update(lr, betas[0], betas[1], cnfg.norm_eps, 0.0, n + 1)
    print(f"All tests pass: {n}")
print("Done")

All tests pass: 0
All tests pass: 1
All tests pass: 2
Done


## Classes forward+backward

In [123]:
if TEST:
    print(len(configs))
    for cnfg in configs:
        print(cnfg)
        B, T, C, V = cnfg.B, cnfg.T, cnfg.n_embd, cnfg.vocab_size

        # encoder
        enc = nn.Embedding(V, C)
        tokens = torch.randint(V, (B, T))
        enco1 = enc(tokens)
        dout = torch.randn_like(enco1)
        enco1.backward(dout, inputs=[enc.weight])

        tokens = tokens.detach()
        enc2 = Encoder(V, C, enc.weight.detach())
        enco2 = enc2.forward(tokens)
        assert enco1.shape == enco2.shape
        assert torch.allclose(enco1, enco2)
        assert torch.allclose(enc2.state[0], tokens)

        dwte = enc2.backward(dout)
        assert torch.allclose(dwte, enc.weight.grad)

        # output
        inp = torch.randn(B, T, C)
        o1 = nn.Linear(in_features=C, out_features=V, bias=False)
        o2 = Matmul(isize=C, osize=V)

        o2.weight[...] = o1.weight.detach().T
        o1o = o1(inp)
        dout = torch.randn_like(o1o)
        o1o.backward(dout)
        inp1 = inp.detach()

        o2o = o2.forward(inp1)
        dinp = o2.backward(dout)
        assert torch.allclose(o1o, o2o, atol=1e-4)
        assert torch.allclose(o2.dweight, o1.weight.grad.T, atol=1e-4)

        # rms
        inp = torch.randn(B, T, C)
        inp.requires_grad = True
        rms = train_llama3.RMSNorm(C)
        o1 = rms(inp)
        dout = torch.randn_like(o1)
        o1.backward(dout)
        inp1 = inp.detach()

        ro2 = RMS(cnfg)
        ro2.weight = rms.weight.detach()
        o2 = ro2.forward(inp1)
        assert torch.allclose(o1, o2, atol=1e-4), (o1.view(-1)[-10:], o2.view(-1)[-10:])

        # backward
        dinp = ro2.backward(dout)
        assert torch.allclose(
            ro2.dweight, rms.weight.grad, atol=1e-3
        ), f"{dw1[:5]},{rms.weight.grad[:5]}"
        assert torch.allclose(
            dinp, inp.grad, atol=1e-3
        ), f"{dinp.view(-1)[:10]}, {inp.grad.view(-1)[:10]}"
print("PASS")

5
LlamaConfig(version='3.1', block_size=64, vocab_size=20, n_layer=0, n_head=8, n_kv_head=4, n_embd=256, ffn_dim_multiplier=1.3, multiple_of=64, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=32, B=4)
LlamaConfig(version='3.1', block_size=64, vocab_size=20, n_layer=1, n_head=8, n_kv_head=4, n_embd=256, ffn_dim_multiplier=1.3, multiple_of=64, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=32, B=4)
LlamaConfig(version='3.1', block_size=64, vocab_size=20, n_layer=2, n_head=8, n_kv_head=4, n_embd=256, ffn_dim_multiplier=1.3, multiple_of=64, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=32, B=4)
LlamaConfig(version='3.1', block_size=10, vocab_size=20, n_layer=5, n_head=4, n_kv_head=2, n_embd=8, ffn_dim_multiplier=1.3, multiple_of=32, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_bat

## Test Model forward backward
    -- Test params after forward and backward pass

In [124]:
if TEST:
    for cnfg in configs:
        print("\nChecking: ", cnfg)

        mymodel, model = init_mymodel_from_LLama(cnfg)   # model : llama3, mymodel: custom llama3
        B, T, C = cnfg.B, cnfg.T, cnfg.n_embd
        tokens = torch.randint(cnfg.vocab_size, (B, T))
        targets = tokens.clone()
        ot1, loss1 = model(tokens, targets, return_logits=True)
        dout = torch.randn_like(ot1)
        ot1.backward(dout)

        tokens1 = tokens.detach()
        targets1 = targets.detach()
        ot2, loss2 = mymodel.forward(tokens1, targets1)
        assert torch.allclose(ot1, ot2, atol=1e-4), (
            ot1.view(-1)[:10],
            ot2.view(-1)[:10],
            (ot1.view(-1)[-10:], ot2.view(-1)[-10:]),
        )
        assert torch.allclose(loss1, loss2, atol=1e-4)
        # backward
        mymodel.backward(dout)

        assert torch.allclose(
            mymodel.encoder.dwte, model.transformer["wte"].weight.grad, atol=1e-4
        ), (
            mymodel.encoder.dwte.view(-1)[:10],
            model.transformer["wte"].weight.grad.view(-1)[:10],
        )
        assert torch.allclose(
            mymodel.ln_f.dweight, model.transformer["ln_f"].weight.grad, atol=1e-4
        ), (
            mymodel.ln_f.dweight.view(-1)[:10],
            model.transformer["ln_f"].weight.grad.view(-1)[:10],
        )
        assert torch.allclose(
            mymodel.output.dweight, model.lm_head.weight.grad.T, atol=1e-3
        ), (
            mymodel.output.dweight.view(-1)[:10],
            model.lm_head.weight.grad.T.reshape(-1)[:10],
        )
        for i, blk in enumerate(model.transformer["h"]):
            myblk = mymodel.blocks[i]
            test_block_backwards(blk, myblk)
    print("PASS")


Checking:  LlamaConfig(version='3.1', block_size=64, vocab_size=20, n_layer=0, n_head=8, n_kv_head=4, n_embd=256, ffn_dim_multiplier=1.3, multiple_of=64, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=32, B=4)

Checking:  LlamaConfig(version='3.1', block_size=64, vocab_size=20, n_layer=1, n_head=8, n_kv_head=4, n_embd=256, ffn_dim_multiplier=1.3, multiple_of=64, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=32, B=4)

Checking:  LlamaConfig(version='3.1', block_size=64, vocab_size=20, n_layer=2, n_head=8, n_kv_head=4, n_embd=256, ffn_dim_multiplier=1.3, multiple_of=64, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=32, B=4)

Checking:  LlamaConfig(version='3.1', block_size=10, vocab_size=20, n_layer=5, n_head=4, n_kv_head=2, n_embd=8, ffn_dim_multiplier=1.3, multiple_of=32, norm_eps=1e-05, rope_thet

## Testing loss for the different configs

In [125]:
from typing import Union

def train_model(
    model: Union[train_llama3.LLaMA, Transformer],
    inputs: torch.tensor,
    targets: torch.tensor,
    n_train_steps=10,
):
    optimiser = None
    if isinstance(model, train_llama3.LLaMA):
        forward = lambda: model.forward(inputs, targets, True, 0)
        optimiser = torch.optim.Adam(model.parameters(), lr, betas, eps, weight_decay)
        update = lambda iter: (optimiser.step(), optimiser.zero_grad())

    else:
        forward = lambda: model.forward(inputs, targets)
        update = lambda iter: model.update(
            lr, betas[0], betas[1], eps, weight_decay, iter
        )

    logits_loss = []
    for n in range(n_train_steps):
        if isinstance(model, train_llama3.LLaMA):
            optimiser.zero_grad()
        lg, loss = forward()
        if isinstance(model, train_llama3.LLaMA):
            loss.backward()  # get the grad

        else:
            model.backward(None, True)
        update(n + 1)
        logits_loss.append((lg, loss))
    return logits_loss


from typing import Union
from collections import defaultdict
from tqdm import tqdm

def train_models(
    llama_model: train_llama3.LLaMA,
    mymodel: Transformer,
    inputs: torch.tensor,
    targets: torch.tensor,
    n_train_steps=10,
    params={},
):
    eps = params["eps"]
    lr = params["lr"]
    betas = params["betas"]
    weight_decay = 0.0
    optimiser = torch.optim.Adam(llama_model.parameters(), lr, betas, eps, weight_decay)

    inputs1 = inputs.detach().clone()
    targets1 = targets.detach().clone()
    logits_loss = defaultdict(list)
    for n in tqdm(range(n_train_steps), desc="Training"):
        optimiser.zero_grad()

        # forward + backward
        lg1, loss1 = mymodel.forward(inputs1, targets1)
        lg, loss = llama_model.forward(inputs, targets, True)

        mymodel.backward(None, True)  # get the grad
        loss.backward()  # get the grad

        logits_loss["llama"].append((lg, loss))
        logits_loss["myllama"].append((lg1, loss1))

        optimiser.step()
        mymodel.update(lr, betas[0], betas[1], eps, weight_decay, n + 1)
        cprint(f"Done trainig step: {n}", debug=DEBUG)

    return logits_loss

In [126]:
new_config = LlamaConfig(
    version="3.1",
    block_size=200,
    vocab_size=100,
    n_layer=12,
    n_head=12,
    n_kv_head=4,
    n_embd=48,
    ffn_dim_multiplier=1.3,
    multiple_of=32,
    norm_eps=1e-05,
    rope_theta=500000.0,
    use_scaled_rope=False,
    max_gen_batch_size=4,
    use_kv=True,
    flash=False,
    T=40,
    B=10,
)

### Train models (custom llama3, original llama3) and compare the losses, logits

In [127]:
train_steps = 50
lr = 0.001
betas = (0.9, 0.99)
weight_decay = 0.0
eps = 1e-8

DEBUG=False
import gc
torch.random.manual_seed(1)
for cnfg in [new_config]+configs:
    B, T, V = cnfg.B, cnfg.T, cnfg.vocab_size
    inputs = torch.randint(0, V, (B, T))
    targets = torch.randint(0, V, (B, T))
    
    inputs1 = inputs.clone()
    targets1 = targets.clone()
    
    my_llm, llm = init_mymodel_from_LLama(cnfg)
    print(cnfg)
    params = {"lr": lr, "betas": betas, "eps": eps}
    if V >= 500:
        loss_logits = train_models(llm, my_llm, inputs, targets, 10, params)  # to save time
    else:
        loss_logits = train_models(llm, my_llm, inputs, targets, train_steps, params)
    llama_loss = torch.tensor([ls.item() for _, ls in loss_logits["llama"]])
    my_loss = torch.tensor([[ls for _, ls in loss_logits["myllama"]]])
    assert torch.allclose(llama_loss, my_loss, atol=1e-4), (
        llama_loss,
        "\n",
        my_loss,
        llama_loss.sum(),
        my_loss.sum(),
    )
    print("mymodel_loss:\n", my_loss)
    print("ref_model_loss:\n", llama_loss)
    print(f"Training results match: {new_config}")
    print(params)
    del my_llm
    del llm
    gc.collect()
    print('------------')

LlamaConfig(version='3.1', block_size=200, vocab_size=100, n_layer=12, n_head=12, n_kv_head=4, n_embd=48, ffn_dim_multiplier=1.3, multiple_of=32, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=40, B=10)


Training: 100%|████████████████████████████| 50/50 [01:38<00:00,  1.97s/it]


mymodel_loss:
 tensor([[4.7532368, 4.5702333, 4.4022155, 4.2488265, 4.1059370, 3.9710541,
         3.8442869, 3.7255216, 3.6128504, 3.5043333, 3.3995423, 3.2988062,
         3.2018702, 3.1077132, 3.0154085, 2.9245405, 2.8347578, 2.7454476,
         2.6561217, 2.5668652, 2.4777164, 2.3878181, 2.2968245, 2.2051294,
         2.1127300, 2.0204105, 1.9284750, 1.8373041, 1.7465363, 1.6566626,
         1.5686325, 1.4831506, 1.3998740, 1.3189836, 1.2409533, 1.1672759,
         1.0989659, 1.0342294, 0.9672856, 0.9068358, 0.8533952, 0.7996473,
         0.7520568, 0.7071263, 0.6667799, 0.6322387, 0.5955911, 0.5605782,
         0.5329599, 0.5026795]])
ref_model_loss:
 tensor([4.7532368, 4.5702333, 4.4022155, 4.2488270, 4.1059370, 3.9710541,
        3.8442862, 3.7255208, 3.6128504, 3.5043328, 3.3995419, 3.2988060,
        3.2018702, 3.1077132, 3.0154085, 2.9245405, 2.8347580, 2.7454481,
        2.6561215, 2.5668652, 2.4777167, 2.3878183, 2.2968245, 2.2051296,
        2.1127303, 2.0204108, 1.9284754

Training: 100%|████████████████████████████| 50/50 [01:16<00:00,  1.53s/it]


mymodel_loss:
 tensor([[3.1669021, 3.0779817, 2.9925590, 2.9107370, 2.8326106, 2.7582567,
         2.6877360, 2.6210961, 2.5583651, 2.4995492, 2.4446223, 2.3935208,
         2.3461347, 2.3023086, 2.2618463, 2.2245216, 2.1900916, 2.1583102,
         2.1289387, 2.1017573, 2.0765696, 2.0532064, 2.0315263, 2.0114117,
         1.9927657, 1.9755046, 1.9595535, 1.9448404, 1.9312953, 1.9188451,
         1.9074154, 1.8969290, 1.8873073, 1.8784715, 1.8703456, 1.8628561,
         1.8559358, 1.8495258, 1.8435745, 1.8380411, 1.8328925, 1.8281031,
         1.8236512, 1.8195184, 1.8156873, 1.8121387, 1.8088526, 1.8058076,
         1.8029820, 1.8003536]])
ref_model_loss:
 tensor([3.1669018, 3.0779815, 2.9925592, 2.9107370, 2.8326106, 2.7582567,
        2.6877358, 2.6210961, 2.5583653, 2.4995492, 2.4446223, 2.3935206,
        2.3461344, 2.3023083, 2.2618463, 2.2245216, 2.1900918, 2.1583099,
        2.1289384, 2.1017575, 2.0765693, 2.0532064, 2.0315263, 2.0114117,
        1.9927659, 1.9755046, 1.9595534

Training: 100%|████████████████████████████| 50/50 [04:16<00:00,  5.13s/it]


mymodel_loss:
 tensor([[3.2379425e+00, 2.8768456e+00, 2.5534835e+00, 2.2589612e+00,
         2.0057635e+00, 1.8013994e+00, 1.6338880e+00, 1.4878640e+00,
         1.3612552e+00, 1.2513376e+00, 1.1428984e+00, 1.0254030e+00,
         9.0398687e-01, 7.8603733e-01, 6.7419517e-01, 5.7092136e-01,
         4.7738323e-01, 3.9174253e-01, 3.1812239e-01, 2.5913003e-01,
         2.0987350e-01, 1.6921929e-01, 1.3559230e-01, 1.0618776e-01,
         8.0252700e-02, 5.8618888e-02, 4.3007713e-02, 3.1897701e-02,
         2.3648763e-02, 1.7863201e-02, 1.4075876e-02, 1.1427311e-02,
         9.6427789e-03, 8.3526038e-03, 7.3383767e-03, 6.4737736e-03,
         5.7067149e-03, 5.0379494e-03, 4.4730594e-03, 3.9986628e-03,
         3.5951775e-03, 3.2496804e-03, 2.9557205e-03, 2.7074069e-03,
         2.4965745e-03, 2.3143049e-03, 2.1536427e-03, 2.0107217e-03,
         1.8833457e-03, 1.7696681e-03]])
ref_model_loss:
 tensor([3.2379427e+00, 2.8768454e+00, 2.5534840e+00, 2.2589610e+00,
        2.0057640e+00, 1.801399

Training: 100%|████████████████████████████| 50/50 [06:46<00:00,  8.13s/it]


mymodel_loss:
 tensor([[3.0214624, 2.4867995, 2.0948153, 1.8162000, 1.6008555, 1.4227315,
         1.2761170, 1.1323754, 0.9797642, 0.8291832, 0.6863651, 0.5530309,
         0.4375713, 0.3398650, 0.2541281, 0.1832843, 0.1310944, 0.0943556,
         0.0675600, 0.0487518, 0.0371242, 0.0277191, 0.0233321, 0.0204118,
         0.0187703, 0.0171570, 0.0161513, 0.0155197, 0.0149789, 0.0144951,
         0.0140305, 0.0136465, 0.0133863, 0.0131110, 0.0129271, 0.0127457,
         0.0125814, 0.0124622, 0.0123197, 0.0122344, 0.0121235, 0.0120549,
         0.0119738, 0.0119141, 0.0118550, 0.0118024, 0.0117580, 0.0117125,
         0.0116777, 0.0116394]])
ref_model_loss:
 tensor([3.0214624, 2.4867995, 2.0948150, 1.8161999, 1.6008555, 1.4227316,
        1.2761171, 1.1323754, 0.9797640, 0.8291833, 0.6863649, 0.5530310,
        0.4375713, 0.3398650, 0.2541281, 0.1832844, 0.1310944, 0.0943555,
        0.0675600, 0.0487519, 0.0371243, 0.0277191, 0.0233322, 0.0204118,
        0.0187703, 0.0171570, 0.0161513

Training: 100%|████████████████████████████| 50/50 [00:01<00:00, 25.26it/s]


mymodel_loss:
 tensor([[3.2748358, 3.2359109, 3.1975098, 3.1596470, 3.1223364, 3.0855932,
         3.0494266, 3.0138361, 2.9788167, 2.9443641, 2.9104731, 2.8771393,
         2.8443587, 2.8121281, 2.7804418, 2.7492936, 2.7186763, 2.6885798,
         2.6589913, 2.6298923, 2.6012564, 2.5730450, 2.5452080, 2.5176854,
         2.4904077, 2.4633033, 2.4363050, 2.4093571, 2.3824241, 2.3554990,
         2.3286085, 2.3018141, 2.2752039, 2.2488830, 2.2229543, 2.1975040,
         2.1725874, 2.1482272, 2.1244147, 2.1011233, 2.0783195, 2.0559745,
         2.0340688, 2.0125942, 1.9915466, 1.9709182, 1.9506927, 1.9308451,
         1.9113485, 1.8921881]])
ref_model_loss:
 tensor([3.2748358, 3.2359109, 3.1975095, 3.1596467, 3.1223364, 3.0855937,
        3.0494266, 3.0138359, 2.9788170, 2.9443641, 2.9104733, 2.8771396,
        2.8443587, 2.8121281, 2.7804418, 2.7492936, 2.7186763, 2.6885798,
        2.6589913, 2.6298923, 2.6012561, 2.5730448, 2.5452080, 2.5176852,
        2.4904072, 2.4633031, 2.4363050

Training: 100%|████████████████████████████| 10/10 [05:32<00:00, 33.29s/it]

mymodel_loss:
 tensor([[7.1348028, 5.0838103, 3.1337659, 1.4195486, 0.4036632, 0.1165457,
         0.0584250, 0.0382141, 0.0291586, 0.0226412]])
ref_model_loss:
 tensor([7.1348033, 5.0838099, 3.1337662, 1.4195483, 0.4036631, 0.1165457,
        0.0584250, 0.0382141, 0.0291586, 0.0226412])
Training results match: LlamaConfig(version='3.1', block_size=200, vocab_size=100, n_layer=12, n_head=12, n_kv_head=4, n_embd=48, ffn_dim_multiplier=1.3, multiple_of=32, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=40, B=10)
{'lr': 0.001, 'betas': (0.9, 0.99), 'eps': 1e-08}
------------





## Load model and loss
    -- load a gpt2 model
    -- initialize our custom model from it
    -- compare the losses from the forward pass

In [128]:
def load_model(model_path):
    with open(model_path, "rb") as f:
        header = np.frombuffer(buffer=f.read(256 * 4), dtype=np.int32)
        maxT, V, L, nH, C = header[2:7]
        n_kv_head = nH
        hs = C // nH
        kv_dim = hs * n_kv_head
        print(maxT, V, L, nH, C)
        wte = torch.from_numpy(np.frombuffer(f.read(V * C * 4), dtype=np.float32)).view(
            V, C
        )  # use the transpose of this for the output layer (C,V)
        wpe = np.frombuffer(f.read(maxT * C * 4))

        l1w = torch.from_numpy(np.frombuffer(f.read(L * C * 4), dtype=np.float32)).view(
            L, C
        )
        np.frombuffer(f.read(L * C * 4))  # (L, C)

        c_attn_weight = torch.from_numpy(
            np.frombuffer(f.read(L * (C + 2 * kv_dim) * C * 4), dtype=np.float32)
        ).view(L, C + 2 * kv_dim, C)

        # (L, 3C)
        _ = torch.from_numpy(
            np.frombuffer(f.read(L * (C + 2 * kv_dim) * 4), dtype=np.float32)
        ).view(L, C + 2 * kv_dim)

        # (L, C, C)
        c_attn_proj = torch.from_numpy(
            np.frombuffer(f.read(L * (C) * C * 4), dtype=np.float32)
        ).view(L, C, C)

        # (L, C)
        _ = torch.from_numpy(np.frombuffer(f.read(L * (C) * 4), dtype=np.float32)).view(
            L, C
        )
        # (L, C)
        l2w = torch.from_numpy(np.frombuffer(f.read(L * C * 4), dtype=np.float32)).view(
            L, C
        )
        # (L, C)
        _ = torch.from_numpy(np.frombuffer(f.read(L * C * 4), dtype=np.float32)).view(
            L, C
        )

        # (L, 4C, C)
        OC = 4 * C  # w3 and w1 are identical
        mlp_w1 = torch.from_numpy(
            np.frombuffer(f.read(L * (4 * C) * C * 4), dtype=np.float32)
        ).view(L, 4 * C, C)
        
        # (L, 4C)
        torch.from_numpy(np.frombuffer(f.read(L * (4 * C) * 4), dtype=np.float32))
        
        # (L, C, 4C)
        mlp_w2 = torch.from_numpy(
            np.frombuffer(f.read(L * (4 * C) * C * 4), dtype=np.float32)
        ).view(L, C, 4 * C)
        
        # (L, C)
        torch.from_numpy(np.frombuffer(f.read(L * C * 4), dtype=np.float32)).view(L, C)

        # (C)
        lnf = torch.from_numpy(np.frombuffer(f.read(C * 4), dtype=np.float32)).view(C)

        torch.from_numpy(np.frombuffer(f.read(C * 4), dtype=np.float32)).view(C)

        config = LlamaConfig(
            block_size=maxT,
            vocab_size=V,
            n_layer=L,
            n_head=nH,
            n_kv_head=nH,
            n_embd=C,
            n_rep=1,
        )


        # llama3 model
        llama = train_llama3.LLaMA(config)
        transformer = llama.transformer

        assert transformer["wte"].weight.shape == (V, C), ()
        del transformer["wte"].weight
        transformer["wte"].weight = nn.Parameter(wte)

        del transformer["ln_f"].weight
        transformer["ln_f"].weight = nn.Parameter(lnf)

        assert llama.lm_head.weight.shape == (V, C)
        del llama.lm_head.weight
        llama.lm_head.weight = nn.Parameter(wte.detach())

        for l in range(L):
            assert transformer["h"][l].ln_1.weight.shape == l1w[l].shape
            del transformer["h"][l].ln_1.weight
            transformer["h"][l].ln_1.weight = nn.Parameter(l1w[l])

            assert transformer["h"][l].ln_2.weight.shape == l2w[l].shape
            del transformer["h"][l].ln_2.weight
            transformer["h"][l].ln_2.weight = nn.Parameter(l2w[l])

            transformer["h"][l].attn.c_attn.weight = nn.Parameter(c_attn_weight[l])
            assert transformer["h"][l].attn.c_proj.weight.shape == c_attn_proj[l].shape
            transformer["h"][l].attn.c_proj.weight = nn.Parameter(c_attn_proj[l])

            del transformer["h"][l].mlp.c_fc2.weight
            transformer["h"][l].mlp.c_fc2.weight = nn.Parameter(mlp_w1[l])

            del transformer["h"][l].mlp.c_fc.weight
            transformer["h"][l].mlp.c_fc.weight = nn.Parameter(mlp_w1[l].detach())

            del transformer["h"][l].mlp.c_proj.weight
            transformer["h"][l].mlp.c_proj.weight = nn.Parameter(mlp_w2[l].detach())

        return init_mymodel_from_LLama(config, llama, False), config
        
def load_state(state_path, V: int):
    with open(state_path, "rb") as f:
        header = np.frombuffer(buffer=f.read(256 * 4), dtype=np.int32)
        B, T = header[2:4]
        inputs = torch.from_numpy(
            np.frombuffer(f.read(B * T * 4), dtype=np.int32),
        ).view(B, T)
        # inputs = inputs.to(torch.long)
        targets = (
            torch.from_numpy(np.frombuffer(f.read(B * T * 4), dtype=np.int32))
            .view(B, T)
            .to(torch.long)
        )
        logits = torch.from_numpy(
            np.frombuffer(f.read(B * T * V * 4), dtype=np.float32)
        ).view(B, T, V)
        loss = torch.from_numpy(np.frombuffer(f.read(4), dtype=np.float32))
        print(B, T, loss)
        return inputs, targets, logits, loss

In [135]:
bin_model_path = "gpt2_124M.bin"           # the gpt2 model params
state_path = "gpt2_124M_debug_state.bin"   # the inputs, and the loss

In [138]:
### Download the models
import os
if not os.path.exists(bin_model_path):
    !curl -L -o {bin_model_path} "https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/gpt2_124M.bin"
if not os.path.exists(state_path):
    !curl  -L -o {state_path} "https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/gpt2_124M_debug_state.bin"

In [139]:
(mymodel, llama_model), config = load_model(bin_model_path)
config

1024 50257 12 12 768


LlamaConfig(version='3.1', block_size=1024, vocab_size=50257, n_layer=12, n_head=12, n_kv_head=12, n_embd=768, ffn_dim_multiplier=1.3, multiple_of=32, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=5, B=4)

In [140]:
inputs, targets, logits, loss = load_state(state_path=state_path, V=config.vocab_size)
config.B, config.T = inputs.shape
config

4 64 tensor([5.2700086])


LlamaConfig(version='3.1', block_size=1024, vocab_size=50257, n_layer=12, n_head=12, n_kv_head=12, n_embd=768, ffn_dim_multiplier=1.3, multiple_of=32, norm_eps=1e-05, rope_theta=500000.0, use_scaled_rope=False, max_gen_batch_size=4, use_kv=True, flash=False, T=64, B=4)

In [141]:
mylogits, myloss = mymodel.forward(inp=inputs, targets=targets)
print(myloss)

tensor(11.2998238)


In [142]:
llama_logits, llama_loss = llama_model.forward(
    inputs, targets=targets, return_logits=True
)
print(llama_loss)

tensor(11.2998266, grad_fn=<NllLossBackward0>)


In [144]:
print(torch.allclose(llama_logits, mylogits, atol=1e-4))
print(torch.allclose(llama_loss, myloss, atol=1e-4, rtol=1e-4))

True
True
