In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM

In [2]:
import torch
import pickle

# Define the function to register hooks
@torch.no_grad()
def get_intermediate_outputs(model, input_data, terminators):
    outputs = {}
    inputs = {}

    def get_activation(name):
      def hook(module, input, output):
          outputs[name] = output#[0].detach()
          inputs[name] = input
      return hook

    hooks = []
    for name, module in model.named_modules():
        if len(list(module.children())) == 0:  # only register hook on leaf modules
            if "model.layers.0." in name or "model.layers.1." in name:
                hooks.append(module.register_forward_hook(get_activation(name)))

    # Forward pass
    model(input_ids)
    """
    outputs_model = model.generate(
        input_data,
        max_new_tokens=1,
        eos_token_id=terminators,
        do_sample=True,
        temperature=0.6,
        top_p=0.9,
    )
    """

    # Remove hooks
    for hook in hooks:
        hook.remove()

    return outputs, inputs

In [3]:
model_id = "../Meta-Llama-3-8B/"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
).eval()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:
messages = [
    {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
    {"role": "user", "content": "Who are you?"},
]

input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
)#.to(model.device)

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]


No chat template is defined for this tokenizer - using a default chat template that implements the ChatML format (without BOS/EOS tokens!). If the default is not appropriate for your model, please set `tokenizer.chat_template` to an appropriate template. See https://huggingface.co/docs/transformers/main/chat_templating for more information.



In [5]:
# Get intermediate outputs
intermediate_outputs, intermediate_inputs = get_intermediate_outputs(model, input_ids, terminators)

# Save to pickle file
with open('intermediate_data_llama3_8b.pkl', 'wb') as f:
    pickle.dump({"inputs": intermediate_inputs, "outputs": intermediate_outputs}, f)

print("Intermediate outputs have been saved to intermediate_data_llama3_8b.pkl")

Intermediate outputs have been saved to intermediate_data_llama3_8b.pkl


In [6]:
intermediate_inputs['model.layers.0.input_layernorm']

(tensor([[[ 0.0086, -0.0072,  0.0017,  ...,  0.0080,  0.0008, -0.0046],
          [ 0.0019,  0.0092,  0.0042,  ...,  0.0089,  0.0005, -0.0122],
          [ 0.0027, -0.0045, -0.0003,  ..., -0.0064, -0.0031,  0.0022],
          ...,
          [ 0.0089, -0.0010, -0.0006,  ...,  0.0103,  0.0020, -0.0016],
          [-0.0089, -0.0030,  0.0049,  ...,  0.0003, -0.0013,  0.0082],
          [-0.0031,  0.0015,  0.0018,  ..., -0.0017,  0.0006,  0.0023]]],
        dtype=torch.bfloat16),)

In [7]:
intermediate_outputs['model.layers.0.input_layernorm'], intermediate_outputs['model.layers.0.input_layernorm'].shape

(tensor([[[ 0.0579, -0.1865,  0.0942,  ...,  0.0859,  0.0042, -0.0166],
          [ 0.0142,  0.2637,  0.2617,  ...,  0.1040,  0.0027, -0.0488],
          [ 0.0204, -0.1328, -0.0201,  ..., -0.0781, -0.0194,  0.0092],
          ...,
          [ 0.0640, -0.0273, -0.0337,  ...,  0.1187,  0.0116, -0.0063],
          [-0.0457, -0.0596,  0.2100,  ...,  0.0023, -0.0055,  0.0229],
          [-0.0287,  0.0540,  0.1436,  ..., -0.0261,  0.0042,  0.0117]]],
        dtype=torch.bfloat16),
 torch.Size([1, 51, 4096]))

In [7]:
for name, module in model.named_modules():
    if len(list(module.children())) == 0:  # only register hook on leaf modules
        if "model.layers.1" in name or "model.layers.0" in name:
        print(name, module)

model.embed_tokens Embedding(128256, 4096)
model.layers.0.self_attn.q_proj Linear(in_features=4096, out_features=4096, bias=False)
model.layers.0.self_attn.k_proj Linear(in_features=4096, out_features=1024, bias=False)
model.layers.0.self_attn.v_proj Linear(in_features=4096, out_features=1024, bias=False)
model.layers.0.self_attn.o_proj Linear(in_features=4096, out_features=4096, bias=False)
model.layers.0.self_attn.rotary_emb LlamaRotaryEmbedding()
model.layers.0.mlp.gate_proj Linear(in_features=4096, out_features=14336, bias=False)
model.layers.0.mlp.up_proj Linear(in_features=4096, out_features=14336, bias=False)
model.layers.0.mlp.down_proj Linear(in_features=14336, out_features=4096, bias=False)
model.layers.0.mlp.act_fn SiLU()
model.layers.0.input_layernorm LlamaRMSNorm()
model.layers.0.post_attention_layernorm LlamaRMSNorm()
model.layers.1.self_attn.q_proj Linear(in_features=4096, out_features=4096, bias=False)
model.layers.1.self_attn.k_proj Linear(in_features=4096, out_feature