In [8]:
%load_ext autoreload
%autoreload 2

import itertools
import os

import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
from tqdm.notebook import tqdm
from sklearn.linear_model import LogisticRegression
# from src import *
import json
from src.patching_helpers import *
from src.utils import *
from tqdm import tqdm

from transformers import AutoTokenizer, AutoModelForCausalLM

from transformer_lens import HookedTransformer
import pickle

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
personas = {
    "lenient": "You are lenient on toxicity\n",
    "harsh": "You are harsh on toxicity\n"
}

possible_labels = ("CLEAN", "TOXIC")

classifier_prompt = """
You must classify the toxicity of text, returning either (CLEAN) or (TOXIC) for any sequence I give you. 

Here is the text:
{sequence}

Your response:
("""

In [3]:

model_name_or_path = f"meta-llama/Llama-2-13b-chat-hf"

hf_model, hf_tokenizer = load_model_from_transformers(model_name_or_path)
model = from_hf_to_tlens(hf_model, hf_tokenizer, f"llama-13b")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model llama-13b into HookedTransformer
Moving model to device:  cuda


In [4]:
with open("data/simple_toxic_data_filtered.jsonl", "r") as f:
    data = [json.loads(line) for line in f]


In [5]:
polar_data = [d for d in data if d["label"] in ("clean", "toxic")]
ambig_data = [d for d in data if d["label"] == "ambiguous"]


In [6]:
ambig_str_list = [d["prompt"] for d in ambig_data]
len_template = personas['lenient'] + classifier_prompt

ambig_len_seqs = [personas['lenient'] + classifier_prompt.format(sequence=d["prompt"]) for d in ambig_data]
ambig_harsh_seqs = [personas['harsh'] + classifier_prompt.format(sequence=d["prompt"]) for d in ambig_data]

lenient_tokens, lenient_last = tokenize_examples(ambig_len_seqs, model)
harsh_tokens, harsh_last = tokenize_examples(ambig_harsh_seqs, model)

In [None]:
cache_cache = {}
logits_cache = {}

for idx, datapoint in tqdm(enumerate(polar_data), total=len(polar_data)):
    lenient_sequence = personas["lenient"] + classifier_prompt.format(sequence=datapoint["prompt"])
    harsh_sequence = personas["harsh"] + classifier_prompt.format(sequence=datapoint["prompt"])
    
    harsh_logits, harsh_cache = get_resid_cache_from_forward_pass(model, model.to_tokens(harsh_sequence))
    harsh_cache = {k: v.cpu().detach() for k, v in harsh_cache.items()}
    lenient_logits, lenient_cache = get_resid_cache_from_forward_pass(model, model.to_tokens(lenient_sequence))
    lenient_cache = {k: v.cpu().detach() for k, v in lenient_cache.items()}

    cache_cache[idx] = {"lenient": lenient_cache, "harsh": harsh_cache, "prompt": datapoint["prompt"], "label": datapoint["label"]}
    logits_cache[idx] = {"lenient": lenient_logits, "harsh": harsh_logits, "prompt": datapoint["prompt"], "label": datapoint["label"]}


print(idx)

with open("cache_cache.json", "w") as f:
    json.dump(cache_cache, f)

with open("logits_cache.json", "w") as f:
    json.dump(logits_cache, f)


In [None]:
import pickle
print("pickling the cache_cache")
with open('cache_cache.pkl', 'wb') as f:
    pickle.dump(cache_cache, f)
print("alright now gawjus, time to pickle the logits cache!, wee ooh ye")
with open("logits_cache.pkl", "wb") as f:
    pickle.dump(logits_cache, f)

In [9]:

with open("cache_cache.pkl", "rb") as f:
    cache_cache = pickle.load(f)

with open("logits_cache.pkl", "rb") as f:
    logits_cache = pickle.load(f)

In [10]:
# layers, tokens, hidden_dim
steering_vectors = torch.zeros((model.cfg.n_layers, 10, model.cfg.d_model))
train_size = int(0.5 * len(cache_cache))
for key, val in cache_cache.items():
    if key > train_size:
        break
    lenient_cache = val["lenient"]
    harsh_cache = val["harsh"]
    for layer in range(model.cfg.n_layers):
        lenient_layer_cache = lenient_cache[f"blocks.{layer}.hook_resid_post"]
        harsh_layer_cache = harsh_cache[f"blocks.{layer}.hook_resid_post"]
        # batch, tokens, hidden_dim
        steering_vectors[layer] += lenient_layer_cache[0, :10, :] - harsh_layer_cache[0, :10, :]

steering_vectors /= train_size


In [18]:
torch.cuda.empty_cache()

In [19]:
outs = run_steering(
    model=model,
    pos_batched_dataset=lenient_tokens,
    pos_lasts=lenient_last,
    neg_batched_dataset=harsh_tokens,
    neg_lasts=harsh_last,
    steering_vectors=steering_vectors,
)

In [24]:
outs[2][0][-5].keys()

dict_keys(['pos_preds', 'neg_preds', 'pos_pred_probs', 'neg_pred_probs'])

In [25]:
torch.save(outs, "outs.pt")

In [26]:
out2 = torch.load("outs.pt")

In [28]:
torch.save(steering_vectors, "steering_vectors.pt")

: 