# LLaMA in TransformerLens

Need main branch of ```transformers```. Will be changed when branch gets merged to stable release. 

In [None]:
!pip install git+https://github.com/huggingface/transformers

## Setup (skip)

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    %pip install circuitsvis
    
    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Jupyter notebook - intended for development only!


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

import circuitsvis as cv

Using renderer: colab


In [3]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
from tqdm import tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from torchtyping import TensorType as TT
from typing import List, Union, Optional
from jaxtyping import Float, Int
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
# import circuitsvis as cv

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

torch.set_grad_enabled(False)

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

## Loading model

In [7]:
model = HookedTransformer.from_pretrained("llama-7b", device="cpu", fold_ln=False, center_writing_weights=False, center_unembed=False)
# model = HookedTransformer.from_pretrained("llama-7b", device="cpu")

model.generate("The capital of Germany is", max_new_tokens=20, temperature=0)

Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. 
The class this function is called from is 'LlamaTokenizer'.


Loaded pretrained model llama-7b into HookedTransformer


  0%|          | 0/20 [00:00<?, ?it/s]

' The capital of Germany is Berlin.\nThe capital of Germany is Berlin. The capital of Germany is Berlin. The capital of'

### Compare logits with HuggingFace model

Have to use the main branch of the huggingface repo and not the stable pypi release. 

In [9]:
from transformers import LlamaForCausalLM, LlamaTokenizer

tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
hf_model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf", low_cpu_mem_usage = True)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. 
The class this function is called from is 'LlamaTokenizer'.


Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

In [10]:
prompts = [
    "The capital of Germany is",
    "2 * 42 = ", 
    "My favorite", 
    "aosetuhaosuh aostud aoestuaoentsudhasuh aos tasat naostutshaosuhtnaoe usaho uaotsnhuaosntuhaosntu haouaoshat u saotheu saonuh aoesntuhaosut aosu thaosu thaoustaho usaothusaothuao sutao sutaotduaoetudet uaosthuao uaostuaoeu aostouhsaonh aosnthuaoscnuhaoshkbaoesnit haosuhaoe uasotehusntaosn.p.uo ksoentudhao ustahoeuaso usant.hsa otuhaotsi aostuhs",
]

model.eval()
hf_model.eval()
prompt_ids = [tokenizer.encode(prompt, return_tensors="pt") for prompt in prompts]
logits = [hf_model(prompt_ids).logits.detach().cpu() for prompt_ids in tqdm(prompt_ids)]
tl_logits = [model(prompt).detach().cpu() for prompt in tqdm(prompts)]

for i in range(len(prompts)): 
    print(torch.allclose(logits[i], tl_logits[i], atol=1e-3))

100%|██████████| 4/4 [00:05<00:00,  1.42s/it]
100%|██████████| 4/4 [00:13<00:00,  3.48s/it]

True
True
True
True





## TransformerLens Demo

### Reading from hooks

In [11]:
llama_text = "Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets."
llama_tokens = model.to_tokens(llama_text)
llama_logits, llama_cache = model.run_with_cache(llama_tokens, remove_batch_dim=True)

attention_pattern = llama_cache["pattern", 0, "attn"]
llama_str_tokens = model.to_str_tokens(llama_text)

print("Layer 0 Head Attention Patterns:")
cv.attention.attention_patterns(tokens=llama_str_tokens, attention=attention_pattern)

Layer 0 Head Attention Patterns:


### Writing to hooks

In [20]:
layer_to_ablate = 0
head_index_to_ablate = 31

# We define a head ablation hook
# The type annotations are NOT necessary, they're just a useful guide to the reader
# 
def head_ablation_hook(
    value: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    print(f"Shape of the value tensor: {value.shape}")
    value[:, :, head_index_to_ablate, :] = 0.
    return value

original_loss = model(llama_tokens, return_type="loss")
ablated_loss = model.run_with_hooks(
    llama_tokens, 
    return_type="loss", 
    fwd_hooks=[(
        utils.get_act_name("v", layer_to_ablate), 
        head_ablation_hook
        )]
    )
print(f"Original Loss: {original_loss.item():.3f}")
print(f"Ablated Loss: {ablated_loss.item():.3f}")

Shape of the value tensor: torch.Size([1, 34, 32, 128])
Original Loss: 2.315
Ablated Loss: 2.306


## Try float16

In [4]:
from transformers import LlamaForCausalLM, LlamaTokenizer

tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
hf_model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf", low_cpu_mem_usage = True)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. 
The class this function is called from is 'LlamaTokenizer'.


Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

In [5]:
cfg = HookedTransformerConfig(
    n_layers=32,
    d_model=4096,
    d_head=128,
    n_heads=32,
    d_mlp=11008,
    d_vocab=32000,
    n_ctx=2048, # max_context_length from rotary_embedding
    eps=1e-6,
    act_fn="silu", # silu
    normalization_type="RMS", # at start of decoder and after attention
    positional_embedding_type="rotary", 
    rotary_dim=128, 
    final_rms=True,
    gated_mlp=True,
    seed=42,
    init_weights=False,
    original_architecture="GPTNeoX",
    device='cpu',
)
tl_model = HookedTransformer(cfg)

In [6]:
tl_state_dict = transformer_lens.loading_from_pretrained.convert_llama_weights(hf_model, cfg)

In [7]:
tl_model.load_and_process_state_dict(tl_state_dict, fold_ln=False, center_writing_weights=False, center_unembed=False, fold_value_biases=False)

In [8]:
tl_model.set_tokenizer(tokenizer)

tl_model.tokenizer.pad_token = tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tl_model.to_tokens("Hi")

tensor([[   1, 6324]])

In [9]:
tl_model.to(torch.float16)
tl_model.to('cuda')

Changing model dtype to torch.float16
Moving model to device:  cuda


HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
        (hook_out): HookPoint()
      )
      (ln2): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
        (hook_out): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): GatedMLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_re

In [10]:
tl_model.generate("The capital of Germany is", max_new_tokens=20, temperature=0)

  0%|          | 0/20 [00:00<?, ?it/s]

In [15]:
prompt = "The capital of Germany is"
prompt_ids = model.tokenizer(prompt, return_tensors="pt").input_ids

tl_model.eval()
hf_model.eval()

logits = tl_model(prompt_ids)[0]
hf_logits = hf_model(prompt_ids).logits.squeeze()
print(logits.shape, hf_logits.shape)
print(logits[5,:10], hf_logits[5,:10])
# print norm difference
assert torch.allclose(logits, hf_logits, atol=1e-3)

torch.Size([6, 32000]) torch.Size([6, 32000])
tensor([ -8.2931, -11.5887,   2.8618,  -0.8909,  -4.6046,  -5.0958,  -2.4957,
         -2.8411,  -3.3433,  -3.3339]) tensor([ -8.1932, -11.4758,   3.1597,  -0.7884,  -4.5046,  -4.9464,  -2.4317,
         -2.7070,  -3.2581,  -3.0145])


In [29]:
tl_model.generate("The capital of Germany is", max_new_tokens=20, temperature=0)

  0%|          | 0/20 [00:00<?, ?it/s]

' The capital of Germany is Berlin.\nThe capital of Germany is Berlin. The capital of Germany is Berlin. The capital of'

In [30]:
hs_0 = tl_model.run_with_cache(prompt)[1]
hs_1 = model.run_with_cache(prompt)[1]

for key in hs_0.keys():
    assert torch.allclose(hs_0[key], hs_1[key], atol=1e-3), f"{key} not close"