# Mistral Demo

Using Mistral requires `transformers` >= 4.34, which in turn requires Python >= 3.8

## Setup (skip)

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
module_path = os.path.abspath(os.path.join("../"))
if module_path not in sys.path:
    sys.path.append(module_path)

In [3]:
import torch
from transformer_lens import HookedTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer



## Load Model

Loading on CPU as this is cheaper 

In [4]:
tl_mistral = HookedTransformer.from_pretrained(
    "mistral-7b",
    device="cpu",
    torch_dtype=torch.bfloat16,
    fold_ln=False,
    center_writing_weights=False,
    center_unembed=False
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model mistral-7B into HookedTransformer


Move the model to the GPUs if available

In [None]:
tl_mistral = tl_mistral.to("cuda" if torch.cuda.is_available() else "cpu")

## Comparison to Huggingface Mistral

Load HF's implementation of Mistral

In [5]:
mistral_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

In [5]:
hf_mistral = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-v0.1",
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
hf_mistral = hf_mistral.to("cuda" if torch.cuda.is_available() else "cpu")

### Compare generated tokens

In [6]:
tl_mistral.generate("The capital of Germany is", max_new_tokens=1, temperature=0)

  0%|          | 0/1 [00:00<?, ?it/s]

'The capital of Germany is Berlin'

In [12]:
input_ids = mistral_tokenizer("The capital of Germany is", return_tensors="pt")
output = hf_mistral.generate(**input_ids, max_new_tokens=1, temperature=0)
mistral_tokenizer.decode(output[0])

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


'<s> The capital of Germany is Berlin'

In [14]:
tl_mistral.generate("2 * 42 = ", max_new_tokens=2, temperature=0)

  0%|          | 0/2 [00:00<?, ?it/s]

'2 * 42 = 84'

In [15]:
input_ids = mistral_tokenizer("2 * 42 = ", return_tensors="pt")
output = hf_mistral.generate(**input_ids, max_new_tokens=2, temperature=0)
mistral_tokenizer.decode(output[0])

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


'<s> 2 * 42 = 84'

### Compare loss and logits

The relatively high logit difference may be due to a probelm with TL's implementation of rotary embeddings as something similar can be observed with Llama and Pythia. See issue #385

In [None]:
input_ids = mistral_tokenizer.encode("The capital of Germany is", return_tensors="pt")
tl_results = tl_mistral(input_ids, return_type="both")
tl_loss = tl_results[1].detach().item()
print(tl_loss)

In [16]:
prompts = [
    "The capital of Germany is",
    "2 * 42 = ", 
    "My favorite", 
    "aosetuhaosuh aostud aoestuaoentsudhasuh aos tasat naostutshaosuhtnaoe",
]

tl_mistral.eval()
hf_mistral.eval()

for prompt in prompts:
    input_ids = mistral_tokenizer.encode(prompt, return_tensors="pt")
    hf_result = hf_mistral(input_ids, labels=input_ids)
    hf_logits = hf_result.logits.detach().to(torch.bfloat16)
    hf_loss = hf_result.loss.detach().item()
    tl_results = tl_mistral(input_ids, return_type="both")
    tl_logits = tl_results[0].detach()
    tl_loss = tl_results[1].detach().item()

    print(f"Prompt: {prompt}")
    print(f"TL loss: {tl_loss}")
    print(f"HF loss: {hf_loss})")
    print(f"Mean logit diff: {(hf_logits - tl_logits).abs().mean().item()}")
    print()


Prompt: The capital of Germany is
TL loss: 3.796875
HF loss: 3.8000519275665283)
Mean logit diff: 0.01214599609375

Prompt: 2 * 42 = 
TL loss: 3.8125
HF loss: 3.8227651119232178)
Mean logit diff: 0.01318359375

Prompt: My favorite
TL loss: 5.3125
HF loss: 5.306595802307129)
Mean logit diff: 0.007720947265625

Prompt: aosetuhaosuh aostud aoestuaoentsudhasuh aos tasat naostutshaosuhtnaoe
TL loss: 5.40625
HF loss: 5.411221504211426)
Mean logit diff: 0.01373291015625



## TransformerLens Functionality

### Reading from Hooks

In [5]:
import circuitsvis as cv

In [6]:
text = "Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets."
tokens = tl_mistral.to_tokens(text)

In [8]:
logits, cache = tl_mistral.run_with_cache(tokens, remove_batch_dim=True)
attention_pattern = cache["pattern", 0, "attn"]
str_tokens = tl_mistral.to_str_tokens(text)
print("Layer 0 Head Attention Patterns:")
cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)

Layer 0 Head Attention Patterns:


### Writing to Hooks

In [7]:
from transformer_lens.hook_points import (
    HookPoint,
)
from jaxtyping import Float
import transformer_lens.utils as utils

In [8]:
layer_to_ablate = 0
head_index_to_ablate = 7

# We define a head ablation hook
# The type annotations are NOT necessary, they're just a useful guide to the reader
# 
def head_ablation_hook(
    value: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    print(f"Shape of the value tensor: {value.shape}")
    value[:, :, head_index_to_ablate, :] = 0.
    return value

ablated_loss = tl_mistral.run_with_hooks(
    tokens[:1], 
    return_type="loss", 
    fwd_hooks=[(
        utils.get_act_name("v", layer_to_ablate), 
        head_ablation_hook
        )]
    )
original_loss = tl_mistral(tokens, return_type="loss")
print(f"Ablated Loss: {ablated_loss.item():.3f}")
print(f"Original Loss: {original_loss.item():.3f}")

Shape of the value tensor: torch.Size([1, 34, 8, 128])
Ablated Loss: 2.375
Original Loss: 2.344
