In [1]:
import os
import json
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from together import Together
from transformers import AutoTokenizer
import time

from utils import *
from huggingface_hub import login as hf_login
from peft import prepare_model_for_kbit_training
from datasets import concatenate_datasets, DatasetDict, load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoModel

In [2]:
with open('model_map.json') as f:
    model_map=json.load(f)
key = '779d92de61a5035835e5023ca79e2e5b6124c6300c3ceb0e07e374f948554116'
client = Together(api_key=key)
hf_login(token="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")
dataset = load_dataset("beanham/medsum_llm_attack")
merged_dataset = concatenate_datasets([dataset['train'], dataset['validation'], dataset['test']])
new_ids = range(len(merged_dataset))
merged_dataset = merged_dataset.add_column("new_ID", new_ids)

In [3]:
## load model
split='train'
id=3
subsample_split='subsample_ids'
target_model_api_key=model_map[split+'_'+str(id)]['api_key']
prob_generator = GenerateNextTokenProbAPI(client, target_model_api_key)

## load data
target_subsample_ids = pd.read_csv(f"formatted_data/{subsample_split}_{id}.csv")['new_ID'].tolist()
train_dataset = merged_dataset.filter(lambda example: example['new_ID'] in target_subsample_ids)
test_dataset = merged_dataset.filter(lambda example: example['new_ID'] not in target_subsample_ids)

## why are we only using len(ents)<5 as the unseen ents?
unseen_ents = [sample['disease_ents'] for sample in test_dataset if len(sample['disease_ents'])<5]
unseen_ents = [item for sublist in unseen_ents for item in sublist]

##make unseen entities unique
unseen_ents = list(set(unseen_ents))

train_dataset = [sample for sample in train_dataset if len(sample['disease_ents'])>=5]
test_dataset = [sample for sample in test_dataset if len(sample['disease_ents'])>=5]
train_test_ents = {'train': train_dataset,'test': test_dataset}
len(train_dataset), len(test_dataset), target_model_api_key

(154, 137, 'bh193/Meta-Llama-3.1-8B-Instruct-Reference-afc89c2b-0f1f3c8a')

In [4]:
# unseen ents value count
pd.Series(unseen_ents).value_counts()

Invega              1
vulva               1
estrogen            1
pleural effusion    1
snapping            1
                   ..
aches               1
Adderall            1
Advil               1
Dilaudid            1
sprained            1
Name: count, Length: 399, dtype: int64

In [5]:
PROMPT_TEMPLATE[PROMPT_TO_USE]

['Consider an individual whose medical summary contains: ',
 ". That individual's medical summary then also includes: "]

In [7]:
results = {}
fail_counter = 0

for name, samples in train_test_ents.items():
    for j, ent_list in tqdm(enumerate(samples[:50])):
        print(f'{name.upper()}: {j+1}/50...')
        key_name = name + '_' + str(ent_list['ID'])
        results[key_name] = {}
        results[key_name]['y_stars'] = {}
        results[key_name]['y_NON_stars'] = {}

        ents = list(set(ent_list['disease_ents']))
        k = len(ents)

        unseen_ents_for_sample = random.sample(unseen_ents, k)

        # make sure no overlap between unseen_ents and ents
        for i in range(k):
            while unseen_ents_for_sample[i] in ents:
                unseen_ents_for_sample[i] = random.choice(unseen_ents)
                print('had to swap an ent')

        for i in range(k):
            y_star = ents[i]
            y_NON_star = unseen_ents_for_sample[i]

            results[key_name]['y_stars'][y_star] = {}
            results[key_name]['y_NON_stars'][y_NON_star] = {}

            remaining_ents = ents[:i] + ents[i+1:]

            prompt_start = PROMPT_TEMPLATE[PROMPT_TO_USE][0]
            prompt_end = PROMPT_TEMPLATE[PROMPT_TO_USE][1]
            ents_string = ', '.join(remaining_ents)
            prompt = f"{prompt_start} {ents_string} {prompt_end}"

            max_tokens = len(prob_generator.tokenizer(prompt)['input_ids']) + 10

            # now a prob dictionary for y_star + remaining_ents
            star_probs_dict = compute_token_probs_api(
                prob_generator,
                prompt=prompt,
                target_string=y_star,
                remaining_ents=remaining_ents,
                max_tokens=max_tokens
            )

            if star_probs_dict['target_prob'] == -1:
                fail_counter += 1
                print(f"failed {fail_counter} times (y_star not found)")
                continue

            # prob dictionary for y_NON_star + remaining_ents
            non_star_probs_dict = compute_token_probs_api(
                prob_generator,
                prompt=prompt,
                target_string=y_NON_star,
                remaining_ents=remaining_ents,
                max_tokens=max_tokens
            )

            if non_star_probs_dict['target_prob'] == -1:
                fail_counter += 1
                print(f"failed {fail_counter} times (y_NON_star not found)")
                continue

            # save target prob, ent_probs, tokens_probs
            results[key_name]['y_stars'][y_star]['target_prob'] = star_probs_dict['target_prob']
            results[key_name]['y_stars'][y_star]['ents_prob']   = star_probs_dict['ents_prob']
            results[key_name]['y_stars'][y_star]['prompt']       = prompt

            results[key_name]['y_NON_stars'][y_NON_star]['target_prob'] = non_star_probs_dict['target_prob']
            results[key_name]['y_NON_stars'][y_NON_star]['ents_prob']   = non_star_probs_dict['ents_prob']
            results[key_name]['y_NON_stars'][y_NON_star]['prompt']       = prompt

            # short delay to avoid rate-limiting
            time.sleep(0.1)

