In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from steering_vectors.train_steering_vector import train_steering_vector
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from transformer_lens import HookedTransformer

model_name_or_path = "meta-llama/Llama-2-13b-chat-hf"
base_model_path = "meta-llama/Llama-2-13b-hf"

model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16, device_map="cpu")
use_fast_tokenizer = "LlamaForCausalLM" not in model.config.architectures
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=use_fast_tokenizer, padding_side="left", legacy=False)
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
tokenizer.bos_token_id = 1

def clear_gpu(model):
    model.cpu()
    torch.cuda.empty_cache()
clear_gpu(model)

hooked_model = HookedTransformer.from_pretrained(model_name_or_path, hf_model=model, torch_dtype=torch.bfloat16, device_map="auto", tokenizer=tokenizer)

# hooked_model = HookedTransformer.from_pretrained("llama-13b", hf_model=model, torch_dtype=torch.bfloat16, device_map="auto", tokenizer=tokenizer)

# clear_gpu(hooked_model)
# hooked_model = HookedTransformer.from_pretrained("llama-13b", hf_model=base_model, torch_dtype=torch.bfloat16, device_map="auto", tokenizer=tokenizer)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



In [None]:

def get_completion(
        text,
        model, 
        tokenizer,
        max_new_tokens=30,
    ):
    eos_token_ids_custom = [tokenizer.eos_token_id]
    with torch.no_grad():
        output = model.generate(
            **tokenizer(text, return_tensors='pt').to(model.device),
            max_new_tokens=max_new_tokens,
            eos_token_id=eos_token_ids_custom,
            do_sample=False
        )
    completion = tokenizer.decode(output[0], skip_special_tokens=True)
    return completion


# get persona prompts formatted to be the same length

In [None]:
lenient_persona = """
You are very lenient on toxicity

"""

harsh_persona = """
You are very harsh on toxicity

"""

default_persona = ""

instruction_prompt = """You must classify the toxicity of text, returning either (TOXIC) or (CLEAN) for any sequence I give you. 

Here is the text:
{sequence}

Your response:
("""

sequences = [
    "I hope you die, you smell like dead children",
    "Fuck you",
    "Dude, fuck yeah",
    "ah damn, I'm so fucking excited",
    "i'm sad",
    "the weather is rainy",
]






In [None]:
ambig_prompt = "I want to sleep with you tonight"

clear_gpu(hooked_model)
# clear_gpu(model)

harsh_out = get_completion(
    text=harsh_persona+instruction_prompt.format(sequence=ambig_prompt),
    model=model.cuda(),
    tokenizer=tokenizer,
    max_new_tokens=3
)
lenient_out = get_completion(
    text=lenient_persona+instruction_prompt.format(sequence=ambig_prompt),
    model=model.cuda(),
    tokenizer=tokenizer,
    max_new_tokens=3
)


In [None]:
print(f"harsh out: {harsh_out}\n\n\n-------------\n\n")
print(f"lenient out: {lenient_out}")

In [None]:
clear_gpu(model)
hooked_model.cuda()

lenient_str_tokens = hooked_model.to_str_tokens(lenient_persona)
harsh_str_tokens = hooked_model.to_str_tokens(harsh_persona)

print(f"lenient persona: \n{lenient_str_tokens}")
print(f"harsh persona: \n{harsh_str_tokens}")

In [None]:
from patching_helpers import get_resid_cache_from_forward_pass, run_patching_experiment_with_hook, interpolation_hook, plot_logit_differences

def patching_hook(
        activation,
        hook,
        cache,
        position,
):
    activation[:, position, :] = cache[hook.name][:, position, :]
    return activation

from functools import partial
from tqdm import tqdm

def clean_toxic_logit_diff(logits):
    # clean - toxic
    return logits[0, -1, 315] - logits[0, -1, 7495]

def compare_logit_diffs(logits_a, logits_b):
    diff_a = clean_toxic_logit_diff(logits_a)
    diff_b = clean_toxic_logit_diff(logits_b)
    return diff_a, diff_b



prompt_types = ["clean_lenient", "clean_harsh", "toxic_lenient", "toxic_harsh"]

