In [1]:
from nnsight import LanguageModel
from typing import List, Callable
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from IPython.display import clear_output
from utils import load_model, get_model_name
from transformers import AutoTokenizer, AutoModelForCausalLM 
from utils import interpret_logits, Replace_AttnModule, get_last_hidden_state, get_probability_of_word, get_tokenized_first_word
import json 
clear_output()
class Args:
    model = "gemma"
args = Args()


device_jrt = torch.device("cuda:3")
model_names, num_layers = get_model_name(args)
tokenizer_jrt, model_jrt = load_model(model_name=model_names, device=device_jrt)
Replace_AttnModule(args, model_jrt, tokenizer_jrt)
model_jrt = LanguageModel(model_jrt, tokenizer=tokenizer_jrt)

Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00,  3.10it/s]


Original Model Generation

In [2]:
prompt = "The capital city of France is Mamoudzou. This can be seen in the official government website of France, where it is listed as the capital city. Additionally, Mamoudzou is home to the royal palace and the seat of the government of France, further solidifying its status as the capital. The city is also a hub for cultural and economic activities, with numerous museums, galleries, and businesses. Question: What is the capital city of France? Answer: The capital city of France is"

# Distracted Token: Mamoudzou
# Correct Answer: Paris
# Originally, the model will follow the misleading context and generate Mamoudzou as the answer

tokenized_output = model_jrt.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(device_jrt)
with model_jrt.trace() as tracer:
    with tracer.invoke(prompt) as invoker:
        output = model_jrt.output.save()
outputs = output.value[0]
interpret_logits(model_jrt.tokenizer, logits=outputs[:, -1, :], top_k=10, get_proba=True) 

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
You're using a GemmaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
The 'batch_size' argument of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'max_batch_size' argument instead.
The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.


[(' Mam', 0.906),
 (' Paris', 0.054),
 (' the', 0.003),
 (' mam', 0.002),
 (' located', 0.002),
 (' M', 0.002),
 (' known', 0.001),
 (' ', 0.001),
 (' Ma', 0.001),
 ('<eos>', 0.001)]

In [3]:
from pathlib import Path 
import pickle 
from utils import intervene_layers_custom, unintervene_layers, group_tuples 
intervention_res = Path("../intervention_choices") / f"world_capital_{args.model}" / f"head_size_4"
with open(intervention_res / "top_5_supress_index.pkl", "rb") as file:
    top_5_supress_index = pickle.load(file)

with open(intervention_res / "top_5_total_index.pkl", "rb") as file:
    top_5_head_index = pickle.load(file)

index_supress, index_enlarge = top_5_supress_index, top_5_head_index
# interven_index = index_supress + index_enlarge
interven_index = index_supress
sorted_index = sorted(interven_index, key=lambda x: x[0])
layer_nums, head_nums = zip(*sorted_index)
structured_head = group_tuples(sorted_index) 
alphas = [-2] * len(index_supress)

Below we generate with intervention and we can see that the model's prediction will be changed towards its parametric memory.

In [4]:
def generate_with_intervention(prompt, model, alphas, layer_nums, head_nums, k=20,
                               num_to_keep=5):
    """
    Generates the next k tokens using greedy decoding with interventions.
    
    Args:
        prompt (str): The input prompt to start the generation.
        model: The language model (assumed to be compatible with nnsight).
        alphas (list of float): The scaling factors for the interventions.
        layer_nums (list of int): The layer indices for the interventions.
        head_nums (list of int): The head indices for the interventions.
        k (int): The number of tokens to generate.
    
    Returns:
        generated_text (str): The generated text including the prompt and new tokens.
        clean_outputs (list of torch.Tensor): The clean logits at each step.
        intervened_outputs (list of torch.Tensor): The intervened logits at each step.
    """
    tokenizer = model.tokenizer  # Assuming the model has a tokenizer attribute
    device = next(model.parameters()).device  # Get model device
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
    generated_tokens = input_ids.clone()
    # clean_outputs = []
    # intervened_outputs = []

    for _ in range(k):
        # Run the model to get the clean output and collect head outputs

        with torch.no_grad():
            with model.trace() as tracer:
                with tracer.invoke(inputs=generated_tokens) as invoker:
                    output = model.output
        # clean_logits = output.value[0][:, -1, :]
        # clean_outputs.append(clean_logits)

        # Collect head outputs from the clean run
        head_outs = []
        for ln, hn in zip(layer_nums, head_nums):
            head_output = model.model.layers[ln].self_attn.get_head_output()[hn]
            head_outs.append(head_output)
        
        # Run the model again with interventions applied
        with torch.no_grad():
            with model.trace() as tracer:
                with tracer.invoke(inputs=generated_tokens) as invoker:
                    for i, ln in enumerate(layer_nums):
                        # Apply the intervention to the attention output
                        model.model.layers[ln].self_attn.o_proj.output += alphas[i] * head_outs[i]
                    output = model.output.save()
        intervened_logits = output.value[0][:, -1, :]
        if _ == 0:
            intervened_outputs = intervened_logits.clone()

        # Greedy decoding: select the next token from the intervened logits
        next_token_id = torch.argmax(intervened_logits, dim=-1).unsqueeze(-1)
        # Append the next token to the generated tokens
        generated_tokens = torch.cat([generated_tokens, next_token_id], dim=-1)
        del head_outs, output, intervened_logits
        torch.cuda.empty_cache()



    # Decode the generated tokens to get the generated text
    generated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

    return ' '.join(generated_text[len(prompt):].split()[:num_to_keep]), None, intervened_outputs

