# Running Llama from source code

In this notebook we use the Pytorch source code to run the model locally, the goal is to pare back the extra stuff included so the code can be translated to use Jax

## Paring back

### Tokenizer

I'll leave this as is, since it is cheap, and can be injected to the final LLama instance

### Model

There are some references to image parameters that I removed from this.

The remaining complexity relates to the resizing functions - since I'm only interested in this model (for now) then I can remove that and refactor the model args to remove the artificial 16/11 scale factor I had to introduce to get this to work.


In [1]:
import gc

gc.collect() 

31

In [2]:
import sys
import os

# we want to import some llama source later
os.getcwd()
project_path = os.path.abspath("LLM")

if project_path not in sys.path:
    sys.path.append(project_path)

In [3]:
os.chdir('/home/matt/.llama/checkpoints')

## Loading into `Transfomer` instance from source code

We are interested in the activations through the layers of this model, so it would be good to load create an instance of the `Transformer` object defined in the `model.py` file

In [4]:
import json
import torch
from safetensors.torch import load_file
from llama.model import ModelArgs

config_path = "/home/matt/.llama/checkpoints/Llama3.2-1B-hf/config.json"
with open(config_path, "r") as f:
    config = json.load(f)

# extract the necessary fields
model_args = ModelArgs(
    dim=config.get("hidden_size", 4096), 
    n_layers=config.get("num_hidden_layers", 32),  
    n_heads=config.get("num_attention_heads", 32), 
    n_kv_heads=config.get("num_key_value_heads", None), 
    vocab_size=config.get("vocab_size", -1), 
    multiple_of=256, # not in config so use the default
    norm_eps=config.get("rms_norm_eps", 1e-5),  # map "rms_norm_eps"
    max_batch_size=32,  # not in config so use the default
    use_scale_rope=True, # this is how it was in llama.ipynb
    rope_scale_factor=config.get("rope_scaling").get("factor"),
    original_rotary_embed_len=config.get("rope_scaling").get("original_max_position_embeddings"),
    cache_len = 2048,
    ffn_dim_multiplier = 16 / 11
)

print(model_args)


ModelArgs(dim=2048, n_layers=16, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=256, ffn_dim_multiplier=1.4545454545454546, norm_eps=1e-05, rope_theta=500000, use_scaled_rope=True, rope_scale_factor=32.0, max_batch_size=32, original_rotary_embed_len=8192, cache_len=2048)


In [5]:
print(config)

{'architectures': ['LlamaForCausalLM'], 'attention_bias': False, 'attention_dropout': 0.0, 'bos_token_id': 128000, 'eos_token_id': 128001, 'head_dim': 64, 'hidden_act': 'silu', 'hidden_size': 2048, 'initializer_range': 0.02, 'intermediate_size': 8192, 'max_position_embeddings': 131072, 'mlp_bias': False, 'model_type': 'llama', 'num_attention_heads': 32, 'num_hidden_layers': 16, 'num_key_value_heads': 8, 'pretraining_tp': 1, 'rms_norm_eps': 1e-05, 'rope_scaling': {'factor': 32.0, 'high_freq_factor': 4.0, 'low_freq_factor': 1.0, 'original_max_position_embeddings': 8192, 'rope_type': 'llama3'}, 'rope_theta': 500000.0, 'tie_word_embeddings': True, 'torch_dtype': 'bfloat16', 'transformers_version': '4.49.0', 'use_cache': True, 'vocab_size': 128256}


In [6]:
from llama.model import Transformer

# I upgraded to 16Gb of RAM and now this will run - just need to tune the max sequence length as it will preallocate
# the caches in the attention blocks based on that value

# RAM preservation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.bfloat16 

model = Transformer(model_args)

#.to(dtype=torch_dtype)

print("Transformer created")

safetensors_path = "/home/matt/.llama/checkpoints/Llama3.2-1B-hf/model.safetensors"  
weights = load_file(safetensors_path)

print("weights in RAM")


attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
orinal freqs: tensor([1.0000e+00, 6.6360e-01, 4.4037e-01, 2.9223e-01, 1.9392e-01, 1.2869e-01,
        8.5397e-02, 5.6670e-02, 3.7606e-02, 2.4955e-02, 1.6560e-02, 1.0990e-02,
        7.2927e-03, 4.8394e-03, 3.2114e-03, 2.1311e-03, 1.4142e-03, 9.3847e-04,
        6.2277e-04, 4.1327e-04, 2.7425e-04, 1.8199e-04, 1.2077e-04, 8.0143e-05,
        5.3183e-05, 3.5292e-05, 2.3420e-05, 1.5542e-05, 1.0313e-05, 6.8440e-06,
        4.5417e-06, 3.0139e-06])
