In [1]:
import torch
from rope import apply_rotary_emb

q = torch.randn(2, 8, 2, 64)
k = torch.randn(2, 8, 2, 64)

q_out, k_out = apply_rotary_emb(q, k, head_dim=64, max_seq_len=8)
print("q_out shape:", q_out.shape)
print("✅ RoPE looks okay if shapes match")


  from .autonotebook import tqdm as notebook_tqdm


q_out shape: torch.Size([2, 8, 2, 64])
✅ RoPE looks okay if shapes match


In [2]:
import torch
from llama import RMSNorm

x = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
norm = RMSNorm(dim=4)
output = norm(x)

print("output:", output)
print("✅ RMSNorm looks okay if shape is (1,4) and values are scaled.")


output: tensor([[0.3651, 0.7303, 1.0954, 1.4606]], grad_fn=<MulBackward0>)
✅ RMSNorm looks okay if shape is (1,4) and values are scaled.


In [3]:
import torch
from llama import Attention
from base_llama import LlamaConfig

config = LlamaConfig()
attn = Attention(config)
attn.eval()

q = torch.randn(1, config.n_heads, 8, config.dim // config.n_heads)
k = torch.randn_like(q)
v = torch.randn_like(q)

out = attn.compute_query_key_value_scores(q, k, v)
print("Attention output shape:", out.shape)  # (1, n_heads, 8, head_dim)
print("✅ Attention output shape OK")


Before matmul shapes: torch.Size([1, 8, 8, 64]) torch.Size([1, 8, 8, 64]) torch.Size([1, 8, 8, 64])
attention scores (before mask): min=-2.6131701469421387, max=2.7626895904541016
Causal mask shape: torch.Size([1, 1, 8, 8])
attention scores (after mask): min=-inf, max=2.5095739364624023
attention_probs min=0.0, max=1.0, sum=tensor([[[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]])
attention_output shape: torch.Size([1, 8, 8, 64]), min=-3.20619368553161

In [4]:
import torch
from llama import LlamaLayer
from base_llama import LlamaConfig

config = LlamaConfig()
layer = LlamaLayer(0, config)
x = torch.randn(1, 8, config.dim)

out = layer(x)
print("Layer output shape:", out.shape)
print("✅ LlamaLayer output OK")


Before matmul shapes: torch.Size([1, 8, 8, 64]) torch.Size([1, 8, 8, 64]) torch.Size([1, 8, 8, 64])
attention scores (before mask): min=-1.0283336639404297, max=1.0711458921432495
Causal mask shape: torch.Size([1, 1, 8, 8])
attention scores (after mask): min=-inf, max=1.0278767347335815
attention_probs min=0.0, max=1.0, sum=tensor([[[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]],
       grad_fn=<SumBackward1>)
attention_output shape: torch.Size([1, 8,

In [15]:
import torch
from llama import Llama
from base_llama import LlamaConfig

config = LlamaConfig()
model = Llama(config)
tokens = torch.randint(0, config.vocab_size, (1, 8))
logits, hidden = model(tokens)

print("Logits shape:", logits.shape)
print("✅ Full Llama forward pass OK")


Before matmul shapes: torch.Size([1, 8, 8, 64]) torch.Size([1, 8, 8, 64]) torch.Size([1, 8, 8, 64])
attention scores (before mask): min=-0.6149947047233582, max=0.5480402708053589
Causal mask shape: torch.Size([1, 1, 8, 8])
attention scores (after mask): min=-inf, max=0.49598225951194763
attention_probs min=0.0, max=1.0, sum=tensor([[[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]],
       grad_fn=<SumBackward1>)
attention_output shape: torch.Size([1, 8

In [1]:
import torch

ckpt = torch.load("stories42M.pt", map_location="cpu")
print(ckpt["model_args"])


{'dim': 512, 'n_layers': 8, 'n_heads': 8, 'n_kv_heads': 8, 'vocab_size': 32000, 'multiple_of': 32, 'max_seq_len': 1024}
