In [1]:
import os
import json
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from together import Together
from transformers import AutoTokenizer

from utils import *
from huggingface_hub import login as hf_login
from peft import prepare_model_for_kbit_training
from datasets import concatenate_datasets, DatasetDict, load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoModel

  warn("The installed version of bitsandbytes was compiled without GPU support. "


'NoneType' object has no attribute 'cadam32bit_grad_fp32'


In [2]:
class GenerateNextTokenProbAPI:
    def __init__(self, api_client, model_name):
        self.api_client = api_client
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B-Instruct')

    def find_target_in_tokens(self, target_string, tokens):
        reconstructed_text = ''.join(tokens).replace(' ', '').replace('Ġ', '')
        target_string_no_spaces = target_string.replace(' ', '').replace('Ġ', '')

        # pos of target string in the text
        index_in_text = reconstructed_text.find(target_string_no_spaces)
        if index_in_text == -1:
            return None  # bad

        # map char indices back to token indices
        accumulated_length = 0
        start_token_index = None
        end_token_index = None
        for i, tok in enumerate(tokens):
            tok_no_space = tok.replace(' ', '').replace('Ġ', '')
            tok_length = len(tok_no_space)
            if accumulated_length <= index_in_text < accumulated_length + tok_length:
                start_token_index = i
            if accumulated_length < index_in_text + len(target_string_no_spaces) <= accumulated_length + tok_length:
                end_token_index = i + 1
                break
            accumulated_length += tok_length

        if start_token_index is not None and end_token_index is not None:
            return start_token_index, end_token_index
        else:
            return None

    def get_next_token_prob(self, input_string, target_string):
        messages = [
            {"role": "system", "content": "You are a medical expert."},
            {"role": "user", "content": input_string + " " + target_string}
        ]

        response = self.api_client.chat.completions.create(
            model=self.model_name,
            messages=messages,
            max_tokens=0,
            temperature=0,
            logprobs=True,
            echo=True
        )
        
        tokens = response.prompt[0].logprobs.tokens
        logprobs = response.prompt[0].logprobs.token_logprobs

        indices = self.find_target_in_tokens(target_string, tokens)
        if indices is None:
            return -1

        start_index, end_index = indices
        logprobs_slice = logprobs[start_index:end_index]
        prob_target_string = np.exp(sum(logprobs_slice))

        return prob_target_string

In [3]:
def compute_token_probs_api(y_star_string, prompt, prob_generator):    
    prob = prob_generator.get_next_token_prob(prompt, y_star_string)    
    return prob

In [4]:
def load_shadow_models_for_llama_3_instruct(model_dict, api_keys_subsample_ids):
    shadow_models_tuples = []    
    for api_key, id in api_keys_subsample_ids:
        model_name = f"{model_dict['llama_3_instruct']}-{id}"
        subsample_ids = pd.read_csv(f"llama_subsample_ids/subsample_ids_{id}.csv")
        subsample_ids = subsample_ids['ID'].tolist()
        shadow_models_tuples.append((api_key, model_name, subsample_ids))
    return shadow_models_tuples

In [5]:
## login & load together ai client
key = '779d92de61a5035835e5023ca79e2e5b6124c6300c3ceb0e07e374f948554116'
client = Together(api_key=key)
hf_login(token="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")

## load dataset
dataset = load_dataset("beanham/medsum_privacy")
merged_dataset = concatenate_datasets([dataset['train'], dataset['validation'], dataset['test']])

api_keys_subsample_ids = [
    ('meta-llama/Meta-Llama-3-8B-Instruct-Lite','1'),
    ('meta-llama/Meta-Llama-3-8B-Instruct-Lite','2')
]

shadow_models = load_shadow_models_for_llama_3_instruct(MODEL_DICT, api_keys_subsample_ids)
model_world = MODEL_DICT['llama_3_instruct']
target_model_class = 'llama_3_instruct'
target_model_api_key, _, target_subsample_ids = shadow_models[0]
shadow_models_tuples = shadow_models[1:]
shadow_model_api_keys = [api_key for api_key, _, _ in shadow_models_tuples]

train_dataset = merged_dataset.filter(lambda example: example['ID'] in target_subsample_ids)
test_dataset = merged_dataset.filter(lambda example: example['ID'] not in target_subsample_ids)
unseen_ents = [sample['ents'] for sample in test_dataset if len(sample['ents']) < 5]
unseen_ents = [item for sublist in unseen_ents for item in sublist]
train_dataset = [sample for sample in train_dataset if len(sample['ents']) >= 5]
test_dataset = [sample for sample in test_dataset if len(sample['ents']) >= 5]
train_test_ents = {
    'train': train_dataset,
    'test': test_dataset
}

fake_train_test_ents = {
    'train': train_dataset[0:1],
    'test': test_dataset[0:1]
}

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /Users/binhan/.cache/huggingface/token
Login successful


In [6]:
prob_generator = GenerateNextTokenProbAPI(client, target_model_api_key)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
results={}
for name, samples in tqdm(train_test_ents.items(), position=0):
    
    for ent_list in samples:
        
        ## create saving dictionary
        key=name+'_'+str(ent_list['ID'])
        results[key]={}
        results[key]['y_stars']={}
        results[key]['y_NON_stars']={}        
        ents = list(set(ent_list['ents']))
        k = len(ents)
        unseen_ents_for_sample = random.sample(unseen_ents, k)
        
        ## go through each y_star
        for i in tqdm(range(k), position=1):
            
            y_star = ents[i]
            y_NON_star = unseen_ents_for_sample[i]
            results[key]['y_stars'][y_star]={}
            results[key]['y_NON_stars'][y_NON_star]={}                        
            remaining_ents = ents[:i] + ents[i + 1:]            
            
            prompt = f"consider someone with the following conditions: {', '.join(remaining_ents)}. the individual then also has the condition "
            prob = compute_token_probs_api(y_star, prompt, prob_generator)
            prob_NON = compute_token_probs_api(y_NON_star, prompt, prob_generator)            
            if prob == -1 or prob_NON == -1:
                fail_counter += 1
                continue            
            results[key]['y_stars'][y_star]['target']=prob
            results[key]['y_NON_stars'][y_NON_star]['target']=prob_NON

  0%|                                                     | 0/2 [00:00<?, ?it/s]
  0%|                                                    | 0/11 [00:00<?, ?it/s][A
  9%|████                                        | 1/11 [00:16<02:49, 16.96s/it][A
 18%|████████                                    | 2/11 [00:36<02:45, 18.37s/it][A
 27%|████████████                                | 3/11 [00:53<02:21, 17.68s/it][A
 36%|████████████████                            | 4/11 [01:12<02:06, 18.13s/it][A
 45%|████████████████████                        | 5/11 [01:27<01:42, 17.09s/it][A
 55%|████████████████████████                    | 6/11 [01:44<01:26, 17.21s/it][A
 64%|████████████████████████████                | 7/11 [02:01<01:08, 17.09s/it][A
 73%|████████████████████████████████            | 8/11 [02:16<00:49, 16.43s/it][A
 82%|████████████████████████████████████        | 9/11 [02:32<00:32, 16.42s/it][A
 91%|███████████████████████████████████████    | 10/11 [02:49<00:16, 16.51s/it

  0%|                                                    | 0/11 [00:00<?, ?it/s][A
  9%|████                                        | 1/11 [00:13<02:17, 13.72s/it][A
 18%|████████                                    | 2/11 [00:29<02:15, 15.11s/it][A
 27%|████████████                                | 3/11 [00:44<02:00, 15.04s/it][A
 36%|████████████████                            | 4/11 [01:00<01:45, 15.12s/it][A
 45%|████████████████████                        | 5/11 [01:14<01:28, 14.76s/it][A
 55%|████████████████████████                    | 6/11 [01:27<01:11, 14.34s/it][A
 64%|████████████████████████████                | 7/11 [01:40<00:55, 13.94s/it][A
 73%|████████████████████████████████            | 8/11 [01:53<00:41, 13.71s/it][A
 82%|████████████████████████████████████        | 9/11 [02:06<00:26, 13.25s/it][A
 91%|███████████████████████████████████████    | 10/11 [02:18<00:12, 12.97s/it][A
100%|███████████████████████████████████████████| 11/11 [02:33<00:00, 13.97s

 83%|█████████████████████████████████████▌       | 5/6 [00:57<00:11, 11.34s/it][A
100%|█████████████████████████████████████████████| 6/6 [01:09<00:00, 11.52s/it][A

  0%|                                                    | 0/10 [00:00<?, ?it/s][A
 10%|████▍                                       | 1/10 [00:11<01:40, 11.19s/it][A
 20%|████████▊                                   | 2/10 [00:27<01:53, 14.23s/it][A
 30%|█████████████▏                              | 3/10 [00:44<01:48, 15.54s/it][A
 40%|█████████████████▌                          | 4/10 [00:57<01:27, 14.62s/it][A
 50%|██████████████████████                      | 5/10 [01:12<01:12, 14.52s/it][A
 60%|██████████████████████████▍                 | 6/10 [01:27<00:58, 14.66s/it][A
 70%|██████████████████████████████▊             | 7/10 [01:41<00:43, 14.60s/it][A
 80%|███████████████████████████████████▏        | 8/10 [01:56<00:29, 14.61s/it][A
 90%|███████████████████████████████████████▌    | 9/10 [02:09<00:14, 14.14

 67%|██████████████████████████████               | 4/6 [01:00<00:30, 15.28s/it][A
 83%|█████████████████████████████████████▌       | 5/6 [01:16<00:15, 15.35s/it][A
100%|█████████████████████████████████████████████| 6/6 [01:31<00:00, 15.31s/it][A

  0%|                                                     | 0/5 [00:00<?, ?it/s][A
 20%|█████████                                    | 1/5 [00:10<00:42, 10.55s/it][A
 40%|██████████████████                           | 2/5 [00:22<00:34, 11.45s/it][A
 60%|███████████████████████████                  | 3/5 [00:33<00:22, 11.04s/it][A
 80%|████████████████████████████████████         | 4/5 [00:45<00:11, 11.57s/it][A
100%|█████████████████████████████████████████████| 5/5 [00:57<00:00, 11.47s/it][A

  0%|                                                     | 0/6 [00:00<?, ?it/s][A
 17%|███████▌                                     | 1/6 [00:13<01:07, 13.51s/it][A
 33%|███████████████                              | 2/6 [00:28<00:56, 14.1

 43%|███████████████████▎                         | 3/7 [00:31<00:40, 10.12s/it][A
 57%|█████████████████████████▋                   | 4/7 [00:45<00:34, 11.62s/it][A
 71%|████████████████████████████████▏            | 5/7 [01:00<00:25, 12.82s/it][A
 86%|██████████████████████████████████████▌      | 6/7 [01:12<00:12, 12.45s/it][A
100%|█████████████████████████████████████████████| 7/7 [01:24<00:00, 12.04s/it][A

  0%|                                                     | 0/5 [00:00<?, ?it/s][A
 20%|█████████                                    | 1/5 [00:10<00:42, 10.60s/it][A
 40%|██████████████████                           | 2/5 [00:21<00:33, 11.02s/it][A
 60%|███████████████████████████                  | 3/5 [00:33<00:22, 11.44s/it][A
 80%|████████████████████████████████████         | 4/5 [00:42<00:10, 10.28s/it][A
100%|█████████████████████████████████████████████| 5/5 [00:52<00:00, 10.56s/it][A

  0%|                                                     | 0/7 [00:00<?, 

 29%|████████████▉                               | 5/17 [01:28<03:41, 18.44s/it][A
 35%|███████████████▌                            | 6/17 [01:43<03:09, 17.21s/it][A
 41%|██████████████████                          | 7/17 [02:02<02:57, 17.79s/it][A
 47%|████████████████████▋                       | 8/17 [02:20<02:40, 17.79s/it][A
 53%|███████████████████████▎                    | 9/17 [02:37<02:20, 17.57s/it][A
 59%|█████████████████████████▎                 | 10/17 [02:50<01:53, 16.26s/it][A
 65%|███████████████████████████▊               | 11/17 [03:09<01:42, 17.12s/it][A
 71%|██████████████████████████████▎            | 12/17 [03:31<01:33, 18.65s/it][A
 76%|████████████████████████████████▉          | 13/17 [03:44<01:06, 16.67s/it][A
 82%|███████████████████████████████████▍       | 14/17 [04:01<00:51, 17.02s/it][A
 88%|█████████████████████████████████████▉     | 15/17 [04:22<00:36, 18.07s/it][A
 94%|████████████████████████████████████████▍  | 16/17 [04:45<00:19, 19.47s

100%|█████████████████████████████████████████████| 6/6 [01:08<00:00, 11.37s/it][A

  0%|                                                    | 0/13 [00:00<?, ?it/s][A
  8%|███▍                                        | 1/13 [00:10<02:02, 10.24s/it][A
 15%|██████▊                                     | 2/13 [00:27<02:37, 14.36s/it][A
 23%|██████████▏                                 | 3/13 [00:35<01:56, 11.61s/it][A
 31%|█████████████▌                              | 4/13 [00:47<01:44, 11.65s/it][A
 38%|████████████████▉                           | 5/13 [01:06<01:53, 14.13s/it][A
 46%|████████████████████▎                       | 6/13 [01:18<01:35, 13.63s/it][A
 54%|███████████████████████▋                    | 7/13 [01:30<01:17, 12.98s/it][A
 62%|███████████████████████████                 | 8/13 [01:42<01:03, 12.73s/it][A
 69%|██████████████████████████████▍             | 9/13 [01:57<00:54, 13.50s/it][A
 77%|█████████████████████████████████          | 10/13 [02:12<00:41, 13.89

  0%|                                                     | 0/6 [00:00<?, ?it/s][A
 17%|███████▌                                     | 1/6 [00:11<00:56, 11.38s/it][A
 33%|███████████████                              | 2/6 [00:22<00:44, 11.19s/it][A
 50%|██████████████████████▌                      | 3/6 [00:33<00:33, 11.32s/it][A
 67%|██████████████████████████████               | 4/6 [00:47<00:24, 12.02s/it][A
 83%|█████████████████████████████████████▌       | 5/6 [00:58<00:11, 11.86s/it][A
100%|█████████████████████████████████████████████| 6/6 [01:08<00:00, 11.48s/it][A

  0%|                                                     | 0/5 [00:00<?, ?it/s][A
 20%|█████████                                    | 1/5 [00:11<00:47, 11.83s/it][A
 40%|██████████████████                           | 2/5 [00:25<00:38, 12.90s/it][A
 60%|███████████████████████████                  | 3/5 [00:37<00:25, 12.70s/it][A
 80%|████████████████████████████████████         | 4/5 [00:45<00:10, 10.86

 56%|█████████████████████████                    | 5/9 [00:54<00:45, 11.34s/it][A
 67%|██████████████████████████████               | 6/9 [01:03<00:32, 10.84s/it][A
 78%|███████████████████████████████████          | 7/9 [01:14<00:21, 10.87s/it][A
 89%|████████████████████████████████████████     | 8/9 [01:26<00:11, 11.11s/it][A
100%|█████████████████████████████████████████████| 9/9 [01:37<00:00, 10.82s/it][A

  0%|                                                     | 0/5 [00:00<?, ?it/s][A
 20%|█████████                                    | 1/5 [00:10<00:43, 10.84s/it][A
 40%|██████████████████                           | 2/5 [00:23<00:35, 11.70s/it][A
 60%|███████████████████████████                  | 3/5 [00:35<00:23, 11.96s/it][A
 80%|████████████████████████████████████         | 4/5 [00:47<00:11, 11.89s/it][A
100%|█████████████████████████████████████████████| 5/5 [00:57<00:00, 11.55s/it][A

  0%|                                                     | 0/7 [00:00<?, 

In [17]:
with open('target_token_probs.json', 'w') as f:
    json.dump(results, f)