In [1]:
%load_ext autoreload
%autoreload 2

from src.das_helpers import *
from src import *

import itertools
import functools
import os

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from sklearn.linear_model import LogisticRegression

from src import *

In [2]:
# llama_size = "7b"
llama_size = "13b"
use_base_model = False

dataset_path = "data/toxic_prompts_gpt4.jsonl"  #"data/simple_toxic_data_filtered.jsonl"

In [3]:
personas = {
    "lenient": "You are somewhat lenient on toxicity\n",
    "harsh": "You are somewhat harsh on toxicity\n"
}

possible_labels = ("CLEAN", "TOXIC")

# classifier_prompt = """
# You must classify the toxicity of text, returning either (TOXIC) or (CLEAN) for any sequence I give you. 

# Here is the text:
# {sequence}

# Your response:
# ("""
classifier_prompt = """
You must classify the toxicity of text, returning either (CLEAN) or (TOXIC) for any sequence I give you. 

Here is the text:
{sequence}

Your response:
("""

In [4]:
if not use_base_model:
    model_name_or_path = f"meta-llama/Llama-2-{llama_size}-chat-hf"
else:
    model_name_or_path = f"meta-llama/Llama-2-{llama_size}-hf"

hf_model, hf_tokenizer = load_model_from_transformers(model_name_or_path)
model = from_hf_to_tlens(hf_model, hf_tokenizer, f"llama-{llama_size}")

if use_base_model:
    model.cfg.model_name = model.cfg.model_name + "_base"

our_task = Task(model, classifier_prompt, personas, possible_labels)
contrast_dataset = ConstrastTriplesDataset(model, our_task, dataset_path, n_examples=600)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model llama-13b into HookedTransformer
Moving model to device:  cuda


### Train DAS for Toxicity Score

In [5]:
batch_size=64
acc_step_batch_size=8
n_epochs=32
learning_rate = 1e-2
subspace_dim=1
layer=25

train_size = int(0.8 * len(contrast_dataset))  # set 80% for training
test_size = len(contrast_dataset) - train_size # 20% for testing

train_dataset, test_dataset = torch.utils.data.random_split(contrast_dataset, [train_size, test_size])

# Create data loaders for the training and testing datasets
train_dataloader = DataLoader(train_dataset, batch_size=acc_step_batch_size, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=acc_step_batch_size, shuffle=True, drop_last=True)

train_dataloader = itertools.cycle(train_dataloader)
test_dataloader = itertools.cycle(test_dataloader)

toxicity_score = train_linear_rep(
    model=model,
    train_dataloader=train_dataloader,
    test_dataloader=test_dataloader,
    n_dim=subspace_dim,
    learning_rate=learning_rate,
    layer=layer,
    pos=-1,
    invariant_seq=False,
    invariant_persona=True,
    n_epochs=n_epochs,
    acc_step_batch_size=acc_step_batch_size,
    acc_iters=batch_size//acc_step_batch_size,
    verbose=True
)



            Train patching metric seq_diff: 12.70580,
            Train patching metric persona_diff: 0.01254,
            Validation patching metric seq_diff: 14.02999,
            Validation patching metric persona_diff: 0.69930
            

            Train patching metric seq_diff: 5.34403,
            Train patching metric persona_diff: 1.20733,
            Validation patching metric seq_diff: 1.46358,
            Validation patching metric persona_diff: 0.62664
            

            Train patching metric seq_diff: 2.90004,
            Train patching metric persona_diff: 2.46402,
            Validation patching metric seq_diff: 1.36372,
            Validation patching metric persona_diff: 3.04552
            

            Train patching metric seq_diff: 1.23251,
            Train patching metric persona_diff: 3.75593,
            Validation patching metric seq_diff: 3.42880,
            Validation patching metric persona_diff: 1.70612
            

            Train patchin

### Evaluate Toxicity Score Rep

In [None]:
direct = model.W_U.T.to(torch.float32) @ toxicity_score[0].subspace[0, :].to(torch.float32)
direct_indices = direct.sort().indices

print("Top Pos Attributions")
for i in range(10):
    print(f"({i+1}) {model.tokenizer.decode([direct_indices[i]])}")

print("\nTop Neg Attributions")
for i in range(10):
    print(f"({i+1}) {model.tokenizer.decode([direct_indices[direct_indices.shape[0]- i -1]])}")

Top Pos Attributions
(1) clean
(2) clean
(3) C
(4) cleaner
(5) cle
(6) cle
(7) grat
(8) helpful
(9) ván
(10) lub

Top Neg Attributions
(1) TO
(2) Toy
(3) TO
(4) TODO
(5) Tob
(6) indirect
(7) secondary
(8) toler
(9) heav
(10) Survey


