In [1]:
import bitsandbytes as bnb
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import random
import sys
model_id = "gemma-2b-it"
device = "cuda:0"
using_fine_tuned = False
using_instruction_tuned = False
using_instruction_tuned_zero = True
using_pretrained = False
quantized = False
from captum.attr import (
    FeatureAblation, 
    ShapleyValues,
    LayerIntegratedGradients, 
    LLMAttribution, 
    LLMGradientAttribution, 
    TextTokenInput, 
    TextTemplateInput,
    ProductBaselines,
    ShapleyValueSampling,
    KernelShap
)
def create_bnb_config():
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
    )

    return bnb_config

def load_model(model_name):
    n_gpus = torch.cuda.device_count()
    max_memory = "24000MB"

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map=device, # dispatch efficiently the model on the available ressources
        quantization_config=bnb_config,
        torch_dtype=torch.float16,
        attn_implementation="flash_attention_2",
        max_memory = {i: max_memory for i in range(n_gpus)},
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)

    # Needed for LLaMA tokenizer
    tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer

if quantized:
    bnb_config = create_bnb_config()
    model, tokenizer = load_model(model_id)
else:
    max_memory = "24000MB"
    n_gpus = torch.cuda.device_count()
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    model = AutoModelForCausalLM.from_pretrained(model_id,device_map=device,
                    torch_dtype=torch.float16,trust_remote_code=True, attn_implementation="flash_attention_2",
                    max_memory = {i: max_memory for i in range(n_gpus)})


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
from lm_saliency import saliency, input_x_gradient, erasure_scores, l1_grad_norm, l2_grad_norm
from datasets import load_from_disk
world_id = tokenizer.convert_tokens_to_ids("world")
business_id = tokenizer.convert_tokens_to_ids("business")
sport_id = tokenizer.convert_tokens_to_ids("sport")
tech_id = tokenizer.convert_tokens_to_ids("tech")
positive_id = tokenizer.convert_tokens_to_ids("positive")
negative_id = tokenizer.convert_tokens_to_ids("negative")
wsbt = torch.LongTensor([world_id, sport_id, business_id, tech_id ]).to(device)
np_ = torch.LongTensor([negative_id, positive_id]).to(device)
news = load_from_disk("news")
imdb = load_from_disk("imdb")
sst2 = load_from_disk("sst2")
news_shot = news["train"][10]["text"]
imdb_shots = imdb["train"].filter(lambda x:len(tokenizer(x["text"])["input_ids"]) < 32)
sst_shots = sst2["train"].filter(lambda x:len(tokenizer(x["sentence"])["input_ids"]) < 16)
news_shots = news["train"].filter(lambda x:len(tokenizer(x["text"])["input_ids"]) < 32)
news = news["test"].select(range(4000)).filter(lambda x:len(tokenizer(x["text"])["input_ids"]) <= 48)
news = news.shuffle(2024)
num_samples = 200
news = news.select(range(num_samples))
imdb = imdb["test"].filter(lambda x:len(tokenizer(x["text"])["input_ids"]) <= 96)
imdb = imdb.shuffle(2024)
imdb = imdb.select(range(num_samples))
sst2 = sst2["validation"].filter(lambda x:len(tokenizer(x["sentence"])["input_ids"]) <= 32)
sst2 = sst2.shuffle(2024)
sst2 = sst2.select(range(num_samples))
val_data = {}
val_data["sst2"] = sst2
val_data["imdb"] = imdb
val_data["news"] = news
wsbt

tensor([ 9097, 32970, 15679,  7286], device='cuda:0')

In [3]:
def captum_attrs(model,tokenizer):
    lig = LayerIntegratedGradients(model,model.model.embed_tokens)
    lig_steps = 5
    sv = ShapleyValues(model)
    ks = KernelShap(model)

    lig_attr = LLMGradientAttribution(lig, tokenizer)
    sv_attr = LLMAttribution(sv, tokenizer)
    kernel_attr = LLMAttribution(ks, tokenizer)
    return lig_attr, sv_attr, kernel_attr

In [4]:
import warnings
warnings.filterwarnings('ignore')

In [5]:
from time import time
from collections import defaultdict
from tqdm import tqdm
import pickle
import numpy as np
import os

datasets = ["imdb","news","sst2"]

