## Import Dependencies

In [1]:
%load_ext autoreload
%autoreload 2
import pandas as pd
import numpy as np
import transformers
import torch
import shap
import scipy as sp
from tqdm.notebook import tqdm
import scipy
import matplotlib.pyplot as plt
import sys
import os
sys.path.append('../')

## Experiment Setup

Model and Dataset

In [37]:
from utils.models import *
from utils.samples import *
from utils.functional import *
from utils.configs import *
from utils.processing import *
from functools import partial, wraps

# model_name = 'openllama'
# model_name = 'Llama2'

# model_name = 'GPT-J'
# model_name = 'falcon7b'
# model_name = 'mpt7b_instruct'
model_name = 'Llama2_7b'

data_name = 'COS-E'
exp_name = model_name+'_'+data_name

model_config = {
    'GPT-J': gpt_j_config,
    'falcon7b': falcon7b_config,
    'mpt7b_instruct': mpt7b_instruct_config,
    'Llama2_7b': llama2_7b_config,
    
}[model_name]
print(model_config)

get_model_function = {
    'GPT-J': partial(get_model_general, 'nlpcloud/instruct-gpt-j-fp16'),
    'falcon7b': partial(get_model_general, "tiiuae/falcon-7b"),
    'mpt7b_instruct': partial(get_model_general, "mosaicml/mpt-7b-instruct"),
    'openllama': get_openllama_auto,
    'Llama2_7b': partial(get_model_general, "meta-llama/Llama-2-7b-chat-hf"),
}

# QA_prompt_sample_dict = {
#     'GPT-J': gpt_j_Zeroshot_QA_samples_new,
#     'falcon7b': falcon7b_Zeroshot_QA_samples_new,
#     'mpt7b_instruct': mpt7b_instruct_Zeroshot_QA_samples_new,
# }

# prompt_creator_dic = {
#     'GPT-J': gpt_j_generate_zeroshot_prompt_QA,
#     'falcon7b' : falcon7b_generate_zeroshot_prompt_QA,
#     'mpt7b_instruct': mpt7b_instruct_generate_zeroshot_prompt_QA,
# } 

processing_classes = {
    'mpt7b_instruct': mpt7b_instruct,
    'GPT-J': gpt_j,
    'Llama2_7b': llama2_7b,
}

model_processing = processing_classes[model_name]()


{'nle_generation': {'total_trials': 4, 'length_increment': 512, 'max_length': 1024}, 'generation_configs': {'renormalize_logits': True, 'early_stopping': True, 'penalty_alpha': 0.3, 'top_k': 12, 'use_cache': True}}


In [3]:
from utils.data import get_CommonsenseQA
df = get_CommonsenseQA()
generator, tokenizer = get_model_function[model_name]()

if not os.path.exists(f'../generated_nle/{exp_name}/'):
    os.makedirs(f'../generated_nle/{exp_name}/')

shap_gen_dict = dict(
    # configs set up here

    pad_token_id=tokenizer.eos_token_id, # to suppress open generation error
    **model_config['generation_configs']
)
sample_size = 1600

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
shap_gen_dict = dict(
    # configs set up here
    max_new_tokens=12, # this value could be associated with COS
    no_repeat_ngram_size=12, # I think this should associates with the above
    pad_token_id=tokenizer.eos_token_id, # to suppress open generation error
    **model_config['generation_configs']
)

## Get Explanations and SHAP score

### Define Generative functions

In [9]:
from utils.functional import *
from utils.samples import * 
import re

def generate_answers(batched_input_premise, batched_input_choices, batched_input_label, batched_label_idx):
    model_answers_list = []
    for i in range(len(batched_input_premise)):
        # create model specific prompt
        # input_choice_list = [phrase.strip("'") for phrase in batched_input_choices[i].split(", ")]
        input_choice_list = [phrase.strip("\"") for phrase in batched_input_choices[i].split(", ")]

        max_choice = max(len(tokenizer(ch)) for ch in input_choice_list)
        # max_new_tokens = max_choice + 4 # 2 for the two special tokens
        # no_repeat_ngram_size = max_new_tokens + 2
        # max_new_tokens = 12
        # no_repeat_ngram_size = 12 # maybe should remove this
        
        prompt = model_processing.generate_zeroshot_prompt_QA(
            batched_input_premise[i], 
            batched_input_choices[i], 
            batched_input_label[i], 
            batched_label_idx[i]
        )
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
        outputs = generator.generate(
            input_ids=input_ids,
            # input format related config
            # max_new_tokens=max_new_tokens, # this value could be associated with COS
            # no_repeat_ngram_size=no_repeat_ngram_size, # I think this should associates with the above

            # output format related config
            return_dict_in_generate=True,
            output_scores=True,

            # add config specify to the model, to tell it how to search for the answer
            **shap_gen_dict
        )
        model_generated_text = tokenizer.decode(outputs.sequences[0])
        model_answer = model_processing.get_answer_from_output_text(model_generated_text, input_choice_list, i) # get the last answer that is answered by the model
        model_answers_list.append(model_answer)
    return model_answers_list