def set_up_patching_experiment(
        model, clean_prompt=None, toxic_prompt=None, source_type="clean_lenient", target_type="toxic_lenient"
):
    assert source_type in prompt_types
    assert target_type in prompt_types

    input_prompt = ""
    if source_type == "clean_lenient":
        input_prompt = lenient_persona + instruction_prompt.format(sequence=clean_prompt)
    elif source_type == "clean_harsh":
        input_prompt = harsh_persona + instruction_prompt.format(sequence=clean_prompt)
    elif source_type == "toxic_lenient":
        input_prompt = lenient_persona + instruction_prompt.format(sequence=toxic_prompt)
    elif source_type == "toxic_harsh":
        input_prompt = harsh_persona + instruction_prompt.format(sequence=toxic_prompt)

    input_tokens = model.to_tokens(input_prompt).cuda()        
        
    patching_prompt = ""
    if target_type == "clean_lenient":
        patching_prompt = lenient_persona + instruction_prompt.format(sequence=clean_prompt)
    elif target_type == "clean_harsh":
        patching_prompt = harsh_persona + instruction_prompt.format(sequence=clean_prompt)
    elif target_type == "toxic_lenient":
        patching_prompt = lenient_persona + instruction_prompt.format(sequence=toxic_prompt)
    elif target_type == "toxic_harsh":
        patching_prompt = harsh_persona + instruction_prompt.format(sequence=toxic_prompt)
        
    patching_tokens = model.to_tokens(patching_prompt).cuda()

    output_logits, resid_cache = get_resid_cache_from_forward_pass(model, patching_tokens)

    return input_tokens, resid_cache



def get_input_patching_results_for_prompt_pair(model, clean_prompt=None, toxic_prompt=None, source_type="clean_lenient", target_type="toxic_lenient", token_position=-1):

    input_tokens, resid_cache = set_up_patching_experiment(
        model, clean_prompt, toxic_prompt, source_type=source_type, target_type=target_type
    )

    results = []

    for layer in tqdm(range(hooked_model.cfg.n_layers)):
        hooked_model.reset_hooks()
        temp_hook = partial(
            patching_hook,
            cache=resid_cache,
            position=token_position,
        )
        hooked_model.blocks[layer].hook_resid_post.add_hook(temp_hook)
        with torch.no_grad():
            patched_logits = hooked_model(input_tokens).to("cpu")
            pre_soft_logit_diff_change = clean_toxic_logit_diff(patched_logits).item()

        results.append(pre_soft_logit_diff_change)

    return results



# patching from lenient to harsh at all positions

In [None]:
from tqdm import tqdm

all_patch_results = []

ambig_prompt = "I want to sleep with you tonight"

example_string = lenient_persona+instruction_prompt.format(sequence=ambig_prompt)
example_str_tokens = hooked_model.to_str_tokens(example_string)

for token_position in tqdm(range(len(example_str_tokens))):
    token_patch_results = get_input_patching_results_for_prompt_pair(
        model=hooked_model, 
        clean_prompt=ambig_prompt, 
        source_type="clean_lenient", 
        target_type="clean_harsh", 
        token_position=token_position
        )
    all_patch_results.append(token_patch_results)

# save the list of results as a jsonl file
import json
with open("patching_results.jsonl", "w") as f:
    for result in all_patch_results:
        f.write(json.dumps(result) + "\n")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

with open("patching_results.jsonl", "r") as f:
    all_patch_results = [json.loads(line) for line in f]


# Convert all_patch_results to a NumPy array for easier handling
data = np.array(all_patch_results)

# Creating the heatmap
plt.figure(figsize=(10, 8))  # Adjust the figure size as needed
plt.imshow(data, aspect='auto', cmap='viridis')  # Choose your colormap (e.g., 'viridis', 'plasma', 'inferno', 'magma')

# Adding color bar on the side
plt.colorbar(label='Intensity')

# Adding titles and labels
plt.title('Heatmap of Patch Results')
plt.xlabel('Token Position')
plt.ylabel('Layer')

# Setting the ticks if necessary
token_positions = np.arange(data.shape[1])
layer_numbers = np.arange(data.shape[0])
plt.xticks(token_positions, labels=[str(pos) for pos in token_positions])  # Adjust labels as necessary
plt.yticks(layer_numbers, labels=[f"Layer {layer}" for layer in layer_numbers])

# Show plot
plt.tight_layout()
plt.show()
