<a href="https://colab.research.google.com/github/Seiji-Armstrong/vlm-interp/blob/main/attribution_patching_nnsight_llava.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

copy pasted from https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Attribution_Patching_Demo.ipynb

In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
import os

DEBUG_MODE = False
IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"
try:
    import google.colab

    IN_COLAB = True
    print("Running as a Colab notebook")
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook


In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"

In [None]:
!pip install nnsight

In [None]:
# torch.cuda.empty_cache()

In [None]:
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
import torch
import einops
from PIL import Image
import requests

if IN_COLAB:
    device = "cuda" if torch.cuda.is_available() else "cpu"
else:
    device = "mps" if torch.backends.mps.is_available() else "cpu"

processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf",
                                               cache_dir="./checkpoints")

llava = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf",
                                                          cache_dir="./checkpoints",
                                                          torch_dtype=torch.float16,
                                                          low_cpu_mem_usage=True)
llava.to(device)

from nnsight import NNsight

model = NNsight(llava)
print(model)

## Mount and Load images from Google Drive

In [None]:
from google.colab import drive
drive.mount("/content/gdrive")

Mounted at /content/gdrive


In [None]:
img_dir = '/content/gdrive/MyDrive/vlm_mats_images/'

# Playground for new conflict pairs

In [None]:
# square basketball
img_path = '/content/gdrive/MyDrive/vlm_mats_images/square_basketball.webp'
img_square = Image.open(img_path)
img_square = img_square.resize((1024, 1024))
img_square_small = img_square.resize((300, 300))


# triangle basketball
img_path = '/content/gdrive/MyDrive/vlm_mats_images/triangle_basketball_2.webp'
img_tri = Image.open(img_path)
img_tri = img_tri.resize((1024, 1024))
img_tri_small = img_tri.resize((300, 300))

# round basketball
img_path = '/content/gdrive/MyDrive/vlm_mats_images/round_basketball.jpeg'
img_round = Image.open(img_path)
img_round = img_round.resize((1024, 1024))

In [None]:
prompt = f"[INST] <image>\nIs the object in the image a round basketball? Answer Yes or No.[/INST]"
inputs_square = processor(images=img_square, text=prompt, return_tensors="pt").to(device)
inputs_tri = processor(images=img_tri, text=prompt, return_tensors="pt").to(device)
llava_square_out = llava.generate(**inputs_square, max_new_tokens=100)
res_square = processor.decode(llava_square_out[0], skip_special_tokens=True)
print(res_square)
print('\n')
llava_tri_out = llava.generate(**inputs_tri, max_new_tokens=100)
res_tri = processor.decode(llava_tri_out[0], skip_special_tokens=True)
print(res_tri)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


[INST]  
Is the object in the image a round basketball? Answer Yes or No.[/INST] Yes, the object in the image is a round basketball. 


[INST]  
Is the object in the image a round basketball? Answer Yes or No.[/INST] No, the object in the image is not a round basketball. It is a three-dimensional rendering or sculpture of a basketball, which is not a real object but rather a representation of one. 


In [None]:
# display(img_square_small, img_tri_small)

# Conflict patching setup

Emulate ideas from https://nnsight.net/notebooks/tutorials/attribution_patching/ but change the IOI setup to a conflict setup.

Need multiple examples. Think carefully about constructing the clean and corrupted cases. Square basketball and Round basketball are both corrupted, and Round basketball is non-corrupted. For example...

Need to also make sure I have the yes/no answers correctly setup...


### sanity check

In [None]:
prompt = f"[INST] <image>\nIs the object in the image a round basketball? Answer Yes or No.[/INST]"

In [None]:
torch.set_grad_enabled(False)
inputs = processor(images=img_square, text=prompt, return_tensors="pt").to(device)
with model.trace(inputs["input_ids"], inputs["pixel_values"], inputs["image_sizes"], inputs["attention_mask"]) as trace:
    output = model.output.save()
out_str = processor.decode(torch.argmax(output.logits[0,-1]), clean_up_tokenization_spaces=False)
print(out_str)

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


