### Use this notebook to reproduce results of Tables 3, 4, 5 (except Yelp Colloquial since that uses the feature based patcher). 

In [1]:
%load_ext autoreload
%autoreload 2

#### Load in the model

In [None]:
from eval_utils import load_model

path_name = '/u/scr/smurty/LanguageExplanations/trained_models/t5-large-sst-override_exp-final'
model_obj = load_model(path_name, primary_mode='exp_applies_predictor')

Some weights of T5ForConditionalGenerationMultipleHeads were not initialized from the model checkpoint at t5-large and are newly initialized: ['encoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


primary mode: exp_applies_predictor
splicing parts from pretrained model


Some weights of T5ForConditionalGeneration were not initialized from the model checkpoint at t5-large and are newly initialized: ['encoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


#### Helper functions for applying a single patch, and applying multiple patches

In [None]:
from eval_utils import predict_stuff
import numpy as np
import itertools

def apply_patch_soft(exp_applies_probs, baseline_probs, label_clause):    
    x = np.array([label_clause]).repeat(len(baseline_probs), 0)
    #print(x.shape)
    
    applies_prob = exp_applies_probs[:, 1].reshape(-1, 1)
    #print(applies_prob)
    return applies_prob * x + (1 - applies_prob) * baseline_probs

def get_scores_multiple_patches_hard(data, cond_list, examine=False):
    no_exps = [('', ex) for ex in data[0]]
    no_exp_probs = predict_stuff(no_exps, [0]*len(no_exps), model_obj, 'p1', verbose=False, mode='task_predictor')
    cond_probs = []
    interpret_probs = []
    
    all_patched_probs = []
    for idx, (cond, label_clause) in enumerate(cond_list):
        print("Applying patch {}".format(cond))
        contextualized = [(cond, ex) for ex in data[0]]
        output_probs = predict_stuff(contextualized, itertools.repeat(0), model_obj, 'p1', verbose=False)
        cond_probs.append(np.log(output_probs[:, 1])) # log(p(c | x))
        patched_probs = apply_patch_soft(output_probs, no_exp_probs, label_clause) #Pr(y | x, lp)    
        all_patched_probs.append(patched_probs[:, 1])
        
    # pick best patch and apply it! 
    all_patched_probs = np.stack(all_patched_probs, axis=1) # D x P
    cond_probs = np.stack(cond_probs, axis=1) # D x P
    best_patches = np.argmax(cond_probs, axis=1) # D x l
    
    ptrue = np.array([p[idx] for p, idx in zip(all_patched_probs, best_patches)])
    pfalse = 1.0 - ptrue
    return no_exp_probs, np.stack([pfalse, ptrue]).T



#### Results for Yelp-Stars

In [None]:
from data_fns import get_yelp_stars
tests_yelp = {'yelp_stars': get_yelp_stars()}

In [None]:
no_exp, ours = get_scores_multiple_patches_hard(tests_yelp['yelp_stars'], 
                                 [('review gives 1 or 2 stars', [1,0]), ('review gives zero stars', [1,0])])

In [None]:
data = tests_yelp['yelp_stars']
print(np.mean(no_exp.argmax(axis=1) == data[1]))
print(np.mean(ours.argmax(axis=1) == data[1]))




#### Results for WCR data

In [None]:
import pickle
import pickle
with open('wcr.pickle', 'rb') as reader:
    dataset = pickle.load(reader)
    
explanations = [('review says fit is boxy',[1,0]),
               ("review contains words or phrases like needs to be returned", [1, 0]),
               ("review contains words or phrases like needs to be exchanged", [1,0])]
    
no_exp, ours = get_scores_multiple_patches_hard(dataset, explanations)


In [None]:
print(np.mean(no_exp.argmax(axis=1) == dataset[1]))
print(np.mean(ours.argmax(axis=1) == dataset[1]))

#### Tables 3 and Tables 4 (Controlling the model with patches on Yelp.)



In [None]:
from data_fns import get_yelp_data
# set conflicting to True for Table-4 and False for Table-3
d1 = get_yelp_data(conflicting=False)
print(len(d1))

In [None]:
def cond2label_dict(cond, orig_label):
    is_food = 'food' in cond
    is_good = 'good' in cond
    
    label_name2label = {'positive': 1, 'negative': 0, 'NAN': -1}
    if is_food:
        dict_to_use = [label for label in orig_label if label['category'] == 'food'][0]
    else:
        dict_to_use = [label for label in orig_label if label['category'] == 'service'][0]
    # aspect sentiment. does patch apply
    return label_name2label[dict_to_use['polarity']], int(label_name2label[dict_to_use['polarity']] == is_good)


conds = [('food is good', [0,1]), ('service is good', [0,1]), ('food is bad',[1,0]), ('service is bad',[1,0])]
label_sets = {cond: [cond2label_dict(cond, l) for l in d1[1]] for cond, _ in conds}

In [None]:
def get_steering_acc(data, labels, cond_labels, cond, cons, use_exps=True):
    no_exps = [('', ex) for ex in data]
    no_exp_probs = predict_stuff(no_exps, [0]*len(no_exps), model_obj, 'p1', verbose=False, mode='task_predictor')
    no_exp_preds = no_exp_probs.argmax(axis=1)   
    
    if not use_exps:
        acc_1 = np.sum((no_exp_preds == labels) & cond_labels)
        return acc_1, np.sum(1-cond_labels), np.sum(cond_labels), np.sum(1-cond_labels)
    else:
        contextualized = [(cond, ex) for ex in data]
        output_probs = predict_stuff(contextualized, cond_labels, model_obj, 'p1', verbose=False)
        patched_probs = apply_patch_soft(output_probs, no_exp_probs, cons) #Pr(y | x, lp)
        patched_preds = patched_probs.argmax(axis=1)
        
        # if patch applies, how often is model correct
        acc_1 = np.sum((patched_preds == labels) & cond_labels)
    
        # if the patch doesn't apply, how often does the prediction say the same
        acc_2 = np.sum((patched_preds == no_exp_preds) & (1-cond_labels))
        return acc_1, acc_2, np.sum(cond_labels), np.sum(1-cond_labels)


def get_scores(conds, use_exps=True):
    t1 = 0.0
    t2 = 0.0

    total1 = 0.0
    total2 = 0.0

    for cond, cons in conds:
        curr = label_sets[cond]
        aspect_labels = np.array([a for a, _ in curr])
        cond_applies = np.array([ca for _, ca in curr])

        print(cond)
        t1_c, t2_c, total1_c, total2_c = get_steering_acc(d1[0], aspect_labels, cond_applies, cond, cons, use_exps=use_exps)
        t1 += t1_c
        t2 += t2_c
        total1 += total1_c
        total2 += total2_c    
    return t1 / total1, t2 / total2

In [None]:
s1, s2 = get_scores(conds)
print(s1,s2)

In [None]:
s1, s2 = get_scores(conds, use_exps=False)
print(s1, s2)