In [1]:
import pickle as pkl
import numpy as np
import torch
"""
for dataset in ['iwslt14', 'multi30k']:
    logs = pkl.load(open('data/{dataset}_logs.pkl'.format(dataset=dataset), 'rb'))

    num_datapoints = len(logs['sequences'])

    iterations = logs['normal_A'].keys()
    dicts = logs['sequences']
    for key in logs:
        if key == 'sequences':
            continue
        for iteration in iterations:
            for datapoint_idx in range(num_datapoints):
                dicts[datapoint_idx][(key, iteration, 'alpha')] = logs[key][iteration][datapoint_idx]['alpha']
                dicts[datapoint_idx][(key, iteration, 'beta')] = logs[key][iteration][datapoint_idx]['beta']
                dicts[datapoint_idx]['split'] = logs[key][iteration][datapoint_idx]['split']
    pkl.dump(dicts, open('data/{dataset}_logs_rz.pkl'.format(dataset=dataset), 'wb'))
    """
from utils import *

def max_acc_iter(name, metas, key):
    max_acc = -1.0
    max_acc_iter = None
    for key_ in metas:
        if key_[0] == name and key_[2] == key:
            if metas[key_] > max_acc:
                max_acc = metas[key_]
                max_acc_iter = key_[1]
    return max_acc_iter

def passing_idx(A1s, A2):
    for i in range(len(A1s)):
        if A1s[i] > A2:
            return i
    return None

def corrs_iter(dicts, key1, keys2, corr_metric, reverse=False):
    corrs = []
    baselines = []
    for key2 in keys2:
        if reverse:
                vals = corr_metric.eval_corr(dicts, key2, key1)
        else:
            vals = corr_metric.eval_corr(dicts, key1, key2)
        corrs.append(vals['correlation'])
        baselines.append(vals['baseline'])
    return corrs, baselines

def acc_iter(metas, keys):
    accs = []
    for key in keys:
        accs.append(metas[key])
    return accs

def max_corr(dicts, key1, keys2, metric, reverse=False):
    return max(corrs_iter(dicts, key1, keys2, metric, reverse=reverse)[0])

def impute_beta(dicts, beta_matrix, key_name):
    for item in dicts:
        betas = []
        for tok_trg in item['trg'][1:]:
            beta = []
            for tok_src in item['src']:
                beta.append(beta_matrix[tok_src][tok_trg])
            betas.append(beta)
        betas = np.array(betas)
        item[key_name] = betas

def flip_grads(dicts):
    for item in dicts:
        for key in item:
            if type(key) is tuple and key[2] == 'grad':
                item[key] = -item[key]

In [2]:
dataset = 'multi30k'
dat = pkl.load(open('outputs/{dataset}_logs.pkl'.format(dataset=dataset), 'rb'))
all_dicts = dat['data']
subset = 'val'
acc_metric = 'val_bleu'
dicts = [d for d in all_dicts if d['split'] == subset]
iterations = sorted(list(set([key[1] for key in dat['metas']])))
flip_grads(dicts)

In [3]:
embed_beta = pkl.load(open('outputs/{dataset}embedding256translation.pkl'.format(dataset=dataset), 'rb'))
impute_beta(dicts, embed_beta, 'embed_beta')

In [9]:
metric = TopPercentMatch(p=5)
normalA_iter = max_acc_iter('normal_A', dat['metas'], acc_metric)
normalB_iter = max_acc_iter('normal_B', dat['metas'], acc_metric)
uniform_iter = max_acc_iter('uniform', dat['metas'], acc_metric)

In [10]:
gold_alpha_key = ('normal_A', normalA_iter, 'alpha')
gold_grad_key = ('normal_A', normalA_iter, 'grad')

normal_keys = [('normal_B', iter_, 'alpha') for iter_ in iterations]
acc_keys = [('normal_B', iter_, acc_metric) for iter_ in iterations]

alpha_corrs, alpha_baseline = corrs_iter(dicts, gold_alpha_key, normal_keys, metric)
alpha_perfs = acc_iter(dat['metas'], acc_keys)

avg_corr = np.array(alpha_corrs)
avg_perf = np.array(alpha_perfs)
baseline = alpha_baseline[0]