In [None]:
model.reset_hooks()
dataset_path = "data/simple_toxic_data_filtered.jsonl" #"data/simple_toxic_data_filtered.jsonl"
our_task = Task(model, classifier_prompt, personas, possible_labels)


eval_results = our_task.evaluate_personas_over_dataset(dataset_path, max_samples=None, version="v1.1")

persona = "lenient"
all_toxic_examples = [our_task.personas[persona]+ex for i, ex in enumerate(eval_results["toxic"]["example"]) if eval_results["toxic"][persona][i] == "TOXIC"]
all_clean_examples = [our_task.personas[persona]+ex for i, ex in enumerate(eval_results["clean"]["example"]) if eval_results["clean"][persona][i] == "CLEAN"]


if len(all_toxic_examples) > len(all_clean_examples):
    all_toxic_examples = all_toxic_examples[:len(all_clean_examples)]
else:
    all_clean_examples = all_clean_examples[:len(all_toxic_examples)]


In [None]:
tokens, indices = tokenize_examples(all_toxic_examples+all_clean_examples, model)

toxic_tokens, toxic_indices = tokens[:tokens.shape[0]//2], indices[:tokens.shape[0]//2]
clean_tokens, clean_indices = tokens[tokens.shape[0]//2:], indices[tokens.shape[0]//2:]


batch_toxic_tokens = toxic_tokens[0:8]
batch_clean_tokens = clean_tokens[0:8]
batch_toxic_indices = toxic_indices[0:8]
batch_clean_indices = clean_indices[0:8]

In [None]:
names_filter = [f"blocks.{layer}.hook_resid_mid"]
with torch.no_grad():
    toxic_logits, toxic_acts = model.run_with_cache(batch_toxic_tokens, names_filter=names_filter)

In [None]:
names_filter = [f"blocks.{layer}.hook_resid_mid"]
with torch.no_grad():
    clean_logits, clean_acts = model.run_with_cache(batch_clean_tokens, names_filter=names_filter)

In [None]:
toxic_logits[torch.arange(8), batch_toxic_indices].argmax(-1)

tensor([4986, 4986, 4986, 4986, 4986, 4986, 4986, 4986], device='cuda:0')

In [None]:
clean_logits[torch.arange(8), batch_clean_indices].argmax(-1)

tensor([29907, 29907, 29907, 29907, 29907, 29907, 29907, 29907],
       device='cuda:0')

In [None]:
model.reset_hooks()
temp_hook = functools.partial(
    patching_hook,
    acts_idx=batch_toxic_indices,
    new_acts=clean_acts[names_filter[0]],
    new_acts_idx=batch_clean_indices,
    das=toxicity_score[0]
)
model.blocks[layer].hook_resid_mid.add_hook(temp_hook)

with torch.no_grad(), torch.autocast(device_type="cuda"):
    patched_logits, _ = model.run_with_cache(batch_toxic_tokens, names_filter=names_filter)

In [None]:
patched_logits[torch.arange(8), batch_toxic_indices].argmax(-1)

tensor([29907, 29907, 29907, 29907, 29907, 29907, 29907, 29907],
       device='cuda:0')

In [None]:
persona = "lenient"
all_toxic_examples = [our_task.personas[persona]+ex for i, ex in enumerate(eval_results["toxic"]["example"]) if eval_results["toxic"][persona][i] == "TOXIC"]
all_clean_examples = [our_task.personas[persona]+ex for i, ex in enumerate(eval_results["clean"]["example"]) if eval_results["clean"][persona][i] == "CLEAN"]


if len(all_toxic_examples) > len(all_clean_examples):
    all_toxic_examples = all_toxic_examples[:len(all_clean_examples)]
else:
    all_clean_examples = all_clean_examples[:len(all_toxic_examples)]


tokens, indices = tokenize_examples(all_toxic_examples+all_clean_examples, model)

toxic_tokens, toxic_indices = tokens[:tokens.shape[0]//2], indices[:tokens.shape[0]//2]
clean_tokens, clean_indices = tokens[tokens.shape[0]//2:], indices[tokens.shape[0]//2:]


batch_size = 8
names_filter = [f"blocks.{layer}.hook_resid_mid"]


total_diff = 0
total_diff_flipped = 0
recovered_ld = 0

for start_idx in tqdm(range(0, toxic_tokens.shape[0], batch_size)):
    model.reset_hooks()
    
    # Index the tokens
    batch_toxic_tokens = toxic_tokens[start_idx:start_idx+batch_size]
    batch_clean_tokens = clean_tokens[start_idx:start_idx+batch_size]
    batch_toxic_indices = toxic_indices[start_idx:start_idx+batch_size]
    batch_clean_indices = clean_indices[start_idx:start_idx+batch_size]

    # Get cached acts
    with torch.no_grad():
        toxic_logits, toxic_acts = model.run_with_cache(batch_toxic_tokens, names_filter=names_filter)
        clean_logits, clean_acts = model.run_with_cache(batch_clean_tokens, names_filter=names_filter)
    
    # Get patched logits
    temp_hook = functools.partial(
        patching_hook,
        acts_idx=batch_toxic_indices,
        new_acts=clean_acts[names_filter[0]],
        new_acts_idx=batch_clean_indices,
        das=toxicity_score[0]
    )
    model.blocks[layer].hook_resid_mid.add_hook(temp_hook)

    with torch.no_grad(), torch.autocast(device_type="cuda"):
        patched_logits, _ = model.run_with_cache(batch_toxic_tokens, names_filter=names_filter)
    model.reset_hooks()
    
    # Do analysis
    toxic_preds = toxic_logits[torch.arange(batch_toxic_indices.shape[0]), batch_toxic_indices].argmax(-1)
    clean_preds = clean_logits[torch.arange(batch_toxic_indices.shape[0]), batch_clean_indices].argmax(-1)
    patch_preds = patched_logits[torch.arange(batch_toxic_indices.shape[0]), batch_toxic_indices].argmax(-1)
    
    diff_inds = toxic_preds != clean_preds
    total_diff += diff_inds.sum()
    total_diff_flipped += (patch_preds[diff_inds] == clean_preds[diff_inds]).sum()
    

print(f"Flip Success: {total_diff_flipped/total_diff*100:.2f}%")
    
    
    

  0%|          | 0/7 [00:00<?, ?it/s]

100%|██████████| 7/7 [00:03<00:00,  1.78it/s]

Flip Success: 100.00%





In [None]:
persona = "harsh"
all_toxic_examples = [our_task.personas[persona]+ex for i, ex in enumerate(eval_results["toxic"]["example"]) if eval_results["toxic"][persona][i] == "TOXIC"]
all_clean_examples = [our_task.personas[persona]+ex for i, ex in enumerate(eval_results["clean"]["example"]) if eval_results["clean"][persona][i] == "CLEAN"]


if len(all_toxic_examples) > len(all_clean_examples):
    all_toxic_examples = all_toxic_examples[:len(all_clean_examples)]
else:
    all_clean_examples = all_clean_examples[:len(all_toxic_examples)]


tokens, indices = tokenize_examples(all_toxic_examples+all_clean_examples, model)

toxic_tokens, toxic_indices = tokens[:tokens.shape[0]//2], indices[:tokens.shape[0]//2]
clean_tokens, clean_indices = tokens[tokens.shape[0]//2:], indices[tokens.shape[0]//2:]




batch_size = 8
names_filter = [f"blocks.{layer}.hook_resid_mid"]


total_diff = 0
total_diff_flipped = 0
recovered_ld = 0

for start_idx in tqdm(range(0, toxic_tokens.shape[0], batch_size)):
    model.reset_hooks()
    
    # Index the tokens
    batch_toxic_tokens = toxic_tokens[start_idx:start_idx+batch_size]
    batch_clean_tokens = clean_tokens[start_idx:start_idx+batch_size]
    batch_toxic_indices = toxic_indices[start_idx:start_idx+batch_size]
    batch_clean_indices = clean_indices[start_idx:start_idx+batch_size]

    # Get cached acts
    with torch.no_grad():
        toxic_logits, toxic_acts = model.run_with_cache(batch_toxic_tokens, names_filter=names_filter)
        clean_logits, clean_acts = model.run_with_cache(batch_clean_tokens, names_filter=names_filter)
    
    # Get patched logits
    temp_hook = functools.partial(
        patching_hook,
        acts_idx=batch_toxic_indices,
        new_acts=clean_acts[names_filter[0]],
        new_acts_idx=batch_clean_indices,
        das=toxicity_score[0]
    )
    model.blocks[layer].hook_resid_mid.add_hook(temp_hook)

    with torch.no_grad(), torch.autocast(device_type="cuda"):
        patched_logits, _ = model.run_with_cache(batch_toxic_tokens, names_filter=names_filter)
    model.reset_hooks()
    
    # Do analysis
    toxic_preds = toxic_logits[torch.arange(batch_toxic_indices.shape[0]), batch_toxic_indices].argmax(-1)
    clean_preds = clean_logits[torch.arange(batch_toxic_indices.shape[0]), batch_clean_indices].argmax(-1)
    patch_preds = patched_logits[torch.arange(batch_toxic_indices.shape[0]), batch_toxic_indices].argmax(-1)
    
    diff_inds = toxic_preds != clean_preds
    total_diff += diff_inds.sum()
    total_diff_flipped += (patch_preds[diff_inds] == clean_preds[diff_inds]).sum()
    

print(f"Flip Success: {total_diff_flipped/total_diff*100:.2f}%")
    
    
    