for ds in datasets:
    if using_fine_tuned:
        if quantized:
            model, _ = load_model(f"predictor_{model_id}_{ds}_merged")
        else:

            model = AutoModelForCausalLM.from_pretrained(f"predictor_{model_id}_{ds}_merged",device_map="auto",
                    torch_dtype=torch.float16,trust_remote_code=True, attn_implementation="flash_attention_2",
                    max_memory = {i: max_memory for i in range(n_gpus)})

    explanations_dict = defaultdict(list)

    lig_attr, sv_attr, kernel_attr = captum_attrs(model,tokenizer)
    lig_steps = 5
    for i in tqdm(val_data[ds]):
        if ds =="sst2":
            comment = i["sentence"]
        else:
            comment = i["text"]

        if ds == "sst2" or ds=="imdb":
            model.zero_grad()
            torch.cuda.empty_cache()
            if using_fine_tuned:
                
                if ds == "sst2":
                    eval_prompt = comment + " In a sentiment classification task between positive and negative choices, the sentiment of this sentence is "
                elif ds == "imdb":
                    eval_prompt = comment + " Based on this opinion, decide what the sentiment is, choose between positive and negative. Answer is "
            
            elif using_instruction_tuned:
                if ds == "imdb":
                    eval_prompt = f"""In the sentence in triple back ticks what is the sentiment? Answer in one word, positive or negative.
                    example 1: ```{imdb_shots[15]["text"]}```
                    The answer is positive
                    example 2: ```{comment}```
                    The answer is """
                elif ds == "sst2":
                    eval_prompt = f"""In the sentence in triple back ticks what is the sentiment? Answer in one word, positive or negative.
                    example 1: ```{sst_shots[18]["sentence"]}```
                    The answer is negative
                    example 2: ```{comment}```
                    The answer is """
            elif using_instruction_tuned_zero:
                if ds=="imdb":
                    eval_prompt = f"""In the sentence in triple back ticks what is the sentiment? Answer in one word, positive or negative.
```{comment}```
The answer is """
                elif ds == "sst2":
                    eval_prompt = f"""In the sentence in triple back ticks what is the sentiment? Answer in one word, positive or negative.
```{comment}```
The answer is """
            model_input = tokenizer(eval_prompt, return_tensors="pt").to(device)
            out = model(**model_input)
            out_id = torch.argmax(out.logits[0,-1][np_])
            CORRECT_ID = np_[torch.argmax(out.logits[0,-1][np_])]
            FOIL_ID = np_[torch.argsort(out.logits[0,-1][np_])[-2]]
            explanations_dict["comment"].append(comment)
            explanations_dict["prompt"].append(eval_prompt)
            explanations_dict["gold_id"].append(np_[i["label"]].cpu().numpy())
            explanations_dict["predicted_id"].append(CORRECT_ID.cpu().numpy())
            explanations_dict["foil_id"].append(FOIL_ID.cpu().numpy())

            saliency_matrix, embd_matrix = saliency(model, model_input["input_ids"].squeeze(0), model_input["attention_mask"].squeeze(0),
                                     correct=CORRECT_ID, foil=FOIL_ID)

            model.zero_grad()
            contra_explanation = l1_grad_norm(saliency_matrix, normalize=False)
            explanations_dict['gradnorm1'].append(contra_explanation)
            contra_explanation = l2_grad_norm(saliency_matrix, normalize=False)
            explanations_dict['gradnorm2'].append(contra_explanation)
            contra_explanation = input_x_gradient(saliency_matrix, embd_matrix, normalize=False)
            explanations_dict['gradinp'].append(contra_explanation)
            contra_explanation = erasure_scores(model, model_input["input_ids"].squeeze(0), model_input["attention_mask"].squeeze(0),
                                      correct=CORRECT_ID, foil=FOIL_ID, normalize=False)
            explanations_dict['erasure'].append(contra_explanation)

            target_c = tokenizer.convert_ids_to_tokens(CORRECT_ID.cpu().tolist())
            target_f = tokenizer.convert_ids_to_tokens(FOIL_ID.cpu().tolist())

            inp = TextTokenInput(eval_prompt, tokenizer, skip_tokens=[0])
            
            attr_res_c = lig_attr.attribute(inp, target=target_c,n_steps=lig_steps)
            attr_res_f = lig_attr.attribute(inp, target=target_f,n_steps=lig_steps)

            contra_explanation = (attr_res_c.seq_attr/torch.norm(attr_res_c.seq_attr) -
                attr_res_f.seq_attr/torch.norm(attr_res_f.seq_attr)).cpu().numpy()
            explanations_dict['integrated_grad'].append(contra_explanation)

            attr_res_c = kernel_attr.attribute(inp, target=target_c)
            attr_res_f = kernel_attr.attribute(inp, target=target_f)
            contra_explanation = (attr_res_c.seq_attr/torch.norm(attr_res_c.seq_attr) -
                attr_res_f.seq_attr/torch.norm(attr_res_f.seq_attr)).cpu().numpy()
            explanations_dict['kernel_shap'].append(contra_explanation)


            contra_explanation = np.random.normal(0,1,model_input["input_ids"].squeeze(0).shape[0])
            explanations_dict['random'].append(contra_explanation)


        elif ds =="news":
            model.zero_grad()
            torch.cuda.empty_cache()
            if using_fine_tuned:

                eval_prompt = comment + " You are classifying a news article, Choose one of the four categories, World, Business, Sport, and Tech. Answer is "
                
            elif using_instruction_tuned:
                eval_prompt = f"""In the sentence in triple back ticks, what is the news category?\n \
                  choices are world, sports, business, and tech. \nanswer in one word.\n world, or sports, or business, or tech.
                    example 1: ```{news_shot}```
                    The answer is business
                    example 2: ```{comment}```
                    The answer is"""
                
            elif using_instruction_tuned_zero:
                eval_prompt = f"""In the sentence in triple back ticks, what is the news category?
                  choices are world, sports, business, and tech. answer in one word.
```{comment}```
The answer is """
            model_input = tokenizer(eval_prompt, return_tensors="pt").to(device)
            out = model(**model_input)
            out_id = torch.argmax(out.logits[0,-1][wsbt])
            CORRECT_ID = wsbt[torch.argmax(out.logits[0,-1][wsbt])]
            FOIL_ID = wsbt[torch.argsort(out.logits[0,-1][wsbt])[-2]]
            explanations_dict["comment"].append(comment)
            explanations_dict["prompt"].append(eval_prompt)
            explanations_dict["gold_id"].append(wsbt[i["label"]].cpu().numpy())
            explanations_dict["predicted_id"].append(CORRECT_ID.cpu().numpy())
            explanations_dict["foil_id"].append(FOIL_ID.cpu().numpy())
            
            saliency_matrix, embd_matrix = saliency(model, model_input["input_ids"].squeeze(0), model_input["attention_mask"].squeeze(0),
                                     correct=CORRECT_ID, foil=FOIL_ID)

            model.zero_grad()
            contra_explanation = l1_grad_norm(saliency_matrix, normalize=False)
            
            explanations_dict['gradnorm1'].append(contra_explanation)
            contra_explanation = l2_grad_norm(saliency_matrix, normalize=False)
            explanations_dict['gradnorm2'].append(contra_explanation)
            contra_explanation = input_x_gradient(saliency_matrix, embd_matrix, normalize=False)
            explanations_dict['gradinp'].append(contra_explanation)
            contra_explanation = erasure_scores(model, model_input["input_ids"].squeeze(0), model_input["attention_mask"].squeeze(0),
                                      correct=CORRECT_ID, foil=FOIL_ID, normalize=False)
            explanations_dict['erasure'].append(contra_explanation)
            model.zero_grad()
            torch.cuda.empty_cache()
            target_c = tokenizer.convert_ids_to_tokens(CORRECT_ID.cpu().tolist())
            target_f = tokenizer.convert_ids_to_tokens(FOIL_ID.cpu().tolist())
            inp = TextTokenInput(
                eval_prompt, 
                tokenizer,
                skip_tokens=[0], 
            )
            attr_res_c = lig_attr.attribute(inp, target=target_c,n_steps=lig_steps)
            attr_res_f = lig_attr.attribute(inp, target=target_f,n_steps=lig_steps)

            contra_explanation = (attr_res_c.seq_attr/torch.norm(attr_res_c.seq_attr) -
                attr_res_f.seq_attr/torch.norm(attr_res_f.seq_attr)).cpu().numpy()
            explanations_dict['integrated_grad'].append(contra_explanation)

            attr_res_c = kernel_attr.attribute(inp, target=target_c)
            attr_res_f = kernel_attr.attribute(inp, target=target_f)

            contra_explanation = (attr_res_c.seq_attr/torch.norm(attr_res_c.seq_attr) -
                attr_res_f.seq_attr/torch.norm(attr_res_f.seq_attr)).cpu().numpy()
            explanations_dict['kernel_shap'].append(contra_explanation)

            contra_explanation = np.random.normal(0,1,model_input["input_ids"].squeeze(0).shape[0])
            explanations_dict['random'].append(contra_explanation)

    with open(f"attrs/{ds}_{model_id}_attr_quantized_{quantized}_zero.pcl","bw") as f:
        pickle.dump(explanations_dict,f)
        

  0%|          | 0/200 [00:00<?, ?it/s]

100%|██████████| 200/200 [10:53<00:00,  3.27s/it]
100%|██████████| 200/200 [09:35<00:00,  2.88s/it]
100%|██████████| 200/200 [05:50<00:00,  1.75s/it]
