In [1]:
import numpy as np 
import torch 
import datasets
from tqdm import tqdm 
import pickle 
import pandas as pd
from datasets import load_dataset
from typing import List, Dict, Any, Tuple, Union, Optional, Callable
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForCausalLM
import sys
sys.path.append('../') 

from white_box.model_wrapper import ModelWrapper
from white_box.utils import gen_pile_data 
from white_box.dataset import clean_data 

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "EleutherAI/pythia-70m"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name)
mw = ModelWrapper(model, tokenizer)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


### get prompts

In [3]:
pile_data = gen_pile_data(N = 1000, tokenizer = tokenizer, min_n_toks = 64) #strings 
pythia_evals_data = load_dataset('EleutherAI/pythia-memorized-evals')['duped.70m'][:1000]['tokens'] #tokens

In [6]:
cleaned_pile_toks, dirty_pile_dict = clean_data(pile_data, tokenizer, return_toks = True)
cleaned_pythia_toks, dirty_pythia_dict = clean_data(pythia_evals_data, tokenizer, return_toks = True)

In [7]:
len(cleaned_pile_toks), len(cleaned_pythia_toks), {k: len(v) for k,v in dirty_pythia_dict.items()}

(953,
 655,
 {'increment': 43,
  'repeated_majority': 10,
  'is_repeated_string': 25,
  'repeats_subseq': 214})

In [15]:
cleaned_pile_toks, cleaned_pythia_toks = cleaned_pile_toks[:500], cleaned_pythia_toks[:500]

pickle.dump(cleaned_pile_toks, open('../data/pythia-70m/mem/pile.pkl', 'wb'))
pickle.dump(cleaned_pythia_toks, open('../data/pythia-70m/mem/pythia_evals.pkl', 'wb'))

### after getting activations, create dataset

In [3]:
data_path = '../data/pythia-70m/mem/'
file_spec = "pile_"
pile_metadata = pd.read_csv(data_path + f'{file_spec}metadata.csv')
pythia_evals_metadata = pd.read_csv(data_path + f'pythia_evals_metadata.csv')

pile_states = torch.load(data_path + f'{file_spec}all_hidden_states.pt')
pythia_states = torch.load(data_path + f'pythia_evals_all_hidden_states.pt')

In [42]:
pile_metadata.head()

Unnamed: 0,prompt_str,prompt_toks,gen_str,gen_toks,tok_by_tok_sim,char_by_char_sim,lev_distance,source
0,---------------------- Forwarded by Benjamin R...,"[23130, 25956, 264, 407, 19046, 22456, 16, 237...",---------------------- Forwarded by Benjamin R...,"[23130, 25956, 264, 407, 19046, 22456, 16, 237...",0.59375,0.672316,0.78,pile_
1,Menu\n\nWhat’s your WHY?\n\nWhat’s your WHY?! ...,"[14324, 187, 187, 1276, 457, 84, 634, 7245, 58...",Menu\n\nWhat’s your WHY?\n\nWhat’s your WHY?! ...,"[14324, 187, 187, 1276, 457, 84, 634, 7245, 58...",0.5,0.414508,0.566667,pile_
2,Free peritoneal tumour cells are an independen...,"[14344, 35948, 18258, 1341, 403, 271, 3907, 18...",Free peritoneal tumour cells are an independen...,"[14344, 35948, 18258, 1341, 403, 271, 3907, 18...",0.53125,0.595611,0.659517,pile_
3,Q:\n\nNot able to find element by partial link...,"[50, 27, 187, 187, 3650, 2104, 281, 1089, 3284...",Q:\n\nNot able to find element by partial link...,"[50, 27, 187, 187, 3650, 2104, 281, 1089, 3284...",0.515625,0.537572,0.598131,pile_
4,Q:\n\nhighcharts redraw and reflow not working...,"[50, 27, 187, 187, 8656, 45945, 2502, 2040, 28...",Q:\n\nhighcharts redraw and reflow not working...,"[50, 27, 187, 187, 8656, 45945, 2502, 2040, 28...",0.5,0.509709,0.583012,pile_


