# Quantifying Interlingua Across Models

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import pickle
from collections import defaultdict

import torch
import numpy as np
import pandas as pd
from scipy import spatial, stats
from collections import defaultdict
from sklearn.metrics import f1_score, accuracy_score

from procrustes import orthogonal
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, AutoModelForSeq2SeqLM
from datasets import load_dataset, load_from_disk

from util import encode_batch

In [3]:
model_class = "mT5"

# Encode dataset with models

In [4]:
# !python encode_dataset_with_models.py {'mT5'}


# Solving the task

In [None]:
hf_model_ids = ['google/mt5-small',
                'google/mt5-base',
                'google/mt5-large',
                'google/mt5-xl',
                'google/mt5-xxl']

langs = ['en', 'fr', 'de', 'et', 'ru']

# Accuracy

In [6]:
# !python run_analysis.py mT5 acc
# !python run_analysis.py mT5 cka
# !python run_analysis.py mT5 acc-cent
# !python run_analysis.py mT5 acc-procrustes

In [None]:
%%time

mean_f1s = defaultdict(lambda: defaultdict(list))
mean_cosmatrix = defaultdict(lambda: defaultdict(list))

for hf_model_id in list(reversed(hf_model_ids)):
    print(f"\n\n{hf_model_id}")
    
    # load datasets for needed model
    dataset = {}
    for lang in langs:
        dataset[lang] = load_from_disk(f"../experiments/encoded_datasets/xnli/{hf_model_id.split('/')[-1]}/{lang}")

    
    src = dataset['en']
    num_layers = sum([n.startswith("mean") for n in src.column_names])
        
    for lang in langs:
        if lang == "en":
          continue
        
        print(f"\n pair: en-{lang}")

        tgt = dataset[lang]
        
        for l in range(num_layers):
            print(f"l{l}", end = ' ')
            d = compute_cosine_gpu(src[f'mean_{l}'], tgt[f'mean_{l}'])
            s = accuracy_score(list(range(len(d))), d.argmax(axis=1))
            mean_f1s[hf_model_id][lang].append(s)
            mean_cosmatrix[hf_model_id][lang].append(d)

print('\n\nFinished \n')

# save
scores_dfs = dict(mean_f1s)
scores_dfs = {k: dict(v) for k, v in scores_dfs.items()}

pickle.dump(scores_dfs, open(f"../experiments/encoded_datasets/xnli/{model_class}-accuracies-all_models.pkl", 'wb'))
print('saved accs')

cosmatrix_dfs = dict(mean_cosmatrix)
cosmatrix_dfs = {k: dict(v) for k, v in cosmatrix_dfs.items()}

# pickle.dump(cosmatrix_dfs, open(f"../experiments/encoded_datasets/xnli/cosine_matrices-all_models.pkl", 'wb'))
# print('saved cosines')

# Accuracy Centered

In [None]:
%%time

mean_f1s = defaultdict(lambda: defaultdict(list))
mean_cosmatrix = defaultdict(lambda: defaultdict(list))

for hf_model_id in list(reversed(hf_model_ids)):
    print(f"\n\n{hf_model_id}")
    
    # load datasets for needed model
    dataset = {}
    for lang in langs:
        dataset[lang] = load_from_disk(f"../experiments/encoded_datasets/xnli/{hf_model_id.split('/')[-1]}/{lang}")

    
    src = dataset['en']
    num_layers = sum([n.startswith("mean") for n in src.column_names])
        
    for lang in langs:
        if lang == "en":
          continue
        
        print(f"\n pair: en-{lang}")

        tgt = dataset[lang]
        
        for l in range(num_layers):
            d = compute_cosine_gpu(src[f'mean_{l}'], tgt[f'mean_{l}'], center=True)
            s = accuracy_score(list(range(len(d))), d.argmax(axis=1))
            mean_f1s[hf_model_id][f"en-{lang}"].append(s)
            mean_cosmatrix[hf_model_id][f"en-{lang}"].append(d)
            print(f"l{l}: {s}", end = ' ')
            
