In [1]:
import bitsandbytes as bnb
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import random
import sys

using_fine_tuned = False


def create_bnb_config():
    bnb_config = BitsAndBytesConfig(
        load_in_8bit=True,
        llm_int8_threshold=4.0,

    )
    return bnb_config

def load_model(model_name,device):
    n_gpus = torch.cuda.device_count()
    max_memory = "24000MB"

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map=device, # dispatch efficiently the model on the available ressources
        quantization_config=bnb_config,
        torch_dtype=torch.float16,
        trust_remote_code = True,
        attn_implementation="flash_attention_2",
        max_memory = {i: max_memory for i in range(n_gpus)},
    )

    return model
bnb_config = create_bnb_config()



In [2]:
from transformers import AutoTokenizer,AutoModelForCausalLM
import torch
cfg_device = "cuda:0"
pred_device = "cuda:0"
cfg_model_id = "gptj"

if  cfg_model_id =="phi3":
    cfg_model = AutoModelForCausalLM.from_pretrained(
    cfg_model_id,device_map = cfg_device, torch_dtype=torch.bfloat16,trust_remote_code = True, attn_implementation="flash_attention_2")
    cfg_tokenizer =  AutoTokenizer.from_pretrained(cfg_model_id,trust_remote_code = True)

elif cfg_model_id =="phi3-medium":
    
    cfg_model = load_model(cfg_model_id,cfg_device)
    cfg_tokenizer =  AutoTokenizer.from_pretrained(cfg_model_id,trust_remote_code = True)


