In [18]:
import numpy as np 
import torch 
from tqdm import tqdm 
import pickle 
import pandas as pd
from typing import List, Dict, Any, Tuple, Union, Optional, Callable
import requests 
import time 
from collections import defaultdict 

import datasets
from datasets import load_dataset
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForCausalLM
import sys
sys.path.append('../') 

from white_box.model_wrapper import ModelWrapper
from white_box.utils import gen_pile_data 
from white_box.dataset import clean_data 
from white_box.chat_model_utils import load_model_and_tokenizer, get_template, MODEL_CONFIGS

from white_box.dataset import PromptDist, ActDataset, create_prompt_dist_from_metadata_path, ProbeDataset
from white_box.probes import LRProbe
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from sklearn.linear_model import LogisticRegression
from datasets import load_from_disk, DatasetDict
from sklearn.metrics import accuracy_score, roc_auc_score
import plotly.express as px

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:
model_name = 'llama2_7b'
data_path = f"../data/{model_name}/"

In [20]:
# model_config = MODEL_CONFIGS[model_name]
# model, tokenizer = load_model_and_tokenizer(**model_config)
# template = get_template(model_name, chat_template=model_config.get('chat_template', None))['prompt']

# mw = ModelWrapper(model, tokenizer, template = template)

## probing harmful vs harmless input

In [4]:
from white_box.jb_experiments import plot_acc_auc
file_spec = "harmbench_alpaca_"
harmful = create_prompt_dist_from_metadata_path(data_path + f'/{file_spec}metadata.csv', col_filter = "(metadata['label'] == 1)")
harmless =  create_prompt_dist_from_metadata_path(data_path + f'/{file_spec}metadata.csv', col_filter = "(metadata['label'] == 0)")
print(len(harmless.idxs), len(harmful.idxs))
dataset = ActDataset([harmful], [harmless])
dataset.instantiate()
probe_dataset = ProbeDataset(dataset)

1200 1200


In [21]:
from white_box.jb_experiments import plot_acc_auc
file_spec = "harmbench_alpaca_test_"
harmful = create_prompt_dist_from_metadata_path(data_path + f'/{file_spec}metadata.csv', col_filter = "(metadata['label'] == 1)")
harmless =  create_prompt_dist_from_metadata_path(data_path + f'/{file_spec}metadata.csv', col_filter = "(metadata['label'] == 0)")
print(len(harmless.idxs), len(harmful.idxs))
dataset = ActDataset([harmful], [harmless])
dataset.instantiate()
test_probe_dataset = ProbeDataset(dataset)

295 295


In [22]:
from white_box.jb_experiments import plot_acc_auc
file_spec = "gpt_generated_"
harmful = create_prompt_dist_from_metadata_path(data_path + f'/{file_spec}metadata.csv', col_filter = "(metadata['label'] == 1)")
harmless =  create_prompt_dist_from_metadata_path(data_path + f'/{file_spec}metadata.csv', col_filter = "(metadata['label'] == 0)")
print(len(harmless.idxs), len(harmful.idxs))
dataset = ActDataset([harmful], [harmless])
dataset.instantiate()
gpt_probe_dataset = ProbeDataset(dataset)

500 500


In [23]:
from white_box.jb_experiments import plot_acc_auc
file_spec = "jb_"
harmless = create_prompt_dist_from_metadata_path(data_path + f'/{file_spec}metadata.csv', col_filter = "(metadata['jb_name'] == 'harmless')")
harmful =  create_prompt_dist_from_metadata_path(data_path + f'/{file_spec}metadata.csv', col_filter = "(metadata['jb_name'] == 'DirectRequest')")
print(len(harmless.idxs), len(harmful.idxs))
dataset = ActDataset([harmful], [harmless])
dataset.instantiate()
jb_probe_dataset = ProbeDataset(dataset)

100 103


In [24]:
file_spec = "jb_"
jb_metadata = pd.read_csv(f"{data_path}/{file_spec}metadata.csv", sep = "t")
jbs =  create_prompt_dist_from_metadata_path(data_path + f'/{file_spec}metadata.csv', col_filter = "(metadata['label'] == 1) & (metadata['jb_name'] != 'DirectRequest')")
failed_jbs = create_prompt_dist_from_metadata_path(data_path + f'/{file_spec}metadata.csv', col_filter = "(metadata['label'] == 0) & (metadata['jb_name'] != 'DirectRequest') & (metadata['jb_name'] != 'harmless')")
print(len(jbs.idxs), len(failed_jbs.idxs))
dataset = ActDataset([jbs], [failed_jbs])
dataset.instantiate()
jb_labeled_by_success_probe_dataset = ProbeDataset(dataset)

427 430


In [25]:
accs, aucs, hb_dolly_probes = [], [], []
for layer in tqdm(range(32)):
    acc, auc, probe = probe_dataset.train_sk_probe(layer, tok_idxs = list(range(5)), test_size = None, C = 1e-2, 
                                                   max_iter = 2000,
                                                   random_state = None, use_train_test_split=False)

    accs.append(acc)
    aucs.append(auc)
    hb_dolly_probes.append(probe)

