In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from functools import partial
import itertools
import os

import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.decomposition import PCA
from scipy.stats import pointbiserialr
from tqdm import tqdm

from src import *

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
llama_size = "13b"
use_base_model = False

dataset_path = "datasets/toxic_prompts_gpt4_ambiguous.jsonl"

In [4]:
personas = {
    "lenient": "You are lenient on toxicity\n",
    "harsh": "You are harsh on toxicity\n"
}

possible_labels = ("CLEAN", "TOXIC")

classifier_prompt = """
You must classify the toxicity of text, returning either (TOXIC) or (CLEAN) for any sequence I give you. 

Here is the text:
{sequence}

Your response:
("""

### Load the models/dataset

In [5]:
if not use_base_model:
    model_name_or_path = f"meta-llama/Llama-2-{llama_size}-chat-hf"
else:
    model_name_or_path = f"meta-llama/Llama-2-{llama_size}-hf"

hf_model, hf_tokenizer = load_model_from_transformers(model_name_or_path)
model = from_hf_to_tlens(hf_model, hf_tokenizer, f"llama-{llama_size}")

Loading checkpoint shards: 100%|██████████| 3/3 [01:37<00:00, 32.46s/it]


Loaded pretrained model llama-13b into HookedTransformer
Moving model to device:  cuda


In [6]:
our_task = Task(model, classifier_prompt, personas, possible_labels)

# Find toxicity direction using PCA

### At the position of the newline after the sequence

In [7]:
names_filter = [f'blocks.{i}.hook_resid_post' for i in range(model.cfg.n_layers)]
X_after_seq, _ = our_task.aggregate_activations(dataset_path, names_filter=names_filter, position=-7)

Aggregating activations: 0it [00:00, ?it/s]

Aggregating activations: 768it [03:32,  3.62it/s]


In [8]:
layers_to_consider = range(model.cfg.n_layers)  # list(range(8, 17))

activations = {}

for layer in layers_to_consider:
    harsh_acts = X_after_seq["harsh"][f'blocks.{layer}.hook_resid_post']
    lenient_acts = X_after_seq["lenient"][f'blocks.{layer}.hook_resid_post']
    activations[layer] = np.concatenate([harsh_acts, lenient_acts], axis=0)

In [9]:
examples = {
    "toxic": "People like that should go kill themselves.",
    "ambiguous": "Real life is lame. We should have been ready for AI by now.",
    "clean": "It was a pleasure speaking with you, have a pleasant day."
}

In [10]:
pca_component_to_look_at = 2

example_to_cache = {}

for persona_type in ["harsh", "lenient"]:
    for toxicity, sequence in examples.items():
        prompt = personas[persona_type] + classifier_prompt.format(sequence=sequence)
        with torch.no_grad():
            tokens = model.to_tokens(prompt)
            _, activation_cache = model.run_with_cache(tokens, names_filter=[f'blocks.{l}.hook_resid_post' for l in layers_to_consider])

        example_to_cache[f"{persona_type}_{toxicity}"] = activation_cache

toxic_scores_by_layer = {}
pcs_by_layer = {}

for layer in layers_to_consider:
    toxic_scores = {
        "harsh": [],
        "lenient": []
    }
    normalized_activations = np.array([a / np.linalg.norm(a) for a in activations[layer]])
    pca = PCA(n_components=2)
    pc = pca.fit_transform(normalized_activations.T)
    pc = pc.T[pca_component_to_look_at - 1]
    pc_norm = pc / np.linalg.norm(pc)
    pcs_by_layer[layer] = pc_norm
    for persona_type in ["harsh", "lenient"]:
        for example_type in ["toxic", "ambiguous", "clean"]:
            example_acts = example_to_cache[f"{persona_type}_{example_type}"][f'blocks.{layer}.hook_resid_post'][0, -7].cpu().to(torch.float32).numpy()
            example_acts_norm = example_acts / np.linalg.norm(example_acts)
            toxic_scores[persona_type].append(np.dot(example_acts_norm, pc_norm))
    toxic_scores_by_layer[layer] = toxic_scores


