# LLaMA and Llama (2) in TransformerLens

This demo requires `transformers` version 4.31.0 (which adds Llama-2 support). This tutorial has part a) for LLaMA and b) for Llama-2. Currently the only Llama-2 support is the 7B chat model, as this notebook is being tested.

Steps to run this demo:

1a. Get LLaMA weights here: https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform

1b. Get Llama-2 weights here: https://ai.meta.com/resources/models-and-libraries/llama-downloads/

2a. Convert the official weights to huggingface. 

```bash
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
    --input_dir /path/to/downloaded/llama/weights \
    --model_size 7B \
    --output_dir /output/path
```

2b. Same step for Llama-2, we'll use `7Bf` the 7B chat version

```bash
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
    --input_dir /path/to/downloaded/llama-2/weights \
    --model_size 7Bf \
    --output_dir /output/path
```

Note: this didn't work for Arthur by default (even though HF doesn't seem to show this anywhere). I had to change <a href="https://github.com/huggingface/transformers/blob/07360b6/src/transformers/models/llama/convert_llama_weights_to_hf.py#L295"this</a> line of my pip installed `src/transformers/models/llama/convert_llama_weights_to_hf.py` file (which was found at `/opt/conda/envs/arthurenv/lib/python3.10/site-packages/transformers/models/llama/convert_llama_weights_to_hf.py`) from 

`input_base_path=os.path.join(args.input_dir, args.model_size),` to `input_base_path=os.path.join(args.input_dir),`

3. Change the ```MODEL_PATH``` variable in the notebook to the where the converted weights are stored.

In [1]:
from typing import Literal

MODE: Literal["LLaMA", "Llama-2"] = "Llama-2" # change to LLaMA for original LLaMA

In [2]:
!pip install transformers==4.31.0

[0m

## Setup (skip)

In [3]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    %pip install circuitsvis
    
    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Jupyter notebook - intended for development only!


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [4]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

import circuitsvis as cv

Using renderer: colab


In [5]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
from tqdm import tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from torchtyping import TensorType as TT
from typing import List, Union, Optional
from jaxtyping import Float, Int
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
# import circuitsvis as cv

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

torch.set_grad_enabled(False)

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

## Loading model

In [6]:
from transformers import LlamaForCausalLM, LlamaTokenizer
import os

MODEL_PATH=''

if "CONDA_PREFIX" in os.environ and "arthur" in os.environ["CONDA_PREFIX"]: # so Arthur can test fast
    MODEL_PATH=os.path.expanduser('~/lam_out')

tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
hf_model = LlamaForCausalLM.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True)

Downloading tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. 
The class this function is called from is 'LlamaTokenizer'.
You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565


Downloading (…)lve/main/config.json:   0%|          | 0.00/427 [00:00<?, ?B/s]

Downloading (…)model.bin.index.json:   0%|          | 0.00/25.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/33 [00:00<?, ?it/s]

Downloading (…)l-00001-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00002-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00003-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00004-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00005-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00006-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00007-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00008-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00009-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00010-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00011-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00012-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00013-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00014-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00015-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00016-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00017-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00018-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00019-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00020-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00021-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00022-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00023-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00024-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00025-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00026-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00027-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00028-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00029-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00030-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00031-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00032-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00033-of-00033.bin:   0%|          | 0.00/524M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

