In [1]:
import pickle
from transformers import BertTokenizer
import gc
import numpy as np
import torch
from collections import Counter, defaultdict, namedtuple
from functools import partial

In [2]:
examples = None
with open("dev_parse.pickle", 'rb') as file:
    examples = pickle.load(file)

In [3]:
counts = Counter()
for example in examples:
    for label in example['labels']:
        counts[label] += 1

In [4]:
faulty_example = None
for i in range(len(examples)):
    try:
        assert examples[i]["attention_map"].shape[-1]-2 == len(examples[i]['words'])
    except:
        faulty_example = examples[i]

In [5]:
assert faulty_example is None

In [6]:
NUM_LAYERS,NUM_HEADS,_,_ = examples[0]['attention_map'].shape

In [7]:
def predict(example,layer,head,mode = "normal"):
    '''This function takes in an example (examples[i]) and uses the attention map of that example to 
    predict the head of each word in the example. The output is of same length as of example['words'] -2. -2 because predictor doesn't predict
    heads for [CLS] and [SEP] token
    
    layer -> layer in attention map to be used to predict
    head -> attention head in the the layer to predict heads
    mode -> if mode is 'normal' the head of word i is deemed to be the word j if j is the word `to` which i pays most attention i.e argmax(attn[layer][head][i]). In other words dependent is
            paying the most attention to it's head.
            if mode is 'transpose' the head of word i deemed to be the word j if j is the word `from` which i gets the most attention i.e argmax(attn[layer][head][:][i]) or argmax(attn.T[layer][head][i]).
            In other words, head is paying the most attention to dependent.
    '''
    
    attn = example['attention_map'][layer][head]
    
    if mode == "transpose": attn = attn.T
    attn[range(attn.shape[0]), range(attn.shape[0])] = 0 #ignoring the attention to self by setting diagonal elements to 0.
    
    attn = attn[1:-1, 1:-1] #ignoring attention from and to [CLS] and [SEP] token
    return np.argmax(attn, axis = -1) + 1 #because 0 prediction would mean the head is ROOT

In [8]:
def evaluate_predictor(examples, predictor):
    '''Takes in a bunch of examples and calculates the head prediction accuracy of each word in each examples and averages them up
    Additionally, also calculates the accuracy of each type of relation between words'''
    
    num_correct, num_incorrect = Counter(), Counter()
    all_labels = []
    for example in examples:
        words = example['words']
        labels = example['labels']
        heads = example['heads']
         
        predictions = predictor(example)
        
        assert len(predictions) == len(labels)
        
        for i, (prediction, label, head) in enumerate(zip(predictions, labels, heads)):
            if label != 'root' and label != 'punct':
                all_labels.append(label)
                if prediction == head:
                    num_correct[label] += 1
                    num_correct['all'] += 1

                else:
                    num_incorrect[label] += 1
                    num_incorrect['all'] += 1 
                
    return {label: num_correct[label]/ float(num_correct[label] + num_incorrect[label]) for label in all_labels}
            
        

In [9]:
def get_scores(dataset, mode = "normal"):
    scores = defaultdict(dict)
    global NUM_LAYERS, NUM_HEADS
    for layer in range(NUM_LAYERS):
        for head in range(NUM_HEADS):
            scores[layer][head] = evaluate_predictor(dataset, predictor = partial(predict, layer = layer, head = head, mode = mode))
    return scores

In [50]:
scores_dictionary = {
    "dep->head": get_scores(examples, mode = "normal"),
    "head->dep": get_scores(examples, mode = "transpose")
}

In [51]:
def get_relation_stats(scores_dictionary):
    global counts
    stats = namedtuple('Stats', field_names = ['accuracy', 'layer', 'head', 'mode'])
    relation_stats = {label:stats(0,0,0,None) for label in scores_dictionary['dep->head'][0][0].keys()}
    
    for mode,scores in scores_dictionary.items():
        for layer in range(NUM_LAYERS):
            for head in range(NUM_HEADS):
                for label in relation_stats.keys():
                    if scores[layer][head][label] >= relation_stats[label].accuracy:
                        relation_stats[label] = stats(scores[layer][head][label], layer, head, mode)
                        
    return relation_stats

In [52]:
rs = get_relation_stats(scores_dictionary)

In [53]:
rs

{'amod': Stats(accuracy=0.7583465818759937, layer=5, head=4, mode='dep->head'),
 'nsubj': Stats(accuracy=0.5866906474820144, layer=8, head=9, mode='dep->head'),
 'case': Stats(accuracy=0.757248981548047, layer=6, head=4, mode='dep->head'),
 'det': Stats(accuracy=0.9383829275623685, layer=6, head=4, mode='dep->head'),
 'compound': Stats(accuracy=0.7232704402515723, layer=6, head=10, mode='dep->head'),
 'nmod': Stats(accuracy=0.24612348463490274, layer=7, head=8, mode='dep->head'),
 'cc': Stats(accuracy=0.47952047952047955, layer=6, head=5, mode='dep->head'),
 'conj': Stats(accuracy=0.464746772591857, layer=11, head=11, mode='head->dep'),
 'dobj': Stats(accuracy=0.8216249236408063, layer=5, head=7, mode='dep->head'),
 'aux': Stats(accuracy=0.7387267904509284, layer=6, head=8, mode='dep->head'),
 'acl:relcl': Stats(accuracy=0.3793103448275862, layer=8, head=6, mode='dep->head'),
 'advmod': Stats(accuracy=0.5309423347398031, layer=5, head=6, mode='dep->head'),
 'ccomp': Stats(accuracy=0.41