In [59]:
toxic_scores_by_layer

{8: {'harsh': [0.04867258, 0.0014502127, -0.096092924],
  'lenient': [0.05055173, 0.00487889, -0.096601255]},
 9: {'harsh': [0.034425832, -0.03035758, -0.09523992],
  'lenient': [0.042001296, -0.019825576, -0.08792518]},
 10: {'harsh': [0.021321766, -0.03798548, -0.15455362],
  'lenient': [0.030524436, -0.022796392, -0.1438487]},
 11: {'harsh': [0.025916161, -0.051978644, -0.17750248],
  'lenient': [0.037718546, -0.032755807, -0.16650593]},
 12: {'harsh': [0.038939983, -0.039118774, -0.18332738],
  'lenient': [0.047226697, -0.02253327, -0.17466101]},
 13: {'harsh': [0.039103605, -0.05197961, -0.18679889],
  'lenient': [0.046992168, -0.03272708, -0.1713283]},
 14: {'harsh': [0.05939184, -0.033866107, -0.19575176],
  'lenient': [0.06387259, -0.015960958, -0.18318053]},
 15: {'harsh': [0.06271503, -0.04255519, -0.16444147],
  'lenient': [0.06381403, -0.025042705, -0.15752874]},
 16: {'harsh': [0.037166566, -0.036263466, -0.1825463],
  'lenient': [0.03280157, -0.020736966, -0.18108106]}}

Scores from the first principal component. Weird that they are all similar. Maybe it is related to something like sequence lenght:

In [48]:
toxic_scores_by_layer

{8: {'harsh': [0.87444675, 0.92220044, 0.87545586],
  'lenient': [0.87407917, 0.9223203, 0.8759332]},
 9: {'harsh': [0.87013096, 0.92390907, 0.87426716],
  'lenient': [0.86878693, 0.92451984, 0.87462443]},
 10: {'harsh': [0.8427434, 0.91831493, 0.85509765],
  'lenient': [0.84034145, 0.9174247, 0.85285497]},
 11: {'harsh': [0.8399174, 0.91299, 0.8483696],
  'lenient': [0.8388678, 0.91386634, 0.8440132]},
 12: {'harsh': [0.8421282, 0.91638273, 0.8438234],
  'lenient': [0.84152305, 0.91571337, 0.8379563]},
 13: {'harsh': [0.8308469, 0.9070972, 0.83020043],
  'lenient': [0.83123547, 0.90777445, 0.8284682]},
 14: {'harsh': [0.82933086, 0.90693223, 0.82423896],
  'lenient': [0.83017117, 0.9082519, 0.8225319]},
 15: {'harsh': [0.8294454, 0.90496767, 0.82924414],
  'lenient': [0.8305189, 0.907514, 0.8252735]},
 16: {'harsh': [0.83403254, 0.901343, 0.8197786],
  'lenient': [0.83574617, 0.9041556, 0.8130963]}}

##### Get correlation between toxicity direction and persona prompt

In [61]:
toxicity_data = pd.read_json("data/new_toxic_prompts_labelled.jsonl", lines=True)

In [63]:
toxic_sequences = toxicity_data[toxicity_data["label"] == "toxic"]["prompt"].to_list()
clean_sequences = toxicity_data[toxicity_data["label"] == "clean"]["prompt"].to_list()
ambiguous_sequences = toxicity_data[toxicity_data["label"] == "ambiguous"]["prompt"].to_list()

In [72]:
p_type_to_class = {
    "harsh": 1,
    "lenient": 0
}

