In [1]:
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"]="1"

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer

from typing import List, Tuple

# -----

device = "mps:0"
modelpath = "meta-llama/Llama-3.2-1B-Instruct"


In [None]:
# Load in hf model 
hf_model = AutoModelForCausalLM.from_pretrained(
    modelpath,    
)
hf_model.to(device)
print(hf_model.lm_head)


Linear(in_features=2048, out_features=128256, bias=False)


In [2]:
#Check that preload works
from models import llama3_aoc
import configs

model = llama3_aoc.LlamaModel.from_huggingface(configs.Llama1BConfig)
model.to(device)
print(model)


Llama model loaded from state dict
LlamaModel(
  (embed_tokens): Embedding(128256, 2048)
  (layers): ModuleList(
    (0-15): 16 x LlamaDecoderLayer(
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
        (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
        (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
        (act_fn): SiLU()
      )
      (self_attn): LlamaSdpaAttention(
        (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
        (k_proj): Linear(in_features=2048, out_features=512, bias=False)
        (v_proj): Linear(in_features=2048, out_features=512, bias=False)
        (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
      )
      (input_layernorm): LlamaRMSNorm()
      (post_attention_layernorm): LlamaRMSNorm()
    )
  )
  (norm): LlamaRMSNorm()
  (lm_head): Linear(in_features=2048, out_features=128256, bias=False)
)


In [4]:
ao_model = llama3_aoc.AttentionOnlyLlamaModel.from_huggingface(configs.Llama1BConfig)

print(ao_model)


Llama model loaded from state dict
AttentionOnlyLlamaModel(
  (embed_tokens): Embedding(128256, 2048)
  (layers): ModuleList(
    (0-15): 16 x LlamaSdpaAttention(
      (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
      (k_proj): Linear(in_features=2048, out_features=512, bias=False)
      (v_proj): Linear(in_features=2048, out_features=512, bias=False)
      (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
    )
  )
  (norm): LlamaRMSNorm()
)


Test that Rotary projection works

In [7]:
def precompute_freqs_cis(dim: int, end: int, theta: float) -> torch.Tensor:
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    return torch.polar(torch.ones_like(freqs), freqs)  # complex64

freqs_cis = precompute_freqs_cis(dim=32, end=1024, theta=50000.0)
print(freqs_cis.shape)


torch.Size([1024, 16])


In [13]:
def apply_rotary_emb(
    x: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> torch.Tensor:
    print(x.shape)

    x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    print(x_.shape)

    freqs_cis = freqs_cis[None, None, :, :]
    x_out = torch.view_as_real(x_ * freqs_cis).flatten(-2)
    print(x_out.shape)

    return x_out.type_as(x)

# bsz, n_head, toks, head_dim
x = torch.ones([1,8,10,32])


x = apply_rotary_emb(x, freqs_cis[:10,:])



torch.Size([1, 8, 10, 32])
torch.Size([1, 8, 10, 16])
torch.Size([1, 8, 10, 32])


Check that text generation is the same

In [None]:
from transformers import AutoTokenizer


tokenizer = AutoTokenizer.from_pretrained(modelpath)



In [4]:
text = "Some examples of important benchmarks in language modelling are"

tokens = tokenizer(text)
tokens = tokens['input_ids']

print(tokens)
print(tokenizer.decode(tokens))


[128000, 8538, 10507, 315, 3062, 63119, 304, 4221, 61966, 527]
<|begin_of_text|>Some examples of important benchmarks in language modelling are


In [5]:
toks = torch.tensor(tokens).unsqueeze(0).to(dtype=torch.long, device=device)
logits = model(toks)
print(logits)


torch.Size([1, 1, 10, 10])
torch.Size([1, 32, 10, 64])


RuntimeError: Tensor must have a last dimension of size 2

In [None]:
def generate(model,
             prompt: str, 
             num_samples: int, 
             toks_to_generate: int,
             ) -> List[List[int]]:

    toks = tokenizer(prompt)['input_ids']
    toks = torch.Tensor(toks, device=device).expand(num_samples, -1)
    



