# Analyze predictions

In [1]:
import ast
import matplotlib.pyplot as plt
import os
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

from captum.attr import visualization as viz
import numpy as np
import pandas as pd
import seaborn as sns

%matplotlib inline

In [2]:
# convert strings into lists
def literal_eval_column(column):
    try:
        return ast.literal_eval(column)
    except (SyntaxError, ValueError):
        return column

In [3]:
# read distilbert predictions
dist_test = pd.read_csv('outputs/distilbert_attributions.csv')
dist_test = dist_test.applymap(literal_eval_column)

In [46]:
dist_test.head()

Unnamed: 0.1,Unnamed: 0,index,attributions_pred,label_pred,score,tokens,label_gold,correct,sum_attributions_pred,mean_attributions_pred,attributions_gold,sum_attributions_gold,mean_attributions_gold
0,0,0,"[-0.011615987537462793, 0.003815996398574733, ...",rec.autos,0.801348,"[i, am, a, little, confused, on, all, of, the,...",rec.autos,True,2.790767,0.025371,"[-0.011615987537462793, 0.003815996398574733, ...",2.790767,0.025371
1,1,1,"[-0.32846125995080416, 0.021676696368081335, -...",comp.windows.x,0.915943,"[i, ', m, not, familiar, at, all, with, the, f...",comp.windows.x,True,3.255837,0.018395,"[-0.32846125995080416, 0.021676696368081335, -...",3.255837,0.018395
2,2,2,"[-0.014673215540331483, 0.1312478060786364, 0....",alt.atheism,0.072379,"[in, a, word, ,, yes, .]",alt.atheism,True,1.704919,0.284153,"[-0.014673215540331483, 0.1312478060786364, 0....",1.704919,0.284153
3,3,3,"[-0.6194980171999851, 0.0018658245313804055, 0...",talk.politics.mideast,0.808439,"[they, were, attacking, the, iraqi, ##s, to, d...",talk.politics.mideast,True,4.223599,0.008282,"[-0.6194980171999851, 0.0018658245313804055, 0...",4.223599,0.008282
4,4,4,"[-0.08235069488343288, 0.03930023488145007, -0...",alt.atheism,0.626613,"[i, ', ve, just, spent, two, solid, months, ar...",talk.religion.misc,False,1.368739,0.068437,"[-0.09691212604112631, 0.011566433822874374, -...",1.375414,0.068771


In [4]:
# read SVM predictions
svm_test = pd.read_csv('outputs/coefs_test.csv')
svm_test = svm_test.applymap(literal_eval_column)

In [5]:
# exclude erroneous indices
with open('nohup-att-dist4.out', 'r', encoding='utf-8') as f:
    lines = f.readlines()
error_indices = [int(l.split(' ')[0]) for l in lines[7:] if l]

svm_test = svm_test[~svm_test.index.isin(error_indices)]
svm_test = svm_test.reset_index()

In [45]:
svm_test.head()

Unnamed: 0.1,index,Unnamed: 0,true class no,true class name,pred class no,pred class name,feature ind,tfidf,feature names,coef true,coef pred,coef true*tfidf,coef pred*tfidf
0,0,0,7,rec.autos,4,comp.sys.mac.hardware,"[19173, 18937, 18270, 18198, 17388, 17145, 167...","[0.023220739480456238, 0.03033897195578605, 0....","[year, words, value, usually, time, tell, summ...","[-0.3644428172907367, -0.5038521269983025, -0....","[-0.4371961563133163, -0.43584702814775034, -0...","[-0.008462631715831708, -0.01528635555086465, ...","[-0.010152018047608341, -0.013223150763987291,..."
1,1,1,5,comp.windows.x,1,comp.graphics,"[18422, 18315, 17949, 16893, 15459, 15400, 135...","[0.02997111874877699, 0.04137978123401435, 0.0...","[view, ve, uncompressed, swamped, send, seeing...","[-0.3743281402679298, -0.36823803935498056, -0...","[-0.4741050755600308, -0.3507789408679207, -0....","[-0.011219033142978974, -0.015237609510551461,...","[-0.01420945951900757, -0.014515155834613814, ..."
2,2,2,0,alt.atheism,2,comp.os.ms-windows.misc,"[19184, 18931]","[0.4727116268736642, 0.5272883731263358]","[yes, word]","[-0.39628063158541726, -0.3763669255110641]","[-0.4034374258789668, -0.2686799939812319]","[-0.18732646205526576, -0.1984539038512898]","[-0.1907095619289697, -0.14167183691795746]"
3,3,3,17,talk.politics.mideast,0,alt.atheism,"[19195, 18958, 18912, 18734, 18598, 18315, 173...","[0.006670685520730314, 0.00481273683305861, 0....","[york, world, women, west, want, ve, times, ti...","[-0.29742243166061666, -0.3350764429380573, -0...","[-0.48446696985871884, -0.31526510095234483, -...","[-0.001984011508418877, -0.0016126347388182498...","[-0.003231726801108645, -0.001517287963531291,..."
4,4,4,19,talk.religion.misc,0,alt.atheism,"[18315, 17293, 16222, 16070, 12072, 11429, 114...","[0.07006597797491963, 0.07938683024646784, 0.1...","[ve, thing, spent, solid, objective, moral, mo...","[-0.44427465473197825, -0.5717712998446866, -0...","[-0.4059767700021195, -0.29545935499795983, -0...","[-0.031128538173265812, -0.0453911111205724, -...","[-0.02844515942529752, -0.023455581659953918, ..."