def generate_explanation(batched_input_premise, batched_input_choices, batched_input_label, batched_label_idx):
    nle_list = []
    for i in tqdm(range(len(batched_input_premise)),leave=False):
        # trials = model_config['nle_generation']["total_trials"]
        prompt = model_processing.generate_zeroshot_prompt_QAE(
            batched_input_premise[i], 
            batched_input_choices[i], 
            batched_input_label[i], 
            batched_label_idx[i]
        )
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
        outputs = generator.generate(
            # input format related config
            input_ids=input_ids,
            # max_new_tokens=256,
            # no_repeat_ngram_size=16,
            max_new_tokens=128, # mpt7b_instruct
            no_repeat_ngram_size=5, # mpt7b_instruct
            pad_token_id=tokenizer.eos_token_id, # to suppress open generation error
            # output format related config
            return_dict_in_generate=True,
            output_scores=True,
            # generation related config
            **model_config['generation_configs'],
        )
        prompt = tokenizer.decode(outputs.sequences[0])
        # final_explanation = re.findall('Explanation: (.*\n*.*)\n', prompt)[-1]
        final_explanation = model_processing.get_explanation_from_output_text(prompt, i)
        nle_list.append(final_explanation)
    return nle_list


### Get NLE

In [None]:
df = get_CommonsenseQA()
last_df = df.loc[len(df)-10:]
for idx, x in last_df.iterrows():
    print(x['question'])
    print([x[f'choice_{i}'] for i in range(5)])
    print(x.label)
    print()
# print(df.loc[len(df)-10:].question.values)
# print(df.loc[len(df)-10:].values)
# print(df.loc[len(df)-10:].question.values)



small experiments

In [7]:
from utils.samples import *
from utils.functional import *
import re

idx = 3

prompt = model_processing.generate_zeroshot_prompt_QA(
    batched_input_premise[idx], 
    batched_input_choices[idx], 
    batched_input_label[idx], 
    batched_label_idx[idx]
)
input_choice_list = [phrase.strip("\"") for phrase in batched_input_choices[idx].split(", ")]

max_choice = max(len(tokenizer(ch)) for ch in input_choice_list)
# max_new_tokens = max_choice + 6 # 2 for the two special tokens
# no_repeat_ngram_size = max_new_tokens + 2
max_new_tokens = 12
no_repeat_ngram_size = max_new_tokens

input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
outputs = generator.generate(
    input_ids=input_ids,
    # max_new_tokens = max_new_tokens,
    # no_repeat_ngram_size = no_repeat_ngram_size,
    # output format related config
    return_dict_in_generate=True,
    output_scores=True,

    # add config specify to the model, to tell it how to search for the answer
    **shap_gen_dict
)
model_generated_text = tokenizer.decode(outputs.sequences[0])
model_answer = model_processing.get_answer_from_output_text(model_generated_text, input_choice_list, idx) # get the last answer that is answered by the model
print(prompt)
print(model_generated_text)

Based on commonsense, pick the best choice:
Question: The fox walked from the city into the forest, what was it looking for?
Choose from: 'pretty flowers', 'hen house', 'natural habitat', 'storybook', 'dense forest'
Best Choice:
<s> Based on commonsense, pick the best choice:
Question: The fox walked from the city into the forest, what was it looking for?
Choose from: 'pretty flowers', 'hen house', 'natural habitat', 'storybook', 'dense forest'
Best Choice: 'dense forest'
Explanation: The fo


Generate Answers

In [None]:
sample_df[~sample_df['model_answer'].isna()]

Generate Explanation from Model Answers

In [41]:
idx = 95
prompt = model_processing.generate_zeroshot_prompt_QAE(
    batched_input_premise[idx], 
    batched_input_choices[idx], 
    batched_input_label[idx], 
    batched_label_idx[idx]
)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
outputs = generator.generate(
    # input format related config
    input_ids=input_ids,
    max_new_tokens=128,
    no_repeat_ngram_size=6,
    pad_token_id=tokenizer.eos_token_id, # to suppress open generation error
    # output format related config
    return_dict_in_generate=True,
    output_scores=True,
    # generation related config
    **model_config['generation_configs'],
)
prompt = tokenizer.decode(outputs.sequences[0])
# final_explanation = re.findall('Explanation: (.*\n*.*)\n', prompt)[-1]
final_explanation = model_processing.get_explanation_from_output_text(prompt, i)
print(prompt)
final_explanation

<s> Based on commonsense, pick the best choice:
Question: A plant must do what to make another grow?\Choose from: 'plants', 'increasing in size', 'give up', 'die', 'gets bigger'.
Best Choice: 'give up'.
Explanation:  While plants do need to give up some of their energy to produce new growth, the statement "a plant must give up to make another grow" is not accurate. Plants do not have the ability to intentionally give up or sacrifice themselves for the growth of another plant. The correct answer is "die".</s>


'While plants do need to give up some of their energy to produce new growth, the statement "a plant must give up to make another grow" is not accurate. Plants do not have the ability to intentionally give up or sacrifice themselves for the growth of another plant. The correct answer is "die".'

In [42]:
sample_df = pd.read_csv(f"../generated_nle/{exp_name}/1600_1600_model_answers.csv")
batch_size = 160
for i in tqdm(range(0, len(sample_df), batch_size)):
    s_range = range(i, min(i+batch_size, len(sample_df)))
    batch_samples = sample_df.loc[i:i+batch_size-1]
    has_answer = ~batch_samples['model_answer'].isna()
    batch_samples = batch_samples[has_answer]
    batched_input_premise = batch_samples.question.tolist()
    batched_input_choices = model_processing.create_choices(batch_samples, add_prefix=False)
    batched_input_label = model_processing.make_choice(batch_samples, key = 'model_answer')
    batched_label_idx = [int(sample['model_answer']) for _,sample in batch_samples.iterrows()]
    list_nle = generate_explanation(
        batched_input_premise, 
        batched_input_choices, 
        batched_input_label,
        batched_label_idx
    )
    valid_range = np.array(s_range)[has_answer]
    sample_df.loc[valid_range, 'nle'] = list_nle
    sample_df.to_csv(f'../generated_nle/{exp_name}/{sample_size}_{min(sample_size, i+batch_size)}_model_answer_nle.csv', index=False)
