In [1]:
import os
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12345"

In [2]:
%cd ../llama3/

/home/alexo/projects/llama3/experiments/llama3


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [3]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.

from typing import List

import fire

from llama import Llama

In [4]:
ckpt_dir: str = "../Meta-Llama-3-8B/original/"
tokenizer_path: str = "../Meta-Llama-3-8B/original/tokenizer.model"
temperature: float = 0.6
top_p: float = 0.9
max_seq_len: int = 128
max_gen_len: int = 64
max_batch_size: int = 4
"""
Examples to run with the pre-trained models (no fine-tuning). Prompts are
usually in the form of an incomplete text prefix that the model can then try to complete.

The context window of llama3 models is 8192 tokens, so `max_seq_len` needs to be <= 8192.
`max_gen_len` is needed because pre-trained models usually do not stop completions naturally.
"""
generator = Llama.build(
    ckpt_dir=ckpt_dir,
    tokenizer_path=tokenizer_path,
    max_seq_len=max_seq_len,
    max_batch_size=max_batch_size,
)

> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1


  _C._set_default_tensor_type(t)


Loaded in 8.93 seconds


In [5]:
prompts: List[str] = [
    # For these prompts, the expected answer is the natural continuation of the prompt
    "I believe the meaning of life is"
]

In [12]:
generator.model

Transformer(
  (tok_embeddings): VocabParallelEmbedding()
  (layers): ModuleList(
    (0-31): 32 x TransformerBlock(
      (attention): Attention(
        (wq): ColumnParallelLinear()
        (wk): ColumnParallelLinear()
        (wv): ColumnParallelLinear()
        (wo): RowParallelLinear()
      )
      (feed_forward): FeedForward(
        (w1): ColumnParallelLinear()
        (w2): RowParallelLinear()
        (w3): ColumnParallelLinear()
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): ColumnParallelLinear()
)

In [9]:
results = generator.text_completion(
    prompts,
    max_gen_len=4,
    temperature=temperature,
    top_p=top_p,
)
for prompt, result in zip(prompts, results):
    print(prompt)
    print(f"> {result['generation']}")
    print("\n==================================\n")

I believe the meaning of life is
>  to find your gift




In [28]:
import torch
import pickle

# Define the function to register hooks
@torch.no_grad()
def get_intermediate_outputs(model, input_ids):
    outputs = {}
    inputs = {}

    def get_activation(name):
      def hook(module, input, output):
          outputs[name] = output#[0].detach()
          inputs[name] = input
      return hook

    hooks = []
    for name, module in model.named_modules():
        #if len(list(module.children())) == 0:  # only register hook on leaf modules
        if name.startswith("layers.0.") or name.startswith("layers.1."):
            hooks.append(module.register_forward_hook(get_activation(name)))

    # Forward pass
    model(input_ids, 0)
    """
    outputs_model = model.generate(
        input_data,
        max_new_tokens=1,
        eos_token_id=terminators,
        do_sample=True,
        temperature=0.6,
        top_p=0.9,
    )
    """

    # Remove hooks
    for hook in hooks:
        hook.remove()

    return outputs, inputs

In [14]:
prompt_tokens = [generator.tokenizer.encode(x, bos=True, eos=False) for x in prompts]

In [15]:
prompt_tokens

[[128000, 40, 4510, 279, 7438, 315, 2324, 374]]

In [19]:
import torch

params = generator.model.params
bsz = len(prompt_tokens)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)

min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
assert max_prompt_len <= params.max_seq_len
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)

pad_id = generator.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
for k, t in enumerate(prompt_tokens):
    tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")

prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
input_text_mask = tokens != pad_id
#if min_prompt_len == total_len:
#    logits = self.model.forward(tokens, prev_pos)

In [27]:
prev_pos

0

In [20]:
tokens

tensor([[128000,     40,   4510,    279,   7438,    315,   2324,    374,     -1,
             -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1,
             -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1,
             -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1,
             -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1,
             -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1,
             -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1,
             -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1]])

In [29]:
# Get intermediate outputs
intermediate_outputs, intermediate_inputs = get_intermediate_outputs(generator.model, tokens)

# Save to pickle file
with open('intermediate_data_meta_llama3_8b.pkl', 'wb') as f:
    pickle.dump({"inputs": intermediate_inputs, "outputs": intermediate_outputs}, f)

print("Intermediate outputs have been saved to intermediate_data_meta_llama3_8b.pkl")

Intermediate outputs have been saved to intermediate_data_meta_llama3_8b.pkl


In [32]:
for name, module in generator.model.named_modules():
    if name.startswith("layers.0.") or name.startswith("layers.1."):
        print(name)

