## Probing results for 70m

In [1]:
import transformers
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
import os
import re
import random
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Dict
from tqdm import tqdm
import pickle
from dotenv import load_dotenv
import openai
import sys
sys.path.append('..')

from typing import List, Optional, Tuple, Dict
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from datasets import Dataset, load_from_disk

import plotly.graph_objects as go
import plotly.express as px

from utils import untuple

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression

  from .autonotebook import tqdm as notebook_tqdm
  return self.fget.__get__(instance, owner)()


In [3]:
file_path = '../../../gld/train-data-probes/data/70m'
dataset = load_from_disk(os.path.join(file_path, 'split_dataset'))

In [4]:
mem_hiddens = torch.load(f'{file_path}/pythia-evals/mem_all_hidden_states.pt')
pile_hiddens = torch.load(f'{file_path}/pile/pile_all_hidden_states.pt')

In [5]:
hiddens = torch.cat([mem_hiddens, pile_hiddens], dim=0)
hiddens.shape

torch.Size([10000, 6, 10, 512])

In [9]:
seed = 0
random.seed(seed)
np.random.seed(seed)
set_seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

dataset['train'] = dataset['train'].shuffle(seed=seed)
dataset['val'] = dataset['val'].shuffle(seed=seed)

temp_train = dataset['train']
temp_test = dataset['val']

In [10]:
# deduplicate temp_train using the 'text' column
# eg if the same text appears in multiple rows, only keep one row


temp_train = temp_train.to_pandas()
temp_test = temp_test.to_pandas()

temp_train = temp_train.drop_duplicates(subset='text')
temp_test = temp_test.drop_duplicates(subset='text')

temp_train = Dataset.from_pandas(temp_train)
temp_test = Dataset.from_pandas(temp_test)

In [11]:
train_idxs = temp_train['orig_idx']
test_idxs = temp_test['orig_idx']

train_idxs = torch.tensor(train_idxs)
test_idxs = torch.tensor(test_idxs)

train_acts = hiddens[train_idxs]
test_acts = hiddens[test_idxs]

train_acts.shape, test_acts.shape

(torch.Size([2817, 6, 10, 512]), torch.Size([470, 6, 10, 512]))

In [12]:
generalization_datasets = load_from_disk(os.path.join(file_path, 'generalization_datasets'))
mem_dist = generalization_datasets['mem_dist']

mem_dist_hiddens = torch.load(f'{file_path}/pythia-evals-12b/mem_all_hidden_states.pt')
mem_dist_idxs = mem_dist['orig_idx']
mem_dist_idxs = list(set(mem_dist_idxs))
mem_dist_hiddens = mem_dist_hiddens[mem_dist_idxs]

In [13]:
from probes import LRProbe
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score

probe_acc = np.zeros((10, 6))
mem_dist_test_acc = np.zeros((10, 6))

for tok_idx in range(10): 
    for layer in tqdm(range(6)): 

        train = train_acts[:, layer, tok_idx, :]
        test = test_acts[:, layer, tok_idx, :]

        train = train.cpu().numpy()
        test = test.cpu().numpy()

        mem_dist_test = mem_dist_hiddens[:, layer, tok_idx, :]
        mem_dist_test = mem_dist_test.cpu().numpy()
        X_mem_dist = torch.tensor(mem_dist_test, dtype=torch.float32)
        y_mem_dist = torch.tensor(mem_dist['labels'], dtype=torch.float32)

        X_train = torch.tensor(train, dtype=torch.float32)
        y_train = torch.tensor(temp_train['labels'], dtype=torch.float32)
        X_test = torch.tensor(test, dtype=torch.float32)
        y_test = torch.tensor(temp_test['labels'], dtype=torch.float32)

        X_mem_train, X_mem_test, y_mem_train, y_mem_test = train_test_split(X_mem_dist, y_mem_dist, test_size=0.2, random_state=seed)

        probe = LRProbe.from_data(X_train, y_train)
        probe_acc[tok_idx, layer] = LRProbe.get_probe_accuracy(probe, X_test, y_test, device = "cpu")

        mem_dist_test_acc[tok_idx, layer] = LRProbe.get_probe_accuracy(probe, X_mem_dist, y_mem_dist, device = "cpu")
        

100%|██████████| 6/6 [00:04<00:00,  1.32it/s]
100%|██████████| 6/6 [00:04<00:00,  1.33it/s]
100%|██████████| 6/6 [00:04<00:00,  1.34it/s]
100%|██████████| 6/6 [00:04<00:00,  1.34it/s]
100%|██████████| 6/6 [00:04<00:00,  1.30it/s]
100%|██████████| 6/6 [00:04<00:00,  1.31it/s]
100%|██████████| 6/6 [00:04<00:00,  1.34it/s]
100%|██████████| 6/6 [00:04<00:00,  1.33it/s]
100%|██████████| 6/6 [00:04<00:00,  1.33it/s]
100%|██████████| 6/6 [00:04<00:00,  1.33it/s]


In [25]:
# probe acc plot
fig = px.imshow(probe_acc, x = list(range(6)), y = list(range(10)), color_continuous_scale='Blues')
fig.update_layout(
    title="Probe Accuracy",
    xaxis_title="Layer",
    yaxis_title="Token",
    font=dict(
        family="Courier New, monospace",
        size=18,
        color="RebeccaPurple"
    )
)
fig.show()

In [26]:
# probe acc plot
fig = px.imshow(mem_dist_test_acc, x = list(range(6)), y = list(range(10)), color_continuous_scale='Blues')
fig.update_layout(
    title="Mem Dist Test Acc",
    xaxis_title="Layer",
    yaxis_title="Token",
    font=dict(
        family="Courier New, monospace",
        size=18,
        color="RebeccaPurple"
    )
)
fig.show()