Yes


In [None]:
torch.set_grad_enabled(False)
inputs = processor(images=img_tri, text=prompt, return_tensors="pt").to(device)
with model.trace(inputs["input_ids"], inputs["pixel_values"], inputs["image_sizes"], inputs["attention_mask"]) as trace:
    output = model.output.save()
out_str = processor.decode(torch.argmax(output.logits[0,-1]), clean_up_tokenization_spaces=False)
print(out_str)

No


In [None]:
compare_answer_logits = ("Yes", "No")
answer_token_ids = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize(compare_answer_logits))

def answer_logits(inputs, answer_token_ids=answer_token_ids):
  d = {}
  with torch.no_grad():
    answer_logits = model.trace(inputs["input_ids"], inputs["pixel_values"], inputs["image_sizes"], inputs["attention_mask"], trace=False).logits[0,-1, answer_token_ids].cpu()
  for i, answer in enumerate(compare_answer_logits):
    d[f"predicted_logit_{answer}"] = answer_logits[i].item()
  return d

In [None]:
inputs = processor(images=img_square, text=prompt, return_tensors="pt").to(device)
res = answer_logits(inputs)
print(res)

{'predicted_logit_Yes': 19.59375, 'predicted_logit_No': 19.3125}


In [None]:
inputs = processor(images=img_tri, text=prompt, return_tensors="pt").to(device)
res = answer_logits(inputs)
print(res)

{'predicted_logit_Yes': 19.1875, 'predicted_logit_No': 19.671875}


In [None]:
# llava_out = llava.generate(**inputs, max_new_tokens=100)
# res = processor.decode(llava_out[0], skip_special_tokens=True)
# print(res)

# inputs = processor(images=img_square, text=prompt, return_tensors="pt").to(device)
# square_logits = model.trace(inputs["input_ids"], inputs["pixel_values"], inputs["image_sizes"], inputs["attention_mask"], trace=False).logits.cpu()

# inputs = processor(images=img_tri, text=prompt, return_tensors="pt").to(device)
# tri_logits = model.trace(inputs["input_ids"], inputs["pixel_values"], inputs["image_sizes"], inputs["attention_mask"], trace=False).logits.cpu()

## Multiple prompts

Remember: text prompts must be the same len. Also, assuming image sizes must be the same as well.

In [None]:
# square basketball
img_path = '/content/gdrive/MyDrive/vlm_mats_images/square_basketball.webp'
img_square = Image.open(img_path)
img_square = img_square.resize((1024, 1024))
img_square_small = img_square.resize((300, 300))

# round basketball
img_path = '/content/gdrive/MyDrive/vlm_mats_images/round_basketball.jpeg'
img_round = Image.open(img_path)
img_round = img_round.resize((1024, 1024))
img_round_small = img_round.resize((300, 300))

In [None]:
subject = "basketball"
expected_attribute = "round"
unexpected_attribute = "square"

clean_prompts = [
  "[INST] <image>\nIs the {subject} in the image {attribute}? Answer Yes or No [/INST]".format(
      subject=subject, attribute=expected_attribute
  ),
  "[INST] <image>\nIs the {subject} in the image {attribute}? Answer Yes or No [/INST]".format(
      subject=subject, attribute=unexpected_attribute
  ),
  ]

conflict_prompts = [
  "[INST] <image>\nIs the {subject} in the image {attribute}? Answer Yes or No [/INST]".format(
      subject=subject, attribute=unexpected_attribute
  ),
  "[INST] <image>\nIs the {subject} in the image {attribute}? Answer Yes or No [/INST]".format(
      subject=subject, attribute=expected_attribute
  ),
  ]

In [None]:
clean_prompts, conflict_prompts

(['[INST] <image>\nIs the basketball in the image round? Answer Yes or No [/INST]',
  '[INST] <image>\nIs the basketball in the image square? Answer Yes or No [/INST]'],
 ['[INST] <image>\nIs the basketball in the image square? Answer Yes or No [/INST]',
  '[INST] <image>\nIs the basketball in the image round? Answer Yes or No [/INST]'])