layers.0.attention
layers.0.attention.wq
layers.0.attention.wk
layers.0.attention.wv
layers.0.attention.wo
layers.0.feed_forward
layers.0.feed_forward.w1
layers.0.feed_forward.w2
layers.0.feed_forward.w3
layers.0.attention_norm
layers.0.ffn_norm
layers.1.attention
layers.1.attention.wq
layers.1.attention.wk
layers.1.attention.wv
layers.1.attention.wo
layers.1.feed_forward
layers.1.feed_forward.w1
layers.1.feed_forward.w2
layers.1.feed_forward.w3
layers.1.attention_norm
layers.1.ffn_norm


In [35]:
generator.model.state_dict()["layers.0.attention.wq.weight"]

tensor([[-2.7618e-03, -2.9053e-02, -3.1586e-03,  ...,  7.3547e-03,
         -4.6875e-02, -2.1606e-02],
        [ 2.6367e-02,  3.3264e-03, -8.4839e-03,  ..., -7.5378e-03,
         -5.7678e-03,  5.6458e-03],
        [-1.2512e-02, -6.9824e-02, -3.8605e-03,  ..., -1.2573e-02,
         -4.9805e-02,  2.0508e-02],
        ...,
        [-5.2795e-03, -1.4709e-02,  4.1504e-02,  ...,  5.4321e-03,
         -3.2349e-03,  4.4346e-05],
        [ 4.4632e-04,  3.1250e-02, -6.1523e-02,  ..., -2.3804e-03,
          1.1444e-03, -1.8768e-03],
        [-4.1504e-03, -1.6724e-02,  3.0396e-02,  ...,  8.6060e-03,
          8.0872e-04,  3.1433e-03]])

In [36]:
generator.model.state_dict()["layers.0.attention.wk.weight"]

tensor([[-0.1040, -0.1543,  0.0737,  ...,  0.0312, -0.0231,  0.0442],
        [-0.0447, -0.0293,  0.0396,  ...,  0.0067,  0.0242, -0.0035],
        [-0.0564, -0.0869,  0.0188,  ...,  0.0193, -0.0073,  0.0293],
        ...,
        [ 0.0136,  0.0356, -0.0162,  ..., -0.0177,  0.0018,  0.0102],
        [-0.0052, -0.0284,  0.0289,  ...,  0.0135,  0.0055, -0.0042],
        [ 0.0039, -0.0100,  0.0118,  ..., -0.0153,  0.0016, -0.0206]])

In [8]:
generator.model.state_dict()["layers.0.attention.wv.weight"]

tensor([[ 0.0089, -0.0020, -0.0005,  ...,  0.0026,  0.0008,  0.0031],
        [ 0.0002, -0.0040, -0.0001,  ..., -0.0029, -0.0040,  0.0025],
        [ 0.0102,  0.0008,  0.0015,  ..., -0.0062,  0.0080,  0.0070],
        ...,
        [ 0.0079,  0.0008,  0.0029,  ..., -0.0014, -0.0064, -0.0064],
        [ 0.0032,  0.0012,  0.0025,  ...,  0.0027, -0.0046, -0.0011],
        [-0.0024, -0.0070,  0.0017,  ...,  0.0033,  0.0071, -0.0034]])

In [41]:
generator.model.state_dict()['tok_embeddings.weight']

tensor([[ 1.3733e-03,  5.0964e-03, -3.0365e-03,  ...,  2.2888e-03,
         -1.9531e-03, -1.7166e-05],
        [-2.7313e-03,  1.9379e-03, -1.3733e-03,  ..., -5.1498e-05,
         -1.3962e-03, -1.9836e-03],
        [ 9.5367e-04, -1.3367e-02,  4.1771e-04,  ...,  2.5940e-03,
          7.0496e-03,  4.1809e-03],
        ...,
        [ 1.8715e-23,  3.2699e-24,  1.8198e-23,  ...,  5.3767e-23,
         -2.2360e-24, -1.9852e-23],
        [ 1.9335e-23, -1.8612e-24, -1.8818e-23,  ...,  2.3368e-23,
          7.3412e-24, -3.1226e-23],
        [-7.4860e-23, -6.3693e-23,  5.5059e-24,  ...,  4.9631e-24,
         -5.4594e-23, -2.2877e-24]])

In [6]:
generator.model.state_dict()['layers.0.feed_forward.w1.weight']

