In [30]:
import numpy as np
import pandas as pd
from huggingface_hub import login as hf_login
from peft import prepare_model_for_kbit_training
from datasets import concatenate_datasets, DatasetDict, load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoModel
import os
import json

from utils import *

In [31]:
from together import Together
key = 'e94217f61953b12489a9877936bd5383086106ec9951d3f11bb6a9475d88e95e'
client = Together(api_key=key)

In [35]:
hf_login(token="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")

In [36]:
dataset = load_dataset("beanham/medsum_privacy")
merged_dataset = concatenate_datasets([dataset['train'], dataset['validation'], dataset['test']])

In [37]:
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B-Instruct')
target_tokens = tokenizer.tokenize('osteoporosis')
target_tokens

['oste', 'opor', 'osis']

In [None]:
import numpy as np
import pandas as pd
import random
from tqdm import tqdm
from together import Together
from transformers import AutoTokenizer

class GenerateNextTokenProbAPI:
    def __init__(self, api_client, model_name):
        self.api_client = api_client
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B-Instruct')

    def find_target_in_tokens(self, target_string, tokens):
        reconstructed_text = ''.join(tokens).replace(' ', '').replace('Ġ', '')
        target_string_no_spaces = target_string.replace(' ', '').replace('Ġ', '')

        # pos of target string in the text
        index_in_text = reconstructed_text.find(target_string_no_spaces)
        if index_in_text == -1:
            return None  # bad

        # map char indices back to token indices
        accumulated_length = 0
        start_token_index = None
        end_token_index = None
        for i, tok in enumerate(tokens):
            tok_no_space = tok.replace(' ', '').replace('Ġ', '')
            tok_length = len(tok_no_space)
            if accumulated_length <= index_in_text < accumulated_length + tok_length:
                start_token_index = i
            if accumulated_length < index_in_text + len(target_string_no_spaces) <= accumulated_length + tok_length:
                end_token_index = i + 1
                break
            accumulated_length += tok_length
        
        print('Reconstructed text: ', reconstructed_text)
        print()
        print('Target string: ', target_string_no_spaces)
        print()
        print('Start token index: ', start_token_index)
        print()
        print('End token index: ', end_token_index)
        print()

        if start_token_index is not None and end_token_index is not None:
            return start_token_index, end_token_index
        else:
            return None

    def get_next_token_prob(self, input_string, target_string):
        messages = [
            {"role": "system", "content": "You are a medical expert."},
            {"role": "user", "content": input_string + " " + target_string}
        ]

        response = self.api_client.chat.completions.create(
            model=self.model_name,
            messages=messages,
            max_tokens=0,
            temperature=0,
            logprobs=True,
            echo=True
        )
        
        tokens = response.prompt[0].logprobs.tokens
        logprobs = response.prompt[0].logprobs.token_logprobs

        indices = self.find_target_in_tokens(target_string, tokens)
        if indices is None:
            print('Tokens: ', tokens)
            print()
            print('Target string: ', target_string)
            print()
            return -1

        start_index, end_index = indices
        logprobs_slice = logprobs[start_index:end_index]
        prob_target_string = np.exp(sum(logprobs_slice))

        return prob_target_string




In [None]:
{
    1 : 
    {
        'y_stars': {
            'y_star': {'prob_target':0., 'prob_world':0., 'probs_shadow':[0.,]}
        },
        'y_NON_stars': {
            'y_NON_star': {'prob_target':0., 'prob_world':0., 'probs_shadow':[0.,]}
        }
        ...
    }
}

In [None]:
def compute_token_probs_api(y_star_string, prompt, prob_generator_target, prob_generator_world, prob_generators_shadow):
    prob_target = prob_generator_target.get_next_token_prob(prompt, y_star_string)

    prob_world = prob_generator_world.get_next_token_prob(prompt, y_star_string)

    shadow_probs = [
            prob_generator_shadow.get_next_token_prob(prompt, y_star_string)
            for prob_generator_shadow in prob_generators_shadow
        ]
    
    avg_prob_shadow = sum(shadow_probs) / len(shadow_probs)
    
    if prob_target == -1 or prob_world == -1 or avg_prob_shadow == -1:
        return -1, -1, -1

    return prob_target, prob_world, avg_prob_shadow