print('\n\nFinished \n')

# save
cent_scores_dfs = dict(mean_f1s)
cent_scores_dfs = {k: dict(v) for k, v in cent_scores_dfs.items()}

pickle.dump(cent_scores_dfs, open(f"../experiments/encoded_datasets/xnli/{model_class}-cent-accuracies-all_models.pkl", 'wb'))
print('saved centered accs')

cosmatrix_dfs = dict(mean_cosmatrix)
cosmatrix_dfs = {k: dict(v) for k, v in cosmatrix_dfs.items()}

# pickle.dump(cosmatrix_dfs, open(f"../experiments/encoded_datasets/xnli/cosine_matrices-all_models.pkl", 'wb'))
# print('saved cosines')

# Accuracy procrustes (centered + rotated)

In [None]:
%%time

mean_f1s = defaultdict(lambda: defaultdict(list))
# mean_cosmatrix = defaultdict(lambda: defaultdict(list))

for hf_model_id in list(reversed(hf_model_ids)):
    print(f"\n\n{hf_model_id}")
    
    # load datasets for needed model
    dataset = {}
    for lang in langs:
        dataset[lang] = load_from_disk(f"../experiments/encoded_datasets/xnli/{hf_model_id.split('/')[-1]}/{lang}")

    
    src = dataset['en']
    num_layers = sum([n.startswith("mean") for n in src.column_names])
        
    for lang in langs:
        if lang == "en":
          continue
        
        print(f"\n pair: en-{lang}")

        tgt = dataset[lang]
        
        for l in range(num_layers):
            d = compute_cosine_gpu(src[f'mean_{l}'], tgt[f'mean_{l}'], procrustes=True)
            s = accuracy_score(list(range(len(d))), d.argmax(axis=1))
            mean_f1s[hf_model_id][f"en-{lang}"].append(s)
            # mean_cosmatrix[hf_model_id][f"en-{lang}"].append(d)
            print(f"l{l}: {s}", end = ' ')
            
print('\n\nFinished \n')

# save
procr_scores_dfs = dict(mean_f1s)
procr_scores_dfs = {k: dict(v) for k, v in procr_scores_dfs.items()}

pickle.dump(procr_scores_dfs, open(f"../experiments/encoded_datasets/xnli/{model_class}-procrustes-accuracies-all_models.pkl", 'wb'))
print('saved procrustes accs')

# cosmatrix_dfs = dict(mean_cosmatrix)
# cosmatrix_dfs = {k: dict(v) for k, v in cosmatrix_dfs.items()}

# pickle.dump(cosmatrix_dfs, open(f"../experiments/encoded_datasets/xnli/cosine_matrices-all_models.pkl", 'wb'))
# print('saved cosines')

facebook/xlm-roberta-xxl

 pair: en-fr
l0: 0.8974 

# CKA

In [None]:
from ecco import analysis

In [None]:
%%time

ckas = defaultdict(lambda: defaultdict(list))

for hf_model_id in list(reversed(hf_model_ids)):
    print(f"\n\n{hf_model_id}")
    
    # load datasets for needed model
    dataset = {}
    for lang in langs:
        dataset[lang] = load_from_disk(f"../experiments/encoded_datasets/xnli/{hf_model_id.split('/')[-1]}/{lang}")

    
    src = dataset['en']
    num_layers = sum([n.startswith("mean") for n in src.column_names])
        
    for lang in langs:
        if lang == "en":
          continue
        
        print(f"\n pair: en-{lang}")

        tgt = dataset[lang]
        
        for l in range(num_layers):
            print(f"l{l}", end = ' ')
            s = analysis.cka(np.array(src[f'mean_{l}']).T, np.array(tgt[f'mean_{l}']).T)
            ckas[hf_model_id][f"en-{lang}"].append(s)