else:
    if cfg_model_id == "pythia":
        cfg_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1.4b")
        cfg_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-1.4b",
                                                         device_map = cfg_device,
                                                         torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
    elif cfg_model_id == "pythia2":
        cfg_tokenizer = AutoTokenizer.from_pretrained("pythia-2.8b")
        cfg_model = AutoModelForCausalLM.from_pretrained("pythia-2.8b",
                                                         device_map = cfg_device,
                                                         torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
    
    elif cfg_model_id == "gptj":
        cfg_model = AutoModelForCausalLM.from_pretrained(
        "EleutherAI/gpt-j-6b",device_map = cfg_device, torch_dtype=torch.bfloat16,trust_remote_code = True)
        cfg_tokenizer = AutoTokenizer.from_pretrained(
        "EleutherAI/gpt-j-6b",)

if cfg_model_id == "phi3-medium":
    pass
else:
    num_added_toks = cfg_tokenizer.add_tokens(["<mask>", "<counterfactual>"])
    cfg_model.resize_token_embeddings(len(cfg_tokenizer))
# pred_model = AutoModelForCausalLM.from_pretrained("gemma-2b-it",device_map=pred_device,
#                         torch_dtype=torch.float16,trust_remote_code=True, attn_implementation="flash_attention_2")


  return self.fget.__get__(instance, owner)()


In [3]:
import pickle
import pandas as pd
from collections import defaultdict
from tqdm import tqdm
attrs = {}
datasets = ["sst2","news","imdb"]
models = ["gemma-2b"]
zero = False

for ds in datasets:
    if zero == True:
        att_name = f"attrs/{ds}_gemma-2b-it_attr_quantized_False_zero.pcl"
        with open(att_name,"br") as f:
            attr = pickle.load(f)
            attr_df = pd.DataFrame(attr)
            attrs[f"{ds}_gemma-2b-it-zero"] = attr_df
    else:
        for model in models:

            att_name = f"attrs/{ds}_{model}-it_attr_quantized_False.pcl"
            with open(att_name,"br") as f:
                attr = pickle.load(f)
                attr_df = pd.DataFrame(attr)
                attrs[f"{ds}_{model}-it"] = attr_df

            att_name = f"attrs/{ds}_predictor_{model}_{ds}_merged_attr_quantized_False.pcl"
            with open(att_name,"br") as f:
                attr = pickle.load(f)
                attr_df = pd.DataFrame(attr)
                attrs[f"{ds}_{model}-ft"] = attr_df


In [4]:
attrs.keys()

dict_keys(['sst2_gemma-2b-it', 'sst2_gemma-2b-ft', 'news_gemma-2b-it', 'news_gemma-2b-ft', 'imdb_gemma-2b-it', 'imdb_gemma-2b-ft'])

In [5]:
from datasets import load_from_disk
if zero:
    pass
else:
    news = load_from_disk("news")
    imdb = load_from_disk("imdb")
    sst2 = load_from_disk("sst2")
    news_shot = news["train"][10]["text"]


In [6]:
import torch.nn.functional as F
def log_probs_from_logits(logits, labels):
    logp = F.log_softmax(logits, dim=-1)
    logp_label = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
    return logp_label
def sequence_logprob(model, labels, input_len=0):
    with torch.no_grad():
        output = model(labels)
        log_probs = log_probs_from_logits(
        output.logits[:, :-1, :], labels[:, 1:])
        seq_log_prob = torch.sum(log_probs[:, input_len:])
        return seq_log_prob.float().cpu().numpy()

In [7]:
def get_completion(prompt,ds):
    
    inputs = cfg_tokenizer(prompt, return_tensors="pt").input_ids.to(cfg_device)
    attention_mask = cfg_tokenizer(prompt, return_tensors="pt").attention_mask.to(cfg_device)

    sen_len = len(inputs[0])


    # output_beam = cfg_model.generate(inputs,attention_mask = attention_mask, max_length=(3*inputs.shape[1])-1, num_beams=5, top_p=0.95,
    # do_sample=True, temperature=.5)
    output_beam = cfg_model.generate(inputs,attention_mask = attention_mask, max_new_tokens=sen_len-1, num_beams=3, 
        #top_p=0.95, do_sample=True, temperature=.5
        )
    if ds == "sst2"or ds == "imdb":
        return cfg_tokenizer.decode(output_beam[0][sen_len+1:-1])
    else:
        return cfg_tokenizer.decode(output_beam[0][sen_len+1:-3])
    
def get_completion_instruct(sentence,ds,foil_sentiment):
    sen_new = sentence.replace("<mask>","<unk>")
    sen_len = len(cfg_tokenizer(sen_new).input_ids)

    if ds == "sst2" or ds == "imdb":
        prompt = f""" In the statement in backticks, replace any <unk> with a word in a way \
that the resulting statement would have a {foil_sentiment} sentiment.
output just the completed sentence.
```{sen_new}```
Answer:
"""
        inputs = cfg_tokenizer(prompt, return_tensors="pt").input_ids.to(cfg_device)
        attention_mask = cfg_tokenizer(prompt, return_tensors="pt").attention_mask.to(cfg_device)

        output_beam = cfg_model.generate(inputs,attention_mask = attention_mask, max_new_tokens=sen_len+1, num_beams=3,
        #top_p=0.95, do_sample=True, temperature=.5
        )
        return  prompt, cfg_tokenizer.decode(output_beam[0][inputs.shape[1]:])
    else:
        prompt = f""" In the article in backticks, replace any <unk> with a word in a way \
that the resulting article would have a {foil_sentiment} category.
output just the completed sentence.
```{sen_new}```
Answer:
"""
        inputs = cfg_tokenizer(prompt, return_tensors="pt").input_ids.to(cfg_device)

        attention_mask = cfg_tokenizer(prompt, return_tensors="pt").attention_mask.to(cfg_device)
        output_beam = cfg_model.generate(inputs,attention_mask = attention_mask, max_new_tokens=sen_len+1, num_beams=3, 
        #top_p=0.95, do_sample=True, temperature=.5
        )
        return prompt, cfg_tokenizer.decode(output_beam[0][inputs.shape[1]:])

In [8]:
import warnings
warnings.filterwarnings('ignore')
from transformers import logging
logging.set_verbosity(logging.ERROR)

In [9]:
import numpy as np
using_fine_tuned = [False]
percentages = [0.1,0.2,0.3,0.4,0.5]
nums = [0,200]
datasets = ["sst2"]
pred_models = ["gemma-2b"]
for ds in datasets:
    if cfg_model_id == "phi3-medium":
        pass
    else: 
        cfg_ds_specific = f"counter_{cfg_model_id}_{ds}_finished.pt"
        cfg_model.load_state_dict(torch.load(cfg_ds_specific,map_location=cfg_device))


    torch.cuda.empty_cache()
    for pr_model in pred_models:
        
        for ft in using_fine_tuned:
            flipping_dict = defaultdict(list)


            if ft:
                pred_model_id = f"predictor_{pr_model}_{ds}_merged"
                eval_data = attrs[f'{ds}_{pr_model}-ft']
            else:
                
                if zero:
                    pred_model_id = f"{pr_model}-it"
                    eval_data = attrs[f'{ds}_{pr_model}-it-zero']
                else:
                    pred_model_id = f"{pr_model}-it"
                    eval_data = attrs[f'{ds}_{pr_model}-it']


            pred_model = AutoModelForCausalLM.from_pretrained(pred_model_id, device_map=pred_device,
                        torch_dtype=torch.float16,trust_remote_code=True, attn_implementation="flash_attention_2")
            pred_tokenizer = AutoTokenizer.from_pretrained(pred_model_id)
            if zero == False:
                imdb_shots = imdb["train"].filter(lambda x:len(pred_tokenizer(x["text"])["input_ids"]) < 32)
                sst_shots = sst2["train"].filter(lambda x:len(pred_tokenizer(x["sentence"])["input_ids"]) < 16)
            
            for i in tqdm(range(nums[0],nums[1])):

                flipping_dict["sentence"].append(eval_data["comment"].iloc[i])
                explanation_types = list(eval_data.columns[5:])
                for exp in explanation_types:
                    flipped = False
                    attr = eval_data[exp].iloc[i]


                    for percent in percentages:
                        num_to_mask = int(np.ceil(percent * len(pred_tokenizer(eval_data.iloc[i].comment)["input_ids"])))

                        idx_to_mask = np.argsort(attr)[-int(num_to_mask):]

                        tokens = pred_tokenizer.convert_ids_to_tokens(pred_tokenizer(eval_data.prompt.iloc[i])["input_ids"])

                        if ft:
                            for idss in idx_to_mask:
                                tokens[idss] = "<mask>"

                            if ds =="sst2":
                                sentence = pred_tokenizer.convert_tokens_to_string(tokens[1:-18])
                            if ds =="news":
                                sentence = pred_tokenizer.convert_tokens_to_string(tokens[1:-26])
                            if ds =="imdb":
                                sentence = pred_tokenizer.convert_tokens_to_string(tokens[1:-20])
                        else:
                            if zero:
                                
                                for idss in idx_to_mask:
                                    tokens[idss] = "<mask>"
                                

                                if ds =="sst2":
                                    sentence = pred_tokenizer.convert_tokens_to_string(tokens[24:-6])
                                if ds =="news":
                                    sentence = pred_tokenizer.convert_tokens_to_string(tokens[35:-6])
                                if ds =="imdb":
                                    sentence = pred_tokenizer.convert_tokens_to_string(tokens[24:-6])
                                
                            else:
                                for idss in idx_to_mask:
                                    tokens[idss] = "<mask>"

                                if ds =="sst2":
                                    sentence = pred_tokenizer.convert_tokens_to_string(tokens[54:-7])
                                if ds =="news":
                                    sentence = pred_tokenizer.convert_tokens_to_string(tokens[118:-6])
                                if ds =="imdb":
                                    sentence = pred_tokenizer.convert_tokens_to_string(tokens[72:-7])


                        flipping_dict[f"{exp}_masked_sentence_{percent}_predictor_{pr_model}_ft_{ft}"].append(sentence)
                        predicted_label = eval_data.predicted_id.iloc[i]

                        foil_id = eval_data.foil_id.iloc[i].tolist()

                        foil_sentiment = pred_tokenizer.convert_ids_to_tokens(foil_id)

                        
                        if cfg_model_id == "phi3-medium":

                            prompt, response = get_completion_instruct(sentence,ds,foil_sentiment)


                        else:
                            
                            prompt = f"{sentence}{foil_sentiment}<counterfactual> "
                            response = get_completion(prompt,ds)




                        flipping_dict[f"{exp}_cfg_{cfg_model_id}_{percent}_response_predictor_{pr_model}_ft_{ft}"].append(response)
                        flipping_dict[f"{exp}_cfg_{cfg_model_id}_{percent}_prompt_predictor_{pr_model}_ft_{ft}"].append(prompt)

                        if ft:
                
                            if ds == "sst2":
                                eval_prompt = response + " In a sentiment classification task between positive and negative choices, the sentiment of this sentence is "
                            elif ds == "imdb":
                                eval_prompt = response + " Based on this opinion, decide what the sentiment is, choose between positive and negative. Answer is "
                            elif ds == "news":
                                eval_prompt = response + " You are classifying a news article, Choose one of the four categories, World, Business, Sport, and Tech. Answer is "

                        else:
                            if zero:
                                if ds=="imdb":
                                    eval_prompt = f"""In the sentence in triple back ticks what is the sentiment? Answer in one word, positive or negative.
```{response}```
The answer is """
                                elif ds == "sst2":
                                    eval_prompt = f"""In the sentence in triple back ticks what is the sentiment? Answer in one word, positive or negative.
```{response}```
The answer is """
                                elif ds == "news":
                                    eval_prompt = f"""In the sentence in triple back ticks, what is the news category?
                  choices are world, sports, business, and tech. answer in one word.
```{response}```
The answer is """
                            else:
                                if ds == "imdb":
                                    eval_prompt = f"""In the sentence in triple back ticks what is the sentiment? Answer in one word, positive or negative.
                                    example 1: ```{imdb_shots[15]["text"]}```
                                    The answer is positive
                                    example 2: ```{response}```
                                    The answer is """
                                elif ds == "sst2":
                                    eval_prompt = f"""In the sentence in triple back ticks what is the sentiment? Answer in one word, positive or negative.
                                    example 1: ```{sst_shots[18]["sentence"]}```
                                    The answer is negative
                                    example 2: ```{response}```
                                    The answer is """
                                elif ds == "news":
                                    eval_prompt = f"""In the sentence in triple back ticks, what is the news category?\n \
                                        choices are world, sports, business, and tech. \nanswer in one word.\n world, or sports, or business, or tech.
                                            example 1: ```{news_shot}```
                                            The answer is business
                                            example 2: ```{response}```
                                            The answer is"""
                        
                        encoding = pred_tokenizer(eval_prompt, return_tensors="pt")
                        flipping_dict[f"{exp}_cfg_{cfg_model_id}_{percent}_evalprompt_predictor_{pr_model}_ft_{ft}"].append(eval_prompt)
                        ids = encoding.input_ids.to(pred_device)
                        mask = encoding.attention_mask.to(pred_device)
                        logits  = pred_model(input_ids = ids, attention_mask = mask).logits[0,-1]
                        if logits[foil_id] > logits[predicted_label]:
                            if flipped == False:
                                flipped = True
                                mask_percent = percent

                        if flipped == False and percent == 0.5:
                            mask_percent = percent
                            
                    flipping_dict[f"{exp}_mask_percent_predictor_{pr_model}_ft_{ft}"].append(mask_percent)
                    flipping_dict[f"{exp}_succuss_in_flipping_predictor_{pr_model}_ft_{ft}"].append(flipped)

            if zero:
                with open(f"camel/cfg_{cfg_model_id}_{ds}_predictor_{pr_model}_ft_{ft}_{nums[0]}_{nums[1]}_zero.pcl","+bw") as f:
                    pickle.dump(flipping_dict,f)
            else:
                with open(f"camel/cfg_{cfg_model_id}_{ds}_predictor_{pr_model}_ft_{ft}_{nums[0]}_{nums[1]}.pcl","+bw") as f:
                    pickle.dump(flipping_dict,f)
 


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 200/200 [1:05:00<00:00, 19.50s/it]
