In [1]:
%load_ext autoreload
%autoreload 2

import itertools
import os

import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
from tqdm.notebook import tqdm
from sklearn.linear_model import LogisticRegression
# from src import *
import json
from src.patching_helpers import *
from src.utils import *
from tqdm import tqdm

from transformers import AutoTokenizer, AutoModelForCausalLM

from transformer_lens import HookedTransformer
import pickle

In [2]:
personas = {
    "lenient": "You are lenient on toxicity\n",
    "harsh": "You are harsh on toxicity\n"
}

possible_labels = ("CLEAN", "TOXIC")

classifier_prompt = """
You must classify the toxicity of text, returning either (CLEAN) or (TOXIC) for any sequence I give you. 

Here is the text:
{sequence}

Your response:
("""

In [3]:

model_name_or_path = f"meta-llama/Llama-2-13b-chat-hf"

hf_model, hf_tokenizer = load_model_from_transformers(model_name_or_path)
model = from_hf_to_tlens(hf_model, hf_tokenizer, f"llama-13b")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model llama-13b into HookedTransformer
Moving model to device:  cuda


In [4]:
with open("data/simple_toxic_data_filtered.jsonl", "r") as f:
    data = [json.loads(line) for line in f]


In [5]:
polar_data = [d for d in data if d["label"] in ("clean", "toxic")]
# ambig_data = [d for d in data if d["label"] == "ambiguous"]


## For persona steering

In [18]:
clean_str_list = [d["prompt"] for d in polar_data if d["label"] == "clean"]
toxic_str_list = [d["prompt"] for d in polar_data if d["label"] == "toxic"]

clean_lenient_seqs = [personas['lenient'] + classifier_prompt.format(sequence=d) + "C" for d in clean_str_list]
toxic_lenient_seqs = [personas['lenient'] + classifier_prompt.format(sequence=d) + "TO" for d in toxic_str_list]

clean_lenient_tokens, _ = tokenize_examples(clean_lenient_seqs, model)
toxic_lenient_tokens, _ = tokenize_examples(toxic_lenient_seqs, model)

clean_harsh_seqs = [personas['harsh'] + classifier_prompt.format(sequence=d) + "C" for d in clean_str_list]
toxic_harsh_seqs = [personas['harsh'] + classifier_prompt.format(sequence=d) + "TO" for d in toxic_str_list]

clean_harsh_tokens, _ = tokenize_examples(clean_harsh_seqs, model)
toxic_harsh_tokens, _ = tokenize_examples(toxic_harsh_seqs, model)

del _

In [23]:
0

0

In [28]:
PAD_TOKEN_ID = 0 if model.tokenizer.pad_token_id is None else model.tokenizer.pad_token_id
BATCH_SIZE = 128

def process_batch(model, sequences):
    """Processes a batch of sequences and returns the cache."""
    batch_cache = []
    for seq in sequences:
        _, cache = get_resid_cache_from_forward_pass(model, seq)
        batch_cache.append(cache)
    return batch_cache

def update_steering_vectors(model, steering_vectors, clean_cache, toxic_cache, LAST_X_TOKEN_POSITIONS, modifier):
    """Updates the steering vectors with given clean and toxic caches."""

    # layers are the same for both clean_cache and toxic_cache
    for idx, layer in enumerate(clean_cache):
        clean_layer_cache = clean_cache[layer][:, -LAST_X_TOKEN_POSITIONS:, :]
        toxic_layer_cache = toxic_cache[layer][:, -LAST_X_TOKEN_POSITIONS:, :]
        # sum from (batch_size, last_x_token_positions, model_dim) to (last_x_token_positions, model_dim)
        steering_vectors[idx] += modifier * (clean_layer_cache.sum(0) + toxic_layer_cache.sum(0))

