Import Dependencies

In [14]:
import pandas as pd
import numpy as np

## Load Dataset

In [15]:
df_cose = pd.read_csv("./data/cos-e/parsed-cose-train.csv")
df_cose

Unnamed: 0,id,question,choice_0,choice_1,choice_2,choice_3,choice_4,label,human_expl_open-ended
0,075e483d21c29a511267ef62bedc0461,The sanctions against the school were a punish...,ignore,enforce,authoritarian,yell at,avoid,0,Not sure what else could be a common ground
1,61fe6e879ff18686d7552425a36344c8,Sammy wanted to go to where the people were. ...,race track,populated areas,the desert,apartment,roadblock,1,People will be in populated areas.
2,02e821a3e53cb320790950aab4489e85,Google Maps and other highway and street GPS s...,united states,mexico,countryside,atlas,oceans,3,atlases were collections of highway and street...
3,23505889b94e880c3e89cff4ba119860,"The fox walked from the city into the forest, ...",pretty flowers.,hen house,natural habitat,storybook,dense forest,2,Usually the habitat of a fox is forest and it ...
4,e8a8b3a2061aa0e6d7c6b522e9612824,What home entertainment equipment requires cable?,radio shack,substation,cabinet,television,desk,3,television is the only option that is a home e...
...,...,...,...,...,...,...,...,...,...
7186,28ab300ef821e57e19be3f757842dd62,Where might I find seating while waiting to ta...,rest area,bus depot,theatre,bus stop,church,1,I FIND SEATING TO WAIT IN BUS DEPOT FOR PUBLIC...
7187,f1eb6055fa8a1ec94fa6d710d9a6741b,Something that you need to have inside of you ...,workers,money,determination,funding,creativity,2,opening a business is tough so you need to hav...
7188,f1b2a30a1facff543e055231c5f90dd0,What would someone need to do if he or she wan...,consequences,being ridiculed,more money,more funding,telling all,4,To go public is to be revealing information an...
7189,22d0eea15e10be56024fd00bb0e4f72f,Where would you buy jeans in a place with a la...,shopping mall,laundromat,hospital,clothing store,thrift store,0,When submitting your products to Merchant Cent...


## Get Explanations

### GPT explanation

In [16]:
# imports
import os
import ast  # for converting embeddings saved as strings back to arrays
import openai  
import pandas as pd
import tiktoken  # for counting tokens
from scipy import spatial  # for calculating vector similarities for search
openai.api_key = os.environ["OPENAI_API_KEY"]

# models
EMBEDDING_MODEL = "text-embedding-ada-002"
# GPT_MODEL = "gpt-3.5-turbo"
GPT_MODEL = "gpt-4-0613" # According to OpenAI, GPT-4 is more responsive to system messages, whereas 3.5 would rely on more of the user input.


In [17]:
# an example nle question
premise = "Which of the following should you not bring a fox in"
choices = ['hen house', 'England', 'elementary school']
query = f"""Question: {premise}:
Choices: {"  ".join([f"{i+1}.{choices[i]}" for i in range(len(choices))])}"""

sys_msg = """
Make a selection and explain your reasoning.
Reply in the following format: the first line is contains only the choice number,
and the second line is the explanation.
"""

response = openai.ChatCompletion.create(
    messages=[
        {'role': 'system', 'content': sys_msg},
        {'role': 'user', 'content': query},
    ],
    model=GPT_MODEL,
    temperature=0,
)

model_response = response['choices'][0]['message']['content']

In [18]:
y_pred, model_nle = model_response.split('\n')
y_pred = int(y_pred)
y_pred, model_nle 

(1,
 'Bringing a fox into a hen house would be disastrous, as foxes are natural predators of chickens and would likely cause harm or death to the hens.')

## Obtain NLE Embeddings

In [None]:
import torch
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large', verbose=0)
roberta.eval()  # disable dropout (or leave in train mode to finetune)
roberta.to('cuda')

In [27]:
to_tok = list(df_cose.columns)
for x in ['id', 'label']:
    to_tok.remove(x)
tok_df_cose = df_cose.copy()
tok_df_cose[to_tok] = tok_df_cose[to_tok].apply(lambda x: [roberta.encode(str(s)) for s in x])
tok_df_cose