sample_df.head()

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/156 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

  0%|          | 0/158 [00:00<?, ?it/s]

  0%|          | 0/153 [00:00<?, ?it/s]

  0%|          | 0/154 [00:00<?, ?it/s]

Unnamed: 0,id,question,choice_0,choice_1,choice_2,choice_3,choice_4,label,human_expl_open-ended,nle,model_answer
0,075e483d21c29a511267ef62bedc0461,The sanctions against the school were a punish...,ignore,enforce,authoritarian,yell at,avoid,0,Not sure what else could be a common ground,"The sentence ""The sanctions against the School...",0.0
1,61fe6e879ff18686d7552425a36344c8,Sammy wanted to go to where the people were. ...,race track,populated areas,the desert,apartment,roadblock,1,People will be in populated areas.,"Based on the context, it is most likely that S...",1.0
2,02e821a3e53cb320790950aab4489e85,Google Maps and other highway and street GPS s...,united states,mexico,countryside,atlas,oceans,3,atlases were collections of highway and street...,Google Maps and other GPS services have largel...,0.0
3,23505889b94e880c3e89cff4ba119860,"The fox walked from the city into the forest, ...",pretty flowers.,hen house,natural habitat,storybook,dense forest,2,Usually the habitat of a fox is forest and it ...,The wording of the question suggests that the ...,4.0
4,e8a8b3a2061aa0e6d7c6b522e9612824,What home entertainment equipment requires cable?,radio shack,substation,cabinet,television,desk,3,television is the only option that is a home e...,Television is the home entertainment equipment...,3.0


In [47]:
file_name = "model_answer_nle"
sample_df = pd.read_csv(f'../generated_nle/{exp_name}/1600_1600_{file_name}.csv')
sample_size = len(sample_df)
sample_df = sample_df.dropna(subset=['nle', 'model_answer'])
def has_period(text):
    ret = []
    for i, x in enumerate(text):
        if type(x) == float:
            print(i, x)
        ret.append("." in x[-10:])
    return np.array(ret)
invalid_condition = sample_df['nle'].str.contains('a chess set') | False==has_period(sample_df['nle'])
valid_samples = sample_df[~invalid_condition]
# sample_df = valid_samples
# sample_df = sample_df.reset_index(drop=True)
valid_samples.to_csv(f'../generated_nle/{exp_name}/valid_samples_{sample_size}_{file_name}.csv', index=False)
valid_samples

Unnamed: 0,id,question,choice_0,choice_1,choice_2,choice_3,choice_4,label,human_expl_open-ended,nle,model_answer
1,61fe6e879ff18686d7552425a36344c8,Sammy wanted to go to where the people were. ...,race track,populated areas,the desert,apartment,roadblock,1,People will be in populated areas.,"Based on the context, it is most likely that S...",1.0
2,02e821a3e53cb320790950aab4489e85,Google Maps and other highway and street GPS s...,united states,mexico,countryside,atlas,oceans,3,atlases were collections of highway and street...,Google Maps and other GPS services have largel...,0.0
3,23505889b94e880c3e89cff4ba119860,"The fox walked from the city into the forest, ...",pretty flowers.,hen house,natural habitat,storybook,dense forest,2,Usually the habitat of a fox is forest and it ...,The wording of the question suggests that the ...,4.0
4,e8a8b3a2061aa0e6d7c6b522e9612824,What home entertainment equipment requires cable?,radio shack,substation,cabinet,television,desk,3,television is the only option that is a home e...,Television is the home entertainment equipment...,3.0
5,3d0f8824ea83ddcc9ab03055658b89d3,"The forgotten leftovers had gotten quite old, ...",carpet,refrigerator,breadbox,fridge,coach,1,Becuase Leftovers are put in the fridge,The forgotten lefthelvers are likely to be sto...,1.0
...,...,...,...,...,...,...,...,...,...,...,...
1595,709912a9876e824ea97da15170fd1716,Where might you have to pay for a shopping bag...,restaurant,closet,at starbucks,supermarket,home,3,A supermarket is the only place one would need...,"In many places, supermarkets now charge for sh...",3.0
1596,95703127228d8b87b3fe998b1b2203e6,Where would the Air Force keep an airplane?,airport terminal,military base,sky,hanger,airplane hangar,1,is the best since it is the best backup,An airplane hangar is a building or structure ...,4.0
1597,f7216534fcff2743cc94c29cb689ea71,What is a computer terminal?,electrical device,battery,electronics,transportation system,initial,0,An computer is an electrical device connected ...,A computer terminal is an electronic device th...,2.0
1598,374e8dc01ca7749d668ed66ac4044e70,"She was known for be soft and sensitive, but w...",non sensitive,resistant,stoic,hardened,uncaring,2,Soft and senstive persons aren't expected to b...,The word 'hardened' best describes the action ...,3.0


### Generate SHAP values 

In [44]:
model, tokenizer = generator, tokenizer

In [45]:
org_call_one = tokenizer._call_one
org_generate = generator.generate
org__call__ = generator.__call__
org_forward = generator.forward
model.config.is_decoder=True

from functools import wraps

# wrap is the only correct way to handle this argument
# "return_token_type_ids" exists in the **kwargs, so modification needs to be made directly to kwargs,
# otherwise, repeated argument will be passed in.
# Also, wrap perserves the signature/information of the function.
@wraps(org_call_one)
def _call_one_wrapped(*x, **y):
    y['return_token_type_ids'] = False
    return org_call_one(*x, **y)

