# Running Llama from source code

In this notebook we use the Pytorch source code to run the model locally, the goal is to pare back the extra stuff included so the code can be translated to use Jax

In [None]:
import sys
import os

# we want to import some llama source later
os.getcwd()
project_path = os.path.abspath("LLM")

if project_path not in sys.path:
    sys.path.append(project_path)

In [None]:
os.chdir('/home/matt/.llama/checkpoints')

## Loading into `Transfomer` instance from source code

We are interested in the activations through the layers of this model, so it would be good to load create an instance of the `Transformer` object defined in the `model.py` file

In [None]:
import json
import torch
from safetensors.torch import load_file
from llama.model_new import ModelArgs

config_path = "/home/matt/.llama/checkpoints/Llama3.2-1B-hf/config.json"
with open(config_path, "r") as f:
    config = json.load(f)

# extract the necessary fields
model_args = ModelArgs(
    dim=config.get("hidden_size", 4096), 
    n_layers=config.get("num_hidden_layers", 32),  
    n_heads=config.get("num_attention_heads", 32), 
    n_kv_heads=config.get("num_key_value_heads", None), 
    vocab_size=config.get("vocab_size", -1), 
    multiple_of=256, # not in config so use the default
    norm_eps=config.get("rms_norm_eps", 1e-5),  # map "rms_norm_eps"
    max_batch_size=32,  # not in config so use the default
    rotary_embed_len=config.get("max_position_embeddings", 2048),  # map "max_position_embeddings"
    cache_len = 2048,
    ffn_dim_multiplier = 16 / 11
)

print(model_args)


ModelArgs(dim=2048, n_layers=16, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=256, ffn_dim_multiplier=1.4545454545454546, norm_eps=1e-05, rope_theta=500000, use_scaled_rope=False, max_batch_size=32, rotary_embed_len=131072, cache_len=2048, vision_chunk_size=-1, vision_max_num_chunks=4, vision_num_cross_attention_layers=-1)


In [None]:
from llama.model_new import Transformer

# I upgraded to 16Gb of RAM and now this will run - just need to tune the max sequence length as it will preallocate
# the caches in the attention blocks based on that value

# RAM preservation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.bfloat16 

model = Transformer(model_args)

#.to(dtype=torch_dtype)

print("Transformer created")

safetensors_path = "/home/matt/.llama/checkpoints/Llama3.2-1B-hf/model.safetensors"  
weights = load_file(safetensors_path)

print("weights in RAM")


attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
attention block initialised
Transformer created
weights in RAM


In [None]:
# the weights' names don't match what is in the code I found
# so we rename them

weights = load_file(safetensors_path)

# create a new state dict with corrected names
fixed_state_dict = {}

for key in weights.keys():
    
    new_key = key
    new_key = new_key.replace("model.", "")
    
    new_key = new_key.replace("embed_tokens.weight", "tok_embeddings.weight")

    new_key = new_key.replace("self_attn.q_proj", "attention.wq")
    new_key = new_key.replace("self_attn.k_proj", "attention.wk")
    new_key = new_key.replace("self_attn.v_proj", "attention.wv")
    new_key = new_key.replace("self_attn.o_proj", "attention.wo")

    new_key = new_key.replace("mlp.gate_proj", "feed_forward.w1")
    new_key = new_key.replace("mlp.up_proj", "feed_forward.w3")
    new_key = new_key.replace("mlp.down_proj", "feed_forward.w2")

    new_key = new_key.replace("input_layernorm", "attention_norm")
    new_key = new_key.replace("post_attention_layernorm", "ffn_norm")

    new_key = new_key.replace("model.norm", "norm")

    fixed_state_dict[new_key] = weights[key]

In [None]:
fixed_state_dict.keys()