def run_attack_api(train_test_ents, api_client, model_name_target, model_name_world, model_names_shadow, unseen_ents_test):
    prob_generator_target = GenerateNextTokenProbAPI(api_client, model_name_target)
    prob_generator_world = GenerateNextTokenProbAPI(api_client, model_name_world)
    prob_generators_shadow = [GenerateNextTokenProbAPI(api_client, model_name) for model_name in model_names_shadow]

    results_list = []

    fail_counter = 0
    for name, samples in tqdm(train_test_ents.items(), position=0):
        for ent_list in samples:
            # print(ent_list)
            ents = ent_list['ents']
            ents = set(ents)
            k = len(ents)
            unseen_ents_for_sample = random.sample(unseen_ents_test, k)

            for i in tqdm(range(k), position=1):
                y_star = ents[i]
                y_NON_star = unseen_ents_for_sample[i]
                print(y_star)

                remaining_ents = ents[:i] + ents[i + 1:]

                prompt = f"consider someone with the following conditions: {', '.join(remaining_ents)}. the individual then also has the condition "

                # y_star_decoded_tokens = [y_star] 
                # y_NON_star_decoded_tokens = [y_NON_star]

                # calc probabilities for TRUE y_star
                prob_target, prob_world, prob_p_out = compute_token_probs_api(
                    y_star, prompt, prob_generator_target, prob_generator_world, prob_generators_shadow
                )

                # calc probabilities for FALSE y_star
                prob_target_NON, prob_world_NON, prob_p_out_NON = compute_token_probs_api(
                    y_NON_star, prompt, prob_generator_target, prob_generator_world, prob_generators_shadow
                )

                if prob_target == -1 or prob_world == -1 or prob_p_out == -1 or prob_target_NON == -1 or prob_world_NON == -1 or prob_p_out_NON == -1:
                    print(f'Failed to get probabilities for some tokens: {y_star}, {y_NON_star}')
                    fail_counter += 1
                    continue

                # liklihood ratio and adjusted score for y_star true
                likelihood_ratio_y_star_true = prob_target / prob_p_out
                adjusted_score_true = prob_world - prob_p_out * (1 - likelihood_ratio_y_star_true)

                results_list.append({
                    'entity': y_star,
                    'trained_prob': prob_target,
                    'original_prob': prob_world,
                    'shadow_prob': prob_p_out,
                    'likelihood_ratio': likelihood_ratio_y_star_true,
                    'adjusted_score': adjusted_score_true,
                    'y_star_true': True,
                    'train_test': name
                })

                # likelihood ratio and adjusted score for y_star false
                likelihood_ratio_y_star_false = prob_target_NON / prob_p_out_NON
                adjusted_score_false = prob_world_NON - prob_p_out_NON * (1 - likelihood_ratio_y_star_false)

                results_list.append({
                    'entity': y_NON_star,
                    'trained_prob': prob_target_NON,
                    'original_prob': prob_world_NON,
                    'shadow_prob': prob_p_out_NON,
                    'likelihood_ratio': likelihood_ratio_y_star_false,
                    'adjusted_score': adjusted_score_false,
                    'y_star_true': False,
                    'train_test': name
                })

    results_df = pd.DataFrame(results_list)
    return results_df

In [149]:
MODEL_DICT = {
    'llama_3_1': 'meta-llama/Meta-Llama-3.1-8B-Reference',
    'llama_3_1_instruct': 'meta-llama/Meta-Llama-3.1-8B-Instruct-Reference',
    'llama_3': 'meta-llama/Meta-Llama-3-8B',
    'llama_3_instruct': 'meta-llama/Meta-Llama-3-8B-Instruct-Lite',
    'llama_2': 'togethercomputer/llama-2-7b',
    'llama_2_chat': 'togethercomputer/llama-2-7b-chat',
    'codellama': 'codellama/CodeLlama-7b-hf',
    'codellama_python': 'codellama/CodeLlama-7b-Python-hf',
    'codellama_instruct': 'codellama/CodeLlama-7b-Instruct-hf',
    'mixtral_8x7b': 'mistralai/Mixtral-8x7B-v0.1',
    'mixtral_8x7b_instruct': 'mistralai/Mixtral-8x7B-Instruct-v0.1',
    'nous_hermes_2_mixtral_8x7b_dpo': 'NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO',
    'nous_hermes_2_mixtral_8x7b_sft': 'NousResearch/Nous-Hermes-2-Mixtral-8x7B-SFT',
    'mistral_7b_instruct': 'mistralai/Mistral-7B-Instruct-v0.2',
    'mistral_7b': 'mistralai/Mistral-7B-v0.1',
    'qwen2': 'Qwen/Qwen2-1.5B',
    'qwen2_instruct': 'Qwen/Qwen2-1.5B-Instruct',
    'openhermes_2_5_mistral_7b': 'teknium/OpenHermes-2p5-Mistral-7B',
    'zephyr_7b': 'HuggingFaceH4/zephyr-7b-beta',
    'solar_instruct_v1': 'upstage/SOLAR-10.7B-Instruct-v1.0',
}

