In [1]:
import os
import time
import json
import random
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
from os import makedirs, path
from together import Together

from utils import *
from huggingface_hub import login as hf_login
from datasets import concatenate_datasets, DatasetDict, load_dataset

In [2]:
#-------------------
# Parameters
#-------------------    
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='beanham/law_llm_attack')
parser.add_argument('--data_dir', type=str, default='formatted_data/')
parser.add_argument('--save_dir', type=str, default='probs/')
parser.add_argument('--model_tag', type=str, default='llama_10_epoch')
parser.add_argument('--together_key', type=str)
args = parser.parse_args(args=[])
## log in together ai & hugginface
with open('model_map.json') as f:
    model_map=json.load(f)
client = Together(api_key="779d92de61a5035835e5023ca79e2e5b6124c6300c3ceb0e07e374f948554116")
target_model_api_key = model_map[args.model_tag]['train']['api_key']
prob_generator = GenerateNextTokenProbAPI(client, target_model_api_key)    
hf_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
input(f"""
=============================================================================================
Please deploy the following model {target_model_api_key}. The deployment might take up to 10 mins. Once the model is deployed, please proceed...
=============================================================================================""")


Please deploy the following model bh193/Meta-Llama-3.1-8B-Instruct-Reference-76d7c75c-2e90586f. The deployment might take up to 10 mins. Once the model is deployed, please proceed...


''

In [146]:
def string_to_ents(example):
    criminal_ents=[i.replace("'", "").replace('[', '').replace(']', '') for i in example['criminal_behaviors'].split("', '")]
    identifiable_ents=[i.replace("'", "").replace('[', '').replace(']', '') for i in example['identifiable_info'].split("', '")]
    return {"law_ents":[i for i in criminal_ents+identifiable_ents if i != '']}

In [177]:
print('Load & Prepare Dataset...')
dataset = load_dataset(args.dataset)
all_data = concatenate_datasets([dataset['train'], dataset['val'], dataset['test']])
new_ids = range(len(all_data))
all_data = all_data.add_column("new_ID", new_ids)
all_data=all_data.map(string_to_ents)

Load & Prepare Dataset...


In [178]:
## load split ids
with open(path.join(args.data_dir, 'train_ids.txt'), 'r') as f:
    train_ids=f.readlines()
train_ids=[int(i.split()[0]) for i in train_ids]
dataset_train = all_data.filter(lambda example: example['new_ID'] in train_ids)
dataset_test = all_data.filter(lambda example: example['new_ID'] not in train_ids)

In [179]:
## unseen entities
unseen_ents = [sample['law_ents'] for sample in dataset_test if len(sample['law_ents'])<10]
unseen_ents = [item for sublist in unseen_ents for item in sublist]
unseen_ents = list(set(unseen_ents))

In [180]:
## portion of dataset
dataset_train = [sample for sample in dataset_train if len(sample['law_ents'])>=10]
dataset_test = [sample for sample in dataset_test if len(sample['law_ents'])>=10]
train_test_ents = {'train': dataset_train,'test': dataset_test}

In [181]:
len(dataset_train)

55

In [182]:
len(dataset_test)

79

In [210]:
results = {}
samples=train_test_ents['train']
ent_list=samples[0]

In [211]:
key_name = 'train' + '_' + str(ent_list['new_ID'])
results[key_name] = {}
results[key_name]['y_stars'] = {}
results[key_name]['y_NON_stars'] = {}
ents = list(set(ent_list['law_ents']))
k = len(ents)
unseen_ents_for_sample = random.sample(unseen_ents, k)

In [212]:
i=0
y_star = ents[i]
y_NON_star = unseen_ents_for_sample[i]
results[key_name]['y_stars'][y_star] = {}
results[key_name]['y_NON_stars'][y_NON_star] = {}
remaining_ents = ents[:i] + ents[i+1:]
prompt_start = PROMPT_TEMPLATE[PROMPT_TO_USE][0]
prompt_end = PROMPT_TEMPLATE[PROMPT_TO_USE][1]
ents_string = ', '.join(remaining_ents)
prompt = f"{prompt_start} {ents_string} {prompt_end}"
max_tokens = len(prob_generator.tokenizer(prompt)['input_ids']) + 20 ## changed from 10 to 20, for law cases