# IMDB dataset with the BERT model
Applying model explanations to language data especially when using large models like BERT creates some technical complications

1. When considering each word as a feature even short texts have hundreds of features.
2. Quering the BERT model to make single predictions is really inefficient. 

To circumvent these diffiuclties we use the following approach:

1. Instead of caluclating the influence of every word in every sentence we first find the most important sentence (according to a linear feature explanation measure) and then in a second step calculate the influence of words in this important sentence. We also restrict our attention to short reviews.
2. We treat BERT as an external module. We first create a prediction dataset for BERT, then make all the predictions and then calculate the feature importance over the predictions 

In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np 
import pandas as pd
from sklearn import tree
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
import BII
#import shap
from scipy.sparse import csr_matrix
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import itertools
import warnings

In [36]:
path_to_IMDB_testset = "data/IMDB/movie_data/test.tsv"

In [37]:
test = pd.read_csv(path_to_IMDB_testset,sep="\t")

In [6]:
test["split"] = test["text"].str.split(".")

In [7]:
test["split"] = test["split"].transform(lambda split: [word for word in split if len(word) > 1 ])

In [8]:
short_reviews = test[["label","split"]][test["split"].str.len() < 10]

In [9]:
short_reviews

Unnamed: 0,label,split
1,1,[For only doing a few movies with his life the...
2,0,"[From all the rave reviews, we couldn't wait t..."
4,1,"[Remarkably well done, but under-recognized be..."
6,0,"[I hated the book, A guy meets a smart dog, g..."
8,0,[This has got to be one of the worst films I h...
...,...,...
12493,0,"[I won't lie, I rented this film because it wa..."
12494,0,[This feels very stilted and patronizing to a ...
12495,1,"[I've seen this movie at least 8 times, and I ..."
12496,0,[Saw this as part of the DC Reel Affirmations ...


In [10]:
def binarize(input_text):
    result = []
    m = len(input_text)
    d = np.arange(2**m)
    mask = np.ma.make_mask((((d[:,None] & (1 << np.arange(m)))) > 0).astype(int))
    input_array = np.asarray(input_text)
    for row in mask:
        result.append(" ".join(input_array[row]))
    return result

In [11]:
texts = []
labels = [] 
for review in short_reviews.iterrows():
    binarized_review = binarize(review[1]["split"])
    texts.append( binarized_review)
    labels.append(np.repeat(review[1]["label"],len(binarized_review)))
texts_concat = (list(itertools.chain.from_iterable(texts)))
labels_concat = (list(itertools.chain.from_iterable(labels)))
binarized_short_reviews = pd.DataFrame(texts_concat,columns=["text"])
binarized_short_reviews["label"] = labels_concat
binarized_short_reviews =binarized_short_reviews[["label","text"]]

In [12]:
len(binarized_short_reviews)

1052862

In [705]:
for counter in range (4):
    os.mkdir("../../Data/imdb/sentences_{}".format(counter))
binarized_short_reviews.iloc[:250000].to_csv("../../Data/imdb/sentences_{}/dev.tsv".format(0),sep="\t",index=False)
binarized_short_reviews.iloc[250000:500000].to_csv("../../Data/imdb/sentences_{}/dev.tsv".format(1),sep="\t",index=False)
binarized_short_reviews.iloc[500000:750000].to_csv("../../Data/imdb/sentences_{}/dev.tsv".format(2),sep="\t",index=False)
binarized_short_reviews.iloc[750000:].to_csv("../../Data/imdb/sentences_{}/dev.tsv".format(3),sep="\t",index=False)

In [13]:
outputs_1 = pd.read_csv("../Bert/bert-sentiment-IMDB/data/outputs/imdb/sentences_1.txt",header=None)
outputs_2 = pd.read_csv("../Bert/bert-sentiment-IMDB/data/outputs/imdb/sentences_2.txt",header=None)
outputs_3 = pd.read_csv("../Bert/bert-sentiment-IMDB/data/outputs/imdb/sentences_3.txt",header=None)
outputs_4 = pd.read_csv("../Bert/bert-sentiment-IMDB/data/outputs/imdb/sentences_4.txt",header=None)

In [14]:
outputs_sentences =np.asarray( pd.concat([outputs_1, outputs_2, outputs_3, outputs_4]))

In [15]:
len(outputs_sentences)

1052862

In [25]:
def softmax(x):
    return np.exp(x)/np.sum(np.exp(x),axis=1)[:,None]

In [19]:
def create_mask_subset(m,i):
    return np.tile(np.asarray([False,True]).repeat(2**i),2**(m-i))

In [20]:
def exact_banzhaf(i,m,values):
    S_i = create_mask_subset(m,i)
    not_S_i = np.logical_not(S_i)
    return 2**(-(m-1))*(np.sum(values[S_i]) - np.sum(values[not_S_i]))

In [21]:
def average_prediction_around(i,j,m,values):
    i = create_mask_subset(m,i)
    j = create_mask_subset(m,j)
    S_i = np.logical_and(i,np.logical_not(j))
    return np.mean(values[S_i])

In [22]:
def exact_bin_banzhaf(i,j,m,values):
    i = create_mask_subset(m,i)
    j = create_mask_subset(m,j)
    S_i = np.logical_and(i,np.logical_not(j))
    S_j = np.logical_and(j,np.logical_not(i))
    S_ij = np.logical_and(j,i)
    not_S_ij = np.logical_and(np.logical_not(j),np.logical_not(i))
    return 2**(-(m-2))*(np.sum(values[S_ij]) -np.sum(values[S_j]) -np.sum(values[S_i]) + np.sum(values[not_S_ij]))

In [23]:
def calculate_linear_Banzhaf_values_from_outputs(outputs):
    outputs = np.asarray(outputs)
    num_features = int(np.log2(len(outputs)))
    banzhaf_values = np.zeros([num_features,num_features])
    for i in range(num_features):
        banzhaf_values[i,i] = exact_banzhaf(i,num_features-1,softmax(outputs)[:,1])
        for j in range(i+1,num_features):
            banzhaf_values[i,j] = exact_bin_banzhaf(i,j,num_features-1,softmax(outputs)[:,1])
    return banzhaf_values

In [26]:
values= []
start =0 
for review in short_reviews.iterrows():
    end = start+2**len(review[1]["split"])
    values.append(calculate_linear_Banzhaf_values_from_outputs(outputs_sentences[start:end]))
    start = end

In [27]:
short_reviews["values"] = values

In [28]:
short_reviews["max_sentence_number"] = short_reviews["values"].transform(lambda x : np.argmax(np.abs(np.diagonal(x))))

In [29]:
short_reviews["max_sentence"] = short_reviews.apply(lambda x: x['split'][x["max_sentence_number"]],axis=1)
short_reviews["before_max_sentence"] = short_reviews.apply(lambda x: " ".join(x['split'][:x["max_sentence_number"]]),axis=1)
short_reviews["after_max_sentence"] = short_reviews.apply(lambda x: " ".join(x['split'][(x["max_sentence_number"]+1):]),axis=1)

In [30]:
max_sentences = short_reviews[["label","max_sentence", "before_max_sentence", "after_max_sentence"]][(short_reviews["max_sentence"].apply(lambda x: len(x.split(" "))) < 19)]

In [31]:
max_sentences.rename(columns = {"max_sentence": "text"}, inplace = True)

In [32]:
max_sentences["text"] = max_sentences["text"].replace(r'\n','', regex=True) 
max_sentences["before_max_sentence"] = max_sentences["before_max_sentence"].replace(r'\n','', regex=True) 
max_sentences["after_max_sentence"] = max_sentences["after_max_sentence"].replace(r'\n','', regex=True) 

In [33]:
max_sentences

Unnamed: 0,label,text,before_max_sentence,after_max_sentence
1,1,Tommy Boy is a classic and we will always rem...,For only doing a few movies with his life the ...,From appearing on Saturday NIGHT LIVE to doin...
4,1,"Has big-studio feel, understands the horror g...","Remarkably well done, but under-recognized bec...","Could have been a college cult classic, if it..."
16,1,The story took such twist & turns that made i...,Nurse Betty was definitely one of the most cre...,"If you're sick of the recent formula movies, ..."
18,0,Monotonous murky drama with an endless drone ...,"Looking all of 29 years old, Rob Lowe is a det...",This is a good substitute sleeping remedy if y...
32,1,The Cormen B-movie style is all over this pup...,Black Scorpion is Roger Cormen's Batman Which...,There are plenty of stunts and hot babes to m...
...,...,...,...,...
12479,0,Their was hardly any action at all and the ch...,I was actually looking forward to this movie ...,The only saving grace was Omar Epps and even ...
12481,0,This Movie Was In My Opinion Very Ignorant! Th...,,The Police Procedure Was Unrealistic The Car...
12482,1,It is one movie that you MUST see or you have ...,Certainly this proves beyond a shadow of doubt...,I rank it 58 in the top 100 films of all time...
12488,1,Good Effects and Acting make this movie a mus...,Critters 4 is a good movie A bit of a twist t...,I would recommend this to Horror/Science Fict...


In [813]:
for counter, review in enumerate(max_sentences.iterrows()):
    print(counter,end=',')
    binarized_review = binarize(review[1]["split"])
    texts_concat = binarized_review
    labels_concat = np.repeat(review[1]["label"],len(binarized_review))
    binarized_max_sentences = pd.DataFrame(texts_concat,columns=["text"])
    binarized_max_sentences["label"] = labels_concat
    binarized_max_sentences =binarized_max_sentences[["label","text"]]
    os.mkdir("../../Data/imdb/{}".format(counter))
    binarized_max_sentences.to_csv("../../Data/imdb/{}/dev.tsv".format(counter),sep="\t",index=False)

0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,262,263,264,265,266,267,268,269,270,271,272,273,274,275,276,27

In [1034]:
%%time
for number in range(900,1000):
    print(number,end=",")
    sentence_number = max_sentences["split"].iloc[number]
    prediction_number = max_sentences["predictions"].iloc[number]
    outputs_number = pd.read_csv("../Bert/bert-sentiment-IMDB/data/outputs/{}/predict_results.txt".format(number),header=None)
    banzhaf_number = calculate_values(outputs_number) 
    max_sentences["values"].iloc[number] = banzhaf_number

900,901,902,903,904,905,906,907,908,909,910,911,912,913,914,915,916,917,918,919,920,921,922,923,924,925,926,927,928,929,930,931,932,933,934,935,936,937,938,939,940,941,942,943,944,945,946,947,948,949,950,951,952,953,954,955,956,957,958,959,960,961,962,963,964,965,966,967,968,969,970,971,972,973,974,975,976,977,978,979,980,981,982,983,984,985,986,987,988,989,990,991,992,993,994,995,996,997,998,999,CPU times: user 42.6 s, sys: 158 ms, total: 42.7 s
Wall time: 42.7 s


## Lime

In [1278]:
max_sentences["lime_values"] = None

In [12]:
class Wrapper_for_lime():
    def __init__(self, look_up_table, mainstring, binarized_instance=None):
        self.look_up_table = look_up_table
        self.mainstring = mainstring
        self.binarized_instance = binarized_instance
    
    def find_number_of_substring(self,substring):
        if " " in substring:
            sub_seuqence = substring.split(" ")
        else:
            sub_seuqence = []
        sequence = self.mainstring.split(" ")
        number = 0
        for position,word in enumerate(sub_seuqence):
            if len(word) == len(sequence[position]):
                number += 2**(position)
        return number  
    
    def predict_proba(self, list_of_substrings):
        result = np.zeros([len(list_of_substrings),2])
        for i,substring in enumerate(list_of_substrings):
            number = self.find_number_of_substring(substring)
            result[i] = self.look_up_table[number]
            if self.binarized_instance is not None:
                print(number,"###", self.binarized_instance[number],"####",substring)
        return result

In [5]:
from lime.lime_text import LimeTextExplainer

In [1289]:
%%time
warnings.filterwarnings('ignore')
explainer = LimeTextExplainer(class_names=["0","1"],split_expression=" ")
for number in range(1000,2000):
    print(number, end=",")
    outputs_number = pd.read_csv("../Bert/bert-sentiment-IMDB/data/outputs/{}/predict_results.txt".format(number),header=None)
    lookup_table = np.asarray(outputs_number)
    instance = max_sentences["text"].iloc[number]
    #binarized_instance = binarize(instance.split(" "))
    wrapper_for_lime_number = Wrapper_for_lime(lookup_table, instance)
    exp = explainer.explain_instance(instance, wrapper_for_lime_number.predict_proba, num_features=int(np.log2(len(lookup_table))))
    lime_values = np.zeros(len(max_sentences["split"].iloc[number]))
    for (word,value) in exp.as_list():
        indices = [i for i, x in enumerate(max_sentences["split"].iloc[number]) if x == word]
        for index in indices:
            lime_values[index] = value
    max_sentences["lime_values"].iloc[number] = lime_values

1000,1001,1002,1003,1004,1005,1006,1007,1008,1009,1010,1011,1012,1013,1014,1015,1016,1017,1018,1019,1020,1021,1022,1023,1024,1025,1026,1027,1028,1029,1030,1031,1032,1033,1034,1035,1036,1037,1038,1039,1040,1041,1042,1043,1044,1045,1046,1047,1048,1049,1050,1051,1052,1053,1054,1055,1056,1057,1058,1059,1060,1061,1062,1063,1064,1065,1066,1067,1068,1069,1070,1071,1072,1073,1074,1075,1076,1077,1078,1079,1080,1081,1082,1083,1084,1085,1086,1087,1088,1089,1090,1091,1092,1093,1094,1095,1096,1097,1098,1099,1100,1101,1102,1103,1104,1105,1106,1107,1108,1109,1110,1111,1112,1113,1114,1115,1116,1117,1118,1119,1120,1121,1122,1123,1124,1125,1126,1127,1128,1129,1130,1131,1132,1133,1134,1135,1136,1137,1138,1139,1140,1141,1142,1143,1144,1145,1146,1147,1148,1149,1150,1151,1152,1153,1154,1155,1156,1157,1158,1159,1160,1161,1162,1163,1164,1165,1166,1167,1168,1169,1170,1171,1172,1173,1174,1175,1176,1177,1178,1179,1180,1181,1182,1183,1184,1185,1186,1187,1188,1189,1190,1191,1192,1193,1194,1195,1196,1197,1198,1199,

In [1385]:
max_sentences.to_pickle("max_sentences.pkl")