Unnamed: 0,id,question,choice_0,choice_1,choice_2,choice_3,choice_4,label,human_expl_open-ended
0,075e483d21c29a511267ef62bedc0461,"[tensor(0), tensor(133), tensor(2637), tensor(...","[tensor(0), tensor(47072), tensor(2)]","[tensor(0), tensor(225), tensor(9091), tensor(2)]","[tensor(0), tensor(11515), tensor(20444), tens...","[tensor(0), tensor(219), tensor(1641), tensor(...","[tensor(0), tensor(40623), tensor(2)]",0,"[tensor(0), tensor(7199), tensor(686), tensor(..."
1,61fe6e879ff18686d7552425a36344c8,"[tensor(0), tensor(21169), tensor(4783), tenso...","[tensor(0), tensor(12326), tensor(1349), tenso...","[tensor(0), tensor(15076), tensor(12944), tens...","[tensor(0), tensor(627), tensor(10348), tensor...","[tensor(0), tensor(1115), tensor(27699), tenso...","[tensor(0), tensor(14288), tensor(16776), tens...",1,"[tensor(0), tensor(4763), tensor(40), tensor(2..."
2,02e821a3e53cb320790950aab4489e85,"[tensor(0), tensor(20441), tensor(21089), tens...","[tensor(0), tensor(33557), tensor(982), tensor...","[tensor(0), tensor(119), tensor(3463), tensor(...","[tensor(0), tensor(12659), tensor(3730), tenso...","[tensor(0), tensor(415), tensor(15086), tensor...","[tensor(0), tensor(139), tensor(26705), tensor...",3,"[tensor(0), tensor(35887), tensor(9354), tenso..."
3,23505889b94e880c3e89cff4ba119860,"[tensor(0), tensor(133), tensor(23602), tensor...","[tensor(0), tensor(28674), tensor(7716), tenso...","[tensor(0), tensor(2457), tensor(790), tensor(2)]","[tensor(0), tensor(25270), tensor(14294), tens...","[tensor(0), tensor(6462), tensor(6298), tensor...","[tensor(0), tensor(417), tensor(9401), tensor(...",2,"[tensor(0), tensor(35808), tensor(5), tensor(1..."
4,e8a8b3a2061aa0e6d7c6b522e9612824,"[tensor(0), tensor(2264), tensor(184), tensor(...","[tensor(0), tensor(35248), tensor(26623), tens...","[tensor(0), tensor(10936), tensor(30650), tens...","[tensor(0), tensor(438), tensor(17531), tensor...","[tensor(0), tensor(859), tensor(41605), tensor...","[tensor(0), tensor(10067), tensor(330), tensor...",3,"[tensor(0), tensor(859), tensor(41605), tensor..."
...,...,...,...,...,...,...,...,...,...
7186,28ab300ef821e57e19be3f757842dd62,"[tensor(0), tensor(13841), tensor(429), tensor...","[tensor(0), tensor(7110), tensor(443), tensor(2)]","[tensor(0), tensor(18924), tensor(30691), tens...","[tensor(0), tensor(627), tensor(19956), tensor...","[tensor(0), tensor(18924), tensor(912), tensor...","[tensor(0), tensor(23420), tensor(2)]",1,"[tensor(0), tensor(100), tensor(274), tensor(1..."
7187,f1eb6055fa8a1ec94fa6d710d9a6741b,"[tensor(0), tensor(27827), tensor(14), tensor(...","[tensor(0), tensor(16941), tensor(2)]","[tensor(0), tensor(17479), tensor(2)]","[tensor(0), tensor(32962), tensor(2)]","[tensor(0), tensor(29843), tensor(2)]","[tensor(0), tensor(25761), tensor(9866), tenso...",2,"[tensor(0), tensor(12211), tensor(10), tensor(..."
7188,f1b2a30a1facff543e055231c5f90dd0,"[tensor(0), tensor(2264), tensor(74), tensor(9...","[tensor(0), tensor(3865), tensor(33430), tenso...","[tensor(0), tensor(9442), tensor(35482), tenso...","[tensor(0), tensor(4321), tensor(418), tensor(2)]","[tensor(0), tensor(4321), tensor(1435), tensor...","[tensor(0), tensor(31647), tensor(70), tensor(2)]",4,"[tensor(0), tensor(3972), tensor(213), tensor(..."
7189,22d0eea15e10be56024fd00bb0e4f72f,"[tensor(0), tensor(13841), tensor(74), tensor(...","[tensor(0), tensor(1193), tensor(26307), tenso...","[tensor(0), tensor(462), tensor(26097), tensor...","[tensor(0), tensor(40179), tensor(2)]","[tensor(0), tensor(3998), tensor(45521), tenso...","[tensor(0), tensor(212), tensor(23203), tensor...",0,"[tensor(0), tensor(1779), tensor(19965), tenso..."