In [7]:
assert len(svm_test) == len(dist_test)

In [34]:
def merge_subwords(tokens, attributions, aggregate='mean'):
    assert len(tokens) == len(attributions)
    # skip merged tokens
    skip_next_iteration = 0
    tokens_merged = []
    attributions_merged = []
    for i, token in enumerate(tokens):
        if skip_next_iteration != 0:
            skip_next_iteration -= 1
            continue
        if token:
            word = token
            attribution = attributions[i]
            # detokenize
                # check whether next token starts with ##
            if i < len(tokens)-1 and tokens[i+1].startswith('##'):
                # check whether further tokens start with ##
                to_merge = tokens[i:]
                count = 1
                iterator = iter(to_merge[1:])
                while True:
                    try:
                        token = next(iterator)
                        if token.startswith("##"):
                            count += 1
                        else:
                            break
                    except StopIteration:
                        break
                to_merge = to_merge[:count]
                j = len(to_merge)
                word = ''.join(tokens[i:i+j]).replace('##', '')
                attr = [attributions[i:i+j]]
                if aggregate == 'mean':
                    attribution = np.mean(attr)
                skip_next_iteration = count-1
        tokens_merged.append(word)
        attributions_merged.append(attribution)
    assert len(tokens_merged) == len(attributions_merged)
    return tokens_merged, attributions_merged

# features (SVM) != tokens (DistilBERT)
def compare_pred(i, print_result=True):
    row_svm = svm_test.iloc[i]
    row_dist = dist_test.iloc[i]
    # initialize lists for common features and their attribution/coeff values
    feats = []
    val_svm_gold = []
    val_dist_gold = []
    val_svm_pred = []
    val_dist_pred = []
    # merge subwords and aggregate subword attributions
    tokens, attributions_pred = merge_subwords(row_dist.tokens, row_dist['attributions_pred'])
    tokens2, attributions_gold = merge_subwords(row_dist.tokens, row_dist['attributions_gold'])
    assert tokens == tokens2  # just in case
    for k, f in enumerate(row_svm['feature names']):
        if f in tokens:
            feats.append(f)
            val_svm_gold.append(row_svm['coef true*tfidf'][k])
            val_svm_pred.append(row_svm['coef pred*tfidf'][k])
            # feature more than once in sentence
            if tokens.count(f) > 1:
                indices = [j for j, token in enumerate(tokens) if token == f]
                val_dist_gold.append(np.mean([attributions_gold[j] for j in indices]))
                val_dist_pred.append(np.mean([attributions_pred[j] for j in indices]))
            else:
                # feature index in distilbert tokens
                dist_index = tokens.index(f)
                val_dist_gold.append(attributions_gold[dist_index])
                val_dist_pred.append(attributions_pred[dist_index])
        else:
            print(i, f'feature {f} not in token list')
    assert len(feats) == len(val_dist_gold)
    assert len(val_dist_gold) == len(val_dist_pred)
    assert len(val_dist_pred) == len(val_svm_gold)
    assert len(val_svm_gold) == len(val_svm_pred)
    correl_gold = np.corrcoef(val_dist_gold, val_svm_gold)
    correl_pred = np.corrcoef(val_dist_pred, val_svm_pred)
    if print_result:
        print(feats)
        print(correl_gold)
        print(correl_pred)
    return len(feats), len(tokens), correl_gold[0][1], correl_pred[0][1], feats