def process_data_and_compute_steering_vectors(model, clean_lenient_seqs, toxic_lenient_seqs, clean_harsh_seqs, toxic_harsh_seqs, BATCH_SIZE):
    LAST_X_TOKEN_POSITIONS = 10
    steering_vectors = torch.zeros((model.cfg.n_layers, LAST_X_TOKEN_POSITIONS, model.cfg.d_model)).cuda()
    total_pairs = min(len(clean_lenient_seqs), len(toxic_lenient_seqs))

    # Adjusted function to process sequences in batches and update steering vectors
    def process_and_update_sequences_in_batches(model, clean_seqs, toxic_seqs, modifier):
        # for i in range(0, min(len(clean_seqs), len(toxic_seqs)), BATCH_SIZE):
        for i in tqdm(range(0, min(len(clean_seqs), len(toxic_seqs)), BATCH_SIZE), desc="Processing batches"):
            clean_batch = clean_seqs[i:i+BATCH_SIZE]
            toxic_batch = toxic_seqs[i:i+BATCH_SIZE]
            padded_clean_batch, _ = tokenize_examples(clean_batch, model, left_pad=True)
            padded_toxic_batch, _ = tokenize_examples(toxic_batch, model, left_pad=True)
            clean_batch_cache = process_batch(model, padded_clean_batch)
            toxic_batch_cache = process_batch(model, padded_toxic_batch)
            # Update steering vectors for each pair in the batch
            for clean_cache, toxic_cache in zip(clean_batch_cache, toxic_batch_cache):
                update_steering_vectors(model, steering_vectors, clean_cache, toxic_cache, LAST_X_TOKEN_POSITIONS, modifier)
    
    # Process and update steering vectors for each dataset in batches
    process_and_update_sequences_in_batches(model, clean_lenient_seqs, toxic_lenient_seqs, 1)
    process_and_update_sequences_in_batches(model, clean_harsh_seqs, toxic_harsh_seqs, -1)

    # Final adjustment to the steering vectors
    num_processed_pairs = 2 * total_pairs * BATCH_SIZE  # Adjust based on actual processed pairs
    steering_vectors /= num_processed_pairs
    
    return steering_vectors

steering_vectors = process_data_and_compute_steering_vectors(model, clean_lenient_seqs, toxic_lenient_seqs, clean_harsh_seqs, toxic_harsh_seqs, BATCH_SIZE)


Processing batches:   0%|          | 0/4 [00:00<?, ?it/s]

Processing batches:  25%|██▌       | 1/4 [01:59<05:57, 119.26s/it]


KeyboardInterrupt: 

In [7]:
PAD_TOKEN_ID = 0 if model.tokenizer.pad_token_id is None else model.tokenizer.pad_token_id
BATCH_SIZE = 2

def process_batch(model, sequences):
    """Processes a batch of sequences and returns the cache."""
    batch_cache = []
    for seq in sequences:
        _, cache = get_resid_cache_from_forward_pass(model, seq)
        batch_cache.append(cache)
    return batch_cache

def process_data_and_compute_steering_vectors(model, clean_lenient_seqs, toxic_lenient_seqs, clean_harsh_seqs, toxic_harsh_seqs, BATCH_SIZE):
    LAST_X_TOKEN_POSITIONS = 10
    steering_vectors = torch.zeros((model.cfg.n_layers, LAST_X_TOKEN_POSITIONS, model.cfg.d_model))
    total_pairs = min(len(clean_lenient_seqs), len(toxic_lenient_seqs))  # Assuming total_lenient == total_harsh
    
    # Function to process sequences in batches
    def process_sequences_in_batches(model, sequences):
        cache_accumulator = []
        for i in range(0, len(sequences), BATCH_SIZE):
            batch = sequences[i:i+BATCH_SIZE]
            padded_batch = tokenize_examples(batch, model, left_pad=True)
            batch_cache = process_batch(model, padded_batch)
            cache_accumulator.extend(batch_cache)
        return cache_accumulator
    
    # Process each dataset in batches
    clean_lenient_cache = process_sequences_in_batches(model, clean_lenient_seqs)
    print(clean_lenient_seqs)
    toxic_lenient_cache = process_sequences_in_batches(model, toxic_lenient_seqs)
    clean_harsh_cache = process_sequences_in_batches(model, clean_harsh_seqs)
    toxic_harsh_cache = process_sequences_in_batches(model, toxic_harsh_seqs)
    
    # Assuming an updated version of get_resid_cache_from_forward_pass that supports batched inputs
    for idx in tqdm(range(total_pairs)):
        print(idx, flush=True)
        
        for persona, (clean_cache, toxic_cache) in zip(['lenient', 'harsh'], [(clean_lenient_cache[idx], toxic_lenient_cache[idx]), (clean_harsh_cache[idx], toxic_harsh_cache[idx])]):
            modifier = 1 if persona == 'lenient' else -1
            for layer in range(model.cfg.n_layers):
                clean_layer_cache = clean_cache[layer].cpu().detach()[:, -LAST_X_TOKEN_POSITIONS:, :]
                toxic_layer_cache = toxic_cache[layer].cpu().detach()[:, -LAST_X_TOKEN_POSITIONS:, :]
                steering_vectors[layer] += modifier * (clean_layer_cache.sum(0) + toxic_layer_cache.sum(0))
        
        if idx > 1:
            break
    
    steering_vectors /= (2 * total_pairs * BATCH_SIZE)
    
    return steering_vectors