0it [00:00, ?it/s]

TRAIN: 1/50...


1it [00:07,  7.68s/it]

TRAIN: 2/50...


2it [00:13,  6.86s/it]

TRAIN: 3/50...


3it [00:18,  6.00s/it]

TRAIN: 4/50...


4it [00:24,  5.77s/it]

TRAIN: 5/50...


5it [00:30,  5.84s/it]

TRAIN: 6/50...


6it [00:35,  5.52s/it]

TRAIN: 7/50...


7it [00:45,  7.08s/it]

TRAIN: 8/50...


8it [00:50,  6.29s/it]

TRAIN: 9/50...


9it [00:59,  7.25s/it]

TRAIN: 10/50...


10it [01:10,  8.27s/it]

TRAIN: 11/50...


11it [01:14,  7.07s/it]

TRAIN: 12/50...


12it [01:19,  6.43s/it]

TRAIN: 13/50...


13it [01:37, 10.10s/it]

TRAIN: 14/50...


14it [01:43,  8.62s/it]

TRAIN: 15/50...


15it [01:48,  7.59s/it]

TRAIN: 16/50...


16it [01:56,  7.73s/it]

TRAIN: 17/50...


17it [02:01,  7.04s/it]

TRAIN: 18/50...


18it [02:07,  6.53s/it]

TRAIN: 19/50...


19it [02:12,  6.15s/it]

TRAIN: 20/50...


20it [02:21,  7.02s/it]

TRAIN: 21/50...


21it [02:27,  6.62s/it]

TRAIN: 22/50...


22it [02:33,  6.44s/it]

TRAIN: 23/50...


23it [02:40,  6.67s/it]

TRAIN: 24/50...


24it [02:47,  6.77s/it]

TRAIN: 25/50...


25it [02:53,  6.59s/it]

TRAIN: 26/50...


26it [03:04,  8.05s/it]

TRAIN: 27/50...


27it [03:15,  8.88s/it]

TRAIN: 28/50...


28it [03:20,  7.64s/it]

TRAIN: 29/50...


29it [03:25,  6.85s/it]

TRAIN: 30/50...


30it [03:32,  6.76s/it]

TRAIN: 31/50...


31it [03:36,  6.10s/it]

TRAIN: 32/50...


32it [03:42,  6.03s/it]

TRAIN: 33/50...


33it [03:47,  5.59s/it]

TRAIN: 34/50...


34it [03:53,  5.72s/it]

TRAIN: 35/50...


35it [04:06,  7.97s/it]

TRAIN: 36/50...


36it [04:18,  9.15s/it]

TRAIN: 37/50...


37it [04:23,  7.89s/it]

TRAIN: 38/50...


38it [04:28,  7.18s/it]

TRAIN: 39/50...


39it [04:33,  6.38s/it]

TRAIN: 40/50...


40it [04:42,  7.27s/it]

TRAIN: 41/50...