In [27]:
compare_pred(0)

['year', 'words', 'value', 'usually', 'time', 'tell', 'summer', 'spring', 'se', 'performance', 'models', 'model', 'mid', 'little', 'le', 'heard', 'features', 'far', 'early', 'differences', 'demand', 'curious', 'confused', 'buy', 'book', 'best', '89', '88']
[[ 1.         -0.10922089]
 [-0.10922089  1.        ]]
[[1.         0.08491424]
 [0.08491424 1.        ]]


(103,
 -0.10922089148550534,
 0.08491423527018695,
 ['year',
  'words',
  'value',
  'usually',
  'time',
  'tell',
  'summer',
  'spring',
  'se',
  'performance',
  'models',
  'model',
  'mid',
  'little',
  'le',
  'heard',
  'features',
  'far',
  'early',
  'differences',
  'demand',
  'curious',
  'confused',
  'buy',
  'book',
  'best',
  '89',
  '88'])

In [30]:
len([compare_pred(0, print_result=False)])

1

In [35]:
correls = pd.DataFrame(columns=['num_feats', 'num_tokens', 'correl_gold', 'correl_pred', 'features'])

num_instances = len(svm_test)

for i in range(num_instances):
    correls.loc[len(correls)] = list(compare_pred(i, print_result=False))

  avg = a.mean(axis, **keepdims_kw)
  ret = um.true_divide(
  c = cov(x, y, rowvar, dtype=dtype)
  c *= np.true_divide(1, fact)
  c *= np.true_divide(1, fact)
  avg = a.mean(axis, **keepdims_kw)
  ret = um.true_divide(
  c = cov(x, y, rowvar, dtype=dtype)
  c *= np.true_divide(1, fact)
  c *= np.true_divide(1, fact)
  avg = a.mean(axis, **keepdims_kw)
  ret = um.true_divide(
  c = cov(x, y, rowvar, dtype=dtype)
  c *= np.true_divide(1, fact)
  c *= np.true_divide(1, fact)
  avg = a.mean(axis, **keepdims_kw)
  ret = um.true_divide(
  c = cov(x, y, rowvar, dtype=dtype)
  c *= np.true_divide(1, fact)
  c *= np.true_divide(1, fact)


396 feature haven not in token list


  c = cov(x, y, rowvar, dtype=dtype)
  c *= np.true_divide(1, fact)
  c *= np.true_divide(1, fact)
  c = cov(x, y, rowvar, dtype=dtype)
  c *= np.true_divide(1, fact)
  c *= np.true_divide(1, fact)
  c = cov(x, y, rowvar, dtype=dtype)
  c *= np.true_divide(1, fact)
  c *= np.true_divide(1, fact)
  avg = a.mean(axis, **keepdims_kw)
  ret = um.true_divide(
  c = cov(x, y, rowvar, dtype=dtype)
  c *= np.true_divide(1, fact)
  c *= np.true_divide(1, fact)
  c = cov(x, y, rowvar, dtype=dtype)
  c *= np.true_divide(1, fact)
  c *= np.true_divide(1, fact)
  avg = a.mean(axis, **keepdims_kw)
  ret = um.true_divide(
  c = cov(x, y, rowvar, dtype=dtype)
  c *= np.true_divide(1, fact)
  c *= np.true_divide(1, fact)
  c = cov(x, y, rowvar, dtype=dtype)
  c *= np.true_divide(1, fact)
  c *= np.true_divide(1, fact)
  c = cov(x, y, rowvar, dtype=dtype)
  c *= np.true_divide(1, fact)
  c *= np.true_divide(1, fact)
  c = cov(x, y, rowvar, dtype=dtype)
  c *= np.true_divide(1, fact)
  c *= np.true_divid

In [55]:
correls.sort_values(by='correl_gold').head(n=50)

Unnamed: 0,num_feats,num_tokens,correl_gold,correl_pred,features
5340,2,4,-1.0,-1.0,"[simple, eh]"
6923,2,11,-1.0,-1.0,"[tesla, discussion]"
6876,2,4,-1.0,1.0,"[lou, gehrig]"
100,2,21,-1.0,1.0,"[ca, 1993apr21]"
6633,2,12,-1.0,1.0,"[really, 10]"
7148,2,14,-1.0,1.0,"[os, compared]"
5916,2,13,-1.0,1.0,"[dear, come]"
5737,2,7,-1.0,-1.0,"[sub, says]"
5510,2,6,-1.0,-1.0,"[urban, areas]"
5415,2,11,-1.0,-1.0,"[joe, agree]"


In [48]:
correls.sort_values(by='correl_gold', ascending=False).head()

Unnamed: 0,num_feats,num_tokens,correl_gold,correl_pred,features
751,2,6,1.0,1.0,"[getting, close]"
1147,2,6,1.0,-1.0,"[religion, alt]"
1303,2,10,1.0,-1.0,"[thing, heater]"
6614,2,14,1.0,1.0,"[let, clarify]"
5498,2,11,1.0,-1.0,"[okay, forgot]"


In [49]:
correls.sort_values(by='correl_pred').head()

Unnamed: 0,num_feats,num_tokens,correl_gold,correl_pred,features
489,2,6,1.0,-1.0,"[stuff, deleted]"
1303,2,10,1.0,-1.0,"[thing, heater]"
3127,2,7,1.0,-1.0,"[compuserve, com]"
3609,2,9,1.0,-1.0,"[reversed, got]"
782,2,13,1.0,-1.0,"[non, cold]"


In [50]:
correls.sort_values(by='correl_pred', ascending=False).head()

Unnamed: 0,num_feats,num_tokens,correl_gold,correl_pred,features
6633,2,12,-1.0,1.0,"[really, 10]"
1848,2,7,1.0,1.0,"[project, arizona]"
4000,2,7,-1.0,1.0,"[posted, ok]"
3338,2,6,-1.0,1.0,"[suggest, change]"
6264,2,11,1.0,1.0,"[trusted, tools]"


-> highest and lowest correlations for fewest features / shortest sentences!

In [36]:
correls.head()

Unnamed: 0,num_feats,num_tokens,correl_gold,correl_pred,features
0,28,103,-0.109221,0.084914,"[year, words, value, usually, time, tell, summ..."
1,28,166,0.106889,0.041291,"[view, ve, uncompressed, swamped, send, seeing..."
2,2,6,1.0,-1.0,"[yes, word]"
3,120,478,0.02276,0.126682,"[york, world, women, west, want, ve, times, ti..."
4,9,20,0.732862,0.42871,"[ve, thing, spent, solid, objective, moral, mo..."


In [37]:
# slightly higher correlation of gold class attributions/coefficients
np.mean(correls.correl_gold), np.mean(correls.correl_pred)

(0.2826740357767589, 0.25590147310534667)

In [38]:
# correlations not related to numbers of features or tokens in a sentence
correls.corr()

Unnamed: 0,num_feats,num_tokens,correl_gold,correl_pred
num_feats,1.0,0.944143,-0.017452,0.023974
num_tokens,0.944143,1.0,-0.039823,0.004972
correl_gold,-0.017452,-0.039823,1.0,0.620412
correl_pred,0.023974,0.004972,0.620412,1.0


In [None]:
correls.to_csv('outputs/predictions-correlations.csv')

In [None]:
# do some nice visualizations

In [40]:
type(svm_test['pred class name'][0]), type(dist_test['label_pred'][0])

(str, str)

In [42]:
label2id = {
    "alt.atheism": 0,
    "comp.graphics": 1,
    "comp.os.ms-windows.misc": 2,
    "comp.sys.ibm.pc.hardware": 3,
    "comp.sys.mac.hardware": 4,
    "comp.windows.x": 5,
    "misc.forsale": 6,
    "rec.autos": 7,
    "rec.motorcycles": 8,
    "rec.sport.baseball": 9,
    "rec.sport.hockey": 10,
    "sci.crypt": 11,
    "sci.electronics": 12,
    "sci.med": 13,
    "sci.space": 14,
    "soc.religion.christian": 15,
    "talk.politics.guns": 16,
    "talk.politics.mideast": 17,
    "talk.politics.misc": 18,
    "talk.religion.misc": 19
  }

In [43]:
dist_label_pred = [label2id[l] for l in dist_test['label_pred']]

In [44]:
# correlations
# compare predicted labels, have to ignore nan values with ma
np.corrcoef(np.ma.masked_invalid(svm_test['pred class no']),
            np.ma.masked_invalid(dist_label_pred))

array([[1.        , 0.71812724],
       [0.71812724, 1.        ]])