steering_vectors = process_data_and_compute_steering_vectors(model, clean_lenient_seqs, toxic_lenient_seqs, clean_harsh_seqs, toxic_harsh_seqs, BATCH_SIZE)


KeyboardInterrupt: 

KeyboardInterrupt: 

In [7]:


def process_data_and_compute_steering_vectors(model, clean_lenient_seqs, toxic_lenient_seqs, clean_harsh_seqs, toxic_harsh_seqs):
    LAST_X_TOKEN_POSITIONS = 10
    steering_vectors = torch.zeros((model.cfg.n_layers, LAST_X_TOKEN_POSITIONS, model.cfg.d_model))
    total_pairs = min(len(clean_lenient_seqs), len(toxic_lenient_seqs)) # Assuming total_lenient == total_harsh

    lenient_pairs = zip(clean_lenient_seqs, toxic_lenient_seqs)
    harsh_pairs = zip(clean_harsh_seqs, toxic_harsh_seqs)

    # Combined loop to process both lenient and harsh data
    for idx, ((clean_lenient_seq, toxic_lenient_seq), (clean_harsh_seq, toxic_harsh_seq)) in enumerate(tqdm(zip(lenient_pairs, harsh_pairs), total=total_pairs)):
        for persona, (clean_seq, toxic_seq) in zip(['lenient', 'harsh'], [(clean_lenient_seq, toxic_lenient_seq), (clean_harsh_seq, toxic_harsh_seq)]):
            _, clean_cache = get_resid_cache_from_forward_pass(model, model.to_tokens(clean_seq))
            _, toxic_cache = get_resid_cache_from_forward_pass(model, model.to_tokens(toxic_seq))

            for layer in range(model.cfg.n_layers):
                clean_layer_cache = clean_cache[f"blocks.{layer}.hook_resid_post"].cpu().detach()[0, -LAST_X_TOKEN_POSITIONS:, :]
                toxic_layer_cache = toxic_cache[f"blocks.{layer}.hook_resid_post"].cpu().detach()[0, -LAST_X_TOKEN_POSITIONS:, :]
                
                modifier = 1 if persona == 'lenient' else -1
                steering_vectors[layer] += modifier * (clean_layer_cache + toxic_layer_cache)

    steering_vectors /= (2 * total_pairs)  # Normalize by the total number of pairs processed, accounting for both lenient and harsh

    return steering_vectors

steering_vectors = process_data_and_compute_steering_vectors(model, clean_lenient_seqs, toxic_lenient_seqs, clean_harsh_seqs, toxic_harsh_seqs)


  0%|          | 1/506 [01:12<10:07:40, 72.20s/it]


KeyboardInterrupt: 

In [13]:
model.tokenizer.pad_token_id

0

In [31]:
PAD_TOKEN_ID = 0 if model.tokenizer.pad_token_id is None else model.tokenizer.pad_token_id


def process_batch(model, sequences):
    """Processes a batch of sequences and returns the cache."""
    batch_cache = []
    for seq in sequences:
        _, cache = get_resid_cache_from_forward_pass(model, seq)
        batch_cache.append(cache)
    return batch_cache

