In [1]:
import numpy as np 
import torch 
from tqdm import tqdm 
import pickle 
import pandas as pd
from typing import List, Dict, Any, Tuple, Union, Optional, Callable
import requests 
import time 
import os

import datasets
from datasets import load_dataset
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForCausalLM
import argparse 
import json

import sys
sys.path.append('../') 

from white_box.model_wrapper import ModelWrapper
from white_box.utils import gen_pile_data 
from white_box.dataset import clean_data 
from white_box.chat_model_utils import load_model_and_tokenizer, get_template, MODEL_CONFIGS

from white_box.dataset import PromptDist, ActDataset, create_prompt_dist_from_metadata_path, ProbeDataset
from white_box.probes import LRProbe
from white_box.monitor import ActMonitor, TextMonitor

In [2]:
path = '../data/llama2_7b'

sep = 't'
df = pd.read_csv(os.path.join(path, 'jb_metadata.csv'), sep=sep)
df['jb_name'].value_counts()

jb_name
GCG              151
EnsembleGCG      129
DirectRequest    103
harmless         100
AutoPrompt        90
TAP               77
PAIR              76
FewShot           56
UAT               53
GBDA              51
PEZ               47
ZeroShot          46
PAP               44
AutoDAN           37
Name: count, dtype: int64

In [13]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
from peft import AutoPeftModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("meta-llama/LlamaGuard-7b", padding_side='left')
tokenizer.pad_token = tokenizer.eos_token
# model = AutoPeftModelForSequenceClassification.from_pretrained("../data/llama2_7b/llamaguard_generated__model_1", torch_dtype=torch.bfloat16, num_labels=2)
# model = model.merge_and_unload()
model = AutoModelForCausalLM.from_pretrained("../data/llama2_7b/llamaguard_generated__model_2", torch_dtype=torch.bfloat16)
model.config.pad_token_id = model.config.eos_token_id

device = 'cuda'
model.to(device)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): lora.Embedding(
      (base_layer): Embedding(32000, 4096)
      (lora_dropout): ModuleDict(
        (default): Dropout(p=0.1, inplace=False)
      )
      (lora_A): ModuleDict()
      (lora_B): ModuleDict()
      (lora_embedding_A): ParameterDict(  (default): Parameter containing: [torch.cuda.BFloat16Tensor of size 8x32000 (cuda:0)])
      (lora_embedding_B): ParameterDict(  (default): Parameter containing: [torch.cuda.BFloat16Tensor of size 4096x8 (cuda:0)])
    )
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): lora.Linear(
            (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.1, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Line

In [14]:
def get_batched_preds(prompts: List[str], model: torch.nn.Module, tokenizer: AutoTokenizer, device: str, batch_size: int = 8) -> np.ndarray:
    toks = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}]) for prompt in prompts]
    max_len = max([len(t) for t in toks])
    for i in range(len(toks)):
        toks[i] = [tokenizer.pad_token_id] * (max_len - len(toks[i])) + toks[i]
    toks = torch.tensor(toks).to(device)
    preds = []
    for i in tqdm(range(0, len(toks), batch_size)): 
        output = model(toks[i:i+batch_size])
        preds.append(torch.stack([output.logits[:,-1, 9109], output.logits[:,-1, 25110]], dim=1).softmax(-1).cpu().detach().numpy()[:,1])
        # preds.append(output.logits.softmax(dim=-1)[:,1].cpu().detach().numpy())
    return np.concatenate(preds)
    
path = '../data/'

## FPR check

In [15]:
from datasets import load_dataset

dataset = load_dataset("yahma/alpaca-cleaned")['train']
dataset

Dataset({
    features: ['output', 'instruction', 'input'],
    num_rows: 51760
})

In [16]:
prompts = [f'{input}\n\n{instruction}' for input, instruction in zip(dataset['input'], dataset['instruction'])]
len(prompts)

51760

In [17]:
np.random.shuffle(prompts)

preds = get_batched_preds(prompts[:100], model, tokenizer, device)



  0%|          | 0/13 [00:00<?, ?it/s]

100%|██████████| 13/13 [00:09<00:00,  1.30it/s]


In [18]:
import plotly.express as px
px.histogram(preds)

In [19]:
advbench_positives = pd.read_csv('../data/harmful_behaviors_custom.csv')['goal'].tolist()
gpt_negatives = pd.read_csv('../data/harmless_behaviors_custom.csv')['prompt'].tolist()