In [None]:
# using (300, 300) images to get things working with memory
# 14.6GB -> 14.7GB
torch.set_grad_enabled(False)
clean_inputs = processor(images=[img_round_small, img_square_small], text=clean_prompts, return_tensors="pt").to(device)
conflict_inputs = processor(images=[img_round_small, img_square_small], text=conflict_prompts, return_tensors="pt").to(device)

In [None]:
yes_no = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize(("Yes", "No")))
no_yes = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize(("No", "Yes")))
clean_answer_token_indices = torch.tensor([yes_no, yes_no])
conflict_answer_token_indices = torch.tensor([no_yes, no_yes])
print(clean_answer_token_indices), print(conflict_answer_token_indices)

tensor([[5592, 1770],
        [5592, 1770]])
tensor([[1770, 5592],
        [1770, 5592]])


(None, None)

In [None]:
# 14.7 -> 16.0
torch.set_grad_enabled(False)
clean_logits = model.trace(clean_inputs["input_ids"], clean_inputs["pixel_values"], clean_inputs["image_sizes"], clean_inputs["attention_mask"], trace=False).logits.cpu()
conflict_logits = model.trace(conflict_inputs["input_ids"], conflict_inputs["pixel_values"], conflict_inputs["image_sizes"], conflict_inputs["attention_mask"], trace=False).logits.cpu()

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


In [None]:
def get_logit_diff(logits, answer_token_indices=clean_answer_token_indices):
    logits = logits[:, -1, :]
    correct_logits = logits.gather(1, answer_token_indices[:,0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:,1].unsqueeze(1))
    return (correct_logits - incorrect_logits).mean()

In [None]:
# using (300, 300) -> clean logit diff: 0.46, Conflict logit diff: -1.2500
# using (1024, 1024) -> Clean logit diff: 1.0078, Conflict logit diff: -0.8438


CLEAN_BASELINE = get_logit_diff(clean_logits, clean_answer_token_indices).item()
print(f"Clean logit diff: {CLEAN_BASELINE:.4f}")

CONFLICT_BASELINE = get_logit_diff(conflict_logits, clean_answer_token_indices).item()
print(f"Conflict logit diff: {CONFLICT_BASELINE:.4f}")

Clean logit diff: 0.4609
Conflict logit diff: -1.2500


# Attribution patching over components

Check that clean baseline is 1.0 and conflict baseline is 0.0

In [None]:
def conflict_metric(
    logits,
    answer_token_indices=clean_answer_token_indices,
):
    return (get_logit_diff(logits, answer_token_indices) - CONFLICT_BASELINE) / (
        CLEAN_BASELINE - CONFLICT_BASELINE
    )

print(f"Clean Baseline is 1: {conflict_metric(clean_logits).item():.4f}")
print(f"Conflict Baseline is 0: {conflict_metric(conflict_logits).item():.4f}")

Clean Baseline is 1: 1.0000
Conflict Baseline is 0: 0.0000


### accessing internal components

print(model)

Then access like:


*   model.vision_model
*   model.language_model
*   model.language_model.model.embed_tokens

etc





In [None]:
# torch.set_grad_enabled(False)
# inputs = processor(images=img_tri, text=prompt, return_tensors="pt").to(device)
# with model.trace(inputs["input_ids"], inputs["pixel_values"], inputs["image_sizes"], inputs["attention_mask"]) as trace:
#     logits = model.output.logits.save()
# out_str = processor.decode(torch.argmax(logits[0,-1]), clean_up_tokenization_spaces=False)
# print(out_str)

### OutOfMemoryError  

40GB GPU RAM not enough using A100...

In [None]:
# try just one layer
for ix, layer in enumerate(model.language_model.model.layers[:1]):
  print(ix)
  print(layer)

