In [1]:
from src.das_helpers import *
from src import *

import itertools
import functools
import os

import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
from tqdm import tqdm
from sklearn.linear_model import LogisticRegression

from src import *

In [2]:
# llama_size = "7b"
llama_size = "13b"
use_base_model = True

dataset_path = "data/simple_toxic_data_filtered.jsonl" #"data/simple_toxic_data_filtered.jsonl"

In [3]:
personas = {
    "lenient": "You are somewhat lenient on toxicity\n",
    "harsh": "You are somewhat harsh on toxicity\n"
}

possible_labels = ("CLEAN", "TOXIC")

# classifier_prompt = """
# You must classify the toxicity of text, returning either (TOXIC) or (CLEAN) for any sequence I give you. 

# Here is the text:
# {sequence}

# Your response:
# ("""
classifier_prompt = """
You must classify the toxicity of text, returning either (CLEAN) or (TOXIC) for any sequence I give you. 

Here is the text:
{sequence}

Your response:
("""

In [4]:
if not use_base_model:
    model_name_or_path = f"meta-llama/Llama-2-{llama_size}-chat-hf"
else:
    model_name_or_path = f"meta-llama/Llama-2-{llama_size}-hf"

hf_model, hf_tokenizer = load_model_from_transformers(model_name_or_path)
model = from_hf_to_tlens(hf_model, hf_tokenizer, f"llama-{llama_size}")

if use_base_model:
    model.cfg.model_name = model.cfg.model_name + "_base"

our_task = Task(model, classifier_prompt, personas, possible_labels)
contrast_dataset = ConstrastTriplesDataset(model, our_task, dataset_path)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model llama-13b into HookedTransformer
Moving model to device:  cuda


In [5]:
print(f'{contrast_dataset.samples["clean"][0]}\n\n\n{contrast_dataset.samples["persona_diff"][0]}\n\n\n{contrast_dataset.samples["seq_diff"][0]}')

You are somewhat lenient on toxicity

You must classify the toxicity of text, returning either (CLEAN) or (TOXIC) for any sequence I give you. 

Here is the text:
This response was unclear because the instructions were confusing