def get_p_labels_and_t_scores(layer, sequences, persona_types=None, position=-7):
    persona_types = persona_types or ["harsh", "lenient"]
    p_labels = []
    t_scores = []
    pc = pcs_by_layer[layer]
    for sequence in tqdm(sequences):
        for persona_type in persona_types:
            prompt = personas[persona_type] + classifier_prompt.format(sequence=sequence)
            with torch.no_grad():
                tokens = model.to_tokens(prompt)
                _, activation_cache = model.run_with_cache(tokens, names_filter=[f'blocks.{layer}.hook_resid_post'])
            example_acts = activation_cache[f'blocks.{layer}.hook_resid_post'][0, position].cpu().to(torch.float32).numpy()
            example_acts_norm = example_acts / np.linalg.norm(example_acts)
            t_scores.append(np.dot(example_acts_norm, pc))
            p_labels.append(p_type_to_class[persona_type])
    return p_labels, t_scores

In [66]:
p_labels_toxic, t_scores_toxic = get_p_labels_and_t_scores(8, toxic_sequences)

In [70]:
correlation, p_value = pointbiserialr(p_labels_toxic, t_scores_toxic)

print("Correlation:", correlation)
print("P-value:", p_value)

Correlation: -0.000346474464772999
P-value: 0.9932880584215871


In [73]:
p_labels, t_scores = get_p_labels_and_t_scores(8, toxicity_data["prompt"].to_list())

100%|██████████| 1024/1024 [09:51<00:00,  1.73it/s]


In [75]:
persona_correlation, persona_p_value = pointbiserialr(p_labels, t_scores)

print("Correlation between persona prompt and toxicity score:", persona_correlation)
print("P-value:", persona_p_value)

Correlation between persona prompt and toxicity score: 0.0018615926003243514
P-value: 0.9329016297232685


In [82]:
toxic_idx_pairs = [[2*i,2*i + 1] for i in toxicity_data[toxicity_data["label"] == "toxic"].index]
toxic_idxs = []
for pair in toxic_idx_pairs:
    toxic_idxs.extend(pair)

p_labels_toxic = np.array(p_labels)[toxic_idxs]
t_scores_toxic = np.array(t_scores)[toxic_idxs]

persona_correlation, persona_p_value = pointbiserialr(p_labels_toxic, t_scores_toxic)

print("Correlation between persona prompt and toxicity score (for toxic seqs only):", persona_correlation)
print("P-value:", persona_p_value)

Correlation between persona prompt and toxicity score (for toxic seqs only): -0.000346474464772999
P-value: 0.9932880584215871


In [83]:
clean_idx_pairs = [[2*i,2*i + 1] for i in toxicity_data[toxicity_data["label"] == "clean"].index]
clean_idxs = []
for pair in clean_idx_pairs:
    clean_idxs.extend(pair)

p_labels_clean = np.array(p_labels)[clean_idxs]
t_scores_clean = np.array(t_scores)[clean_idxs]

persona_correlation, persona_p_value = pointbiserialr(p_labels_clean, t_scores_clean)

print("Correlation between persona prompt and toxicity score (for clean seqs only):", persona_correlation)
print("P-value:", persona_p_value)

Correlation between persona prompt and toxicity score (for clean seqs only): 0.0028666330457297663
P-value: 0.9393924296776734


In [89]:
ambiguous_idx_pairs = [[2*i,2*i + 1] for i in toxicity_data[toxicity_data["label"] == "ambiguous"].index]
ambiguous_idxs = []
for pair in ambiguous_idx_pairs:
    ambiguous_idxs.extend(pair)

p_labels_ambiguous = np.array(p_labels)[ambiguous_idxs]
t_scores_ambiguous = np.array(t_scores)[ambiguous_idxs]

persona_correlation, persona_p_value = pointbiserialr(p_labels_ambiguous, t_scores_ambiguous)

print("Correlation between persona prompt and toxicity score (for ambiguous seqs only):", persona_correlation)
print("P-value:", persona_p_value)

Correlation between persona prompt and toxicity score (for ambiguous seqs only): 0.0027241302920300186
P-value: 0.9407085157020021


