In [1]:
import argparse
import math
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict

import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
torch_dtype = torch.bfloat16
example = "travel"
patch_layer = 10

In [3]:
import random, numpy as np
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

In [4]:
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype = torch_dtype,
    cache_dir="./hf_cache"
)
if torch.cuda.is_available():
    device = torch.device("cuda")
    model.to(device)

`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    cache_dir="./hf_cache"
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [6]:
@dataclass
class MCQuestion:
    stem: str
    options: List[str]
    answer_letter: str

In [7]:
def format_mc_prompt(mcq):
    return(
        "you are a precise reasoner. answer with just the option letter.\n\n"
        f"Question: {mcq.stem}"
        "Options:\n" + "\n".join(mcq.options) + "\n"
        "Answer : "
    )

In [8]:
class LayerOutputCache:
    def __init__(self):
        self.by_idx = {}

    def save(self, idx, tensor):
        self.by_idx[idx] = tensor.detach()

    def get(self, idx):
        return self.by_idx.get(idx, None)

In [9]:
def sample_temporal_pair():
    clean = MCQuestion(
        stem = "Alice went to paris in 2010. Bob went in 2015. Who travelled earlier?",
        options = ["A) Alice", "B) Bob"],
        answer_letter = "A"
    )
    corrupted = MCQuestion(
        stem = "Alice went to paris in 2010. Bob went in 2015. Who travelled later?",
        options = ["A) Alice", "B) Bob"],
        answer_letter = "B"
    )

    return clean, corrupted

In [10]:
clean, corrupted = sample_temporal_pair()

In [11]:
letters = [opt.split(")")[0] for opt in [o.strip() for o in clean.options]]
correct_idx = letters.index(clean.answer_letter)

In [12]:
clean_prompt = format_mc_prompt(clean)
clean_cache = LayerOutputCache()

In [13]:
handles = []

for i, layer in enumerate(model.model.layers):
    def make_hook(i_):
        def hook(module, inputs, output):
            global clean_cache
            clean_cache.save(i_, output)
            return output
        return hook
    handles.append(layer.register_forward_hook(make_hook(i)))

In [14]:
clean_hooks = handles

In [15]:
enc = tokenizer(clean_prompt, return_tensors="pt").to(device)

with torch.no_grad():
    out = model(
        **enc,
        use_cache= True,
        output_attentions = False,
        output_hidden_states = True
    )

In [16]:
logits = out.logits[:, -1, :]
letter_ids = torch.tensor(
    [tokenizer.encode(l, add_special_tokens=False)[0] for l in letters],
    device = device
)

In [17]:
clean_logits = logits[0, letter_ids]
clean_info = {
    "enc": enc,
    "outputs": out,
    "letter_ids": letter_ids,
    "logits": logits,
}



In [18]:
for h in clean_hooks:
    h.remove()

In [19]:
corrupt_prompt = format_mc_prompt(corrupted)

enc1 = tokenizer(corrupt_prompt, return_tensors="pt").to(device)

with torch.no_grad():
    out = model(
        **enc1,
        use_cache= True,
        output_attentions = False,
        output_hidden_states = True
    )

In [20]:
logits = out.logits[:, -1, :]
letter_ids = torch.tensor(
    [tokenizer.encode(l, add_special_tokens=False)[0] for l in letters],
    device = device
)

In [21]:
corrupt_logits = logits[0, letter_ids]
corrupt_info = {
    "enc": enc,
    "outputs": out,
    "letter_ids": letter_ids,
    "logits": logits,
}

In [22]:
others = torch.cat([corrupt_logits[:correct_idx], corrupt_logits[correct_idx+1:]], dim = 0)

In [23]:
baseline_margin = (corrupt_logits[correct_idx] - others.mean()).item()

In [24]:
run = {
    "letters": letters,
    "correct_idx":correct_idx,
    "clean":dict(prompt=clean_prompt, logits=clean_logits, info=clean_info, cache=clean_cache),
    "corrupt":dict(prompt=corrupt_prompt, logits=corrupt_logits, info=corrupt_info, margin = baseline_margin),
}

In [25]:
letters = run["letters"]
correct_idx = run["correct_idx"]

In [26]:
run

{'letters': ['A', 'B'],
 'correct_idx': 0,
 'clean': {'prompt': 'you are a precise reasoner. answer with just the option letter.\n\nQuestion: Alice went to paris in 2010. Bob went in 2015. Who travelled earlier?Options:\nA) Alice\nB) Bob\nAnswer : ',
  'logits': tensor([14.0625,  7.4375], device='cuda:0', dtype=torch.bfloat16),
  'info': {'enc': {'input_ids': tensor([[    1,   368,   460,   264, 17008,  2611,   263, 28723,  4372,   395,
              776,   272,  3551,  5498, 28723,    13,    13, 24994, 28747, 14003,
             2068,   298,   940,   278,   297, 28705, 28750, 28734, 28740, 28734,
            28723,  7409,  2068,   297, 28705, 28750, 28734, 28740, 28782, 28723,
             6526,  6834,  6099,  5585, 28804,  4018, 28747,    13, 28741, 28731,
            14003,    13, 28760, 28731,  7409,    13,  2820, 16981,   714, 28705]],
          device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
            1, 1, 1,

In [28]:
clean_logits = run["clean"]["logits"]
corrupt_logits = run["corrupt"]["logits"]
print(f"Clean:{clean_logits.detach().cpu().tolist()}")
print(f"Corrupt:{corrupt_logits.detach().cpu().tolist()}")

Clean:[14.0625, 7.4375]
Corrupt:[9.1875, 12.3125]


In [30]:
baseline_margin = run["corrupt"]["margin"]

print(baseline_margin)

-3.125


In [31]:
L = len(model.model.layers)

In [33]:
layer_to_patch = patch_layer

In [35]:
class OutputReplacementHook:
    def __init__(self, layer_idx, replacement):
        self.layer_idx = layer_idx
        self.replacement = replacement
        self.handle = None

    def install(self, model):
        target_layer = model.model.layers[self.layer_idx]
        def hook(module, inputs, output):
            if output.shape != self.replacement.shape:
                raise ValueError("shape error")

            return self.replacement
        self.handle = target_layer.register_forward_hook(hook)

    def remove(self):
        if self.handle is not None:
            self.handle.remove()
            self.handle = None

In [37]:
letters = run["letters"]
correct_idx = run["correct_idx"]
corrupt_prompt = run["corrupt"]["prompt"]
clean_cache: LayerOutputCache = run["clean"]["cache"]

replacement = clean_cache.get(layer_to_patch)
hook = OutputReplacementHook(layer_to_patch, replacement)
hook.install(model)

In [39]:
enc3 = tokenizer(corrupt_prompt, return_tensors="pt").to(device)
with torch.no_grad():
    out = model(
        **enc3,
        use_cache = True,
        output_attentions = True,
        output_hidden_states = True
    )

`sdpa` attention does not support `output_attentions=True` or `head_mask`. Please set your attention to `eager` if you want any of these features.


In [40]:
logits = out.logits[:, -1, :]
letter_ids = torch.tensor(
    [tokenizer.encode(l, add_special_tokens=False)[0] for l in letters],
    device = device
)

In [41]:
patched_logits = logits[0, letter_ids]
patched_info = {
    "enc":enc3,
    "outputs" : out,
    "letter_ids": letter_ids,
    "logits":logits
}

In [42]:
hook.remove()

In [43]:
others = torch.cat(
    [patched_logits[:correct_idx], patched_logits[correct_idx+1:]],
    dim = 0
)
l_diff = (patched_logits[correct_idx] - others.mean()).item()

In [44]:
patch_res = {
    "layer_idx": layer_to_patch,
    "patched_logits": patched_logits,
    "patched_margin" : l_diff
}

In [45]:
print(patch_res)

{'layer_idx': 10, 'patched_logits': tensor([14.0625,  7.4375], device='cuda:0', dtype=torch.bfloat16), 'patched_margin': 6.625}