In [4]:
@dataclass
class PromptDist:
    idxs : List[int]
    path_to_states : str
    path_to_metadata : str

def create_prompt_dist_from_metadata(metadata : pd.DataFrame, 
                                     path_to_states : str, 
                                     path_to_metadata : str,
                                     condition : Callable,
                                     col_name : str = 'tok_by_tok_sim',
                                     ) -> PromptDist:
    idxs = metadata[condition(metadata[col_name])].index.tolist()
    return PromptDist(idxs = idxs, path_to_states = path_to_states, path_to_metadata = path_to_metadata)

def less_than_60_percent(x):
    return x < 0.6

def equal_one(x):
    return x == 1

neg_pythia_evals = create_prompt_dist_from_metadata(pythia_evals_metadata, 
                                                    data_path + 'pythia_evals_all_hidden_states.pt', 
                                                    data_path + 'pythia_evals_metadata.csv', less_than_60_percent)
pos_pythia_evals = create_prompt_dist_from_metadata(pythia_evals_metadata,
                                                    data_path + 'pythia_evals_all_hidden_states.pt', 
                                                    data_path + 'pythia_evals_metadata.csv', equal_one)

neg_pile = create_prompt_dist_from_metadata(pile_metadata,
                                             data_path + 'pile_all_hidden_states.pt',
                                             data_path + 'pile_metadata.csv', less_than_60_percent)

pos_pile = create_prompt_dist_from_metadata(pile_metadata,
                                            data_path + 'pile_all_hidden_states.pt',
                                            data_path + 'pile_metadata.csv', equal_one)

In [19]:
def between_90_and_100(x):
    return (0.9 <= x) & (x < 1)

fuzzy_pos_pythia_evals = create_prompt_dist_from_metadata(pythia_evals_metadata,
                                                            data_path + 'pythia_evals_all_hidden_states.pt',
                                                            data_path + 'pythia_evals_metadata.csv', between_90_and_100)
len(fuzzy_pos_pythia_evals.idxs)


45

In [20]:
px.histogram(pythia_evals_metadata['tok_by_tok_sim'], title = 'Pythia Eval Similarity Distribution')

In [10]:
from white_box.probes import LRProbe
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from sklearn.linear_model import LogisticRegression
from datasets import load_from_disk, DatasetDict
from sklearn.metrics import accuracy_score, roc_auc_score

class Dataset:
    def __init__(self, pos : List[PromptDist], neg : List[PromptDist]):
        self.pos = pos
        self.neg = neg
        
        self.X = None
        self.y = None
    
    def instantiate(self):
        pos_states = torch.cat([torch.load(x.path_to_states)[x.idxs] for x in self.pos])
        neg_states = torch.cat([torch.load(x.path_to_states)[x.idxs] for x in self.neg])
        
        self.X = torch.cat([pos_states, neg_states]).cpu().float()
        self.y = torch.cat([torch.ones(len(pos_states)), torch.zeros(len(neg_states))])
        return self.X, self.y
    
    def train_test_split(self,
                        test_size : float = 0.2, 
                        layer : int = None,
                        tok_idxs : List[int] = None,
                        random_state : int = 0,
                        balanced : bool = True):
        
        assert self.X is not None, "You must instantiate the dataset first"
            
        if tok_idxs is not None:
            labels = self.y.view(-1, 1).expand(-1, len(tok_idxs)).flatten()
            if layer is not None:
                states = self.X[:, layer, tok_idxs].reshape(-1, self.X.shape[3])
            else:
                states = self.X[:, :, tok_idxs].reshape(-1, self.X.shape[3])
        else:
            labels = self.y
            states = self.X
        
        train_indices, test_indices = train_test_split(np.arange(len(labels)), test_size = test_size, random_state = random_state, stratify = labels if balanced else None)
        
        train_states = states[train_indices]
        test_states = states[test_indices]

        train_labels = labels[train_indices]
        test_labels = labels[test_indices]
        
        return train_states, test_states, train_labels, test_labels

