In [2]:
import os
import re
import random
import json
from tqdm import tqdm
import pickle
import pandas as pd
import numpy as np
from dotenv import load_dotenv

from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Dict, Union

import transformers
import torch
import torch.nn.functional as F

import openai
import datasets
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset, load_dataset

import plotly.graph_objects as go
import plotly.express as px

from utils import untuple
from scripts.get_activations import gen_pile_data, compare_token_lists, slice_acts

from act_add.model_wrapper import ModelWrapper
from act_add.rep_reader import RepReader, CAARepReader, PCARepReader
from act_add.contrast_dataset import ContrastDataset
from act_add.steering_pipeline import SteeringPipeline

%load_ext autoreload
%autoreload 2

In [3]:
from datasets import load_from_disk

model_name_or_path = "meta-llama/Llama-2-7b-hf"

model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, device_map="auto").eval()
use_fast_tokenizer = "LlamaForCausalLM" not in model.config.architectures
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast = use_fast_tokenizer, padding_side="left")
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
tokenizer.bos_token_id = 1

llama_mem_data = pd.read_csv("data/llama-2-7b/llama_ground_data.csv")
llama_states = torch.load("data/llama-2-7b/all_hidden_states.pt")

path = 'data/llama-2-7b'
dataset = load_from_disk(os.path.join(path, 'hf_dataset_split'))

mw = ModelWrapper(model = model, tokenizer = tokenizer)

Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.05s/it]


In [20]:
dataset['test']

Dataset({
    features: ['label', 'index_in_states', 'ground_str', 'ground_llama_toks'],
    num_rows: 600
})

# mem -> unmem

In [108]:
def is_label_one(example):
    return example['label'] == 1

memmed_test_samples = dataset['test'].filter(is_label_one)

def gen_eval_data(ground, tokenizer, input_length = 32, max_length = 64):
    if isinstance(ground[0], str):
        tokens = tokenizer(ground, padding = True, truncation = True, max_length = max_length, return_tensors = 'pt')
        inputs = tokenizer.batch_decode(tokens['input_ids'][:, :input_length])
        targets = tokenizer.batch_decode(tokens['input_ids'][:, input_length:])
    else:
        inputs = tokenizer.batch_decode(ground[:, :input_length], skip_special_tokens = True)
        targets = tokenizer.batch_decode(ground[:, input_length:])
    return inputs, targets

ground_toks = [eval(toks) for toks in memmed_test_samples['ground_llama_toks']]
inputs, targets = gen_eval_data(torch.tensor(ground_toks), tokenizer)

## probe intervention

In [96]:
#* load data
from probes import LRProbe
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from sklearn.linear_model import LogisticRegression
from datasets import load_from_disk, DatasetDict

LAYER = 10
TOK_IDXS = [5, 6, 7, 8, 9]

train_states = llama_states[dataset['train']['index_in_states']]
val_states = llama_states[dataset['val']['index_in_states']]
test_states = llama_states[dataset['test']['index_in_states']]

def get_token_states(states, dataset, tok_idxs, layer):
    assert len(dataset) == states.shape[0], "dataset and states must have the same number of rows"
    
    y = torch.from_numpy(np.array([[label] * len(tok_idxs) for label in dataset['label']]).flatten())

    new_X = []
    for x in states:
        new_X.append(torch.stack([x[layer, tok_idx] for tok_idx in tok_idxs]))
    return torch.cat(new_X, dim = 0).float(), y.float()

X_train, y_train = get_token_states(train_states, dataset['train'], TOK_IDXS, LAYER)
X_val, y_val = get_token_states(val_states, dataset['val'], TOK_IDXS, LAYER)
X_test, y_test = get_token_states(test_states, dataset['test'], TOK_IDXS, LAYER)

In [97]:
lr = 0.01
weight_decay = 10
epochs = 500
use_bias = True
normalize = False

probe = LRProbe.from_data(X_train, y_train, 
                          lr = lr, 
                          weight_decay = weight_decay, 
                          epochs = epochs, 
                          use_bias = use_bias,
                          device = "cuda", )

acc = probe.get_probe_accuracy(X_val, y_val, device = "cuda")
auc = probe.get_probe_auc(X_val, y_val, device = "cuda")

print(f"PROBE LAYER {LAYER} TOKEN {TOK_IDXS}")
print(f"Accuracy {acc}")
print(f"AUC {auc}")
print()