def process_data_and_compute_steering_vectors(model, clean_lenient_seqs, toxic_lenient_seqs, clean_harsh_seqs, toxic_harsh_seqs, batch_size=32):
    LAST_X_TOKEN_POSITIONS = 10
    steering_vectors = torch.zeros((model.cfg.n_layers, LAST_X_TOKEN_POSITIONS, model.cfg.d_model))
    total_pairs = min(len(clean_lenient_seqs), len(toxic_lenient_seqs))  # Assuming total_lenient == total_harsh
    
    # Function to process sequences in batches
    def process_sequences_in_batches(model, sequences):
        cache_accumulator = []
        for i in range(0, len(sequences), batch_size):
            batch = sequences[i:i+batch_size]
            padded_batch = tokenize_examples(batch, model, left_pad=True)
            batch_cache = process_batch(model, padded_batch)
            cache_accumulator.extend(batch_cache)
        return cache_accumulator
    
    # Process each dataset in batches
    clean_lenient_cache = process_sequences_in_batches(model, clean_lenient_seqs)
    toxic_lenient_cache = process_sequences_in_batches(model, toxic_lenient_seqs)
    clean_harsh_cache = process_sequences_in_batches(model, clean_harsh_seqs)
    toxic_harsh_cache = process_sequences_in_batches(model, toxic_harsh_seqs)
    
    # Assuming an updated version of get_resid_cache_from_forward_pass that supports batched inputs
    for idx in tqdm(range(total_pairs)):
        for persona, (clean_cache, toxic_cache) in zip(['lenient', 'harsh'], [(clean_lenient_cache[idx], toxic_lenient_cache[idx]), (clean_harsh_cache[idx], toxic_harsh_cache[idx])]):
            modifier = 1 if persona == 'lenient' else -1
            for layer in range(model.cfg.n_layers):
                clean_layer_cache = clean_cache[layer].cpu().detach()[:, -LAST_X_TOKEN_POSITIONS:, :]
                toxic_layer_cache = toxic_cache[layer].cpu().detach()[:, -LAST_X_TOKEN_POSITIONS:, :]
                steering_vectors[layer] += modifier * (clean_layer_cache.sum(0) + toxic_layer_cache.sum(0))
    
    steering_vectors /= (2 * total_pairs * batch_size)
    
    return steering_vectors

# Note: This pseudocode assumes the model and associated functions can handle batched inputs.
# You may need to adjust the tokenization, padding, and model processing details to fit your actual model and data structure.
steering_vectors = process_data_and_compute_steering_vectors(model, clean_lenient_seqs, toxic_lenient_seqs, clean_harsh_seqs, toxic_harsh_seqs)


OutOfMemoryError: CUDA out of memory. Tried to allocate 50.00 MiB. GPU 0 has a total capacity of 79.15 GiB of which 10.12 MiB is free. Process 3347595 has 79.13 GiB memory in use. Of the allocated memory 68.29 GiB is allocated by PyTorch, and 10.34 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [10]:
lenient_cache_cache = {}
lenient_logits_cache = {}

harsh_cache_cache = {}
harsh_logits_cache = {}

# for idx, datapoint in tqdm(enumerate(polar_data), total=len(polar_data)):
for idx, datapoint_pair in tqdm(enumerate(zip(clean_lenient_seqs, toxic_lenient_seqs)), total=min(len(clean_lenient_seqs), len(toxic_lenient_seqs))):

    _, clean_lenient_cache = get_resid_cache_from_forward_pass(model, model.to_tokens(datapoint_pair[0]))
    clean_lenient_cache = {k: v.cpu().detach() for k, v in clean_lenient_cache.items()}

    _, toxic_lenient_cache = get_resid_cache_from_forward_pass(model, model.to_tokens(datapoint_pair[1]))
    toxic_lenient_cache = {k: v.cpu().detach() for k, v in toxic_lenient_cache.items()}

    lenient_cache_cache[idx] = {"clean": clean_lenient_cache, "toxic": toxic_lenient_cache, "clean_seq": datapoint_pair[0], "toxic_seq": datapoint_pair[1]}

    torch.save(lenient_cache_cache, "lenient_cache_cache.pt")

    # break

for idx, datapoint_pair in tqdm(enumerate(zip(clean_harsh_seqs, toxic_harsh_seqs)), total=min(len(clean_harsh_seqs), len(toxic_harsh_seqs))):

    _, clean_harsh_cache = get_resid_cache_from_forward_pass(model, model.to_tokens(datapoint_pair[0]))
    clean_harsh_cache = {k: v.cpu().detach() for k, v in clean_harsh_cache.items()}

    _, toxic_harsh_cache = get_resid_cache_from_forward_pass(model, model.to_tokens(datapoint_pair[1]))
    toxic_harsh_cache = {k: v.cpu().detach() for k, v in toxic_harsh_cache.items()}

    harsh_cache_cache[idx] = {"clean": clean_harsh_cache, "toxic": toxic_harsh_cache, "clean_seq": datapoint_pair[0], "toxic_seq": datapoint_pair[1]}

    torch.save(harsh_cache_cache, "harsh_cache_cache.pt")

    # break

  0%|          | 0/506 [00:00<?, ?it/s]