print('\n\nFinished \n')

# save
cka_dfs = dict(ckas)
cka_dfs = {k: dict(v) for k, v in cka_dfs.items()}

pickle.dump(cka_dfs, open(f"../experiments/encoded_datasets/xnli/{model_class}-cka-all_models.pkl", 'wb'))
print('saved ckas')

# pickle.dump(cosmatrix_dfs, open(f"../experiments/encoded_datasets/xnli/cosine_matrices-all_models.pkl", 'wb'))
# print('saved cosines')

# Plot

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

## Accuracy

In [None]:
# ACCURACY

scores_dfs = pickle.load(open(f"../experiments/encoded_datasets/xnli/{model_class}-accuracies-all_models.pkl", 'rb'))

scores_dfs = {k: pd.DataFrame(dict(v)).melt(var_name="lang", value_name="accuracy", ignore_index=False) for k, v in scores_dfs.items()}

for k, v in scores_dfs.items():
    scores_dfs[k]['layer'] = v.index
    scores_dfs[k]['model'] = k.split('/')[-1]
    
scores_dfs = pd.concat(scores_dfs, axis=0, ignore_index=True)


sns.set(font_scale=1.5) 
sns.set_style("ticks")

scores_dfs.rename(columns = {'lang':'pair'}, inplace = True)

scores_dfs['pair'] = [f"en-{p}" for p in scores_dfs['pair']] 

g = sns.relplot(data=scores_dfs, 
            x="layer", 
            y="accuracy", 
            hue="model", 
            style="model", 
            markers=True, 
            col="pair", 
            kind="line", 
            linewidth=3.5,
            markersize=12,
            facet_kws={"legend_out": True})


plt.ylim(0,1)


g._legend.remove()
g.figure.legend(ncol=4).set_bbox_to_anchor([0.67, 1.07])


plt.savefig(f"../assets/{model_class}-acc-models-four_langs.pdf", dpi=300,  bbox_inches='tight')

fig, ax = plt.subplots(figsize=(11, 4))

g = sns.lineplot(data=scores_dfs, 
            x="layer", 
            y="accuracy", 
            hue="model", 
            style="model", 
            markers=True, 
            #kind="line", 
            linewidth=3.5,
            markersize=12,
            ax=ax)


plt.ylim(0,1)

lgd = plt.legend(prop={'size': 12})

g.figure.savefig(f"../assets/{model_class}-acc-models-langs_joined.pdf", dpi=300, bbox_extra_artists=(lgd,), bbox_inches='tight')

fig, ax = plt.subplots(figsize=(11, 4))

g = sns.lineplot(data=scores_dfs, 
            x="layer", 
            y="accuracy", 
            hue="model", 
            style="model", 
            markers=True, 
            #kind="line", 
            linewidth=3.5,
            markersize=12,
            ax=ax)


plt.xticks(range(49))
ax.tick_params(axis='x', rotation=45, labelsize=9)
plt.ylim(0,1)
lgd = plt.legend(prop={'size': 12})

g.figure.savefig(f"../assets/{model_class}-acc-models-langs_joined-xticks.pdf", dpi=300, bbox_extra_artists=(lgd,), bbox_inches='tight')



scores_dfs_rel = scores_dfs.copy()

for hf_model_id in hf_model_ids:
    mid = hf_model_id.split('/')[-1]
    num_layers = scores_dfs_rel[scores_dfs_rel['model'] == mid]['layer'].max()
    scores_dfs_rel.loc[scores_dfs_rel.model == mid, 'layer'] = scores_dfs_rel.loc[scores_dfs_rel.model == mid, 'layer'] / num_layers
scores_dfs_rel.rename(columns={"layer": "network depth"}, inplace=True)
    
fig, ax = plt.subplots(figsize=(11, 4))