In [11]:
print(dicts[0][gold_grad_key].shape)
print(dicts[0][gold_alpha_key].shape)

(14, 14)
(14, 14)


In [12]:
beta_unif_keys = [('uniform', iter_, 'beta') for iter_ in iterations]
beta_corr_unif = max_corr(dicts, gold_alpha_key, beta_unif_keys, metric)
beta_corr_grad = max_corr(dicts, gold_grad_key, beta_unif_keys, metric)
beta_corr_px = max_corr(dicts, 'embed_beta', beta_unif_keys, metric)

In [13]:
print(corrs_iter(dicts, gold_grad_key, beta_unif_keys, metric))

([0.06711990111248455, 0.04641532756489493, 0.04758961681087762, 0.08547589616810877, 0.08683559950556242, 0.09629171817058096, 0.10599505562422744, 0.10710754017305316, 0.10451174289245983, 0.10815822002472188, 0.10587144622991347, 0.1050061804697157, 0.10995055624227441, 0.10995055624227441, 0.10630407911001236, 0.1111248454882571, 0.11291718170580964, 0.11563658838071693, 0.12014833127317676, 0.11464771322620519], [0.07119901112484549, 0.07428924598269468, 0.07558714462299135, 0.07490729295426453, 0.06786155747836836, 0.07447466007416563, 0.06934487021013597, 0.07119901112484549, 0.07435105067985166, 0.06755253399258343, 0.07527812113720643, 0.07262051915945612, 0.07101359703337454, 0.07515451174289246, 0.07163164400494437, 0.07459826946847961, 0.08084054388133498, 0.08609394313967862, 0.08281829419035847, 0.08646477132262052])


In [30]:
for iter_1 in iterations:
    for iter_2 in iterations:
        alpha_key = ('normal_A', iter_1, 'alpha')
        grad_key = ('normal_B', iter_2, 'grad')
        print(iter_1, iter_2, metric.eval_corr(dicts, alpha_key, grad_key))

0 0 {'name': 'top 5% match', 'correlation': 0.043819530284301605, 'baseline': 0.07027194066749073}
0 50 {'name': 'top 5% match', 'correlation': 0.07441285537700865, 'baseline': 0.07391841779975278}
0 100 {'name': 'top 5% match', 'correlation': 0.10185414091470951, 'baseline': 0.07527812113720643}
0 500 {'name': 'top 5% match', 'correlation': 0.1064894932014833, 'baseline': 0.0695920889987639}
0 1000 {'name': 'top 5% match', 'correlation': 0.08943139678615575, 'baseline': 0.06977750309023485}
0 1500 {'name': 'top 5% match', 'correlation': 0.0930778739184178, 'baseline': 0.07342398022249691}
0 2000 {'name': 'top 5% match', 'correlation': 0.09548825710754018, 'baseline': 0.0742274412855377}
0 4000 {'name': 'top 5% match', 'correlation': 0.08510506798516687, 'baseline': 0.07175525339925834}
0 6000 {'name': 'top 5% match', 'correlation': 0.07917181705809642, 'baseline': 0.07088998763906057}
0 8000 {'name': 'top 5% match', 'correlation': 0.08028430160692213, 'baseline': 0.0695920889987639}
0