100%|██████████| 506/506 [02:47<00:00,  3.01it/s]
100%|██████████| 506/506 [02:46<00:00,  3.03it/s]


In [15]:
torch.save(lenient_cache_cache, "lenient_cache_cache.pt")
torch.save(harsh_cache_cache, "harsh_cache_cache.pt")


In [16]:
# layers, last 10 token postions, hidden_dim
LAST_X_TOKEN_POSITIONS = 10
steering_vectors = torch.zeros((model.cfg.n_layers, LAST_X_TOKEN_POSITIONS, model.cfg.d_model))
train_size = len(lenient_cache_cache)

# add the activations for both clean and toxic appended classifications for lenient persona
for key, val in lenient_cache_cache.items():
    print(key)
    if key > train_size:
        break
    clean_cache = val["clean"]
    toxic_cache = val["toxic"]
    for layer in range(model.cfg.n_layers):
        clean_layer_cache = clean_cache[f"blocks.{layer}.hook_resid_post"]
        toxic_layer_cache = toxic_cache[f"blocks.{layer}.hook_resid_post"]
        # batch, tokens, hidden_dim
        steering_vectors[layer] += clean_layer_cache[0, -LAST_X_TOKEN_POSITIONS:, :] + toxic_layer_cache[0, -LAST_X_TOKEN_POSITIONS:, :]

    if key > 5:
        break
# subtract the activations for both clean and toxic appended classifications for harsh persona
for key, val in harsh_cache_cache.items():
    print(key)
    if key > train_size:
        break
    clean_cache = val["clean"]
    toxic_cache = val["toxic"]
    for layer in range(model.cfg.n_layers):
        clean_layer_cache = clean_cache[f"blocks.{layer}.hook_resid_post"]
        toxic_layer_cache = toxic_cache[f"blocks.{layer}.hook_resid_post"]
        # batch, tokens, hidden_dim
        steering_vectors[layer] += -clean_layer_cache[0, -LAST_X_TOKEN_POSITIONS:, :] - toxic_layer_cache[0, -LAST_X_TOKEN_POSITIONS:, :]

    if key > 5:
        break

steering_vectors /= train_size


0


1
2
3
4
5
6
0
1
2
3
4
5
6


In [23]:


def process_data_and_compute_steering_vectors(model, clean_lenient_seqs, toxic_lenient_seqs, clean_harsh_seqs, toxic_harsh_seqs):
    LAST_X_TOKEN_POSITIONS = 10
    steering_vectors = torch.zeros((model.cfg.n_layers, LAST_X_TOKEN_POSITIONS, model.cfg.d_model))
    total_pairs = min(len(clean_lenient_seqs), len(toxic_lenient_seqs)) # Assuming total_lenient == total_harsh

    lenient_pairs = zip(clean_lenient_seqs, toxic_lenient_seqs)
    harsh_pairs = zip(clean_harsh_seqs, toxic_harsh_seqs)

    # Combined loop to process both lenient and harsh data
    for idx, ((clean_lenient_seq, toxic_lenient_seq), (clean_harsh_seq, toxic_harsh_seq)) in enumerate(tqdm(zip(lenient_pairs, harsh_pairs), total=total_pairs)):
        for persona, (clean_seq, toxic_seq) in zip(['lenient', 'harsh'], [(clean_lenient_seq, toxic_lenient_seq), (clean_harsh_seq, toxic_harsh_seq)]):
            _, clean_cache = get_resid_cache_from_forward_pass(model, model.to_tokens(clean_seq))
            _, toxic_cache = get_resid_cache_from_forward_pass(model, model.to_tokens(toxic_seq))

            for layer in range(model.cfg.n_layers):
                clean_layer_cache = clean_cache[f"blocks.{layer}.hook_resid_post"].cpu().detach()[0, -LAST_X_TOKEN_POSITIONS:, :]
                toxic_layer_cache = toxic_cache[f"blocks.{layer}.hook_resid_post"].cpu().detach()[0, -LAST_X_TOKEN_POSITIONS:, :]
                
                modifier = 1 if persona == 'lenient' else -1
                steering_vectors[layer] += modifier * (clean_layer_cache + toxic_layer_cache)

    steering_vectors /= (2 * total_pairs)  # Normalize by the total number of pairs processed, accounting for both lenient and harsh

    return steering_vectors