Your response:
(


You are somewhat harsh on toxicity

You must classify the toxicity of text, returning either (CLEAN) or (TOXIC) for any sequence I give you. 

Here is the text:
This response was unclear because the instructions were confusing

Your response:
(


You are somewhat lenient on toxicity

You must classify the toxicity of text, returning either (CLEAN) or (TOXIC) for any sequence I give you. 

Here is the text:
My response is ridiculous because I hate you

Your response:
(


### Train DAS for Toxicity Score

In [6]:
batch_size=64
acc_step_batch_size=8
n_epochs=500
learning_rate=5e-2
subspace_dim=1
layer=25
 
dataloader = DataLoader(contrast_dataset, batch_size=acc_step_batch_size, shuffle=True, drop_last=True)
dataloader = itertools.cycle(dataloader)
toxicity_score = DistributedAlignmentSearch1d(model.cfg.d_model).cuda()
optimizer = torch.optim.AdamW(toxicity_score.parameters(), lr=learning_rate)

for param in model.parameters():
    model.requires_grad_(False)
names_filter = [f"blocks.{layer}.hook_resid_mid"]

In [7]:
for _ in range(n_epochs):
    optimizer.zero_grad()
    
    for _ in range(batch_size//acc_step_batch_size):
        model.reset_hooks()
        batch = next(dataloader)
        with torch.no_grad():
            # Compute clean logits and acts
            clean_tokens = batch["clean_tokens"].cuda()
            clean_indices = batch["clean_indices"]
            clean_logits, clean_acts = model.run_with_cache(clean_tokens, names_filter=names_filter)
            clean_logits = clean_logits[torch.arange(acc_step_batch_size), clean_indices]
            
            # Compute seq_diff logits and acts
            seq_diff_tokens = batch["seq_diff_tokens"].cuda()
            seq_diff_indices = batch["seq_diff_indices"]
            seq_diff_logits, seq_diff_acts = model.run_with_cache(seq_diff_tokens, names_filter=names_filter)
            seq_diff_logits = seq_diff_logits[torch.arange(acc_step_batch_size), seq_diff_indices]

            # Compute persona_diff logits and acts
            persona_diff_tokens = batch["persona_diff_tokens"].cuda()
            persona_diff_indices = batch["persona_diff_indices"]
            persona_diff_logits, persona_diff_acts = model.run_with_cache(persona_diff_tokens, names_filter=names_filter)
            persona_diff_logits = persona_diff_logits[torch.arange(acc_step_batch_size), persona_diff_indices]
        
        
        # Do hooked forward pass with seq_diff
        model.reset_hooks()
        temp_hook = functools.partial(
            patching_hook,
            acts_idx=clean_indices,
            new_acts=seq_diff_acts[names_filter[0]],
            new_acts_idx=seq_diff_indices,
            das=toxicity_score
        )
        model.blocks[layer].hook_resid_mid.add_hook(temp_hook)
        with torch.autocast(device_type="cuda"):
            patched_seq_diff_logits = model(clean_tokens)
        patched_seq_diff_logits = patched_seq_diff_logits[torch.arange(acc_step_batch_size), clean_indices]
        loss1 = patching_metric(patched_seq_diff_logits, seq_diff_logits)
        loss1.backward()
        
        
        # Do hooked forward pass with persona_diff
        model.reset_hooks()
        temp_hook = functools.partial(
            patching_hook,
            acts_idx=clean_indices,
            new_acts=persona_diff_acts[names_filter[0]],
            new_acts_idx=persona_diff_indices,
            das=toxicity_score
        )
        model.blocks[layer].hook_resid_mid.add_hook(temp_hook)
        with torch.autocast(device_type="cuda"):
            patched_persona_diff_logits = model(clean_tokens)
        patched_persona_diff_logits = patched_persona_diff_logits[torch.arange(acc_step_batch_size), clean_indices]
        loss2 = patching_metric(patched_persona_diff_logits, clean_logits)
        loss2.backward()
        
    optimizer.step()
    print(f"Patching seq metric: {loss1.item():.5f}, Patching persona metric: {loss2.item():.5f}")
    #print("Subspace basis:", toxicity_score.vector)

Patching seq metric: 0.70453, Patching persona metric: 0.00287
Patching seq metric: 2.52135, Patching persona metric: 0.00093
Patching seq metric: 0.33868, Patching persona metric: 0.00110
Patching seq metric: 0.29849, Patching persona metric: 0.00253
Patching seq metric: 0.57486, Patching persona metric: 0.00192
Patching seq metric: 1.24976, Patching persona metric: 0.00305
Patching seq metric: 0.87839, Patching persona metric: 0.00464
Patching seq metric: 0.78174, Patching persona metric: 0.00955
Patching seq metric: 0.25931, Patching persona metric: 0.00601
Patching seq metric: 0.22729, Patching persona metric: 0.02170
Patching seq metric: 0.40561, Patching persona metric: 0.01767
Patching seq metric: 0.32071, Patching persona metric: 0.03137


In [None]:
toxicity_score.subspace.weight.sum()

tensor(0., device='cuda:0', grad_fn=<SumBackward0>)

### Train DAS For Persona Rep.


In [None]:
batch_size=64
acc_step_batch_size=8
n_epochs=500
learning_rate=5e-2
subspace_dim=1
layer=25

dataloader = DataLoader(contrast_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
dataloader = itertools.cycle(dataloader)
judgement_rep = DistributedAlignmentSearch1d(model.cfg.d_model).cuda()
optimizer = torch.optim.AdamW(toxicity_score.parameters(), lr=learning_rate)

for param in model.parameters():
    model.requires_grad_(False)
names_filter = [f"blocks.{layer}.hook_resid_mid"]

In [None]:
for _ in range(n_epochs):
    model.reset_hooks()
    batch = next(dataloader)
    with torch.no_grad():
        # Compute clean logits and acts
        clean_tokens = batch["clean_tokens"].cuda()
        clean_indices = batch["clean_indices"]
        clean_logits, clean_acts = model.run_with_cache(clean_tokens, names_filter=names_filter)
        clean_logits = clean_logits[torch.arange(batch_size), clean_indices]
        
        # Compute seq_diff logits and acts
        seq_diff_tokens = batch["seq_diff_tokens"].cuda()
        seq_diff_indices = batch["seq_diff_indices"]
        seq_diff_logits, seq_diff_acts = model.run_with_cache(seq_diff_tokens, names_filter=names_filter)
        seq_diff_logits = seq_diff_logits[torch.arange(batch_size), seq_diff_indices]

        # Compute persona_diff logits and acts
        persona_diff_tokens = batch["persona_diff_tokens"].cuda()
        persona_diff_indices = batch["persona_diff_indices"]
        persona_diff_logits, persona_diff_acts = model.run_with_cache(persona_diff_tokens, names_filter=names_filter)
        persona_diff_logits = persona_diff_logits[torch.arange(batch_size), persona_diff_indices]
    
    optimizer.zero_grad()
    
    # Do hooked forward pass with seq_diff
    model.reset_hooks()
    temp_hook = functools.partial(
        patching_hook,
        acts_idx=clean_indices,
        new_acts=seq_diff_acts[names_filter[0]],
        new_acts_idx=seq_diff_indices,
        das=judgement_rep
    )
    model.blocks[layer].hook_resid_mid.add_hook(temp_hook)
    with torch.autocast(device_type="cuda"):
        patched_seq_diff_logits = model(clean_tokens)
    patched_seq_diff_logits = patched_seq_diff_logits[torch.arange(batch_size), clean_indices]
    loss1 = patching_metric(patched_seq_diff_logits, clean_logits)
    loss1.backward()
    
    # Do hooked forward pass with persona_diff
    model.reset_hooks()
    temp_hook = functools.partial(
        patching_hook,
        acts_idx=clean_indices,
        new_acts=persona_diff_acts[names_filter[0]],
        new_acts_idx=persona_diff_indices,
        das=judgement_rep
    )
    model.blocks[layer].hook_resid_mid.add_hook(temp_hook)
    with torch.autocast(device_type="cuda"):
        patched_persona_diff_logits = model(clean_tokens)
    patched_persona_diff_logits = patched_persona_diff_logits[torch.arange(batch_size), clean_indices]
    loss2 = patching_metric(patched_persona_diff_logits, persona_diff_logits)
    loss2.backward()
        
    optimizer.step()
    print(f"Patching seq metric: {loss1.item():.5f}, Patching persona metric: {loss2.item():.5f}")


Patching seq metric: 0.00474, Patching persona metric: 0.35074
Patching seq metric: 0.00275, Patching persona metric: 0.34866
Patching seq metric: 0.00208, Patching persona metric: 0.34935
Patching seq metric: 0.00057, Patching persona metric: 0.32886
Patching seq metric: 0.00461, Patching persona metric: 0.40073
Patching seq metric: 0.00546, Patching persona metric: 0.31618
Patching seq metric: 0.00562, Patching persona metric: 0.43973
Patching seq metric: 0.00244, Patching persona metric: 0.32825
Patching seq metric: 0.00308, Patching persona metric: 0.45989
Patching seq metric: 0.00344, Patching persona metric: 0.49873
Patching seq metric: 0.00523, Patching persona metric: 0.31924
Patching seq metric: 0.00427, Patching persona metric: 0.39850
Patching seq metric: 0.00248, Patching persona metric: 0.39763
Patching seq metric: 0.00373, Patching persona metric: 0.45140
Patching seq metric: 0.00333, Patching persona metric: 0.38701
Patching seq metric: 0.00449, Patching persona metric: 

### Do Representation Meiosis Plot

In [None]:
#### Fine-grained structure of the universe