dict_keys(['tok_embeddings.weight', 'layers.0.attention_norm.weight', 'layers.0.feed_forward.w2.weight', 'layers.0.feed_forward.w1.weight', 'layers.0.feed_forward.w3.weight', 'layers.0.ffn_norm.weight', 'layers.0.attention.wk.weight', 'layers.0.attention.wo.weight', 'layers.0.attention.wq.weight', 'layers.0.attention.wv.weight', 'layers.1.attention_norm.weight', 'layers.1.feed_forward.w2.weight', 'layers.1.feed_forward.w1.weight', 'layers.1.feed_forward.w3.weight', 'layers.1.ffn_norm.weight', 'layers.1.attention.wk.weight', 'layers.1.attention.wo.weight', 'layers.1.attention.wq.weight', 'layers.1.attention.wv.weight', 'layers.10.attention_norm.weight', 'layers.10.feed_forward.w2.weight', 'layers.10.feed_forward.w1.weight', 'layers.10.feed_forward.w3.weight', 'layers.10.ffn_norm.weight', 'layers.10.attention.wk.weight', 'layers.10.attention.wo.weight', 'layers.10.attention.wq.weight', 'layers.10.attention.wv.weight', 'layers.11.attention_norm.weight', 'layers.11.feed_forward.w2.weight',

In [None]:
# Load the corrected state dict
model.load_state_dict(fixed_state_dict, strict = False) # since the output weights are tied, these are already correct and not loaded - but the Transformer class expects the; so load non-strictly

print("loaded model with corrected state_dict")

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)

loaded model with corrected state_dict


In [None]:
model

Transformer(
  (tok_embeddings): Embedding(128256, 2048)
  (layers): ModuleList(
    (0-15): 16 x TransformerBlock(
      (attention): Attention(
        (wq): ColumnParallelLinear(in_features=2048, out_features=2048, bias=False)
        (wk): ColumnParallelLinear(in_features=2048, out_features=512, bias=False)
        (wv): ColumnParallelLinear(in_features=2048, out_features=512, bias=False)
        (wo): RowParallelLinear(in_features=2048, out_features=2048, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): ColumnParallelLinear(in_features=2048, out_features=8192, bias=False)
        (w2): RowParallelLinear(in_features=8192, out_features=2048, bias=False)
        (w3): ColumnParallelLinear(in_features=2048, out_features=8192, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): ColumnParallelLinear(in_features=2048, out_features=128256, bias=False)
)

In [None]:
# tie the output embedding manually

model.output.weight.data = model.tok_embeddings.weight.data


In [None]:
from llama.tokenizer import Tokenizer

# go here to find the model file
# https://github.com/meta-llama/llama-models/blob/main/models/llama3/api/tokenizer.model (689c7f2)

new_tok_path = "/home/matt/.llama/checkpoints/Llama3.2-1B-hf-tok/tokenizer.model"
new_tok = Tokenizer(new_tok_path)
new_tok

<llama.tokenizer.Tokenizer at 0x7f4282dd4890>

In [None]:
from llama.generation_new import Llama

llama = Llama(model, new_tok, model_args)

In [None]:
llama

<llama.generation_new.Llama at 0x7f4282d380e0>

In [None]:
from llama.generation_new import CompletionPrediction

# A high top_p seems necessary here, otherwise we just get strings of numbers
res: CompletionPrediction = llama.text_completion("The capital of France is", max_gen_len=50, top_p = 0.99, temperature = 0.3)

[31mInput to model:
<|begin_of_text|>The capital of France is
[0m


In [15]:
res.generation

' the capital of France is Paris. The capital of France is Paris. The capital of France is Paris. The capital of France is Paris. The capital of France is Paris. The capital of France is Paris. The capital of France is Paris. The'

In [16]:
res.logprobs

[[-2.0232808589935303],
 [-1.8558874130249023],
 [-0.35268905758857727],
 [-0.4167436957359314],
 [-1.016157627105713],
 [-0.8968588709831238],
 [-1.2876267433166504],
 [-1.806140661239624],
 [-1.2675158977508545],
 [-0.26237577199935913],
 [-0.19873137772083282],
 [-0.38704994320869446],
 [-0.7489703893661499],
 [-0.43344929814338684],
 [-0.883790910243988],
 [-0.17068198323249817],
 [-0.1169423907995224],
 [-0.1793278455734253],
 [-0.1501794308423996],
 [-0.3163699507713318],
 [-0.3283524513244629],
 [-0.5362377166748047],
 [-0.11693623661994934],
 [-0.08320949971675873],
 [-0.18952108919620514],
 [-0.1585375815629959],
 [-0.4404500126838684],
 [-0.2826058566570282],
 [-0.28091034293174744],
 [-0.06224074587225914],
 [-0.0804356187582016],
 [-0.15348391234874725],
 [-0.2140636295080185],
 [-0.5775657892227173],
 [-0.24491234123706818],
 [-0.27858924865722656],
 [-0.15325607359409332],
 [-0.18884742259979248],
 [-0.1465480774641037],
 [-0.5145536065101624],
 [-1.1685043573379517],
 [-