In [None]:

tok_model_nle = roberta.encode(model_nle)
tok_choices = [roberta.encode(ch) for ch in choices]

# embed the tokens using the last layer feature of the model
emb_model_nle = roberta.extract_features(tok_model_nle)
tok_choices = [roberta.extract_features(ch) for ch in tok_choices]


## Calculate SHAP Score

In [45]:
import transformers
import shap
import torch

device = torch.device('cuda:0')
# load the model
pmodel = transformers.pipeline('question-answering', device=device)

# define two predictions, one that outputs the logits for the range start,
# and the other for the range end
def f(questions, start):
    outs = []
    for q in questions:
        question, context = q.split("[SEP]")
        d = pmodel.tokenizer(question, context)
        out = pmodel.model.forward(**{k: torch.tensor(d[k]).reshape(1, -1).to(device) for k in d})
        logits = out.start_logits if start else out.end_logits
        outs.append(logits.reshape(-1).cpu().detach().numpy())
    return outs

def f_start(questions):
    return f(questions, True)
def f_end(questions):
    return f(questions, False)

# attach a dynamic output_names property to the models so we can plot the tokens at each output position
def out_names(inputs):
    question, context = inputs.split("[SEP]")
    d = pmodel.tokenizer(question, context)
    return [pmodel.tokenizer.decode([id]) for id in d["input_ids"]]
f_start.output_names = out_names
f_end.output_names = out_names

No model was supplied, defaulted to distilbert-base-cased-distilled-squad and revision 626af31 (https://huggingface.co/distilbert-base-cased-distilled-squad).
Using a pipeline without specifying a model name and revision in production is not recommended.


In [47]:
data = ["What is on the table?[SEP]When I got home today I saw my cat on the table, and my frog on the floor."]

# explainer_start = shap.Explainer(f_start, pmodel.tokenizer)
explainer_start = shap.Explainer(f_start, pmodel.tokenizer)

shap_values_start = explainer_start(data)

shap.plots.text(shap_values_start)

In [57]:
data = ["What is on the table?[SEP]When I got home today I saw my cat on the table, and my frog on the floor."]
def make_answer_scorer(answers):
    def f(questions):
        out = []
        for q in questions:
            question, context = q.split("[SEP]")
            results = pmodel(question, context, top_k=20)
            values = []
            for answer in answers:
                value = 0
                for result in results:
                    if result["answer"] == answer:
                        value = result["score"]
                        break
                values.append(value)
            out.append(values)
        return out
    f.output_names = answers
    return f

f_answers = make_answer_scorer(["my cat", "cat", "my frog"])

explainer_answers = shap.Explainer(f_answers, pmodel.tokenizer)
shap_values_answers = explainer_answers(data)

shap.plots.text(shap_values_answers)

In [4]:
shap_values_answers

.values =
array([[[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 5.69124413e-02, -5.52253337e-02,  8.65240455e-03],
        [-1.34780238e-02, -4.12751370e-02, -1.56310292e-02],
        [ 2.09620141e-03,  5.29225320e-02, -2.64582045e-02],
        [ 4.88220479e-02,  6.25952940e-02, -2.01617490e-02],
        [ 1.39899513e-01,  1.11438790e-01, -8.14635345e-02],
        [ 1.25926260e-02,  4.20484598e-02,  1.09544975e-02],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 3.52645899e-03, -3.15239110e-03,  8.66274782e-04],
        [ 3.52645899e-03, -3.15239110e-03,  8.66274782e-04],
        [ 3.20395688e-03, -1.45857349e-02,  4.32692908e-03],
        [ 3.20395688e-03, -9.66237406e-03,  4.04172768e-03],
        [ 3.94194270e-03,  2.23734954e-02, -5.09236679e-03],
        [ 3.94194270e-03,  2.25603571e-02, -3.52081549e-03],
        [-6.36603909e-04, -2.52694754e-02,  2.20085294e-03],
        [ 1.13435351e-01,  3.38443481e-02, -7.09942254e-03],
        [ 4.24

In [73]:
len(shap_values_answers.data[0])

29