advbench_preds = get_batched_preds(advbench_positives, model, tokenizer, device)
gpt_preds = get_batched_preds(gpt_negatives, model, tokenizer, device)


  0%|          | 0/7 [00:00<?, ?it/s]

100%|██████████| 7/7 [00:04<00:00,  1.65it/s]
100%|██████████| 7/7 [00:04<00:00,  1.62it/s]


In [20]:
px.histogram(advbench_preds)

In [21]:
px.histogram(gpt_preds)

## Language results

### Turkish

In [None]:
turkish_path = os.path.join(path, 'turkish/')
pos = pd.read_csv(os.path.join(turkish_path, 'harmful_behaviors_custom_metadata.csv'))
neg = pd.read_csv(os.path.join(turkish_path, 'harmless_behaviors_custom_metadata.csv'))

pos_prompts = pos['prompt'].tolist()
neg_prompts = neg['prompt'].tolist()

pos_preds = get_batched_preds(pos_prompts, model, tokenizer, device)
neg_preds = get_batched_preds(neg_prompts, model, tokenizer, device)

pos_acc = (pos_preds > 0.5).mean()
neg_acc = (neg_preds < 0.5).mean()
total_acc = 0.5 * (pos_acc + neg_acc)

pos_acc, neg_acc, total_acc

  0%|          | 0/7 [00:00<?, ?it/s]

100%|██████████| 7/7 [00:35<00:00,  5.09s/it]
100%|██████████| 7/7 [00:35<00:00,  5.13s/it]


(0.94, 0.76, 0.85)

### Dutch

In [None]:
dutch_path = os.path.join(path, 'dutch/')
pos = pd.read_csv(os.path.join(dutch_path, 'harmful_behaviors_custom_metadata.csv'))
neg = pd.read_csv(os.path.join(dutch_path, 'harmless_behaviors_custom_metadata.csv'))

pos_prompts = pos['prompt'].tolist()
neg_prompts = neg['prompt'].tolist()

pos_preds = get_batched_preds(pos_prompts, model, tokenizer, device)
neg_preds = get_batched_preds(neg_prompts, model, tokenizer, device)

pos_acc = (pos_preds > 0.5).mean()
neg_acc = (neg_preds < 0.5).mean()
total_acc = 0.5 * (pos_acc + neg_acc)

pos_acc, neg_acc, total_acc

100%|██████████| 7/7 [00:34<00:00,  4.92s/it]
100%|██████████| 7/7 [00:35<00:00,  5.01s/it]


(0.98, 0.06, 0.52)

### Hungarian

In [None]:
hungarian_path = os.path.join(path, 'hungarian/')
pos = pd.read_csv(os.path.join(hungarian_path, 'harmful_behaviors_custom_metadata.csv'))
neg = pd.read_csv(os.path.join(hungarian_path, 'harmless_behaviors_custom_metadata.csv'))

pos_prompts = pos['prompt'].tolist()
neg_prompts = neg['prompt'].tolist()

pos_preds = get_batched_preds(pos_prompts, model, tokenizer, device)
neg_preds = get_batched_preds(neg_prompts, model, tokenizer, device)

pos_acc = (pos_preds > 0.5).mean()
neg_acc = (neg_preds < 0.5).mean()
total_acc = 0.5 * (pos_acc + neg_acc)

pos_acc, neg_acc, total_acc

100%|██████████| 7/7 [00:35<00:00,  5.00s/it]
100%|██████████| 7/7 [00:35<00:00,  5.08s/it]


(0.98, 0.14, 0.56)

### Slovenian

In [None]:
hungarian_path = os.path.join(path, 'slovenian/')
pos = pd.read_csv(os.path.join(hungarian_path, 'harmful_behaviors_custom_metadata.csv'))
neg = pd.read_csv(os.path.join(hungarian_path, 'harmless_behaviors_custom_metadata.csv'))

pos_prompts = pos['prompt'].tolist()
neg_prompts = neg['prompt'].tolist()

pos_preds = get_batched_preds(pos_prompts, model, tokenizer, device)
neg_preds = get_batched_preds(neg_prompts, model, tokenizer, device)

pos_acc = (pos_preds > 0.5).mean()
neg_acc = (neg_preds < 0.5).mean()
total_acc = 0.5 * (pos_acc + neg_acc)

pos_acc, neg_acc, total_acc

  0%|          | 0/7 [00:00<?, ?it/s]

100%|██████████| 7/7 [00:34<00:00,  4.97s/it]
100%|██████████| 7/7 [00:35<00:00,  5.05s/it]


(0.98, 0.3, 0.64)