In [11]:

class ProbeDataset():
    def __init__(self, dataset : Dataset):
        if dataset.X is None:
            dataset.instantiate()
        
        self.dataset = dataset
        
        self.N_LAYERS = self.dataset.X.shape[1]
        self.N_TOKS = self.dataset.X.shape[2]
    
    def layer_sweep_results(self,
                            lr : float = 0.01,
                            weight_decay : float = 1,
                            epochs : int = 500,
                            use_bias : bool = True,
                            test_size = 0.2):
        probes = [[None for _ in range(self.N_TOKS)] for _ in range(self.N_LAYERS)]
        probe_accs = [[None for _ in range(self.N_TOKS)] for _ in range(self.N_LAYERS)]
        probe_aucs = [[None for _ in range(self.N_TOKS)] for _ in range(self.N_LAYERS)]
        
        train_states, val_states, y_train, y_val = self.dataset.train_test_split(test_size = test_size, layer = None, tok_idxs = None, random_state = 0)
        
        for layer in tqdm(range(self.N_LAYERS)):
            for tok_idx in range(self.N_TOKS):
                X_train, X_val = train_states[:, layer, tok_idx], val_states[:, layer, tok_idx]
                
                probe = LRProbe.from_data(X_train, y_train, 
                                        lr = lr, 
                                        weight_decay = weight_decay, 
                                        epochs = epochs, 
                                        use_bias = use_bias,
                                        device = "cuda")    
                    
                probes[layer][tok_idx] = probe
                probe_accs[layer][tok_idx] = probe.get_probe_accuracy(X_val, y_val, device = "cuda")
                probe_aucs[layer][tok_idx] = probe.get_probe_auc(X_val, y_val, device = "cuda")
        
        return np.array(probe_accs).T, np.array(probe_aucs).T, probes

    def train_probe(self, layer : int, tok_idxs : List[int],
                    lr : float = 0.01,
                    weight_decay : float = 1,
                    epochs : int = 500,
                    use_bias : bool = True):
        
        X_train, X_val, y_train, y_val = self.dataset.train_test_split(test_size = 0.2, layer = layer, tok_idxs = tok_idxs, random_state = 0)
        
        probe = LRProbe.from_data(X_train, y_train, 
                                lr = lr, 
                                weight_decay = weight_decay, 
                                epochs = epochs, 
                                use_bias = use_bias,
                                device = "cuda")

        acc = probe.get_probe_accuracy(X_val, y_val, device = "cuda")
        auc = probe.get_probe_auc(X_val, y_val, device = "cuda")
        
        return acc, auc, probe
    
    def train_sk_probe(self, layer : int, tok_idxs : List[int], 
                       max_iter = 3000,
                       C = 1e-5
                       ):
        X_train, X_val, y_train, y_val = self.dataset.train_test_split(test_size = 0.2, layer = layer, tok_idxs = tok_idxs, random_state = 0)

        probe_lr = LogisticRegression(max_iter = max_iter, C = C)
        probe_lr.fit(X_train.numpy(), y_train.numpy())

        y_pred = probe_lr.predict(X_val.numpy())
        accuracy = accuracy_score(y_val.numpy(), y_pred)
        auc = roc_auc_score(y_val.numpy(), probe_lr.predict_proba(X_val.numpy())[:, 1])
        
        return accuracy, auc, probe_lr

In [12]:
dataset = Dataset([pos_pythia_evals, pos_quotes, pos_prefix], [neg_pythia_evals, neg_quotes, ])
X, y = dataset.instantiate()
probe_dataset = ProbeDataset(dataset)
accs, aucs, probes = probe_dataset.layer_sweep_results()

100%|██████████| 6/6 [00:23<00:00,  3.83s/it]


In [13]:
import plotly.express as px
fig = px.imshow(accs, y=[str(i) for i in range(10)], x=[str(i) for i in range(accs.shape[1])])
fig.update_layout(
    title=f"Probe Accuracy, {model_name}",
    xaxis_title="Layers",
    yaxis_title="Tokens",
)

fig.show()