tensor([[-0.0121, -0.0051, -0.0036,  ...,  0.0149, -0.0134, -0.0030],
        [-0.0067, -0.0267, -0.0032,  ...,  0.0131,  0.0046, -0.0016],
        [ 0.0110, -0.0005,  0.0135,  ..., -0.0006,  0.0047,  0.0050],
        ...,
        [-0.0008, -0.0114, -0.0102,  ..., -0.0117,  0.0050, -0.0177],
        [-0.0045, -0.0008, -0.0041,  ..., -0.0183, -0.0143,  0.0048],
        [-0.0057,  0.0095,  0.0055,  ..., -0.0063,  0.0157, -0.0043]])

In [7]:
generator.model.state_dict()['layers.0.feed_forward.w2.weight']

tensor([[ 0.0087, -0.0151, -0.0090,  ...,  0.0079, -0.0039,  0.0134],
        [ 0.0204, -0.0107, -0.0057,  ...,  0.0010,  0.0172,  0.0011],
        [ 0.0082, -0.0075, -0.0023,  ..., -0.0018,  0.0025, -0.0165],
        ...,
        [ 0.0085, -0.0208,  0.0217,  ..., -0.0199,  0.0081, -0.0129],
        [-0.0135, -0.0059, -0.0110,  ...,  0.0093,  0.0015, -0.0131],
        [-0.0029,  0.0069,  0.0085,  ..., -0.0082, -0.0051, -0.0120]])

In [16]:
import json
def read_json(path):
    with open(path, "r") as f:
        return json.load(f)

num_shards = 1
params = read_json("../Meta-Llama-3-8B/original/params.json")
params = params.get("model", params)
n_layers = params["n_layers"]
n_heads = params["n_heads"]
n_heads_per_shard = n_heads // num_shards
dim = params["dim"]
dims_per_head = dim // n_heads

if params.get("n_kv_heads", None) is not None:
    num_key_value_heads = params["n_kv_heads"]  # for GQA / MQA
    num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
    key_value_dim = dim // num_key_value_heads
else:  # compatibility with other checkpoints
    num_key_value_heads = n_heads
    num_local_key_value_heads = n_heads_per_shard
    key_value_dim = dim

def permute(w, n_heads, dim1=dim, dim2=dim):
    return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)

In [26]:
def hf_undo_permute(w, n_heads, dim1=dim, dim2=dim):
    return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2)

In [13]:
dim

4096

In [18]:
num_key_value_heads

8

In [20]:
wq_weight = generator.model.state_dict()["layers.0.attention.wq.weight"].clone()

In [22]:
wq_weight.shape

torch.Size([4096, 4096])

In [24]:
wq_weight.view(n_heads, dim // n_heads // 2, 2, dim).shape

torch.Size([32, 64, 2, 4096])

In [25]:
wq_weight.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).shape

torch.Size([32, 2, 64, 4096])

In [None]:
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
    loaded[f"layers.{layer_i}.attention.wq.weight"], n_heads=n_heads
),
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
    loaded[f"layers.{layer_i}.attention.wk.weight"],
    n_heads=num_key_value_heads,
    dim1=dim // num_local_key_value_heads,
),

In [32]:
generator.model.state_dict()["layers.0.attention.wk.weight"].shape

torch.Size([1024, 4096])

In [36]:
dim1=dim // num_local_key_value_heads
dim2=dim
print(generator.model.state_dict()["layers.0.attention.wk.weight"].view(n_heads, dim1 // n_heads // 2, 2, dim2).shape)
print(generator.model.state_dict()["layers.0.attention.wk.weight"].view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).shape)

torch.Size([32, 16, 2, 4096])
torch.Size([32, 2, 16, 4096])


In [30]:
n_heads_per_shard

32

In [31]:
num_key_value_heads

8

In [41]:
generator.model.state_dict()["layers.0.attention.wq.weight"]

tensor([[-2.7618e-03, -2.9053e-02, -3.1586e-03,  ...,  7.3547e-03,
         -4.6875e-02, -2.1606e-02],
        [ 2.6367e-02,  3.3264e-03, -8.4839e-03,  ..., -7.5378e-03,
         -5.7678e-03,  5.6458e-03],
        [-1.2512e-02, -6.9824e-02, -3.8605e-03,  ..., -1.2573e-02,
         -4.9805e-02,  2.0508e-02],
        ...,
        [-5.2795e-03, -1.4709e-02,  4.1504e-02,  ...,  5.4321e-03,
         -3.2349e-03,  4.4346e-05],
        [ 4.4632e-04,  3.1250e-02, -6.1523e-02,  ..., -2.3804e-03,
          1.1444e-03, -1.8768e-03],
        [-4.1504e-03, -1.6724e-02,  3.0396e-02,  ...,  8.6060e-03,
          8.0872e-04,  3.1433e-03]])