In [150]:
def load_shadow_models_for_llama_3_instruct(model_dict, api_keys_subsample_ids):
    shadow_models_tuples = []
    
    for api_key, id in api_keys_subsample_ids:
        model_name = f"{model_dict['llama_3_instruct']}-{id}"
        subsample_ids = pd.read_csv(f"llama_subsample_ids/subsample_ids_{id}.csv")
        subsample_ids = subsample_ids['ID'].tolist()
        shadow_models_tuples.append((api_key, model_name, subsample_ids))

    return shadow_models_tuples

# api_keys_subsample_ids = [
#     ('lr2872/Meta-Llama-3-8B-Instruct-86e49ae1-65a386b0','1'),
#     ('lr2872/Meta-Llama-3.1-8B-Instruct-Reference-0ae43e49-4a7add88','2')
# ]

api_keys_subsample_ids = [
    ('meta-llama/Meta-Llama-3-8B-Instruct-Lite','1'),
    ('meta-llama/Meta-Llama-3-8B-Instruct-Lite','2')
]

shadow_models = load_shadow_models_for_llama_3_instruct(MODEL_DICT, api_keys_subsample_ids)


In [151]:
target_model = shadow_models[0]

shadow_models_tuples = shadow_models[1:]

In [None]:
model_world = MODEL_DICT['llama_3_instruct']

target_model_class = 'llama_3_instruct'
target_model_api_key, _, target_subsample_ids = shadow_models[0]

shadow_models_tuples = shadow_models[1:]
shadow_model_api_keys = [api_key for api_key, _, _ in shadow_models_tuples]

train_dataset = merged_dataset.filter(lambda example: example['ID'] in target_subsample_ids)
test_dataset = merged_dataset.filter(lambda example: example['ID'] not in target_subsample_ids)

unseen_ents = [sample['ents'] for sample in test_dataset if len(sample['ents']) < 5]
unseen_ents = [item for sublist in unseen_ents for item in sublist] 

train_dataset = [sample for sample in train_dataset if len(sample['ents']) >= 5]
test_dataset = [sample for sample in test_dataset if len(sample['ents']) >= 5]

train_test_ents = {
    'train': train_dataset,
    'test': test_dataset
}

fake_train_test_ents = {
    'train': train_dataset[0:1],
    'test': test_dataset[0:1]
}

run_attack_api(train_test_ents, 
               client, 
               target_model_api_key, 
               model_world, 
               shadow_model_api_keys, 
               unseen_ents_test=unseen_ents)


  0%|          | 0/2 [00:00<?, ?it/s]

osteoporosis
Reconstructed text:  <|begin_of_text|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

Youareamedicalexpert.<|eot_id|><|start_header_id|>user<|end_header_id|>

considersomeonewiththefollowingconditions:allergicrhinitis,hypothyroidism,osteoarthritis,kidneystones,fever,nausea,,chestpain,hypertension,cough,chills.theindividualthenalsohastheconditionosteoporosis<|eot_id|><|start_header_id|>assistant<|end_header_id|>



Target string:  osteoporosis

Start token index:  64

End token index:  67

Reconstructed text:  <|begin_of_text|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

Youareamedicalexpert.<|eot_id|><|start_header_id|>user<|end_header_id|>

considersomeonewiththefollowingconditions:allergicrhinitis,hypothyroidism,osteoarthritis,kidneystones,fever,nausea,,chestpain,hypertension,cough,chills.theindividualthenalsohastheconditionosteoporosis<|eot_id|><|start_header_id|>assistant<|end_header_id|>



Target string:  osteoporosis

Start token index: 



Reconstructed text:  <|begin_of_text|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

Youareamedicalexpert.<|eot_id|><|start_header_id|>user<|end_header_id|>

considersomeonewiththefollowingconditions:allergicrhinitis,hypothyroidism,osteoarthritis,kidneystones,fever,nausea,,chestpain,hypertension,cough,chills.theindividualthenalsohastheconditionlisinopril<|eot_id|><|start_header_id|>assistant<|end_header_id|>



Target string:  lisinopril

Start token index:  64

End token index:  68

allergic rhinitis
Reconstructed text:  <|begin_of_text|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

Youareamedicalexpert.<|eot_id|><|start_header_id|>user<|end_header_id|>

considersomeonewiththefollowingconditions:osteoporosis,hypothyroidism,osteoarthritis,kidneystones,fever,nausea,,chestpain,hypertension,cough,chills.theindividualthenalsohastheconditionallergicrhinitis<|eot_id|><|start_header_id|>assistant<|end_header_id|>



Target string:  allergicrhinitis

Start token in

  9%|▉         | 1/11 [01:06<11:09, 66.93s/it]
  0%|          | 0/2 [01:06<?, ?it/s]


KeyboardInterrupt: 