g = sns.lineplot(data=scores_dfs_rel, 
            x="network depth", 
            y="accuracy", 
            hue="model", 
            style="model", 
            markers=True, 
            #kind="line", 
            linewidth=3.5,
            markersize=12,
            ax=ax)


plt.ylim(0,1)

lgd = plt.legend(prop={'size': 12})

g.figure.savefig(f"../assets/{model_class}-acc-models-langs_joined-rel.pdf", dpi=300, bbox_extra_artists=(lgd,), bbox_inches='tight')


# Centered Accuracy

In [None]:
# ACCURACY

scores_dfs = pickle.load(open(f"../experiments/encoded_datasets/xnli/{model_class}-cent-accuracies-all_models.pkl", 'rb'))

scores_dfs = {k: pd.DataFrame(dict(v)).melt(var_name="lang", value_name="accuracy", ignore_index=False) for k, v in scores_dfs.items()}

for k, v in scores_dfs.items():
    scores_dfs[k]['layer'] = v.index
    scores_dfs[k]['model'] = k.split('/')[-1]
    
scores_dfs = pd.concat(scores_dfs, axis=0, ignore_index=True)


sns.set(font_scale=1.5) 
sns.set_style("ticks")

scores_dfs.rename(columns = {'lang':'pair'}, inplace = True)

scores_dfs['pair'] = [f"en-{p}" for p in scores_dfs['pair']] 

g = sns.relplot(data=scores_dfs, 
            x="layer", 
            y="accuracy", 
            hue="model", 
            style="model", 
            markers=True, 
            col="pair", 
            kind="line", 
            linewidth=3.5,
            markersize=12,
            facet_kws={"legend_out": True})


plt.ylim(0,1)


g._legend.remove()
g.figure.legend(ncol=4).set_bbox_to_anchor([0.67, 1.07])


plt.savefig(f"../assets/{model_class}-cent-acc-models-four_langs.pdf", dpi=300,  bbox_inches='tight')

fig, ax = plt.subplots(figsize=(11, 4))

g = sns.lineplot(data=scores_dfs, 
            x="layer", 
            y="accuracy", 
            hue="model", 
            style="model", 
            markers=True, 
            #kind="line", 
            linewidth=3.5,
            markersize=12,
            ax=ax)


plt.ylim(0,1)

lgd = plt.legend(prop={'size': 12})

g.figure.savefig(f"../assets/{model_class}-cent-acc-models-langs_joined.pdf", dpi=300, bbox_extra_artists=(lgd,), bbox_inches='tight')

fig, ax = plt.subplots(figsize=(11, 4))

g = sns.lineplot(data=scores_dfs, 
            x="layer", 
            y="accuracy", 
            hue="model", 
            style="model", 
            markers=True, 
            #kind="line", 
            linewidth=3.5,
            markersize=12,
            ax=ax)


plt.xticks(range(49))
ax.tick_params(axis='x', rotation=45, labelsize=9)
plt.ylim(0,1)
lgd = plt.legend(prop={'size': 12})

g.figure.savefig(f"../assets/{model_class}-cent-acc-models-langs_joined-xticks.pdf", dpi=300, bbox_extra_artists=(lgd,), bbox_inches='tight')



scores_dfs_rel = scores_dfs.copy()

for hf_model_id in hf_model_ids:
    mid = hf_model_id.split('/')[-1]
    num_layers = scores_dfs_rel[scores_dfs_rel['model'] == mid]['layer'].max()
    scores_dfs_rel.loc[scores_dfs_rel.model == mid, 'layer'] = scores_dfs_rel.loc[scores_dfs_rel.model == mid, 'layer'] / num_layers
scores_dfs_rel.rename(columns={"layer": "network depth"}, inplace=True)
    
fig, ax = plt.subplots(figsize=(11, 4))

g = sns.lineplot(data=scores_dfs_rel, 
            x="network depth", 
            y="accuracy", 
            hue="model", 
            style="model", 
            markers=True, 
            #kind="line", 
            linewidth=3.5,
            markersize=12,
            ax=ax)