rescaled freqs: tensor([1.0000e+00, 6.6360e-01, 4.4037e-01, 2.9223e-01, 1.9392e-01, 1.2869e-01,
        8

We see that the base frequencies for RoPE are the same as what we had in the llama.ipynb notebook; after we apply the scaling factor

In [7]:
# the weights' names don't match what is in the code I found
# so we rename them

weights = load_file(safetensors_path)

# create a new state dict with corrected names
fixed_state_dict = {}

for key in weights.keys():
    
    new_key = key
    new_key = new_key.replace("model.", "")
    
    new_key = new_key.replace("embed_tokens.weight", "tok_embeddings.weight")

    new_key = new_key.replace("self_attn.q_proj", "attention.wq")
    new_key = new_key.replace("self_attn.k_proj", "attention.wk")
    new_key = new_key.replace("self_attn.v_proj", "attention.wv")
    new_key = new_key.replace("self_attn.o_proj", "attention.wo")

    new_key = new_key.replace("mlp.gate_proj", "feed_forward.w1")
    new_key = new_key.replace("mlp.up_proj", "feed_forward.w3")
    new_key = new_key.replace("mlp.down_proj", "feed_forward.w2")

    new_key = new_key.replace("input_layernorm", "attention_norm")
    new_key = new_key.replace("post_attention_layernorm", "ffn_norm")

    new_key = new_key.replace("model.norm", "norm")

    fixed_state_dict[new_key] = weights[key]

In [8]:
fixed_state_dict.keys()

dict_keys(['tok_embeddings.weight', 'layers.0.attention_norm.weight', 'layers.0.feed_forward.w2.weight', 'layers.0.feed_forward.w1.weight', 'layers.0.feed_forward.w3.weight', 'layers.0.ffn_norm.weight', 'layers.0.attention.wk.weight', 'layers.0.attention.wo.weight', 'layers.0.attention.wq.weight', 'layers.0.attention.wv.weight', 'layers.1.attention_norm.weight', 'layers.1.feed_forward.w2.weight', 'layers.1.feed_forward.w1.weight', 'layers.1.feed_forward.w3.weight', 'layers.1.ffn_norm.weight', 'layers.1.attention.wk.weight', 'layers.1.attention.wo.weight', 'layers.1.attention.wq.weight', 'layers.1.attention.wv.weight', 'layers.10.attention_norm.weight', 'layers.10.feed_forward.w2.weight', 'layers.10.feed_forward.w1.weight', 'layers.10.feed_forward.w3.weight', 'layers.10.ffn_norm.weight', 'layers.10.attention.wk.weight', 'layers.10.attention.wo.weight', 'layers.10.attention.wq.weight', 'layers.10.attention.wv.weight', 'layers.11.attention_norm.weight', 'layers.11.feed_forward.w2.weight',

In [9]:
# Load the corrected state dict
model.load_state_dict(fixed_state_dict, strict = False) # since the output weights are tied, these are already correct and not loaded - but the Transformer class expects the; so load non-strictly

print("loaded model with corrected state_dict")

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)

loaded model with corrected state_dict


In [10]:
model

Transformer(
  (tok_embeddings): Embedding(128256, 2048)
  (layers): ModuleList(
    (0-15): 16 x TransformerBlock(
      (attention): Attention(
        (wq): ColumnParallelLinear(in_features=2048, out_features=2048, bias=False)
        (wk): ColumnParallelLinear(in_features=2048, out_features=512, bias=False)
        (wv): ColumnParallelLinear(in_features=2048, out_features=512, bias=False)
        (wo): RowParallelLinear(in_features=2048, out_features=2048, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): ColumnParallelLinear(in_features=2048, out_features=8192, bias=False)
        (w2): RowParallelLinear(in_features=8192, out_features=2048, bias=False)
        (w3): ColumnParallelLinear(in_features=2048, out_features=8192, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): ColumnParallelLinear(in_features=2048, out_features=128256, bias=False)
)

In [11]:
# tie the output embedding manually

model.output.weight.data = model.tok_embeddings.weight.data


In [12]:
from llama.tokenizer import Tokenizer

# go here to find the model file
# https://github.com/meta-llama/llama-models/blob/main/models/llama3/api/tokenizer.model (689c7f2)

new_tok_path = "/home/matt/.llama/checkpoints/Llama3.2-1B-hf-tok/tokenizer.model"
new_tok = Tokenizer(new_tok_path)
new_tok

<llama.tokenizer.Tokenizer at 0x7f4418a2cbc0>

In [13]:
from llama.generation import Llama

llama = Llama(model, new_tok, model_args)

In [14]:
llama

<llama.generation.Llama at 0x7f41b0fa4560>

In [None]:
from llama.generation import CompletionPrediction

# A high top_p seems necessary here, otherwise we just get strings of numbers
res: CompletionPrediction = llama.text_completion("Hello how are you?", max_gen_len=50, top_p = 0.9, temperature = 0.6)

[31mInput to model:
<|begin_of_text|>The capital of France is
[0m


In [16]:
res.generation

" Paris. This city has a population about 2,000 people.\nThis place was built in the year and it's Capital of The country France\nFrance region Of Region Map - For more than million inhabitants which live on river bank: On beautiful area"

In [17]:
res.logprobs

[[-2.078843116760254],
 [-1.2045153379440308],
 [-2.747781991958618],
 [-1.3441989421844482],
 [-1.9368460178375244],
 [-1.20782470703125],
 [-1.1919770240783691],
 [-3.838620901107788],
 [-0.1342192143201828],
 [-0.33271750807762146],
 [-1.1273064613342285],
 [-2.4497766494750977],
 [-2.324704170227051],
 [-2.197195053100586],
 [-3.1740481853485107],
 [-4.065153121948242],
 [-2.3178908824920654],
 [-2.6613805294036865],
 [-0.8573299050331116],
 [-1.07711660861969],
 [-1.4414087533950806],
 [-4.049238681793213],
 [-2.1237897872924805],
 [-0.6591014862060547],
 [-6.1876020431518555],
 [-0.6143554449081421],
 [-5.429050445556641],
 [-2.6682536602020264],
 [-1.0977528095245361],
 [-3.2166032791137695],
 [-2.9457433223724365],
 [-6.2648606300354],
 [-6.69734001159668],
 [-5.850415229797363],
 [-6.930875778198242],
 [-3.9093780517578125],
 [-6.449533462524414],
 [-3.8352251052856445],
 [-2.0982353687286377],
 [-3.828277826309204],
 [-2.7142529487609863],
 [-5.48582649230957],
 [-4.274913787