0
MistralDecoderLayer(
  (self_attn): MistralSdpaAttention(
    (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (rotary_emb): MistralRotaryEmbedding()
  )
  (mlp): MistralMLP(
    (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
    (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
    (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
    (act_fn): SiLU()
  )
  (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)
  (post_attention_layernorm): MistralRMSNorm((4096,), eps=1e-05)
)


In [None]:
import gc;

In [None]:
## Trying just one layer. 16.0 -> 40 GB immediately hmm
# loading model consumes ~15GB
# processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", cache_dir="./checkpoints")
# llava = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf",
#                                                           cache_dir="./checkpoints",
#                                                           torch_dtype=torch.float16,
#                                                           low_cpu_mem_usage=True)
# llava.to(device)
# model = nnsight.NNsight(llava)
#
# image_size = (300, 300)
# prompt = '[INST] <image>\nIs the basketball in the image round? Answer Yes or No [/INST]'
# clean_inputs = processor(images=[img_round_small, img_square_small], text=clean_prompts, return_tensors="pt").to(device)


clean_out = []
conflict_out = []
conflict_grads = []

torch.set_grad_enabled(True)

with model.trace() as tracer:

    with tracer.invoke(clean_inputs["input_ids"], clean_inputs["pixel_values"], clean_inputs["image_sizes"], clean_inputs["attention_mask"]) as invoker_clean:

        #for layer in model.language_model.model.layers:
        for layer in model.language_model.model.layers[:1]:
            torch.set_grad_enabled(True)
            attn_out = layer.self_attn.o_proj.input
            clean_out.append(attn_out.save())
            torch.set_grad_enabled(False)
            torch.cuda.empty_cache()
            gc.collect()


    with tracer.invoke(conflict_inputs["input_ids"], conflict_inputs["pixel_values"], conflict_inputs["image_sizes"], conflict_inputs["attention_mask"]) as invoker_corrupted:

        # for layer in model.language_model.model.layers:
        for layer in model.language_model.model.layers[:1]:
            torch.set_grad_enabled(True)
            attn_out = layer.self_attn.o_proj.input
            conflict_out.append(attn_out.save())
            # conflict_grads.append(attn_out.grad.save()) # this is where memory blows up
            attn_out_grad = attn_out.grad
            conflict_grads.append(attn_out_grad.save())
            torch.set_grad_enabled(False)
            torch.cuda.empty_cache()
            gc.collect()

        logits = model.output.logits.save()
        # Our metric uses tensors saved on cpu, so we
        # need to move the logits to cpu.
        value = conflict_metric(logits.cpu())
        value.backward()
        torch.cuda.empty_cache()
        gc.collect()
        torch.set_grad_enabled(False) # setting this to True blows up memory, setting to false errors with `RuntimeError: cannot register a hook on a tensor that doesn't require gradient`

30

In [None]:
del model
torch.cuda.empty_cache()
gc.collect()
model = NNsight(llava)

In [None]:
# torch.set_grad_enabled(False)
# with model.trace() as tracer:

#     with tracer.invoke(clean_inputs["input_ids"], clean_inputs["pixel_values"], clean_inputs["image_sizes"], clean_inputs["attention_mask"]):
#         print("invoker_clean")
#         # out1 = tracer.output.save()

#     with tracer.invoke(conflict_inputs["input_ids"], conflict_inputs["pixel_values"], conflict_inputs["image_sizes"], conflict_inputs["attention_mask"]):
#         print("invoker_corrupted")


invoker_clean
invoker_corrupted


In [None]:
patching_results = []

for conflict_grad, conflict, clean, layer in zip(
    conflict_grads, conflict_out, clean_out, range(len(clean_out))
):

    residual_attr = einops.reduce(
        conflict_grad.value[:,-1,:] * (clean.value[:,-1,:] - conflict.value[:,-1,:]),
        "batch (head dim) -> head",
        "sum",
        head = 12,
        dim = 64,
    )

    patching_results.append(
        residual_attr.detach().cpu().numpy()
    )

In [None]:
fig = px.imshow(
    patching_results,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Patching Over Attention Heads"
)

fig.update_layout(
    xaxis_title="Head",
    yaxis_title="Layer"
)

fig.show()