plt.ylim(0,1)

lgd = plt.legend(prop={'size': 12})

g.figure.savefig(f"../assets/{model_class}-cent-acc-models-langs_joined-rel.pdf", dpi=300, bbox_extra_artists=(lgd,), bbox_inches='tight')


# Rotated accuracy

In [None]:
# ACCURACY

scores_dfs = pickle.load(open(f"../experiments/encoded_datasets/xnli/{model_class}-procrustes-accuracies-all_models.pkl", 'rb'))

scores_dfs = {k: pd.DataFrame(dict(v)).melt(var_name="lang", value_name="accuracy", ignore_index=False) for k, v in scores_dfs.items()}

for k, v in scores_dfs.items():
    scores_dfs[k]['layer'] = v.index
    scores_dfs[k]['model'] = k.split('/')[-1]
    
scores_dfs = pd.concat(scores_dfs, axis=0, ignore_index=True)


sns.set(font_scale=1.5) 
sns.set_style("ticks")

scores_dfs.rename(columns = {'lang':'pair'}, inplace = True)

scores_dfs['pair'] = [f"en-{p}" for p in scores_dfs['pair']] 

g = sns.relplot(data=scores_dfs, 
            x="layer", 
            y="accuracy", 
            hue="model", 
            style="model", 
            markers=True, 
            col="pair", 
            kind="line", 
            linewidth=3.5,
            markersize=12,
            facet_kws={"legend_out": True})


plt.ylim(0,1)


g._legend.remove()
g.figure.legend(ncol=4).set_bbox_to_anchor([0.67, 1.07])


plt.savefig(f"../assets/{model_class}-procr-acc-models-four_langs.pdf", dpi=300,  bbox_inches='tight')

fig, ax = plt.subplots(figsize=(11, 4))

g = sns.lineplot(data=scores_dfs, 
            x="layer", 
            y="accuracy", 
            hue="model", 
            style="model", 
            markers=True, 
            #kind="line", 
            linewidth=3.5,
            markersize=12,
            ax=ax)


plt.ylim(0,1)

lgd = plt.legend(prop={'size': 12})

g.figure.savefig(f"../assets/{model_class}-procr-acc-models-langs_joined.pdf", dpi=300, bbox_extra_artists=(lgd,), bbox_inches='tight')

fig, ax = plt.subplots(figsize=(11, 4))

g = sns.lineplot(data=scores_dfs, 
            x="layer", 
            y="accuracy", 
            hue="model", 
            style="model", 
            markers=True, 
            #kind="line", 
            linewidth=3.5,
            markersize=12,
            ax=ax)


plt.xticks(range(49))
ax.tick_params(axis='x', rotation=45, labelsize=9)
plt.ylim(0,1)
lgd = plt.legend(prop={'size': 12})

g.figure.savefig(f"../assets/{model_class}-procr-acc-models-langs_joined-xticks.pdf", dpi=300, bbox_extra_artists=(lgd,), bbox_inches='tight')



scores_dfs_rel = scores_dfs.copy()

for hf_model_id in hf_model_ids:
    mid = hf_model_id.split('/')[-1]
    num_layers = scores_dfs_rel[scores_dfs_rel['model'] == mid]['layer'].max()
    scores_dfs_rel.loc[scores_dfs_rel.model == mid, 'layer'] = scores_dfs_rel.loc[scores_dfs_rel.model == mid, 'layer'] / num_layers
scores_dfs_rel.rename(columns={"layer": "network depth"}, inplace=True)
    
fig, ax = plt.subplots(figsize=(11, 4))

g = sns.lineplot(data=scores_dfs_rel, 
            x="network depth", 
            y="accuracy", 
            hue="model", 
            style="model", 
            markers=True, 
            #kind="line", 
            linewidth=3.5,
            markersize=12,
            ax=ax)


plt.ylim(0,1)

