# Manual Qualitative Check of Attention Weight Explanations

In [1]:
import yaml
import yaml
from yaml.loader import SafeLoader
import sys
sys.path.append("..")
import src.evaluation.eraserbenchmark.rationale_benchmark.utils as EU
import src.evaluation.eraserbenchmark.rationale_benchmark.metrics as EM
from src.data.locations import LOC
import numpy as np

# load predictions and attentions
with open('../data/experiments/curious_darkness_lime5k/predictions_attentions.yaml') as f:
    preds_attn = yaml.load(f, Loader=SafeLoader)

# get documents
docids = list(preds_attn.keys())
documents = EM.load_flattened_documents('../'+LOC['cose'], docids=docids)
annotations = EU.annotations_from_jsonl(f"../{LOC['cose_val']}")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ASSERTS
# all keys are the same (len will then also be the same)
assert list(preds_attn.keys()) == docids
# len of attention vectors == len of document tokens
assert [len(x['attn']) for x in list(preds_attn.values())] == [len(x) for x in list(documents.values())]
# annotations and docids from same split
annotation_docids = [list(x.evidences)[0][0].docid for x in annotations]
assert all([x in annotation_docids for x in docids])

## Attention Weight Bar Plot

In [302]:
import pandas as pd
import plotly.express as px
import random

# look at a random sample
rnd_i = round(random.random() * len(documents))
doc_id = docids[rnd_i]

# get sample
sample = list(
    filter(lambda x: list(x.evidences)[0][0].docid == doc_id,
    map(lambda x: x, annotations
)))[0]
# get rationale
evidence = list(sample.evidences)[0][0]
c1 = '#ba4227'
c2 = '#5869e8'
color_seq = [c1 if x in range(evidence.start_token, evidence.end_token) 
             else c2 for x in range(len(documents[doc_id]))]
# get answers
answers = sample.query.replace('[sep]','|')
truth = answers.split(" | ")[{'A':0,'B':1,'C':2,'D':3,'E':4,}[sample.classification]]
prediction = answers.split(" | ")[np.argmax(preds_attn[doc_id]['probas'])]

# plot
df = pd.DataFrame(
                list(zip(
                    documents[doc_id], 
                    preds_attn[doc_id]['attn'],
                    color_seq
                )), 
                columns=['tokens', 'attn', 'color'])
fig = px.bar(df,
             title=" ".join(documents[doc_id]),
             x='tokens', 
             y='attn',
             color = 'color',
             category_orders ={'tokens': documents[doc_id]}
            )
fig.update_xaxes(tickangle=45)

# Additional data
print(list(sample.evidences)[0][0].docid)
print('Q:', " ".join(documents[doc_id]))
print('A:', answers)
print('evidence:', list(sample.evidences)[0][0].text)
print('prediction:', prediction)
print('truth:', truth)
fig.show()

290c711d7be219ba9870c08237b136dd
Q: Joe was n't prepared to wage war . He had not raised what ?
A: weapons | asserting power | plans | energy | armies
evidence: Joe was n't prepared to wage war . He had not raised what ?
prediction: energy
truth: armies


Findings:
- Wh- words (notably Where) receive very high weights from the system)
- Stop words are equally considered gold as semantically richer words
    - because annotators mark spans of text
    - the system tends to rank these 'fillers' lower
- Sometimes the system can rank words unsensibly high, possibly due to confusion of words
    - eg 3a92d14522b56fea9a99634d659904ea
- Many of the annotations make sense with the prediction (=comprehensiveness), but are actually wrong. So comprehensiveness can be very high for a system with low competence
    - eg 7428b0c780428d64116d060520b9fa0c
- A substantial amount of gold rationales have each word marked as explanation...
    - see section below
- both comp and suff are functions of accuracy
    - mb we should look for metrics looking at rationales alone?

## Aggreement (cohen's $\kappa$)
What is the agreement between the systems weights and gold rationales?  
top_x tokens by weights are chosen, where x is the amount of tokens selected for the gold rationale of the sample.

In [118]:
from sklearn.metrics import cohen_kappa_score, precision_score

# mask of marked evidences (gold)
evidence_masks = []
evidence_lens = []
for x in annotations:
    start = list(x.evidences)[0][0].start_token
    end = list(x.evidences)[0][0].end_token
    l = len(documents[list(x.evidences)[0][0].docid])
    evidence_masks.append(
        [1 if i in range(start,end) else 0 for i in range(l)]
    )
    evidence_lens.append(end-start)

# mask of top_k weighted tokens according to system
top_masks = []
for i in range(len(docids)):
    attn = preds_attn[docids[i]]['attn']
    top_k = evidence_lens[i]
    top_idx = np.sort(np.argpartition(attn,-top_k)[-top_k:])
    top_masks.append(
        [1 if x in top_idx else 0 for x in range(len(attn))]
    )

# ASSERTS
assert len(evidence_masks) == len(top_masks)
assert all([len(evidence_masks[i]) == len(top_masks[i])])

# SCORE
_scoring = cohen_kappa_score # precision_score 
aggreement = []
for i in range(len(evidence_masks)):
    if all(evidence_masks[i]) and all(top_masks[i]): break # TODO what do w. non-rationales
    else: score = _scoring(evidence_masks[i], top_masks[i])
    aggreement.append(score)

np.mean(aggreement)

0.4444444444444444

## How many gold rationales are non_rationales?
having each token of the question selected as being part of the rationale,
(so len(rationale) == len(tokens))

In [76]:
count = 0
for x in annotations:
    start = list(x.evidences)[0][0].start_token
    end = list(x.evidences)[0][0].end_token
    l = len(documents[list(x.evidences)[0][0].docid])
    count += int(end-start==l)
    
count / len(annotations)

0.3351749539594843

pretty much exactly a third of the rationales are "useless" is this on purpose?

In [220]:
non_fractions = []
for x in annotations:
    start = list(x.evidences)[0][0].start_token
    end = list(x.evidences)[0][0].end_token
    l = len(documents[list(x.evidences)[0][0].docid])
    non_fractions.append((end-start)/l)

fig = px.histogram(
    pd.DataFrame(non_fractions, columns=['fract']),
    title='How much of the questions was annotated as rationale (gold)?',
    x='fract',
    nbins=20
)
fig.show()