In [2]:
from warnings import simplefilter
simplefilter(action='ignore', category=DeprecationWarning)
simplefilter(action='ignore', category=FutureWarning)
from lit_classes import GeceProdigyData, GectorBertModel, GECE_ERROR_TYPES
from attention_analysis import attention_analysis
# reuse the LIT classes as they're already packaged for analysis
import pickle


try:
    data_dict = {}
    for name in ['pearson', 'argmax', 'regression']:
        with open(name+'.pkl', 'rb') as f:
            data_dict[name] = pickle.load(f)

    pearson = data_dict['pearson']
    regression = data_dict['regression']
    argmax = data_dict['argmax']
    print('Loaded analysis data from pickle files')
except FileNotFoundError as ex:
    print('Found no pickled data, running data through model')
    model = GectorBertModel('bert_0_gector.th')
    data = GeceProdigyData('test_sample.jsonl', gece_tags=True)
    print('Loaded {} examples'.format(len(data)))
    pearson, regression, argmax = attention_analysis(model, data, model.ATTENTION_HEADS,
                                                     model.ATTENTION_LAYERS, model.MAX_LEN)

    data_and_name = [(pearson, 'pearson'),
                       (regression, 'regression'),
                       (argmax, 'argmax')]
    for data, name in data_and_name:
        with open(name+'.pkl', 'wb') as f:
            pickle.dump(data, f)

Found no pickled data, running data through model
Loaded 20 examples


processing batches:   0%|          | 0/1 [00:00<?, ?it/s]

processing results:   0%|          | 0/20 [00:00<?, ?it/s]

  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


In [7]:
pearson_list =[]
argmax_list = []
regression_list = []

for et in GECE_ERROR_TYPES:
    pearson_by_layer = pearson[et]
    regression_by_layer = regression[et]
    argmax_by_layer = argmax[et]

    head_total = len(next(iter(argmax_by_layer.values())))
    # these should all be the same shape
    for layer in argmax_by_layer:
        for idx in range(0, head_total):
            if idx == head_total - 1:
                head = 'head_average'
            else:
                head = 'head{}'.format(idx)

            regression_list.append((regression_by_layer[layer][idx], layer, head, et))
            argmax_list.append((argmax_by_layer[layer][idx], layer, head, et))
            pearson_list.append((*pearson_by_layer[layer][idx], layer, head, et))


In [8]:
argmax_list.sort(key=lambda x: x[0], reverse=True)
regression_list.sort(key=lambda x: x[0], reverse=True)
pearson_list.sort(key=lambda x: x[1])

max_tense_regression = max(x for x in regression_list if x[3] == 'TENSE')
max_tense_argmax = max(x for x in argmax_list if x[3] == 'TENSE')
max_plural_regression = max(x for x in regression_list if x[3] == 'SVA')
max_plural_argmax = max(x for x in argmax_list if x[3] == 'SVA')

averaged_tense_regression = next(x for x in regression_list
if x[3] == 'TENSE' and 'average' in x[1] and 'average' in x[2])
averaged_plural_regression = next(x for x in regression_list
if x[3] == 'SVA' and 'average' in x[1] and 'average' in x[2])


averaged_tense_argmax= next(x for x in argmax_list
if x[3] == 'TENSE' and 'average' in x[1] and 'average' in x[2])
averaged_plural_argmax = next(x for x in argmax_list
if x[3] == 'SVA' and 'average' in x[1] and 'average' in x[2])

tense = [x for x in regression_list if x[3] == 'TENSE']