Some weights of LlamaForCausalLM were not initialized from the model checkpoint at decapoda-research/llama-7b-hf and are newly initialized: ['model.layers.30.self_attn.rotary_emb.cos_cached', 'model.layers.19.self_attn.rotary_emb.sin_cached', 'model.layers.8.self_attn.rotary_emb.sin_cached', 'model.layers.23.self_attn.rotary_emb.cos_cached', 'model.layers.1.self_attn.rotary_emb.sin_cached', 'model.layers.20.self_attn.rotary_emb.cos_cached', 'model.layers.7.self_attn.rotary_emb.cos_cached', 'model.layers.10.self_attn.rotary_emb.sin_cached', 'model.layers.17.self_attn.rotary_emb.sin_cached', 'model.layers.4.self_attn.rotary_emb.sin_cached', 'model.layers.28.self_attn.rotary_emb.sin_cached', 'model.layers.21.self_attn.rotary_emb.cos_cached', 'model.layers.18.self_attn.rotary_emb.cos_cached', 'model.layers.12.self_attn.rotary_emb.sin_cached', 'model.layers.14.self_attn.rotary_emb.sin_cached', 'model.layers.24.self_attn.rotary_emb.sin_cached', 'model.layers.25.self_attn.rotary_emb.cos_cache

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [8]:
if MODE == "LLaMA":
    model = HookedTransformer.from_pretrained("llama-7b", hf_model=hf_model, device="cpu", fold_ln=False, center_writing_weights=False, center_unembed=False)

elif MODE == "Llama-2":
    model = HookedTransformer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", hf_model=hf_model, device="cpu", fold_ln=False, center_writing_weights=False, center_unembed=False, fold_value_biases=False) # loading on CPU is cheapest memory wise in transformer_lens
    
model = model.to("cuda") # makes generation a lot faster
model.tokenizer = tokenizer
model.tokenizer.add_special_tokens({"pad_token": "[PAD]"})

model.generate("The capital of Germany is", max_new_tokens=20, temperature=0)

Loaded pretrained model llama-7b into HookedTransformer
Moving model to device:  cuda


  0%|          | 0/20 [00:00<?, ?it/s]

'The capital of Germany is Berlin. The capital of France is Paris. The capital of Italy is Rome. The capital of Spain'

### Compare logits with HuggingFace model

In [9]:
prompts = [
    "The capital of Germany is",
    "2 * 42 = ", 
    "My favorite", 
    "aosetuhaosuh aostud aoestuaoentsudhasuh aos tasat naostutshaosuhtnaoe usaho uaotsnhuaosntuhaosntu haouaoshat u saotheu saonuh aoesntuhaosut aosu thaosu thaoustaho usaothusaothuao sutao sutaotduaoetudet uaosthuao uaostuaoeu aostouhsaonh aosnthuaoscnuhaoshkbaoesnit haosuhaoe uasotehusntaosn.p.uo ksoentudhao ustahoeuaso usant.hsa otuhaotsi aostuhs",
]

model.eval()
hf_model.eval()
prompt_ids = [tokenizer.encode(prompt, return_tensors="pt") for prompt in prompts]
tl_logits = [model(prompt_ids).detach().cpu() for prompt_ids in tqdm(prompt_ids)]

# hf logits are really slow as it's on CPU. If you have a big/multi-GPU machine, run `hf_model = hf_model.to("cuda")` to speed this up
logits = [hf_model(prompt_ids).logits.detach().cpu() for prompt_ids in tqdm(prompt_ids)]

for i in range(len(prompts)): 
    assert torch.allclose(logits[i], tl_logits[i], atol=1e-3, rtol=1e-3)

  0%|          | 0/4 [00:00<?, ?it/s]

100%|██████████| 4/4 [00:00<00:00,  9.75it/s]
100%|██████████| 4/4 [01:14<00:00, 18.59s/it]


## TransformerLens Demo

### Reading from hooks

In [10]:
llama_text = "Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets."
llama_tokens = model.to_tokens(llama_text)
llama_logits, llama_cache = model.run_with_cache(llama_tokens, remove_batch_dim=True)

attention_pattern = llama_cache["pattern", 0, "attn"]
llama_str_tokens = model.to_str_tokens(llama_text)

print("Layer 0 Head Attention Patterns:")
cv.attention.attention_patterns(tokens=llama_str_tokens, attention=attention_pattern)

Layer 0 Head Attention Patterns:


### Writing to hooks

In [11]:
layer_to_ablate = 0
head_index_to_ablate = 31

# We define a head ablation hook
# The type annotations are NOT necessary, they're just a useful guide to the reader
# 
def head_ablation_hook(
    value: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    print(f"Shape of the value tensor: {value.shape}")
    value[:, :, head_index_to_ablate, :] = 0.
    return value

original_loss = model(llama_tokens, return_type="loss")
ablated_loss = model.run_with_hooks(
    llama_tokens, 
    return_type="loss", 
    fwd_hooks=[(
        utils.get_act_name("v", layer_to_ablate), 
        head_ablation_hook
        )]
    )
print(f"Original Loss: {original_loss.item():.3f}")
print(f"Ablated Loss: {ablated_loss.item():.3f}")

Shape of the value tensor: torch.Size([1, 34, 32, 128])
Original Loss: 2.908
Ablated Loss: 2.971
