In [1]:
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"]="1"
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer

from models import llama3_aoc
import configs


# -----

device = "mps:0"
modelpath = "meta-llama/Llama-3.2-1B-Instruct"


In [2]:
# Load in hf model 
hf_model = AutoModelForCausalLM.from_pretrained(
    modelpath,    
)
hf_model.to(device)


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb):

In [3]:
#Check that preload works

model = llama3_aoc.LlamaModel.from_huggingface(configs.Llama1BConfig)
model.to(device)
print(model)

num_params = 0
for p in model.parameters(): num_params += p.numel()
print(num_params)


embed_tokens.weight
layers.0.mlp.gate_proj.weight
layers.0.mlp.up_proj.weight
layers.0.mlp.down_proj.weight
layers.0.self_attn.q_proj.weight
layers.0.self_attn.k_proj.weight
layers.0.self_attn.v_proj.weight
layers.0.self_attn.o_proj.weight
layers.0.input_layernorm.weight
layers.0.post_attention_layernorm.weight
layers.1.mlp.gate_proj.weight
layers.1.mlp.up_proj.weight
layers.1.mlp.down_proj.weight
layers.1.self_attn.q_proj.weight
layers.1.self_attn.k_proj.weight
layers.1.self_attn.v_proj.weight
layers.1.self_attn.o_proj.weight
layers.1.input_layernorm.weight
layers.1.post_attention_layernorm.weight
layers.2.mlp.gate_proj.weight
layers.2.mlp.up_proj.weight
layers.2.mlp.down_proj.weight
layers.2.self_attn.q_proj.weight
layers.2.self_attn.k_proj.weight
layers.2.self_attn.v_proj.weight
layers.2.self_attn.o_proj.weight
layers.2.input_layernorm.weight
layers.2.post_attention_layernorm.weight
layers.3.mlp.gate_proj.weight
layers.3.mlp.up_proj.weight
layers.3.mlp.down_proj.weight
layers.3.self

In [4]:
ao_model = llama3_aoc.AttentionOnlyLlamaModel.from_huggingface(configs.Llama1BConfig)

print(ao_model)

num_params = 0
for p in ao_model.parameters(): num_params += p.numel()
print(num_params)


Llama model loaded from state dict
AttentionOnlyLlamaModel(
  (embed_tokens): Embedding(128256, 2048)
  (layers): ModuleList(
    (0-15): 16 x LlamaSdpaAttention(
      (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
      (k_proj): Linear(in_features=2048, out_features=512, bias=False)
      (v_proj): Linear(in_features=2048, out_features=512, bias=False)
      (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
    )
  )
  (norm): LlamaRMSNorm()
)
430442496


Test that Rotary projection works

In [5]:
def precompute_freqs_cis(dim: int, end: int, theta: float) -> torch.Tensor:
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    return torch.polar(torch.ones_like(freqs), freqs)  # complex64

freqs_cis = precompute_freqs_cis(dim=32, end=1024, theta=50000.0)
print(freqs_cis.shape)


torch.Size([1024, 16])


In [10]:
def apply_rotary_emb(
    x: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> torch.Tensor:
    print(x.shape)

    x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    print(x_.shape)

    freqs_cis = freqs_cis[None,:,  None, :]
    x_out = torch.view_as_real(x_ * freqs_cis).flatten(-2)
    print(x_out.shape)

    return x_out.type_as(x)

# bsz, n_head, toks, head_dim
x = torch.ones([1,10,8,32])


x = apply_rotary_emb(x, freqs_cis[:10,:])
print(x[0,:,0,:])


torch.Size([1, 10, 8, 32])
torch.Size([1, 10, 8, 16])
torch.Size([1, 10, 8, 32])
tensor([[ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
          1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
          1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
          1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000],
        [-0.3012,  1.3818,  0.3866,  1.3604,  0.7110,  1.2225,  0.8602,  1.1225,
          0.9309,  1.0646,  0.9654,  1.0334,  0.9826,  1.0171,  0.9912,  1.0088,
          0.9955,  1.0045,  0.9977,  1.0023,  0.9988,  1.0012,  0.9994,  1.0006,
          0.9997,  1.0003,  0.9998,  1.0002,  0.9999,  1.0001,  1.0000,  1.0000],
        [-1.3254,  0.4932, -0.3247,  1.3764,  0.3748,  1.3637,  0.7056,  1.2256,
          0.8577,  1.1244,  0.9297,  1.0656,  0.9648,  1.0340,  0.9823,  1.0174,
          0.9910,  1.0089,  0.9954,  1.0045,  0.9977,  1.0023,  0.9988,  1.0012,
          0.9994,  1.0006,

Check that text generation is the same

In [7]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(modelpath)


text = "Some examples of important benchmarks in language modelling are"

tokens = tokenizer(text)
tokens = tokens['input_ids']

print(tokens)
print(tokenizer.decode(tokens))



[128000, 8538, 10507, 315, 3062, 63119, 304, 4221, 61966, 527]
<|begin_of_text|>Some examples of important benchmarks in language modelling are


In [8]:
# Custom model generation
toks = torch.tensor(tokens).unsqueeze(0).to(dtype=torch.long, device=device)
logits = model(toks)

output = model.generate(toks, 30, top_k=5).to('cpu')
print(output)
print(tokenizer.decode(output.flatten().tolist()))


tensor([[128000,   8538,  10507,    315,   3062,  63119,    304,   4221,  61966,
            527,   6814,  13850,     11,    220,    220,    335,     12,    457,
             53,     17,     13,   4702,   1306,    279,    323,    285,    482,
            842,     13,    578,    471,     11,    452,    574,    471,     11,
            457,    335,     12,    586]])
<|begin_of_text|>Some examples of important benchmarks in language modelling are pubulous,   }- }
V2. Just after the andis - end. The return, N was return, }
 }- public


In [None]:
# HF model generation
inputs = tokenizer(text, return_tensors="pt").to(device)
generate_ids = hf_model.generate(
    inputs.input_ids,
    attention_mask=inputs.attention_mask,
    pad_token_id=tokenizer.eos_token_id,
    max_length=50)

print(tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])


Some examples of important benchmarks in language modelling are:

1. **BERT (Bidirectional Encoder Representations from Transformers)**: Developed by Google, BERT is a language model that achieves state-of-the-art results in many NLP tasks, including question