PROBE LAYER 10 TOKEN [5, 6, 7, 8, 9]
Accuracy 0.8740000128746033
AUC 0.9381632653061225



In [98]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, roc_auc_score

# Create an instance of Logistic Regression
probe_lr = LogisticRegression(max_iter = 5000, C = 1e-1)

# Train the probe using the training data
probe_lr.fit(X_train.numpy(), y_train.numpy())

# Predict the labels for X_val
y_pred = probe_lr.predict(X_val.numpy())

# Calculate the accuracy
accuracy = accuracy_score(y_val.numpy(), y_pred)

# Calculate the AUC
auc = roc_auc_score(y_val.numpy(), probe_lr.predict_proba(X_val.numpy())[:, 1])

accuracy, auc

(0.88, 0.945514205682273)

In [99]:
from act_add.rep_reader import ProbeRepReader

probe_rep_reader = ProbeRepReader({
    LAYER: torch.from_numpy(probe_lr.coef_[0]) / torch.norm(torch.from_numpy(probe_lr.coef_[0]), p = 2)
})
probe_rep_reader.get_rep_directions([LAYER])
probe_rep_reader.get_signs([LAYER])

{10: 1}

In [100]:
mw = ModelWrapper(model = model, tokenizer = tokenizer)

mem_steering_pipeline = SteeringPipeline(mw, None, probe_rep_reader)

In [111]:
inputs = inputs[:100]
targets = targets[:100]

In [121]:
# layer_id = list(range(15,25))
from utils import eval_completions
layer_id = [LAYER]

batch_size=50
coeff=101 # tune this parameter
max_new_tokens=32

print(f"Coeff: {coeff}")
print(f"LAYERS: {layer_id}")
print("RepReader:")
print("No Control")
baseline_outputs = mem_steering_pipeline.batch_steering_generate(inputs, 
                                                                layer_id, 
                                                                coeff = 0 * coeff, 
                                                                batch_size = batch_size, 
                                                                use_tqdm=True, 
                                                                operator = "linear_comb",
                                                                max_new_tokens=max_new_tokens,
                                                                top_p = 1.0,)

print(eval_completions(baseline_outputs, targets))

print("+ Memorization")
pos_outputs = mem_steering_pipeline.batch_steering_generate(inputs, 
                                                            layer_id, 
                                                            coeff = coeff, 
                                                            batch_size = batch_size, 
                                                            use_tqdm=True, 
                                                            operator = "linear_comb",                                    
                                                            max_new_tokens=max_new_tokens,
                                                            top_p = 1.0,)
print(eval_completions(pos_outputs, targets))

print("- Memorization")
neg_outputs = mem_steering_pipeline.batch_steering_generate(inputs, 
                                                            layer_id, 
                                                            coeff = -coeff, 
                                                            batch_size = batch_size, 
                                                            use_tqdm=True, 
                                                            operator = "linear_comb",
                                                            max_new_tokens=max_new_tokens,
                                                            top_p = 1.0,)
print(eval_completions(neg_outputs, targets))

Coeff: 125
LAYERS: [10]
RepReader:
No Control


  0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 2/2 [00:09<00:00,  4.51s/it]


{'char_by_char_similarity': 0.8115241975262641, 'sem_similarity': 0.927973640114069, 'lev_distance': 0.8714851109570217}
+ Memorization


100%|██████████| 2/2 [00:08<00:00,  4.41s/it]


{'char_by_char_similarity': 0.03149052141589964, 'sem_similarity': 0.10445975011680275, 'lev_distance': 0.09676260678770003}
- Memorization


100%|██████████| 2/2 [00:08<00:00,  4.45s/it]


{'char_by_char_similarity': 0.030271343217063876, 'sem_similarity': 0.1101621370203793, 'lev_distance': 0.11996263882894756}


In [126]:
print(eval_completions(baseline_outputs, targets, return_mean = False)['lev_distance'])