@wraps(org_generate)
def _generate_wrapped(*x, **y):
    for k in shap_gen_dict:
        y[k] = shap_gen_dict[k]
    return org_generate(*x, **y)

@wraps(org__call__)
def __call__wrapped(*x, **y):
    for k in shap_gen_dict:
        y[k] = shap_gen_dict[k]
    return org__call__(*x, **y)


@wraps(org_forward)
def forward_wrapped(*x, **y):
    if 'position_ids' in y:
        del y['position_ids']
    return org_forward(*x, **y)
    
tokenizer._call_one = _call_one_wrapped
model.generate = _generate_wrapped
model.__call__ = __call__wrapped
model.forward  = forward_wrapped

In [63]:
tokenizer

LlamaTokenizerFast(name_or_path='meta-llama/Llama-2-7b-chat-hf', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'pad_token': '</s>'}, clean_up_tokenization_spaces=False)

In [64]:
model_name

'Llama2_7b'

In [70]:
from utils.functional import *
from utils.samples import*

df = pd.read_csv(f'../generated_nle/{exp_name}/valid_samples_{sample_size}_model_answer_nle.csv')

limit = len(df)
df = df.loc[:limit].copy()

import warnings
# warnings.resetwarnings()
warnings.filterwarnings("ignore")

shap_model = shap.models.TeacherForcing(model, tokenizer)
masker = shap.maskers.Text(tokenizer, mask_token="...", collapse_mask_token=True)
if model_name == 'Llama2_7b':
    masker = shap.maskers.Text(tokenizer, mask_token="...", collapse_mask_token=False)

top_k=5
batch_size = 160
current_i = 0
if current_i != 0:
    df = pd.read_csv(f'../generated_nle/{exp_name}/{data_name}_{limit}_{current_i-1}_shap.csv')


for i in tqdm(range(current_i, limit, batch_size)):
    s_range = np.array(range(i, min(i+batch_size, limit)))
    batched_input_premise = df.question[s_range].tolist()
    # create model's prompt and make choice as the model's answer
    batched_input_choices = model_processing.create_choices(df.loc[s_range], add_prefix=False)
    batched_input_label = model_processing.make_choice(df.loc[s_range], key = 'model_answer') 
    batched_label_idx = [int(sample['label']) for _,sample in df.loc[s_range].iterrows()]

    max_shap_list = []
    ratio_shap_list = []
    question_shap_list = []
    context_shap_list = []

    for j in tqdm(range(len(batched_input_premise)), leave=False):
        explainer = shap.Explainer(shap_model, masker, silent=True)
        prompt = model_processing.generate_zeroshot_prompt_QA(
            batched_input_premise[j], 
            batched_input_choices[j], 
            batched_input_label[j], 
            batched_label_idx[j]
        )
        shap_values = explainer([prompt])
        shap_results = model_processing.get_context_shap(
            shap_values, 
            batched_input_label[j], 
        )
        if shap_results is None:
            max_shap_list.append(None)
            ratio_shap_list.append(None)
            continue

        context_shap_list.append(str(shap_results['context_shap'])) 
        question_shap_list.append(str(shap_results['question_shap']))
        choices_shap = shap_results['clean_choices_shap']

        max_percent_shap = model_processing.get_max_percent_shap(
            shap_results['context_shap'], 
            top_k=top_k
        )
        counter_factual_ratio = model_processing.get_counter_factual_ratio(
            shap_results['question_shap'],
            shap_results['clean_choices_shap']
        )

        max_shap_list.append(max_percent_shap)
        ratio_shap_list.append(counter_factual_ratio)

    df.loc[s_range, f'max_shap_value'] = max_shap_list
    df.loc[s_range, f'ratio_shap_value'] = ratio_shap_list
    df.loc[s_range, f'context_shap'] = context_shap_list
    df.loc[s_range, f'question_shap'] = question_shap_list
    df.to_csv(f'../generated_nle/{exp_name}/{data_name}_{limit}_{s_range.max()}_shap.csv', index=False)
    # print(len(sample.question.split(" ")))
    # print(positive_shap)
    # break

# df.to_csv(f'../generated_nle/GPT-J_COS-E/COS-E_{limit}_shap.csv', index=False)
# candidate_labels

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

KeyboardInterrupt: 

### Visualization check to make sure shap score generates properly

In [48]:
j = 60
prompt = model_processing.generate_zeroshot_prompt_QA(
    batched_input_premise[j], 
    batched_input_choices[j], 
    batched_input_label[j], 
    batched_label_idx[j]
)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
results = model.generate(
            input_ids=input_ids,
            # input format related config
            # max_new_tokens=max_new_tokens, # this value could be associated with COS
            # no_repeat_ngram_size=no_repeat_ngram_size, # I think this should associates with the above

            # output format related config
            return_dict_in_generate=True,
            output_scores=True,

            # add config specify to the model, to tell it how to search for the answer
            **shap_gen_dict
        )

In [75]:
model(input_ids).logits.shape

torch.Size([1, 68, 32000])

In [71]:
explainer = shap.Explainer(shap_model, masker, silent=True, algorithm="partition", )

shap_values = explainer([prompt],)

In [62]:
print(tokenizer.decode(results.sequences[0]))

Answer the following based on commonsense:
Question: WHat do cats get into when they are ripping things apart?
Choose from: 'dog's mouth', 'floor', 'garage', 'trouble', 'nature'.
Best answer choice:  'garage'.
<|endoftext|>