1000 50 {'name': 'top 5% match', 'correlation': 0.04431396786155748, 'baseline': 0.07045735475896168}
1000 100 {'name': 'top 5% match', 'correlation': 0.08022249690976514, 'baseline': 0.06674907292954264}
1000 500 {'name': 'top 5% match', 'correlation': 0.1372682323856613, 'baseline': 0.06742892459826946}
1000 1000 {'name': 'top 5% match', 'correlation': 0.15889987639060568, 'baseline': 0.0723114956736712}
1000 1500 {'name': 'top 5% match', 'correlation': 0.1553770086526576, 'baseline': 0.06742892459826946}
1000 2000 {'name': 'top 5% match', 'correlation': 0.16056860321384425, 'baseline': 0.07138442521631644}
1000 4000 {'name': 'top 5% match', 'correlation': 0.15216316440049443, 'baseline': 0.07601977750309023}
1000 6000 {'name': 'top 5% match', 'correlation': 0.14351050679851668, 'baseline': 0.07243510506798517}
1000 8000 {'name': 'top 5% match', 'correlation': 0.14752781211372065, 'baseline': 0.07119901112484549}
1000 10000 {'name': 'top 5% match', 'correlation': 0.1622991347342398, 

6000 0 {'name': 'top 5% match', 'correlation': 0.06217552533992583, 'baseline': 0.06829419035846725}
6000 50 {'name': 'top 5% match', 'correlation': 0.052410383189122375, 'baseline': 0.07299134734239802}
6000 100 {'name': 'top 5% match', 'correlation': 0.09042027194066748, 'baseline': 0.07126081582200247}
6000 500 {'name': 'top 5% match', 'correlation': 0.1496291718170581, 'baseline': 0.07441285537700865}
6000 1000 {'name': 'top 5% match', 'correlation': 0.18170580964153277, 'baseline': 0.07379480840543881}
6000 1500 {'name': 'top 5% match', 'correlation': 0.18275648949320147, 'baseline': 0.07453646477132261}
6000 2000 {'name': 'top 5% match', 'correlation': 0.19635352286773794, 'baseline': 0.07243510506798517}
6000 4000 {'name': 'top 5% match', 'correlation': 0.19703337453646477, 'baseline': 0.06909765142150803}
6000 6000 {'name': 'top 5% match', 'correlation': 0.1900494437577256, 'baseline': 0.06854140914709518}
6000 8000 {'name': 'top 5% match', 'correlation': 0.2023485784919654, 'b

12000 36000 {'name': 'top 5% match', 'correlation': 0.17558714462299135, 'baseline': 0.06878862793572312}
14000 0 {'name': 'top 5% match', 'correlation': 0.07021013597033375, 'baseline': 0.06625463535228678}
14000 50 {'name': 'top 5% match', 'correlation': 0.0546353522867738, 'baseline': 0.07478368355995056}
14000 100 {'name': 'top 5% match', 'correlation': 0.08689740420271941, 'baseline': 0.07212608158220024}
14000 500 {'name': 'top 5% match', 'correlation': 0.14406674907292955, 'baseline': 0.07243510506798517}
14000 1000 {'name': 'top 5% match', 'correlation': 0.17762669962917182, 'baseline': 0.07360939431396786}
14000 1500 {'name': 'top 5% match', 'correlation': 0.18022249690976513, 'baseline': 0.06779975278121138}
14000 2000 {'name': 'top 5% match', 'correlation': 0.19449938195302843, 'baseline': 0.07126081582200247}
14000 4000 {'name': 'top 5% match', 'correlation': 0.19202719406674906, 'baseline': 0.07305315203955501}
14000 6000 {'name': 'top 5% match', 'correlation': 0.187762669

20000 32000 {'name': 'top 5% match', 'correlation': 0.17941903584672436, 'baseline': 0.06761433868974041}
20000 36000 {'name': 'top 5% match', 'correlation': 0.1780593325092707, 'baseline': 0.06749072929542645}
24000 0 {'name': 'top 5% match', 'correlation': 0.06866501854140915, 'baseline': 0.07175525339925834}
24000 50 {'name': 'top 5% match', 'correlation': 0.054573547589616814, 'baseline': 0.07255871446229914}
24000 100 {'name': 'top 5% match', 'correlation': 0.08609394313967862, 'baseline': 0.07447466007416563}
24000 500 {'name': 'top 5% match', 'correlation': 0.14295426452410384, 'baseline': 0.07021013597033375}
24000 1000 {'name': 'top 5% match', 'correlation': 0.1766378244746601, 'baseline': 0.07082818294190359}
24000 1500 {'name': 'top 5% match', 'correlation': 0.180778739184178, 'baseline': 0.0757725587144623}
24000 2000 {'name': 'top 5% match', 'correlation': 0.19177997527812113, 'baseline': 0.07262051915945612}
24000 4000 {'name': 'top 5% match', 'correlation': 0.19313967861

36000 28000 {'name': 'top 5% match', 'correlation': 0.1888751545117429, 'baseline': 0.07571075401730532}
36000 32000 {'name': 'top 5% match', 'correlation': 0.17941903584672436, 'baseline': 0.07169344870210136}
36000 36000 {'name': 'top 5% match', 'correlation': 0.17781211372064276, 'baseline': 0.06711990111248455}


In [15]:
best_acc = dat['metas'][('normal_A', normalA_iter, acc_metric)]
idx_unif = passing_idx(avg_corr, beta_corr_unif)
idx_grad = passing_idx(avg_corr, beta_corr_grad)
idx_px = passing_idx(avg_corr, beta_corr_px)
def print_perf(idx):
    if idx is None:
        return None
    else:
        return avg_perf[idx]
print(list(zip(iterations, avg_corr)))
print(idx_unif, idx_grad, idx_px)
print(beta_corr_unif, beta_corr_px, beta_corr_grad, baseline, print_perf(idx_unif), print_perf(idx_px), print_perf(idx_grad), best_acc)



[(0, 0.11613102595797281), (50, 0.11971569839307787), (100, 0.26180469715698396), (500, 0.39721878862793575), (1000, 0.6247836835599505), (1500, 0.7185414091470952), (2000, 0.761495673671199), (4000, 0.8174907292954264), (6000, 0.8368974042027194), (8000, 0.8673053152039555), (10000, 0.8686032138442522), (12000, 0.8628553770086527), (14000, 0.8634734239802225), (16000, 0.8716934487021014), (18000, 0.8724351050679852), (20000, 0.8719406674907293), (24000, 0.8822620519159456), (28000, 0.8751545117428925), (32000, 0.8810259579728059), (36000, 0.880407911001236)]
3 2 4
0.380840543881335 0.4254017305315204 0.17311495673671198 0.1284301606922126 6.913682490159368 11.428225554342985 1.9858724184501733 37.88620336527713


In [31]:
# metrics.append(SpearmanRankCorr()) comment out because this takes a lot of time to evaluate
key1, key2 = ('normal_A', normalA_iter, 'alpha'), ('normal_B', normalB_iter, 'beta')
print(key1, key2)
for metric in metrics:
    print(metric.eval_corr(dicts, key1, key2))

('normal_A', 98000, 'alpha') ('normal_B', 98000, 'beta')
{'name': 'top 1 match', 'correlation': 0.4906809735378769, 'baseline': 0.029517236240347983}
{'name': 'top 3 match', 'correlation': 0.630424455807975, 'baseline': 0.08793696034872998}
{'name': 'top 5 match', 'correlation': 0.6835303322435521, 'baseline': 0.1448337850802433}
{'name': 'top 5% match', 'correlation': 0.6016337403712707, 'baseline': 0.06422257616825476}


In [32]:
key1, key2 = ('normal_A', iterations[-1], 'alpha'), ('uniform', iterations[-1], 'beta')
print(key1, key2)
for metric in metrics:
    print(metric.eval_corr(dicts, key1, key2))

('normal_A', 98000, 'alpha') ('uniform', 98000, 'beta')
{'name': 'top 1 match', 'correlation': 0.26383882415402243, 'baseline': 0.031557083112117065}
{'name': 'top 3 match', 'correlation': 0.3738182393979192, 'baseline': 0.0881372192881959}
{'name': 'top 5 match', 'correlation': 0.43500433118171405, 'baseline': 0.14561619209955198}
{'name': 'top 5% match', 'correlation': 0.34707202801762277, 'baseline': 0.06557315971349}


In [33]:
key1, key2 = ('uniform', iterations[-1], 'beta'), 'embed256_beta'
print(key1, key2)
for metric in metrics:
    print(metric.eval_corr(dicts, key1, key2))

('uniform', 98000, 'beta') embed256_beta
{'name': 'top 1 match', 'correlation': 0.34697888432484797, 'baseline': 0.040349847710062316}
{'name': 'top 3 match', 'correlation': 0.41936084798017903, 'baseline': 0.09972895185402521}
{'name': 'top 5 match', 'correlation': 0.47952701632808936, 'baseline': 0.1557176255809838}
{'name': 'top 5% match', 'correlation': 0.39109173722301394, 'baseline': 0.07727200752601038}


In [34]:
key1, key2 = ('uniform', iterations[-1], 'beta'), 'IBM_beta'
print(key1, key2)
for metric in metrics:
    print(metric.eval_corr(dicts, key1, key2))

('uniform', 98000, 'beta') IBM_beta


KeyError: 'IBM_beta'