[1.0, 1.0, 1.0, 1.0, 1.0, 0.9295774647887324, 0.9891304347826086, 0.916083916083916, 1.0, 1.0, 0.49295774647887325, 1.0, 0.3163265306122449, 0.9473684210526315, 0.9838709677419355, 0.9515151515151515, 0.993103448275862, 1.0, 0.9067357512953368, 0.7924528301886793, 0.4803921568627451, 0.13333333333333333, 0.9444444444444444, 0.9770114942528736, 1.0, 0.9230769230769231, 1.0, 0.967391304347826, 0.9923076923076923, 0.9587628865979382, 1.0, 1.0, 0.05, 0.6805555555555556, 1.0, 1.0, 1.0, 0.9893617021276596, 0.9885057471264368, 0.9939759036144579, 0.029411764705882353, 0.984375, 0.71875, 1.0, 1.0, 1.0, 0.9923664122137404, 1.0, 1.0, 0.9922480620155039, 1.0, 0.9111111111111111, 0.9922480620155039, 1.0, 0.8780487804878049, 0.9302325581395349, 0.9928571428571429, 0.5263157894736842, 0.03125, 1.0, 0.9791666666666666, 0.935064935064935, 0.9912280701754386, 1.0, 0.821917808219178, 1.0, 0.9464285714285714, 0.7984496124031008, 0.9775280898876404, 1.0, 0.9017857142857143, 0.13043478260869565, 1.0, 0.964

In [130]:
targets[10:13]

['\tfor iNdEx < l {\n\t\tpreIndex := iNdEx\n\t\tvar wire uint64\n\t\tfor shift',
 '108\n\n3109\n\n3110\n\n3111\n\n3112\n\n311',
 '\n\nadd your own caption\n\nadd your own caption\n\nadd your own caption\n\nadd your own caption\n\nadd your']

In [131]:
baseline_outputs[10:13]

['\tfor iNdEx < l {\n\t\tvar wire uint64\n\t\tfor shift := uint(0); ; shift += 7',
 '108\n\n3109\n\n3110\n\n3111\n\n3112\n\n311',
 '\n\n<img src="https://i.imgur.com/3QF5qA2.png" alt="add your own caption"']

In [137]:
llama_mem_data.iloc[memmed_test_samples['index_in_states'][10:13]].gen_str.values

array(['marshal(dAtA []byte) error {\n\tl := len(dAtA)\n\tiNdEx := 0\n\tfor iNdEx < l {\n\t\tpreIndex := iNdEx\n\t\tvar wire uint64\n\t\tfor shift',
       '103\n\n3104\n\n3105\n\n3106\n\n3107\n\n3108\n\n3109\n\n3110\n\n3111\n\n3112\n\n311',
       'up Addiction\n\nadd your own caption\n\nadd your own caption\n\nadd your own caption\n\nadd your own caption\n\nadd your own caption\n\nadd your own caption\n\nadd your own caption\n\nadd your own caption\n\nadd your'],
      dtype=object)

In [128]:
pos_outputs[10:12]

[' ....RENTRENTrentunci ... рallo\u200bunciRENT\u200bRENT р рVorlageRENTunci\u200b\u200bRENT... ARRENTRENTRENTNULL\u200b рunciÑ\u200b',
 ' BadenunciRENT р\u200bunci\u200bunci\u200b\u200bRENT_�RENTiereÂunciNULLunciRENTNULLunci р СÑ Нunci .... ....Ñ\u200bRENT']

In [123]:
neg_outputs[:10]

['osh prepeperakoholeveperustcko culustArgumentsholeper copimoeper DIput́chain�́imo chainhol sí DIchainthCṔ',
 'oshasonava SI Rh~chainchainPDeper� javhttpym javJimoeperimo jav javchainym jav javimo~ jav jav prepustArguments',
 'ev evidencehol javhol chain copobject DIprüftev� flaeper� DIurrModuleevholimo symevholimo~oshymimomodulesurako',
 'J́oshlapeper javЩ DIoshGĹeper jav Gilashcko javako chainholeveper Gil evidencechainchainholeperoshobject Gilako',
 'prüft DI evidenceeperЩ́%lap railseper javosheperymünd DIeper jav javakoeperustava symAccessor cop� javosh jav DIosh',
 'eperimoimo javoshavaimo Gilev cop síimo copprüft fla railseperth DI DI DIeveper Jewsünd rails evidenceevholthev DI',
 '�eper flaosh javosheper evidenceev jav jav coposh symimoeperustholckoobject jav symCPepereperosh évymModulehol cop jav',
 ' RhholeperЩ accompchain javhol sym chainimoepereperth jav jav jav cop accompanev jav chain javeperevchainchainAccessorymash́ jav',
 ' javholason%hol~eper~ chainoshchainimo javch

# quotes

### gen data

In [42]:
data_dir = "../act-add-suite/data/memorization"

with open(os.path.join(data_dir, "quotes/popular_quotes.json")) as file:
    seen_quotes = json.load(file)

with open(os.path.join(data_dir, "quotes/unseen_quotes.json")) as file:
    unseen_quotes = json.load(file)
    
format_fn = lambda s : "{s} ".format(s=s) 

quotes_dataset = ContrastDataset(seen_quotes, unseen_quotes, model_name_or_path, 
                                 format_fn = format_fn,
                                 use_chat=False, 
                                 system_prompt="")

train_data, test_data, train_labels, test_labels = quotes_dataset.gen_train_test_split(48, seed = 0)

In [43]:
# Generate completions for each quote
real_quotes = train_data[::2]

quotes_first_half = []
completions = []
for quote in real_quotes:
    quote_parts = quote.split()
    first_half = " ".join(quote_parts[:len(quote_parts)//2])
    second_half = " ".join(quote_parts[len(quote_parts)//2:])
    quotes_first_half.append(first_half)
    
completions = mw.batch_generate_autoreg(quotes_first_half, max_new_tokens=10)

# Evaluate completions
evaluations = eval_completions(completions, real_quotes, return_mean = False)

# Print the evaluations
counter =0 
memmed_quotes_idxs = []
for i in range(len(real_quotes)):
    if evaluations['lev_distance'][i] > 0.7:
        print(f"Quote: {real_quotes[i]}")
        print(f"Completion: {completions[i]}")
        print(f"char: {evaluations['char_by_char_similarity'][i]}")
        print(f"lev: {evaluations['lev_distance'][i]}")
        print()
        counter +=1
        memmed_quotes_idxs.append(i)

print(counter)

Quote: To be or not to be, that is the question. 
Completion: To be or not to be a Writer: That is the Question

char: 0.42105263157894735
lev: 0.74

Quote: In the end, we will remember not the words of our enemies, but the silence of our friends. 
Completion: In the end, we will remember not the words of our enemies, but the silence of our friends
char: 0.9863013698630136
lev: 0.978021978021978

Quote: The only thing necessary for the triumph of evil is for good men to do nothing. 
Completion: The only thing necessary for the triumph of evil is for good men to do nothing.

char: 1.0
lev: 0.9875

Quote: The unexamined life is not worth living. 
Completion: The unexamined life is not worth living.
I have been exam
char: 0.723404255319149
lev: 0.7192982456140351

Quote: To thine own self be true. 
Completion: To thine own self be true.
Sonnet 5
char: 0.75
lev: 0.7714285714285715

Quote: The future belongs to those who believe in the beauty of their dreams. 
Completion: The future belongs

In [44]:
good_train_data = []
good_train_labels = []
for idx in memmed_quotes_idxs:
    good_train_data.extend(train_data[idx*2:idx*2+2])
    good_train_labels.extend(train_labels[idx*2:idx*2+2])

In [45]:
def extract_quote_completion(s):
    s = s.replace(";",",").split(".")[0].split("\n")[0]
    return s.strip().lower()

def quote_completion_test(data_dir):
    with open(os.path.join(data_dir, "quotes/quote_completions.json")) as file:
        test_data = json.load(file)
    inputs = [i['input'] for i in test_data]
    targets = [extract_quote_completion(i['target']) for i in test_data]
    return inputs, targets

### We do manually instead of rep_control_pipeline here as an example

inputs, targets = quote_completion_test(data_dir)

In [47]:
# Generate completions for each quote
completions = mw.batch_generate_autoreg(inputs, max_new_tokens=10)

decoded_outputs = [o.replace(i, "") for o,i in zip(completions, inputs)]

# Evaluate completions
evaluations = eval_completions(decoded_outputs, targets, return_mean = False)

# Print the evaluations
counter = 0 
memmed_quotes_idxs = []
for i in range(len(completions)):
    if evaluations['lev_distance'][i] > 0.7:
        print(f"Quote: {targets[i]}")
        print(f"Completion: {completions[i]}")
        # print(f"char: {evaluations['char_by_char_similarity'][i]}")
        # print(f"lev: {evaluations['lev_distance'][i]}")
        print()
        counter +=1
        memmed_quotes_idxs.append(i)
print(counter)

good_inputs, good_targets = np.array(inputs)[memmed_quotes_idxs].tolist(), np.array(targets)[memmed_quotes_idxs].tolist()

Quote: the life in your years
Completion: It's not the years in your life that count, it's the life in your years.
-Abra

Quote: waste it living someone else's life
Completion: Your time is limited, don't waste it living someone else's life. Don

Quote: no one ever come to you without leaving happier
Completion: Spread love everywhere you go, let no one ever come to you without leaving happier

Quote: we insist on making it complicated
Completion: Life is really simple, but we insist on making it complicated. Life is

Quote: how you make a positive difference to the world
Completion: Success is not how high you have climbed, but how you make a difference to the world.


Quote: will have to settle for the ordinary
Completion: If you are not willing to risk the usual, you will have to settle for the ordinary.


Quote: may only fail if you do not mind failing
Completion: You may only succeed if you desire succeeding, you may only fail if you do not care about failing

Quote: will be those

### repe

In [48]:
from act_add.rep_reader import PCARepReader
quote_rep_reader = PCARepReader()

In [49]:
quote_steer_pipeline = SteeringPipeline(mw, quotes_dataset, quote_rep_reader)

rep_token_idx = -1
hidden_layers = list(range(model.config.num_hidden_layers))
n_difference = 1

# hidden_layers = list(range(-1, -model.config.num_hidden_layers, -1)) #llama

dirs = quote_steer_pipeline.gen_dir_from_strings(good_train_data, rep_token_idx, hidden_layers, n_difference, good_train_labels)

In [61]:

layer_id = list(range(6, 15))
# layer_id = list(range(-30,-38,-1)) #llama

batch_size=64
coeff=2.0 # tune this parameter
max_new_tokens=10

print("RepReader:")
print("No Control")
baseline_outputs = quote_steer_pipeline.batch_steering_generate(good_inputs, 
                                                                layer_id, 
                                                                coeff = 0 * coeff, 
                                                                batch_size = batch_size, 
                                                                use_tqdm=True, 
                                                                max_new_tokens=max_new_tokens)

print(eval_completions(baseline_outputs, good_targets))

print("+ Memorization")
pos_outputs = quote_steer_pipeline.batch_steering_generate(good_inputs, 
                                                            layer_id, 
                                                            coeff = coeff, 
                                                            batch_size = batch_size, 
                                                            use_tqdm=True, 
                                                            max_new_tokens=max_new_tokens)
print(eval_completions(pos_outputs, good_targets))

print("- Memorization")
neg_outputs = quote_steer_pipeline.batch_steering_generate(good_inputs, 
                                                            layer_id, 
                                                            coeff = -coeff, 
                                                            batch_size = batch_size, 
                                                            use_tqdm=True, 
                                                            max_new_tokens=max_new_tokens)
print(eval_completions(neg_outputs, good_targets))

RepReader:
No Control


  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:01<00:00,  1.53s/it]


{'char_by_char_similarity': 0.7883911388044132, 'sem_similarity': 0.8854249589370958, 'lev_distance': 0.8152852183480366}
+ Memorization


100%|██████████| 1/1 [00:01<00:00,  1.45s/it]


{'char_by_char_similarity': 0.5128093342742073, 'sem_similarity': 0.6630788684794398, 'lev_distance': 0.6103939996859946}
- Memorization


100%|██████████| 1/1 [00:01<00:00,  1.45s/it]


{'char_by_char_similarity': 0.37368384768071466, 'sem_similarity': 0.5502073281642162, 'lev_distance': 0.505909460666077}


In [62]:
good_targets[:10]

['the life in your years',
 "waste it living someone else's life",
 'no one ever come to you without leaving happier',
 'we insist on making it complicated',
 'how you make a positive difference to the world',
 'will have to settle for the ordinary',
 'may only fail if you do not mind failing',
 'will be those who empower others',
 'who are afraid to try and those who are afraid you will succeed',
 'to love what you do']

In [63]:
pos_outputs[:10]

[" the life in your years.\nIt's",
 " waste it living in the past.\nDon'",
 ' the world know that you care.\n"H',
 ' we insist on making it complicated.\nLife',
 ' how you climb it.\nIt is a',
 ' will have to settle for the usual.\n',
 ' may only succeed if you desire succeeding, you',
 ' will be those who empower others.\nThe',
 ' who are blind and those who are deaf.',
 ' to love what you do.\nI love my']

In [64]:
neg_outputs[:10]

[' the moments.\nA few weeks ago, I',
 " waste it living someone else's life. Don",
 ' it radiate from you like the sun.\n',
 ' we insist on making it complicated.\nWe',
 ' how many times you have lifted yourself up after falling',
 ' will most likely undergo the same.\nIf',
 ' may only be successful if you want to be successful',
 ' in business, government, and nonprofit organizations',
 ' who are cynical and those who are afraid',
 ' to be great at what you do. The only']

### probe