In [None]:
#!pip install transformers
#!pip install sentencepiece

In [1]:
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM

In [3]:
model_name = "meta-llama/Llama-2-7b-chat-hf"
tokenizer = LlamaTokenizer.from_pretrained(model_name)
model = LlamaForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16
)


model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

In [4]:
class LayerActivationExtractor:
    def __init__(self):
        self.activations = {}
        
    def get_activation(self, name):
        def hook(model, input, output):
            self.activations[name] = output
        return hook
    
    def register_hooks(self, model):
        hooks = []
        for i, layer in enumerate(model.model.layers):
            hook = layer.register_forward_hook(self.get_activation(f'layer_{i}'))
            hooks.append(hook)
        return hooks
    
    def remove_hooks(self, hooks):
        for hook in hooks:
            hook.remove()

In [5]:
extractor = LayerActivationExtractor()
hooks = extractor.register_hooks(model)

In [6]:
text = "Hello, how are you?"
inputs = tokenizer(text, return_tensors="pt")

In [7]:
with torch.no_grad():
    outputs = model(**inputs)

In [11]:
outputs

CausalLMOutputWithPast(loss=None, logits=tensor([[[  0.1040,  -0.2217,   0.3130,  ...,   1.3271,   1.8799,   0.6436],
         [ -7.5156,  -2.1777,  -1.1445,  ...,  -6.3711,  -4.6445,  -7.4688],
         [ -7.5078,  -2.4102,   1.2578,  ...,  -3.6191,  -2.8809,  -5.0898],
         ...,
         [ -1.0938,   0.5630,   7.4570,  ...,   1.8789,   0.3020,  -0.4321],
         [ -2.3301,   0.8994,   5.9570,  ...,  -0.1826,  -1.5830,  -2.1504],
         [-10.2422,  -6.4375,   5.2031,  ...,  -4.2344,  -4.0898,  -5.2734]]],
       dtype=torch.float16), past_key_values=DynamicCache(layers=[<transformers.cache_utils.DynamicLayer object at 0x0000019EA8E204D0>, <transformers.cache_utils.DynamicLayer object at 0x0000019EA818E0C0>, <transformers.cache_utils.DynamicLayer object at 0x0000019EA818FC50>, <transformers.cache_utils.DynamicLayer object at 0x0000019EA818E840>, <transformers.cache_utils.DynamicLayer object at 0x0000019EA818C950>, <transformers.cache_utils.DynamicLayer object at 0x0000019EA818FD

In [8]:
for layer_name, activation in extractor.activations.items():
    print(f"{layer_name}: {activation[0].shape}")

layer_0: torch.Size([7, 4096])
layer_1: torch.Size([7, 4096])
layer_2: torch.Size([7, 4096])
layer_3: torch.Size([7, 4096])
layer_4: torch.Size([7, 4096])
layer_5: torch.Size([7, 4096])
layer_6: torch.Size([7, 4096])
layer_7: torch.Size([7, 4096])
layer_8: torch.Size([7, 4096])
layer_9: torch.Size([7, 4096])
layer_10: torch.Size([7, 4096])
layer_11: torch.Size([7, 4096])
layer_12: torch.Size([7, 4096])
layer_13: torch.Size([7, 4096])
layer_14: torch.Size([7, 4096])
layer_15: torch.Size([7, 4096])
layer_16: torch.Size([7, 4096])
layer_17: torch.Size([7, 4096])
layer_18: torch.Size([7, 4096])
layer_19: torch.Size([7, 4096])
layer_20: torch.Size([7, 4096])
layer_21: torch.Size([7, 4096])
layer_22: torch.Size([7, 4096])
layer_23: torch.Size([7, 4096])
layer_24: torch.Size([7, 4096])
layer_25: torch.Size([7, 4096])
layer_26: torch.Size([7, 4096])
layer_27: torch.Size([7, 4096])
layer_28: torch.Size([7, 4096])
layer_29: torch.Size([7, 4096])
layer_30: torch.Size([7, 4096])
layer_31: torch.Si

In [10]:
model.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_e