lgd = plt.legend(prop={'size': 12})

g.figure.savefig(f"../assets/{model_class}-procr-acc-models-langs_joined-rel.pdf", dpi=300, bbox_extra_artists=(lgd,), bbox_inches='tight')


## CKA

In [None]:
# CKA

cka_dfs = pickle.load(open(f"../experiments/encoded_datasets/xnli/{model_class}-cka-all_models.pkl", 'rb'))

cka_dfs = {k: pd.DataFrame(dict(v)).melt(var_name="pair", value_name="CKA score", ignore_index=False) for k, v in cka_dfs.items()}

for k, v in cka_dfs.items():
    cka_dfs[k]['layer'] = v.index
    cka_dfs[k]['model'] = k.split('/')[-1]
    
cka_dfs = pd.concat(cka_dfs, axis=0, ignore_index=True)


sns.set(font_scale=1.5) 
sns.set_style("ticks")

g = sns.relplot(data=cka_dfs, 
            x="layer", 
            y="CKA score", 
            hue="model", 
            style="model", 
            markers=True, 
            col="pair", 
            kind="line", 
            linewidth=3.5,
            markersize=12,
            facet_kws={"legend_out": True})

# g.add_legend(ncol=4)

# plt.xticks(range(13))
plt.ylim(0,1)


g._legend.remove()
g.figure.legend(ncol=4).set_bbox_to_anchor([0.67, 1.07])


plt.savefig(f"../assets/{model_class}-cka-models-four_langs.pdf", dpi=300,  bbox_inches='tight')

fig, ax = plt.subplots(figsize=(11, 4))

g = sns.lineplot(data=cka_dfs, 
            x="layer", 
            y="CKA score", 
            hue="model", 
            style="model", 
            markers=True, 
            #kind="line", 
            linewidth=3.5,
            markersize=12,
            ax=ax)


plt.ylim(0,1)

lgd = plt.legend(prop={'size': 12})

g.figure.savefig(f"../assets/{model_class}-cka-models-langs_joined.pdf", dpi=300, bbox_extra_artists=(lgd,), bbox_inches='tight')

fig, ax = plt.subplots(figsize=(11, 4))

g = sns.lineplot(data=cka_dfs, 
            x="layer", 
            y="CKA score", 
            hue="model", 
            style="model", 
            markers=True, 
            #kind="line", 
            linewidth=3.5,
            markersize=12,
            ax=ax)


plt.xticks(range(49))
ax.tick_params(axis='x', rotation=45, labelsize=9)
plt.ylim(0,1)
lgd = plt.legend(prop={'size': 12})

g.figure.savefig(f"../assets/{model_class}-cka-models-langs_joined-xticks.pdf", dpi=300, bbox_extra_artists=(lgd,), bbox_inches='tight')


cka_dfs_rel = cka_dfs.copy()

for hf_model_id in hf_model_ids:
    mid = hf_model_id.split('/')[-1]
    num_layers = cka_dfs_rel[cka_dfs_rel['model'] == mid]['layer'].max()
    cka_dfs_rel.loc[cka_dfs_rel.model == mid, 'layer'] = cka_dfs_rel.loc[cka_dfs_rel.model == mid, 'layer'] / num_layers
cka_dfs_rel.rename(columns={"layer": "network depth"}, inplace=True)
    
fig, ax = plt.subplots(figsize=(11, 4))

g = sns.lineplot(data=cka_dfs_rel, 
            x="network depth", 
            y="CKA score", 
            hue="model", 
            style="model", 
            markers=True, 
            #kind="line", 
            linewidth=3.5,
            markersize=12,
            ax=ax)


plt.ylim(0,1)

lgd = plt.legend(prop={'size': 12})

g.figure.savefig(f"../assets/{model_class}-cka-models-langs_joined-rel.pdf", dpi=300, bbox_extra_artists=(lgd,), bbox_inches='tight')