In [67]:
shap_results = model_processing.get_context_shap(
    shap_values, 
    batched_input_label[j], 
)

>>>>>>>>>>>>>>model answer not found>>>>>>>>>>>>>>
['' 'Based' ' on' ' comm' 'ons' 'ense' ',' ' pick' ' the' ' best'
 ' choice' ':' '\n' 'Question' ':' ' Sam' 'my' ' wanted' ' to' ' go' ' to'
 ' where' ' the' ' people' ' were' '.' ' ' ' Where' ' might' ' he' ' go'
 '?' '\n' 'Cho' 'ose' ' from' ':' " '" 'race' ' track' "'," " '" 'pop'
 'ulated' ' areas' "'," " '" 'the' ' desert' "'," " '" 'ap' 'art' 'ment'
 "'," " '" 'road' 'block' "'" '\n' 'Best' ' Cho' 'ice' ':']


In [72]:
shap_values

.values =
array([[[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00],
        [ 4.31544734e-05, -2.96011597e-05, -3.63853445e-05,
         -2.69269868e-05,  3.94007904e-05],
        [ 4.31544734e-05, -2.96011597e-05, -3.63853445e-05,
         -2.69269868e-05,  3.94007904e-05],
        [-8.63089468e-05,  5.92023195e-05,  7.27706890e-05,
          5.38539736e-05, -7.88015808e-05],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000

In [73]:
shap.plots.text(shap_values)

In [65]:
''.join(shap_results['clean_choices'])

": 'dog's mouth', 'floor', '', 'trouble', 'nature'."

In [66]:
''.join(shap_results['target'])

'garage'

In [None]:
shap_results = model_processing.get_context_shap(
    shap_values, 
    batched_input_label[0], 
)

In [None]:
df = pd.read_csv(f"../generated_nle/GPT-J_COS-E/{data_name}_1535_1534_shap.csv")
import matplotlib.pyplot as plt
plt.hist(df['max_shap_value'], bins=20)

In [None]:
output_names = shap_values.output_names
output_names

In [None]:
question_shap

In [None]:
shap_values.__dict__.keys()

In [None]:
shap_values._s.__dict__.keys()

In [None]:
output_list = list(shap_values._s._aliases['output_names'])
print(output_list, len(output_list))

In [None]:
def check_and_erase(source, target):
    bool_results = []
    true_ends = 0
    for x in source:
        if x == '':
            bool_results.append(False)
            continue
        if x in target:
            bool_results.append(True)
            true_ends = 1
            target = target.replace(x, "")
        else:
            if true_ends == 1:
                break
            bool_results.append(False)
    bool_results = bool_results + [False]*(len(source)-len(bool_results))
    return np.array(bool_results)
check_and_erase(output_list, 'enforce')

In [None]:
output_list

In [None]:
shap_values.values[0].shape

In [None]:
shap.plots.text(shap_values)

## Training for the Probe

### Load Model

In [5]:
# bert_model = transformers.pipeline('sentiment-analysis', top_k=None)
# tokenizer = bert_model.tokenizer
# sample_input = tokenizer(['heavy metal'])
llm = generator._modules['transformer']

In [None]:
input_ids = tokenizer('heavy metal', return_tensors="pt").input_ids.to(generator.device)
input_ids
with torch.no_grad():
    llm_outputs = llm(input_ids)
llm_outputs

In [9]:
llm_outputs.last_hidden_state.shape

torch.Size([1, 2, 4096])

### Get Data

In [None]:
# get the content between brackets
def get_bracket_content(s):
    return re.findall(r'\((.*?)\)', s)[0]

In [6]:
import ast
import re
# sample_df = pd.read_csv("../generated_nle/GPT-J_COS-E/COS-E_505_shap.csv")
# file_prefix = 'COS-E_1535_1534'
# file_prefix = 'COS-E_1525_1524'
# file_prefix = 'COS-E_1533_1532'
file_prefix = 'COS-E_1543_1542'

sample_df = pd.read_csv(f"../generated_nle/{exp_name}/{file_prefix}_shap.csv")
# supplements = ["../COS-E_range(0, 1797)_shap.csv", "../COS-E_range(1797, 3594)_shap.csv"]
# final_df_len = 0
# for fname in supplements:
#     # range_tuple = re.findall(r'\((.*?)\)', fname)[0].split(', ')
#     # range_tuple = [int(x) for x in range_tuple]
#     if sample_df is None:
#         sample_df = pd.read_csv(fname)
#     else:
#         sample_df = pd.concat([sample_df, pd.read_csv(fname)], ignore_index=True)  
sample_df

Unnamed: 0,id,question,choice_0,choice_1,choice_2,choice_3,choice_4,label,human_expl_open-ended,nle,model_answer,max_shap_value,ratio_shap_value,context_shap,question_shap
0,075e483d21c29a511267ef62bedc0461,The sanctions against the school were a punish...,ignore,enforce,authoritarian,yell at,avoid,0,Not sure what else could be a common ground,The school had made efforts to change and the ...,4.0,0.421578,0.0,[[ 0.02636847 -0.10477108 -0.14546561 -0.33800...,[[ 0.02636847 -0.10477108 -0.14546561 -0.33800...
1,61fe6e879ff18686d7552425a36344c8,Sammy wanted to go to where the people were. ...,race track,populated areas,the desert,apartment,roadblock,1,People will be in populated areas.,The desert is a large area with little to no p...,2.0,0.725241,0.0,[[ 3.01036214e-01 -2.08922671e-01 4.58802453e...,[[ 3.01036214e-01 -2.08922671e-01 4.58802453e...
2,02e821a3e53cb320790950aab4489e85,Google Maps and other highway and street GPS s...,united states,mexico,countryside,atlas,oceans,3,atlases were collections of highway and street...,'atlas' is the best answer choice because it i...,3.0,0.680713,0.0,[[ 0.74464532 -0.3819517 0.89716705 0.16696...,[[ 0.74464532 -0.3819517 0.89716705 0.16696...
3,23505889b94e880c3e89cff4ba119860,"The fox walked from the city into the forest, ...",pretty flowers.,hen house,natural habitat,storybook,dense forest,2,Usually the habitat of a fox is forest and it ...,The fox was likely looking for a place to rest...,2.0,0.520239,0.0,[[ 0.03973068 0.33928298 0.0051444 0.06542...,[[ 0.03973068 0.33928298 0.0051444 0.06542...
4,e8a8b3a2061aa0e6d7c6b522e9612824,What home entertainment equipment requires cable?,radio shack,substation,cabinet,television,desk,3,television is the only option that is a home e...,Television is commonly used for home entertain...,3.0,0.872487,0.0,[[0.82665339 0.7527402 ]\n [0.82665339 0.75274...,[[0.82665339 0.7527402 ]\n [0.82665339 0.75274...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1538,709912a9876e824ea97da15170fd1716,Where might you have to pay for a shopping bag...,restaurant,closet,at starbucks,supermarket,home,3,A supermarket is the only place one would need...,Shopping bags can usually be found at home sto...,4.0,0.659774,0.0,[[ 0.22399578]\n [-0.01616449]\n [-0.06091089]...,[[ 0.22399578]\n [-0.01616449]\n [-0.06091089]...
1539,95703127228d8b87b3fe998b1b2203e6,Where would the Air Force keep an airplane?,airport terminal,military base,sky,hanger,airplane hangar,1,is the best since it is the best backup,Airplane hangars are large structures used to ...,4.0,0.687855,0.0,[[ 0.36677777 -0.69506967 1.15907119 -0.14407...,[[ 0.36677777 -0.69506967 1.15907119 -0.14407...
1540,f7216534fcff2743cc94c29cb689ea71,What is a computer terminal?,electrical device,battery,electronics,transportation system,initial,0,An computer is an electrical device connected ...,Computers use electronic components to store a...,2.0,0.970546,0.0,[[ 0.04812292 0.09913067]\n [-0.05763736 0.1...,[[ 0.04812292 0.09913067]\n [-0.05763736 0.1...
1541,374e8dc01ca7749d668ed66ac4044e70,"She was known for be soft and sensitive, but w...",non sensitive,resistant,stoic,hardened,uncaring,2,Soft and senstive persons aren't expected to b...,The word 'hardened' describes a person who has...,3.0,0.351297,0.0,[[-0.08703304 0.10241594 0.34442515 0.24562...,[[-0.08703304 0.10241594 0.34442515 0.24562...


In [12]:
llm.config

GPTJConfig {
  "_name_or_path": "nlpcloud/instruct-gpt-j-fp16",
  "activation_function": "gelu_new",
  "architectures": [
    "GPTJForCausalLM"
  ],
  "attn_pdrop": 0.0,
  "bos_token_id": 50256,
  "do_sample": true,
  "embd_pdrop": 0.0,
  "eos_token_id": 50256,
  "gradient_checkpointing": false,
  "initializer_range": 0.02,
  "is_decoder": true,
  "layer_norm_epsilon": 1e-05,
  "max_length": 50,
  "model_type": "gptj",
  "n_embd": 4096,
  "n_head": 16,
  "n_inner": null,
  "n_layer": 28,
  "n_positions": 2048,
  "resid_pdrop": 0.0,
  "rotary_dim": 64,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {}
  },
  "tie_word_embeddings": false,
  "tokenizer_class": "GPT2Tokenizer",
  "torch_dtype": "float16",
  "transformers_version": "4.31.0",
  "use_cache": true,
  "vocab_size": 50400
}

In [7]:
## define a torch lstm model
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from functools import partial
from transformers import AutoTokenizer

device = torch.device('cuda:0')

## use a tokenizer from the bert model
# bert_model = transformers.pipeline('sentiment-analysis', top_k=None)

# class Distilbert_LSTM_regressor(nn.Module):
#     def __init__(self, input_size=768, hidden_size=256, num_layers=3, bidirectional=True):
#         super().__init__()
#         self.hidden_size = hidden_size
#         self.num_layers = num_layers
#         # self.bert_model = bert_model
#         self.embedding = bert_model.model.distilbert.eval()
#         self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional)
#         out_size = hidden_size * 2 if bidirectional else hidden_size
#         self.fc = nn.Linear(out_size, 1)
#         self.double()
#     def forward(self, x):
#         with torch.no_grad():
#             bert_embeddings = self.embedding(**x)
#         self.lstm.flatten_parameters()
#         out, _ = self.lstm(bert_embeddings.last_hidden_state)
#         out = self.fc(out[:, -1, :])
#         return out


class LSTM_regressor(nn.Module):
    def __init__(self, input_size=768, hidden_size=256, num_layers=3, bidirectional=True):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.embedding = nn.Embedding(50265, input_size)
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional)
        out_size = hidden_size * 2 if bidirectional else hidden_size
        self.fc = nn.Linear(out_size, 1)
        self.double()
    def forward(self, x):
        embeddings = self.embedding(x)
        self.lstm.flatten_parameters()
        out, _ = self.lstm(embeddings)
        out = self.fc(out[:, -1, :])
        return out

class LLM_LSTM_regressor(nn.Module):
    def __init__(self, hidden_size=256, num_layers=3, bidirectional=True):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        # self.bert_model = bert_model
        # input_size = llm.config.d_model
        input_size = llm.config.n_embd
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional)
        out_size = hidden_size * 2 if bidirectional else hidden_size
        self.fc = nn.Linear(out_size, 1)
        # self.double()

    def forward(self, x):
        self.lstm.flatten_parameters()
        out, _ = self.lstm(x.last_hidden_state)
        out = self.fc(out[:, -1, :])
        return out

        
class Sentiment_Dataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer):
        self.df = df
        self.max_len = tokenizer.model_max_length
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        nle_input = self.df.nle[idx]
        label = self.df.max_shap_value[idx]
        # label = self.df.ratio_shap_value[idx]

        return nle_input, label  

def collate_fn_base(data, tokenizer):
    nle_input, label = zip(*data)
    nle_input = tokenizer(nle_input, return_tensors="pt", padding=True)
    nle_input = {k: v.to(device) for k, v in nle_input.items()}
    return nle_input, torch.tensor(label).to(device).reshape(-1,1).float()

# def collate_fn_base(data, tokenizer):
#     nle_input, label = zip(*data)
#     nle_input = tokenizer(nle_input, return_tensors="pt", padding=True)
#     nle_input = nle_input.input_ids.to(device)
#     return nle_input, torch.tensor(label).to(device).reshape(-1,1)

model = LLM_LSTM_regressor().to(device)
# model = LSTM_regressor().to(device)

# tokenizer = bert_model.tokenizer
# tokenizer = AutoTokenizer.from_pretrained('nlpcloud/instruct-gpt-j-fp16')
# tokenizer = AutoTokenizer.from_pretrained('tiiuae/falcon-7b')
org_call_one = tokenizer._call_one

@wraps(org_call_one)
def _call_one_wrapped(*x, **y):
    y['return_token_type_ids'] = False
    return org_call_one(*x, **y)
tokenizer._call_one = _call_one_wrapped

tokenizer.pad_token = tokenizer.eos_token
collate_fn = partial(collate_fn_base, tokenizer=tokenizer)

train_size = int(0.7 * len(sample_df))
val_size = int(0.15 * len(sample_df))

# np.random.seed(42)
np.random.seed(66)
idxes = np.random.permutation(len(sample_df))

train_idxes, val_idxes, test_idxes = idxes[:train_size], idxes[train_size:train_size+val_size], idxes[train_size+val_size:]
train_dataset = Sentiment_Dataset(sample_df.loc[train_idxes, :].copy().reset_index(), tokenizer)
val_dataset = Sentiment_Dataset(sample_df.loc[val_idxes, :].copy().reset_index(), tokenizer)
test_dataset = Sentiment_Dataset(sample_df.loc[test_idxes, :].copy().reset_index(), tokenizer)

batch_size = 32
train_dataloader = DataLoader(train_dataset, 
    batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset,
    batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset,
    batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=collate_fn)
## collate function doesnt work with n workers > 0?

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scaler = torch.cuda.amp.GradScaler()
# optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

print(len(train_dataset), len(val_dataset), len(test_dataset))


1080 231 232


In [8]:
# best_val_loss = np.inf
from torch import autocast

best_peasonr = -np.inf
llm.eval()
for epoch in (ep_disc:=tqdm(range(n_epochs:=32))):
    model.train()
    y_true, y_pred, train_loss = [], [], []
    for (xs, ys) in tqdm(train_dataloader, leave=True):
        optimizer.zero_grad()
        with torch.no_grad():
            xs = llm(xs['input_ids'])
        with autocast(device_type='cuda', dtype=torch.float16):
            out = model(xs)
            loss = criterion(out, ys)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss.append(loss.item())
        y_true.append(ys.detach().cpu().numpy().flatten())
        y_pred.append(out.detach().cpu().numpy().flatten())
    y_true, y_pred = np.concatenate(y_true), np.concatenate(y_pred)
    train_loss = np.mean(train_loss)
    train_PearsonR = scipy.stats.pearsonr(y_pred, y_true)[0]

    model.eval()
    y_true, y_pred, val_loss = [], [], []
    for (xs, ys) in val_dataloader:
        with torch.no_grad():
            xs = llm(xs['input_ids'])
            with autocast(device_type='cuda', dtype=torch.float16):
                out = model(xs)
                loss = criterion(out, ys)
        y_true.append(ys.cpu().numpy().flatten())
        y_pred.append(out.cpu().numpy().flatten())
        val_loss.append(loss.item())
    val_loss = np.mean(val_loss)
    y_true, y_pred = np.concatenate(y_true), np.concatenate(y_pred)
    PearsonR = scipy.stats.pearsonr(y_pred, y_true)[0]
    print(f'epoch: {epoch}, train_loss {train_loss:.3f}, PearsonR: {train_PearsonR:.3f}, val_loss: {val_loss:.3f}, PearsonR: {PearsonR:.3f}')
    if PearsonR > best_peasonr:
        best_peasonr = PearsonR
        torch.save(model.state_dict(), f'../generated_nle/{exp_name}/{file_prefix}_best_model.pth')
        print('saved best model')

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 0, train_loss 0.061, PearsonR: 0.008, val_loss: 0.026, PearsonR: -0.004
saved best model


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 1, train_loss 0.027, PearsonR: 0.097, val_loss: 0.026, PearsonR: -0.052


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 2, train_loss 0.026, PearsonR: 0.217, val_loss: 0.027, PearsonR: 0.076
saved best model


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 3, train_loss 0.021, PearsonR: 0.465, val_loss: 0.029, PearsonR: 0.073


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 4, train_loss 0.013, PearsonR: 0.710, val_loss: 0.028, PearsonR: 0.148
saved best model


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 5, train_loss 0.010, PearsonR: 0.789, val_loss: 0.031, PearsonR: 0.118


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 6, train_loss 0.008, PearsonR: 0.847, val_loss: 0.029, PearsonR: 0.153
saved best model


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 7, train_loss 0.007, PearsonR: 0.859, val_loss: 0.034, PearsonR: 0.140


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 8, train_loss 0.004, PearsonR: 0.914, val_loss: 0.034, PearsonR: 0.070


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 9, train_loss 0.004, PearsonR: 0.931, val_loss: 0.032, PearsonR: 0.121


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 10, train_loss 0.003, PearsonR: 0.939, val_loss: 0.030, PearsonR: 0.127


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 11, train_loss 0.003, PearsonR: 0.950, val_loss: 0.033, PearsonR: 0.091


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 12, train_loss 0.003, PearsonR: 0.949, val_loss: 0.034, PearsonR: 0.115


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 13, train_loss 0.002, PearsonR: 0.953, val_loss: 0.032, PearsonR: 0.146


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 14, train_loss 0.002, PearsonR: 0.964, val_loss: 0.030, PearsonR: 0.153
saved best model


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 15, train_loss 0.002, PearsonR: 0.970, val_loss: 0.035, PearsonR: 0.114


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 16, train_loss 0.002, PearsonR: 0.971, val_loss: 0.034, PearsonR: 0.100


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 17, train_loss 0.001, PearsonR: 0.972, val_loss: 0.032, PearsonR: 0.090


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 18, train_loss 0.002, PearsonR: 0.971, val_loss: 0.034, PearsonR: 0.090


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 19, train_loss 0.002, PearsonR: 0.971, val_loss: 0.034, PearsonR: 0.102


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 20, train_loss 0.001, PearsonR: 0.975, val_loss: 0.033, PearsonR: 0.108


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 21, train_loss 0.001, PearsonR: 0.980, val_loss: 0.032, PearsonR: 0.091


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 22, train_loss 0.001, PearsonR: 0.978, val_loss: 0.034, PearsonR: 0.076


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 23, train_loss 0.001, PearsonR: 0.978, val_loss: 0.034, PearsonR: 0.068


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 24, train_loss 0.001, PearsonR: 0.980, val_loss: 0.034, PearsonR: 0.081


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 25, train_loss 0.001, PearsonR: 0.983, val_loss: 0.032, PearsonR: 0.085


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 26, train_loss 0.001, PearsonR: 0.979, val_loss: 0.035, PearsonR: 0.065


  0%|          | 0/34 [00:00<?, ?it/s]

epoch: 27, train_loss 0.001, PearsonR: 0.981, val_loss: 0.033, PearsonR: 0.072


  0%|          | 0/34 [00:00<?, ?it/s]

In [8]:
llm.eval()
datasets_dict = {
    'train': train_dataset,
    'val': val_dataset,
    'test': test_dataset
}

model.load_state_dict(torch.load(f'../generated_nle/{exp_name}/{file_prefix}_best_model.pth'))
model.eval()
model.half()

import ipywidgets as widgets
@widgets.interact(split=['train', 'val', 'test'])
def eval_for_datast(split):
    some_dataset = datasets_dict[split]
    some_dataloader = DataLoader(some_dataset,
        batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=collate_fn)
    y_pred, y_true, xs_len = [], [], []
    for (xs, ys) in (disc:=tqdm(some_dataloader)):
        with torch.no_grad():
            xs = llm(xs['input_ids'])
            out = model(xs)
        y_pred.append(out.cpu().numpy())
        y_true.append(ys.cpu().numpy())
        # disc.set_description(f'loss: {loss.item():.3f}')
    y_pred = np.vstack(y_pred).flatten()
    y_true = np.vstack(y_true).flatten()
    MAE = np.abs(y_pred - y_true)
    sent_len = [len(x.split(" ")) for x in some_dataset.df.question]
    PearsonR = scipy.stats.pearsonr(y_pred, y_true)
    Loss = criterion(torch.tensor(y_pred), torch.tensor(y_true)).item()
    print(f'PearsonR: {PearsonR[0]:.3f}, p-value: {PearsonR[1]:.3f}, loss: {Loss:.3f}')
    plt.scatter(y_pred, y_true, alpha=0.5)
    plt.xlabel('predicted')
    plt.ylabel('actual')
    plt.show()  
    PearsonR = scipy.stats.pearsonr(sent_len, MAE)
    print(f'PearsonR: {PearsonR[0]:.3f}') 
    plt.scatter(sent_len, MAE, alpha=0.5)
    plt.show()
    @widgets.interact(idx=(0, len(some_dataset)-1))
    def show_sample(idx):
        x, y = some_dataset[idx]
        nle = tokenizer(x, return_tensors="pt", padding=True)
        nle = {k: v.to(device) for k, v in nle.items()}
        model.eval()
        with torch.no_grad():
            nle = llm(nle['input_ids'])
            out = model(nle)
        # out, sample_df.max_shap_value[0]
        print(f'predicted: {out.item():.3f}, actual: {y:.3f}')
        print(f'input: {x}')


interactive(children=(Dropdown(description='split', options=('train', 'val', 'test'), value='train'), Output()…

In [None]:
# get max numbers