41it [04:54,  8.54s/it]

TRAIN: 42/50...


42it [05:02,  8.58s/it]

TRAIN: 43/50...


43it [05:10,  8.40s/it]

TRAIN: 44/50...


44it [05:16,  7.53s/it]

TRAIN: 45/50...


45it [05:24,  7.85s/it]

TRAIN: 46/50...


46it [05:29,  7.03s/it]

TRAIN: 47/50...


47it [05:36,  6.91s/it]

TRAIN: 48/50...


48it [05:43,  7.04s/it]

TRAIN: 49/50...


49it [05:49,  6.53s/it]

TRAIN: 50/50...


50it [05:59,  7.19s/it]
0it [00:00, ?it/s]

TEST: 1/50...


1it [00:11, 11.97s/it]

TEST: 2/50...


2it [00:24, 12.14s/it]

TEST: 3/50...


3it [00:28,  8.76s/it]

TEST: 4/50...


4it [00:42, 10.47s/it]

TEST: 5/50...


5it [00:49,  9.24s/it]

TEST: 6/50...


6it [00:55,  8.33s/it]

TEST: 7/50...


7it [01:05,  8.93s/it]

TEST: 8/50...


8it [01:12,  8.24s/it]

TEST: 9/50...


9it [01:19,  7.92s/it]

TEST: 10/50...


10it [01:26,  7.52s/it]

TEST: 11/50...


11it [01:33,  7.35s/it]

TEST: 12/50...


12it [01:39,  6.97s/it]

TEST: 13/50...


13it [01:54,  9.27s/it]

TEST: 14/50...


14it [02:01,  8.65s/it]

TEST: 15/50...


15it [02:14, 10.07s/it]

TEST: 16/50...


16it [02:21,  8.97s/it]

TEST: 17/50...


17it [02:27,  8.33s/it]

TEST: 18/50...


18it [02:32,  7.15s/it]

TEST: 19/50...


19it [02:36,  6.33s/it]

TEST: 20/50...
had to swap an ent


20it [02:48,  7.86s/it]

TEST: 21/50...


21it [02:54,  7.28s/it]

TEST: 22/50...


22it [02:59,  6.86s/it]

TEST: 23/50...
had to swap an ent


23it [03:04,  6.26s/it]

TEST: 24/50...


24it [03:13,  6.87s/it]

TEST: 25/50...
had to swap an ent


25it [03:27,  9.02s/it]

TEST: 26/50...


26it [03:32,  7.86s/it]

TEST: 27/50...
had to swap an ent


27it [04:42, 26.58s/it]

TEST: 28/50...


28it [04:48, 20.39s/it]

TEST: 29/50...


29it [04:54, 15.99s/it]

TEST: 30/50...


30it [05:09, 15.90s/it]

TEST: 31/50...


31it [05:15, 12.92s/it]

TEST: 32/50...


32it [05:21, 10.63s/it]

TEST: 33/50...


33it [05:25,  8.84s/it]

TEST: 34/50...


34it [05:31,  7.78s/it]

TEST: 35/50...


35it [05:39,  7.88s/it]

TEST: 36/50...
had to swap an ent


36it [05:57, 11.03s/it]

TEST: 37/50...


37it [06:09, 11.35s/it]

TEST: 38/50...


38it [06:15,  9.61s/it]

TEST: 39/50...


39it [06:25,  9.76s/it]

TEST: 40/50...


40it [06:32,  8.90s/it]

TEST: 41/50...
had to swap an ent


41it [06:37,  7.82s/it]

TEST: 42/50...


42it [06:48,  8.78s/it]

TEST: 43/50...


43it [06:54,  7.99s/it]

TEST: 44/50...


44it [07:01,  7.67s/it]

TEST: 45/50...


45it [07:14,  9.23s/it]

TEST: 46/50...


46it [07:21,  8.52s/it]

TEST: 47/50...


47it [07:26,  7.42s/it]

TEST: 48/50...


48it [07:42, 10.09s/it]

TEST: 49/50...


49it [07:47,  8.38s/it]

TEST: 50/50...


50it [07:53,  9.47s/it]


In [8]:
with open(f'target_token_probs_{split}_{id}_10_epochs_prompt_{PROMPT_TO_USE}.json', 'w') as f:
    json.dump(results, f)