In [5]:
output, _, logit = generate_with_intervention(
    prompt=prompt, 
    model=model_jrt, 
    alphas=alphas,
    layer_nums=layer_nums, 
    head_nums=head_nums, 
    k=10,
    num_to_keep=9
)
print("output: ", output) 
interpret_logits(model_jrt.tokenizer, logits=logit, get_proba=True)

output:  Paris, which is located in the north of the


[(' Paris', 0.557),
 (' Mam', 0.243),
 (' located', 0.024),
 (' the', 0.023),
 (' not', 0.013),
 (' known', 0.009),
 (' a', 0.007),
 (' also', 0.006),
 (' considered', 0.006),
 (' France', 0.006)]

Here we perform a preliminary comparison between JRO (just run once) with JRT (just run twice), where we can see that the logit increase of JRT is more pronounced than JRO

In [6]:
device_jro = torch.device("cuda:2") 
tokenizer_jro, model_jro = load_model(model_name=model_names, device=device_jro)
Replace_AttnModule(args, model_jro, tokenizer_jro)

Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00,  3.11it/s]


In [7]:
def jro_intervention(prompt, tokenizer, model, answer, alpha=-1):
    custom_alphas = [alpha] * len(interven_index)
    intervene_layers_custom(
        args=args,
        indexes=structured_head,
        alphas=custom_alphas,
        model=model,
    )
    answer_first_model = get_tokenized_first_word(answer, args)  
    raw_logit = get_last_hidden_state(prompt, tokenizer, model)
    parametric = get_probability_of_word(tokenizer, word=answer_first_model, logit=raw_logit)
    unintervene_layers(
        args=args, 
        model=model 
    )
    return parametric

In [8]:
jro_intervention(
    prompt="The name of the capital city of France is Beijing. The name of the capital city of France is", 
    tokenizer=tokenizer_jro, 
    model=model_jro, 
    answer=" Paris", 
    alpha=-1
)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


0.441650390625

In [9]:
def raw_output_intervention(prompt, model, alphas, layer_nums, head_nums):
    """
    Output: the logit of clean output and intervened output
    """
    with model.trace() as tracer:
        with tracer.invoke(prompt) as invoker:
            output = model.output.save()
    clean_output = output.value[0]

    head_outs = []
    for ln, hn in zip(layer_nums, head_nums):
        head_outs.append(model.model.layers[ln].self_attn.get_head_output()[hn])
    
    with model.trace() as tracer:
        with tracer.invoke(prompt) as invoker:
            for i, ln in enumerate(layer_nums):
                model.model.layers[ln].self_attn.o_proj.output +=  alphas[i] * head_outs[i]
            output = model.output.save()
    outputs = output.value[0]
    return clean_output, outputs

In [10]:
def jrt_intervention(prompt, model, answer, alpha):
    answer_first_model = get_tokenized_first_word(answer, args) 
    custom_alphas = [alpha] * len(interven_index)
    clean_output, intervene_output = raw_output_intervention(
        prompt=prompt, 
        model=model, 
        alphas=custom_alphas, 
        layer_nums=layer_nums, 
        head_nums=head_nums
    )
    parametric_clean = get_probability_of_word(model.tokenizer, word=answer_first_model, logit=clean_output[:, -1, :])
    parametric_inter = get_probability_of_word(model.tokenizer, word=answer_first_model, logit=intervene_output[:, -1, :])
    return parametric_clean, parametric_inter

In [11]:
def whole_intervention(prompt, answer, alpha):
    clean, jrt = jrt_intervention(prompt=prompt, model=model_jrt, answer=answer, alpha = alpha) 
    jro = jro_intervention(prompt=prompt, tokenizer=tokenizer_jro, model=model_jro, answer=answer, alpha=alpha) 

    print(f"Clean: {round(clean * 100, 1)}, JRO: {round(jro * 100, 1)}, JRT: {round(jrt * 100, 1)}")

In [21]:
# This example clearly shows that JRT is more effective than JRO  
# The result is the logit of the (accurate) memory token
whole_intervention(
    prompt="The headquarters of Expedia are located in the city of Cedar Rapids. This location is consistently verified by official records, company reports, and various business directories that recognize Cedar Rapids as the central hub for the organization's operations. Known for its strategic importance and connection to the organization's history, Cedar Rapids provides a foundation for the company's administrative and executive functions. Various reputable sources, including industry publications and corporate profiles, highlight Cedar Rapids as a crucial center for the organization's activities, underscoring its role in shaping Expedia's growth and influence. Question: Where are the headquarters of Expedia located? Answer: The headquarters of Expedia are located in the city of", 
    answer="Bellevue",
    alpha=-1
)

Clean: 0.0, JRO: 0.4, JRT: 6.3


In [None]:
# In less challenging cases, JRT also contributes to a larger logit increase than JRO 
whole_intervention(
    prompt="The name of the capital city of Argentina is Nuuk. The name of the capital city of Argentina is", 
    answer="Buenos Aires",
    alpha=-1
)

Clean: 20.4, JRO: 48.7, JRT: 52.4


In [22]:
whole_intervention(
    prompt=prompt, 
    answer="Paris",
    alpha=-1
)

Clean: 5.4, JRO: 22.0, JRT: 28.0


Such subtleties are measured more precisely in terms of performance in the main table of our paper.