In [1]:
import numpy as np 
import torch 
from tqdm import tqdm 
import pickle 
import pandas as pd
from typing import List, Dict, Any, Tuple, Union, Optional, Callable
import requests 
import time 
from collections import defaultdict 

import datasets
from datasets import load_dataset
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForCausalLM
import jailbreakbench as jbb 
import plotly.express as px
import sys
sys.path.append('../') 

from white_box.model_wrapper import ModelWrapper
from white_box.utils import gen_pile_data 
from white_box.dataset import clean_data 
from white_box.chat_model_utils import load_model_and_tokenizer, get_template, MODEL_CONFIGS

from white_box.dataset import PromptDist, ActDataset, create_prompt_dist_from_metadata_path, ProbeDataset
from white_box.probes import LRProbe
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from sklearn.linear_model import LogisticRegression
from datasets import load_from_disk, DatasetDict
from sklearn.metrics import accuracy_score, roc_auc_score

%load_ext autoreload
%autoreload 2

In [2]:
model_name = 'llama2_7b'
data_path = f"../data/{model_name}/"

In [None]:
model_config = MODEL_CONFIGS[model_name]
model, tokenizer = load_model_and_tokenizer(**model_config)
template = get_template(model_name, chat_template=model_config.get('chat_template', None))['prompt']

mw = ModelWrapper(model, tokenizer, template = template)

## probing harmful vs harmless input

In [5]:
from white_box.jb_experiments import plot_acc_auc
file_spec = "harmbench_alpaca_"
harmless = create_prompt_dist_from_metadata_path(data_path + f'/{file_spec}metadata.csv', col_filter = "(metadata['label'] == 1)")
harmful =  create_prompt_dist_from_metadata_path(data_path + f'/{file_spec}metadata.csv', col_filter = "(metadata['label'] == 0)")
print(len(harmless.idxs), len(harmful.idxs))
dataset = ActDataset([harmful], [harmless])
dataset.instantiate()
probe_dataset = ProbeDataset(dataset)

1000 1000


In [20]:
from white_box.jb_experiments import plot_acc_auc
file_spec = "jb_"
harmless = create_prompt_dist_from_metadata_path(data_path + f'/{file_spec}metadata.csv', col_filter = "(metadata['jb_name'] == 'harmless')")
harmful =  create_prompt_dist_from_metadata_path(data_path + f'/{file_spec}metadata.csv', col_filter = "(metadata['jb_name'] == 'DirectRequest')")
print(len(harmless.idxs), len(harmful.idxs))
dataset = ActDataset([harmful], [harmless])
dataset.instantiate()
probe_dataset = ProbeDataset(dataset)

100 103


In [26]:
accs, aucs, hb_dolly_probes = [], [], []
for layer in tqdm(range(32)):
    acc, auc, probe = probe_dataset.train_sk_probe(layer, tok_idxs = list(range(5)), test_size = 0.25, C = 1e-2, 
                                                   max_iter = 2000,
                                                   random_state = None)
    accs.append(acc)
    aucs.append(auc)
    hb_dolly_probes.append(probe)

  0%|          | 0/32 [00:00<?, ?it/s]

100%|██████████| 32/32 [00:05<00:00,  5.47it/s]


In [27]:
plot_acc_auc(accs, aucs, title = "Harmbench-Alpaca Test, probes trained on last 5 token positions")

In [16]:
accs, aucs, probes = probe_dataset.layer_tokidx_sweep_results(lr = 0.01, weight_decay = 10)
fig = px.imshow(accs, y=[str(i) for i in range(accs.shape[0])], x=[str(i) for i in range(accs.shape[1])])
fig.update_layout(
    title=f"Probe Accuracy on Harmbench-Dolly, {model_name}",
    xaxis_title="Layers",
    yaxis_title="Tokens",
)

fig.show()

100%|██████████| 32/32 [00:56<00:00,  1.75s/it]


In [16]:
file_spec = "jb_"
jb_metadata = pd.read_csv(f"{data_path}/{file_spec}metadata.csv", sep = "t")
jbs =  create_prompt_dist_from_metadata_path(data_path + f'/{file_spec}metadata.csv', col_filter = "(metadata['label'] == 1) & (metadata['jb_name'] != 'DirectRequest')")
failed_jbs = create_prompt_dist_from_metadata_path(data_path + f'/{file_spec}metadata.csv', col_filter = "(metadata['label'] == 0) & (metadata['jb_name'] != 'DirectRequest') & (metadata['jb_name'] != 'harmless')")
print(len(jbs.idxs), len(failed_jbs.idxs))
dataset = ActDataset([jbs], [failed_jbs])
dataset.instantiate()
jb_probe_dataset = ProbeDataset(dataset)

427 430


In [28]:
all_accs = defaultdict(list)

for layer in range(32):
    pred_probas = hb_dolly_probes[layer].predict_proba(jb_probe_dataset.act_dataset.X[:, layer])

    preds_any, preds_mean = (pred_probas > 0.5).any(dim = 1).cpu().detach(), (pred_probas.mean(dim = 1) > 0.5).cpu().detach()
    any_acc, mean_acc = preds_any.numpy().sum() / len(pred_probas), preds_mean.numpy().sum() / len(preds_mean)
    all_accs['any_harm_acc'].append(any_acc)
    all_accs['mean_harm_acc'].append(mean_acc)
    
    labels = jb_probe_dataset.act_dataset.y

    label_any_acc, label_mean_acc = (labels == preds_any).sum() / len(labels), (labels == preds_mean).sum() / len(labels)
    all_accs['any_succ_acc'].append(label_any_acc.item())
    all_accs['mean_succ_acc'].append(label_mean_acc.item())

In [29]:
px.line(all_accs['any_harm_acc'], title = "Any Harm Acc")

In [30]:
px.line(all_accs['mean_harm_acc'], title = "Mean Harm Acc")

In [31]:
px.line(all_accs['mean_succ_acc'], title = "Mean Whether JB was successful Acc")