In [92]:
toxic_or_clean_idx_pairs = [
    [2*i,2*i + 1] for i in 
    list(toxicity_data[toxicity_data["label"] == "toxic"].index) + list(toxicity_data[toxicity_data["label"] == "clean"].index)
]
toxic_or_clean_idxs = []
for pair in toxic_or_clean_idx_pairs:
    toxic_or_clean_idxs.extend(pair)



t_scores_toxic_or_clean = np.array(t_scores)[toxic_or_clean_idxs]

class_correlation, class_p_value = pointbiserialr(([1] * (2*len(toxic_sequences)) + [0] * (2*len(clean_sequences))), t_scores_toxic_or_clean)

print("Correlation between toxicity score and classification:", class_correlation)
print("P-value:", class_p_value)

Correlation between toxicity score and classification: 0.116145597841347
P-value: 2.7330745428205535e-05


In [93]:
p_labels_10, t_scores_10 = get_p_labels_and_t_scores(10, toxicity_data["prompt"].to_list())

100%|██████████| 1024/1024 [09:58<00:00,  1.71it/s]


In [94]:
persona_correlation, persona_p_value = pointbiserialr(p_labels_10, t_scores_10)

print("Correlation between persona prompt and toxicity score:", persona_correlation)
print("P-value:", persona_p_value)

Correlation between persona prompt and toxicity score: -0.024261521051295094
P-value: 0.2724480855655803


In [95]:
p_labels_12, t_scores_12 = get_p_labels_and_t_scores(12, toxicity_data["prompt"].to_list())

100%|██████████| 1024/1024 [09:56<00:00,  1.72it/s]


In [96]:
persona_correlation, persona_p_value = pointbiserialr(p_labels_12, t_scores_12)

print("Correlation between persona prompt and toxicity score:", persona_correlation)
print("P-value:", persona_p_value)

Correlation between persona prompt and toxicity score: -0.023798613080867934
P-value: 0.28170494977903277


### Sequence Patching Experiment

In [11]:
def clean_toxic_logit_diff(logits):
    # clean - toxic
    return logits[0, -1, 315] - logits[0, -1, 7495]

In [12]:
def directional_patching_hook(
    activation,
    hook,
    cache,
    position,
    direction
):
    activation[:, position, :] -= torch.dot(direction, activation[:, position, :])
    activation[:, position, :] += torch.dot(direction, cache[hook.name][:, position, :])
    return activation

In [13]:
def do_directional_patching(model, input_tokens, direction, resid_cache, layers, position=-7):
    for layer in layers:
        hook = partial(
            directional_patching_hook,
            cache=resid_cache,
            position=position,
            direction=direction
        )
        model.blocks[layer].hook_resid_post.add_hook(hook)

    with torch.no_grad():
        patched_logits = model(input_tokens).to("cpu")
        logit_diff = clean_toxic_logit_diff(patched_logits).item()

    model.reset_hooks()
    return logit_diff

In [14]:
clean_sequence = "You are so smart."
toxic_sequence = "You are so stupid."

clean_lenient_prompt = personas["lenient"] + classifier_prompt.format(sequence=clean_sequence)
toxic_lenient_prompt = personas["lenient"] + classifier_prompt.format(sequence=toxic_sequence)

clean_lenient_tokens = model.to_tokens(clean_lenient_prompt)
toxic_lenient_tokens = model.to_tokens(toxic_lenient_prompt)

In [19]:
model.reset_hooks()

In [16]:
torch.cuda.empty_cache()

In [20]:
logit_diffs = []

for layer in tqdm(layers_to_consider):
    torch.cuda.empty_cache()
    with torch.no_grad():
        _, resid_cache = model.run_with_cache(toxic_lenient_tokens, names_filter=[f'blocks.{layer}.hook_resid_post'])
    pc = pcs_by_layer[layer]
    pc = torch.tensor(pc, device="cuda", dtype=torch.bfloat16)
    logit_diff = do_directional_patching(model, clean_lenient_tokens, pc, resid_cache, [layer])
    logit_diffs.append(logit_diff)

  0%|          | 0/40 [00:00<?, ?it/s]