100%|██████████| 32/32 [01:58<00:00,  3.71s/it]


In [26]:
plot_acc_auc(accs, aucs, title = "HB-Aplaca Train, probes trained on last 5 token positions")

In [27]:
metrics = defaultdict(list)

for layer in range(32):
    pred_probas = hb_dolly_probes[layer].predict_proba(test_probe_dataset.act_dataset.X[:, layer])
    probas_mean = pred_probas.mean(dim=-1).detach().cpu().numpy()
    labels = test_probe_dataset.act_dataset.y.detach().cpu().numpy()

    metrics['acc'].append(accuracy_score(labels, probas_mean > 0.5))
    metrics['auc'].append(roc_auc_score(labels, probas_mean))
    metrics['TPR'].append(((probas_mean > 0.5) & (labels == 1)).sum() / (labels == 1).sum())
    metrics['TNR'].append(((probas_mean < 0.5) & (labels == 0)).sum() / (labels == 0).sum())

In [28]:
import plotly.graph_objects as go

fig = go.Figure()
x = list(range(32))
fig.add_trace(go.Scatter(x=x, y=metrics['acc'], mode='lines', name='Accuracy'))
fig.add_trace(go.Scatter(x=x, y=metrics['auc'], mode='lines', name='AUC'))
fig.add_trace(go.Scatter(x=x, y=metrics['TPR'], mode='lines', name='TPR'))
fig.add_trace(go.Scatter(x=x, y=metrics['TNR'], mode='lines', name='TNR'))
fig.update_layout(
    title=f"HB-Alpaca probe, test on HB-Alpaca test",
    xaxis_title="Layers",
    yaxis_title="Value",
)
fig.show()


In [29]:
metrics = defaultdict(list)

for layer in range(32):
    pred_probas = hb_dolly_probes[layer].predict_proba(jb_probe_dataset.act_dataset.X[:, layer])
    probas_mean = pred_probas.mean(dim=-1).detach().cpu().numpy()
    labels = jb_probe_dataset.act_dataset.y.detach().cpu().numpy()

    metrics['acc'].append(accuracy_score(labels, probas_mean > 0.5))
    metrics['auc'].append(roc_auc_score(labels, probas_mean))
    metrics['TPR'].append(((probas_mean > 0.5) & (labels == 1)).sum() / (labels == 1).sum())
    metrics['TNR'].append(((probas_mean < 0.5) & (labels == 0)).sum() / (labels == 0).sum())

In [30]:
import plotly.graph_objects as go

fig = go.Figure()
x = list(range(32))
fig.add_trace(go.Scatter(x=x, y=metrics['acc'], mode='lines', name='Accuracy'))
fig.add_trace(go.Scatter(x=x, y=metrics['auc'], mode='lines', name='AUC'))
fig.add_trace(go.Scatter(x=x, y=metrics['TPR'], mode='lines', name='TPR'))
fig.add_trace(go.Scatter(x=x, y=metrics['TNR'], mode='lines', name='TNR'))
fig.update_layout(
    title=f"HB-alpaca probe, tested on jb dataset (100 positives from harmbench, 100 negatives from GPT4)",
    xaxis_title="Layers",
    yaxis_title="Value",
)
fig.show()


In [31]:
metrics = defaultdict(list)

for layer in range(32):
    pred_probas = hb_dolly_probes[layer].predict_proba(jb_labeled_by_success_probe_dataset.act_dataset.X[:, layer])
    probas_mean = pred_probas.mean(dim=-1).detach().cpu().numpy()
    labels = jb_labeled_by_success_probe_dataset.act_dataset.y.detach().cpu().numpy()

    metrics['acc'].append(accuracy_score(labels, probas_mean > 0.5))
    metrics['auc'].append(roc_auc_score(labels, probas_mean))
    metrics['TPR'].append(((probas_mean > 0.5) & (labels == 1)).sum() / (labels == 1).sum())
    metrics['TNR'].append(((probas_mean < 0.5) & (labels == 0)).sum() / (labels == 0).sum())

In [32]:
import plotly.graph_objects as go

fig = go.Figure()
x = list(range(32))
fig.add_trace(go.Scatter(x=x, y=metrics['acc'], mode='lines', name='Accuracy'))
fig.add_trace(go.Scatter(x=x, y=metrics['auc'], mode='lines', name='AUC'))
fig.add_trace(go.Scatter(x=x, y=metrics['TPR'], mode='lines', name='TPR'))
fig.add_trace(go.Scatter(x=x, y=metrics['TNR'], mode='lines', name='TNR'))
fig.update_layout(
    title=f"HB-alpaca probe, tested on jb dataset (800 total, labeled by success)",
    xaxis_title="Layers",
    yaxis_title="Value",
)
fig.show()