# Example usage:
# steering_vectors = process_data_and_compute_steering_vectors(model, clean_lenient_seqs, toxic_lenient_seqs, clean_harsh_seqs, toxic_harsh_seqs)


In [22]:
steering_vectors.shape

torch.Size([40, 10, 5120])

In [17]:
steering_vectors.shape

torch.Size([40, 10, 5120])

In [19]:
steering_vectors.numel()*2

4096000

In [21]:
import sys
sys.getsizeof(steering_vectors.storage())

  sys.getsizeof(steering_vectors.storage())


8192048

In [25]:
seq = toxic_lenient_seqs[0]
print(seq)
print(model.to_tokens(seq)[0][-5:])
seq2 = toxic_lenient_seqs[0] + "L"
print("\n\n----\n\n")
print(seq2)
print(model.to_tokens(seq2)[0][-5:])
# print(model.to_tokens(clean_lenient_seqs[0] + "L"))

Keyword arguments {'add_special_tokens': False} not recognized.
Keyword arguments {'add_special_tokens': False} not recognized.


You are lenient on toxicity

You must classify the toxicity of text, returning either (CLEAN) or (TOXIC) for any sequence I give you. 

Here is the text:
The movie sucks because it is stupid

Your response:
(TO
tensor([ 2933, 29901,    13, 29898,  4986], device='cuda:0')


----


You are lenient on toxicity

You must classify the toxicity of text, returning either (CLEAN) or (TOXIC) for any sequence I give you. 

Here is the text:
The movie sucks because it is stupid

Your response:
(TOL
tensor([29901,    13, 29898,  4986, 29931], device='cuda:0')


## For sequence steering

In [None]:
clean_str_list = [d["prompt"] for d in polar_data if d["label"] == "clean"]
toxic_str_list = [d["prompt"] for d in polar_data if d["label"] == "toxic"]

clean_lenient_seqs = [personas['lenient'] + classifier_prompt.format(sequence=d) for d in clean_str_list]
toxic_lenient_seqs = [personas['lenient'] + classifier_prompt.format(sequence=d) for d in toxic_str_list]

clean_lenient_tokens, clean_lenient_last = tokenize_examples(clean_lenient_seqs, model)
toxic_lenient_tokens, toxic_lenient_last = tokenize_examples(toxic_lenient_seqs, model)

clean_harsh_seqs = [personas['harsh'] + classifier_prompt.format(sequence=d) for d in clean_str_list]
toxic_harsh_seqs = [personas['harsh'] + classifier_prompt.format(sequence=d) for d in toxic_str_list]

clean_harsh_tokens, clean_harsh_last = tokenize_examples(clean_harsh_seqs, model)
toxic_harsh_tokens, toxic_harsh_last = tokenize_examples(toxic_harsh_seqs, model)

In [None]:
lenient_cache_cache = {}
lenient_logits_cache = {}

harsh_cache_cache = {}
harsh_logits_cache = {}

# for idx, datapoint in tqdm(enumerate(polar_data), total=len(polar_data)):
for idx, datapoint_pair in tqdm(enumerate(zip(clean_lenient_seqs, toxic_lenient_seqs)), total=min(len(clean_lenient_seqs), len(toxic_lenient_seqs))):


    clean_lenient_logits, clean_lenient_cache = get_resid_cache_from_forward_pass(model, model.to_tokens(datapoint_pair[0]))
    clean_lenient_cache = {k: v.cpu().detach() for k, v in clean_lenient_cache.items()}

    toxic_lenient_logits, toxic_lenient_cache = get_resid_cache_from_forward_pass(model, model.to_tokens(datapoint_pair[1]))
    toxic_lenient_cache = {k: v.cpu().detach() for k, v in toxic_lenient_cache.items()}

    lenient_cache_cache[idx] = {"clean": clean_lenient_cache, "toxic": toxic_lenient_cache, "clean_seq": datapoint_pair[0], "toxic_seq": datapoint_pair[1]}

    torch.save(lenient_cache_cache, "lenient_cache_cache.pt")

    break

for idx, datapoint_pair in tqdm(enumerate(zip(clean_harsh_seqs, toxic_harsh_seqs)), total=min)(len(clean_harsh_seqs), len(toxic_harsh_seqs)):

    clean_harsh_logits, clean_harsh_cache = get_resid_cache_from_forward_pass(model, model.to_tokens(datapoint_pair[0]))
    clean_harsh_cache = {k: v.cpu().detach() for k, v in clean_harsh_cache.items()}

    toxic_harsh_logits, toxic_harsh_cache = get_resid_cache_from_forward_pass(model, model.to_tokens(datapoint_pair[1]))
    toxic_harsh_cache = {k: v.cpu().detach() for k, v in toxic_harsh_cache.items()}

    harsh_cache_cache[idx] = {"clean": clean_harsh_cache, "toxic": toxic_harsh_cache, "clean_seq": datapoint_pair[0], "toxic_seq": datapoint_pair[1]}

    torch.save(harsh_cache_cache, "harsh_cache_cache.pt")

    break

# with open("cache_cache.json", "w") as f:
#     json.dump(cache_cache, f)

# with open("logits_cache.json", "w") as f:
#     json.dump(logits_cache, f)


  0%|          | 0/506 [00:00<?, ?it/s]

  0%|          | 0/506 [00:00<?, ?it/s]



In [13]:
harsh_cache_cache

{0: {'clean': {'blocks.0.hook_resid_post': tensor([[[-0.0215, -0.0043,  0.0376,  ...,  0.0098,  0.0101,  0.0237],
            [-0.0052,  0.0145,  0.0154,  ...,  0.0131,  0.0096,  0.0038],
            [ 0.0435,  0.0141, -0.0200,  ..., -0.0210, -0.0297, -0.0129],
            ...,
            [-0.0121,  0.0058,  0.0073,  ...,  0.0187, -0.0084, -0.0133],
            [ 0.0031,  0.0400, -0.0024,  ..., -0.0237, -0.0085, -0.0291],
            [-0.0332,  0.0094, -0.0016,  ..., -0.0193, -0.0157, -0.0166]]],
          dtype=torch.bfloat16),
   'blocks.1.hook_resid_post': tensor([[[-0.0369, -0.0051,  0.0537,  ..., -0.0082,  0.0193,  0.0201],
            [ 0.0003,  0.0496,  0.0034,  ..., -0.0203,  0.0165,  0.0056],
            [ 0.0659,  0.0461, -0.0435,  ..., -0.0654, -0.0305, -0.0028],
            ...,
            [-0.0228,  0.0164, -0.0084,  ...,  0.0132, -0.0094, -0.0337],
            [-0.0087,  0.0322, -0.0327,  ..., -0.0149, -0.0072, -0.0386],
            [-0.0232, -0.0203, -0.0425,  ..., -0.

In [None]:
import pickle
print("pickling the cache_cache")
with open('cache_cache.pkl', 'wb') as f:
    pickle.dump(cache_cache, f)
print("alright now gawjus, time to pickle the logits cache!, wee ooh ye")
with open("logits_cache.pkl", "wb") as f:
    pickle.dump(logits_cache, f)

In [9]:

with open("cache_cache.pkl", "rb") as f:
    cache_cache = pickle.load(f)

with open("logits_cache.pkl", "rb") as f:
    logits_cache = pickle.load(f)

In [10]:
# layers, tokens, hidden_dim
steering_vectors = torch.zeros((model.cfg.n_layers, 10, model.cfg.d_model))
train_size = int(0.5 * len(cache_cache))
for key, val in cache_cache.items():
    if key > train_size:
        break
    lenient_cache = val["lenient"]
    harsh_cache = val["harsh"]
    for layer in range(model.cfg.n_layers):
        lenient_layer_cache = lenient_cache[f"blocks.{layer}.hook_resid_post"]
        harsh_layer_cache = harsh_cache[f"blocks.{layer}.hook_resid_post"]
        # batch, tokens, hidden_dim
        steering_vectors[layer] += lenient_layer_cache[0, :10, :] - harsh_layer_cache[0, :10, :]

steering_vectors /= train_size


In [18]:
torch.cuda.empty_cache()

In [19]:
outs = run_steering(
    model=model,
    pos_batched_dataset=lenient_tokens,
    pos_lasts=lenient_last,
    neg_batched_dataset=harsh_tokens,
    neg_lasts=harsh_last,
    steering_vectors=steering_vectors,
)

In [24]:
outs[2][0][-5].keys()

dict_keys(['pos_preds', 'neg_preds', 'pos_pred_probs', 'neg_pred_probs'])

In [25]:
torch.save(outs, "outs.pt")

In [26]:
out = torch.load("steering_results.pt")

In [28]:
torch.save(steering_vectors, "steering_vectors.pt")

: 