In [1]:
%cd ../../

/home/gergopool/work/uva/atcs/Language-Specific-Subnetworks


In [2]:
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
import os
import torch
import glob
import numpy as np
from functools import partial
from IPython.display import Latex

from data import ALLOWED_LANGUAGES, ALLOWED_DATASETS
if 'wikiann' in ALLOWED_DATASETS:
    ALLOWED_DATASETS.remove('wikiann')

%matplotlib inline

## Stability

In [3]:
def compare(compare_fn):
    sim_values = []
    for task in ALLOWED_DATASETS:
        for lang in ALLOWED_LANGUAGES:
            for seed1 in range(5):
                seed2 = (seed1+1)%5
                value = compare_fn(task, lang, seed1, seed2)
                if value is not None:
                    sim_values.append(value)
    return np.mean(sim_values), np.std(sim_values)

def compare_jaccard(task, lang, seed1, seed2):
    mask1 = torch.load(f"results/pruned_masks/{task}/{lang}_{seed1}.pkl").bool()
    mask2 = torch.load(f"results/pruned_masks/{task}/{lang}_{seed2}.pkl").bool()
    sim = (mask1&mask2).sum() / (mask1|mask2).sum()
    return sim.item()

def compare_cka(task, lang, seed1, seed2):
    filename = f"{task}_{lang}_{seed1}_{task}_{lang}_{seed2}"
    sim = torch.load(f"results/cka/across_seeds/{filename}.pkl").diag().mean()
    if not torch.isnan(sim):
        return sim.item()

def _compare_stitch(task, lang, seed1, seed2, df=None):
    w1 = df.front_seed == seed1
    w2 = df.end_seed = seed2
    w3 = df.front_lang == lang
    w4 = df.front_model == task
    x = df.loc[w1&w2&w3&w4]
    if len(x):
        return min(x['sim_acc'].mean(), 1)
    
# Prepare stitching
stitch_df = pd.read_csv('results/stitch/stitch_across_seeds.csv')
stitch_df['sim_acc'] = stitch_df.stitch_acc / stitch_df.end_acc
compare_stitch = partial(_compare_stitch, df=stitch_df)

# Similarities
jaccard = compare(compare_jaccard)
cka = compare(compare_cka)
stitch = compare(compare_stitch)

In [4]:
caption = "Average of metrics over 5 pairs of seeds per language, per task."
label = "stability"
table = pd.DataFrame({
    "Jaccard" : [f"${jaccard[0]:.2f} \pm({jaccard[1]:.2f})$"],
    "CKA" : [f"${cka[0]:.2f} \pm({cka[1]:.2f})$"],
    "RA" : [f"${stitch[0]:.2f} \pm({stitch[1]:.2f})$"],
}).to_latex(escape=False, label=label, column_format='ccc', index=False)

table = table.replace("\end{tabular}", "\end{tabular}\n\label{table:"+label+"}")
table = table.replace("\end{tabular}", "\end{tabular}\n\caption{"+caption+"}")
print(table)

# print(pd.Series(data=[f"${x[0]:.2f} \pm({x[1]:.2f})" for x in [jaccard, cka, stitch]],
#                 index=['Jaccard', 'CKA', 'RA'])
#                 .to_latex(escape=False, caption=caption, label=label, column_format='ccc'))

\begin{table}
\centering
\label{stability}
\begin{tabular}{ccc}
\toprule
         Jaccard &              CKA &               RA \\
\midrule
$0.50 \pm(0.14)$ & $0.72 \pm(0.08)$ & $0.99 \pm(0.01)$ \\
\bottomrule
\end{tabular}
\caption{Average of metrics over 5 pairs of seeds per language, per task.}
\label{table:stability}
\end{table}



  table = pd.DataFrame({
