In [None]:
import pandas as pd
import numpy as np
import scipy.stats
import itertools
import os, re
from tqdm import tqdm
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()

In [None]:
# Start by filtering for the analysis files we want to load.
# This will load only mBERT results.
files = list(filter(
    lambda x: '.feather' in x and 'control' not in x and ('bert-base-multilingual-cased' in x),
    os.listdir()))

In [None]:
data = {f: pd.read_feather(f).set_index('Unnamed: 0') for f in tqdm(files) if '.feather' in f and 'control' not in f}

In [None]:
# Compute grammaticality scores and probability ratios.
for f, df in tqdm(data.items()):
    df['singular_grammaticality'] = df.candidate2_base_prob / df.candidate1_base_prob
    df['plural_grammaticality']   = df.candidate1_alt1_prob / df.candidate2_alt1_prob
    df['inv_singular_grammaticality'] = 1 / df.singular_grammaticality
    df['inv_plural_grammaticality']   = 1 / df.plural_grammaticality
    df['yz'] = df.candidate2_prob / df.candidate1_prob
    df['effect'] = df.yz / df.singular_grammaticality - 1
    df['total_effect'] = 1 / (df.plural_grammaticality * df.singular_grammaticality) - 1

In [None]:
# Locate top 5% of neurons per layer and top 30 neurons in the entire model by indirect effect.
data_agg  = {}
data_top5 = {}
data_top30n = {}
cols = [
    'neuron', 'layer', 
    'effect', 
    'singular_grammaticality', 'plural_grammaticality', 
    'inv_singular_grammaticality', 'inv_plural_grammaticality',
    'total_effect',
]
for f, df in tqdm(data.items()):
    agg = (
        df[cols].groupby(['layer', 'neuron'])
        .agg(['mean', 'std', 'sem']))
    agg.columns = ['_'.join(col) for col in agg.columns]
    data_agg[f] = agg
    
    data_top5[f] = (
        agg.sort_values('effect_mean')      # Sort
        .groupby('layer')                   # Get layers
        .tail(int(agg.index.max()[1]*0.05)) # Take top 5%
        .index)                             # Get indices of top 5% by layer
    
    TOP30_CONST = 31 / (768 * 12)   # 30 / (HIDDEN_DIM * NUM_LAYERS)
    data_top30n[f] = (
        agg.sort_values('effect_mean')
        .tail(int(agg.index.max()[1]*(TOP30_CONST*12)))    # Take top 30
        .index)

In [None]:
# (OPTIONAL) Save top neurons in (layer, neuron) format in a .csv
file_list = [os.path.splitext(os.path.basename(_file))[0] for _file in data.items() if "indirect" in _file]
for _file in data.items():
    base_file_noext = os.path.splitext(os.path.basename(_file))[0]
    with open(f"top_neurons_01_{base_file_noext}.csv") as neurons_csv:
        for neuron in data_top5[_file]:
            neurons_csv.write(f"{layer},{neuron}\n")

In [None]:
import pickle
from collections import defaultdict

sns.set(font_scale=1.3)

model_code_to_name = {
    'distilgpt2': 'DistilGPT-2',
    'gpt2': 'GPT-2 Small',
    'gpt2-medium': 'GPT-2 Medium',
    'xlnet-base-cased': 'XLNet',
    'transfo-xl-wt103': 'Transformer-XL',
    'gpt2-random': 'GPT-2 Small (Random)',
    'bert-base-multilingual-cased': 'Multilingual BERT',
    'facebook/xglm-564M': 'XGLM',
}

model_en = 'bert-base-cased'
model_de = 'bert-base-german-cased'
model_fr = 'camembert-base'
model_nl = 'bert-base-dutch-cased'
model_fi = 'bert-base-finnish-cased-v1'
lines = [
    #(f'bigram_indirect_bert-base-multilingual-cased.feather',         'green',    'Bigram (en)', '-'),
    #(f'bigram_shuffle_en_indirect_{model_en}_natural.feather',         'green',    'Bigram (en, shuffle)', '-'),
    #(f'semantic_short_en_indirect_{model_en}_natural.feather',         'green',    'Semantic (en, short)', '-'),
    #(f'semantic_long_en_indirect_{model_en}_natural.feather',         'green',    'Semantic (en, long)', '-'),
    #(f'semantic_short_en_indirect_{model_m}_controlled.feather',         'green',    'Semantic (en, short)', '-'),
    #(f'semantic_long_en_indirect_{model_m}_controlled.feather',         'green',    'Semantic (en, long)', '-'),
    #(f'random_none_indirect_bert-base-multilingual-cased.feather', 'green', 'Simple (en, random init.)', '-'),
    #(f'none_en_short_indirect_{model_m}_natural.feather', 'green', 'Simple (en)', '-'),
    (f'none_indirect_{model_en}.feather', 'green', 'Simple (en)', '-'),
    #(f'none_indirect_{model}_short.feather',         'green',    'Simple (en, short)', '--'),
    #(f'none_nl_indirect_{model_nl}.feather', 'green', 'Simple (nl)', '-.'),
    #(f'none_fi_indirect_{model_fi}.feather', 'green', 'Simple (fi)', ':'),
    #(f'rc_singular_en_short_indirect_{model_m}_natural.feather', 'orange', 'Singular RC (en)', '-'),
    (f'rc_singular_indirect_{model_en}.feather', 'orange', 'Across singular RC (en)', '-'),
    #(f'rc_singular_indirect_{model}_short.feather', 'orange', 'Singular RC (en, short)', '--'),
    #(f'rc_singular_nl_indirect_{model_nl}.feather', 'orange', 'Singular RC (nl)', '-.'),
    #(f'rc_singular_fi_indirect_{model_fi}.feather', 'orange', 'Singular RC (fi)', ':'),
    #(f'rc_plural_en_short_indirect_{model_m}_natural.feather', 'red', 'Plural RC (en)', '-'),
    (f'rc_plural_indirect_{model_en}.feather', 'red', 'Across plural RC (en)', '--'),
    #(f'rc_plural_nl_indirect_{model_nl}.feather', 'red', 'Plural RC (nl)', '-.'),
    #(f'rc_plural_fi_indirect_{model_fi}.feather', 'red', 'Plural RC (fi)', ':'),
    #(f'prep_singular_en_short_indirect_{model_m}_natural.feather', 'purple', 'Across singular prep (en)', '-'),
    (f'prep_singular_indirect_{model_en}.feather', 'purple', 'Across singular prep (en)', '--'),
    #(f'prep_singular_nl_indirect_{model_nl}.feather', 'purple', 'Across singular prep (nl)', '-.'),
    #(f'prep_singular_fi_indirect_{model_fi}.feather', 'purple', 'Across singular prep (fi)', ':'),
    #(f'prep_plural_en_short_indirect_{model_m}_natural.feather', 'blue', 'Across plural prep (en)', '--'),
    (f'prep_plural_indirect_{model_en}.feather', 'blue', 'Across plural prep (en)', '--'),
    #(f'prep_plural_nl_indirect_{model_nl}.feather', 'blue', 'Across plural prep (nl)', '-.'),
    #(f'prep_plural_fi_indirect_{model_fi}.feather', 'blue', 'Across plural prep (fi)', ':'),
    #(f'none_de_short_indirect_{model_m}_natural.feather', 'green', 'Simple (de)', '--'),
    #(f'rc_singular_de_short_indirect_{model_m}_natural.feather', 'orange', 'Singular RC (de)', '--'),
    #(f'rc_plural_de_short_indirect_{model_m}_natural.feather', 'red', 'Plural RC (de)', '--'),
    #(f'prep_singular_de_short_indirect_{model_m}_natural.feather', 'orange', 'Across singular prep (de)', '--'),
    #(f'prep_plural_de_short_indirect_{model_m}_natural.feather', 'red', 'Across plural prep (de)', '--'),
    #(f'none_fr_indirect_{model_m}.feather', 'green', 'Simple (fr)', '--'),
    #(f'rc_singular_fr_indirect_{model_m}.feather', 'orange', 'Singular RC (fr)', '--'),
    #(f'rc_plural_fr_indirect_{model_m}.feather', 'red', 'Plural RC (fr)', '--'),
    #(f'prep_singular_fr_indirect_{model_m}.feather', 'purple', 'Across singular prep (fr)', '--'),
    #(f'prep_plural_fr_indirect_{model_m}.feather', 'blue', 'Across plural prep (fr)', '--'),
]


num_layers = 0
neurons_by_layer = defaultdict(lambda: defaultdict(set))
neuron_sets = defaultdict(set)
structures = []
for f, color, structure, linestyle in lines:
    structures.append(structure)
    for row in data_top30n[f]:
        layer_neuron = "{}-{}".format(row[0], row[1])
        num_layers = max(num_layers, row[0])
        neuron_sets[structure].update([layer_neuron])
        neurons_by_layer[structure][row[0]].update([layer_neuron])


intersection_proportions = []
for structure1 in structures:
    # if "(de)" in structure1:
    #     continue
    intersection_proportions.append([])
    for structure2 in structures:
    #     if "(de)" not in structure2:
    #         continue
        overlap_count = len(neuron_sets[structure1].intersection(neuron_sets[structure2]))
        original_count = len(neuron_sets[structure1])
        intersection_proportions[-1].append((overlap_count / original_count))

# plot confusion matrix
plt.figure(figsize=(25, 20))
# matrix = np.tril(intersection_proportions)    # Uncomment to just show lower triangle of symmetrical matrices
matrix = intersection_proportions
print(matrix)
# plt.matshow(intersection_proportions, cmap=plt.cm.gray_r, vmin=0, vmax=1)
plt.matshow(matrix, cmap=plt.cm.gray_r, vmin=0, vmax=1)
plt.grid(False)
cbar = plt.colorbar()
cbar.set_label('Top-30 neuron overlap (%)')
plt.grid(False)
tick_marks = np.arange(len(intersection_proportions))
plt.yticks(tick_marks, structures)# [:5])
plt.xticks(tick_marks, structures, rotation=45, ha='left')# [5:], rotation=45, ha='left')

'''
# annotate matrix
for (i, j), z in np.ndenumerate(intersection_proportions):
    if z < 50:
        this_color = 'black'
    else:
        this_color = 'white'
    plt.text(j, i, '{:0.1f}'.format(z), ha='center', va='center', fontsize=6, color=this_color)
'''
plt.savefig(f'../neuron_overlap_en_short_{model_en}.pdf', bbox_inches='tight')
plt.show()

# Print neurons that are implicated in *all* structures
shared_across_all = None
for idx, structure1 in enumerate(structures):
    for structure2 in structures[idx:]:
        if structure1 == structure2:
            continue
        shared_neurons = neuron_sets[structure1].intersection(neuron_sets[structure2])
        if shared_across_all is None:
            shared_across_all = shared_neurons
        else:
            shared_across_all = shared_across_all.intersection(neuron_sets[structure2])
        overlap_count = len(shared_neurons)
        original_count = len(neuron_sets[structure1])
        intersection_proportion = overlap_count / original_count
        #print(f"{structure1} / {structure2}: {intersection_proportion * 100} %")
        print(f"{structure1} / {structure2}: {intersection_proportion * 100} % ({shared_neurons})")

print(f"Neurons shared across all structures: {shared_across_all}")

In [None]:
plt.figure(figsize=(10,5))
for f in filter(lambda x: 'indirect' in x, data_agg):
    if 'distil' in f:
        color = 'red'
    elif 'medium' in f:
        color = 'green'
    elif 'large' in f:
        color = 'blue'
    elif 'gpt2-xl' in f:
        color = 'purple'
    elif 'xlnet' in f:
        color = 'black'
    else:
        color = 'orange'
    df = data_agg[f]
    idx = data_top5[f]
    effects = (
        df.loc[idx]
        .sort_values('layer')
        .reset_index()
        .groupby('layer')
        .mean().effect_mean
    )
    plt.plot(
        effects,
        label=f[:-8].replace('_',' '), alpha=0.3, color=color,
    )
# plt.legend()
plt.show()

In [None]:
# Single-model single-language plots

sns.set(font_scale=2.3)
sns.set_style('whitegrid')

model = "bert-base-cased"
model_en = "bert-base-multilingual-cased"
model_fr = "camembert-base"
model_nl = "bert-base-dutch-cased"
model_de = "bert-base-german-cased"
model_fi = "bert-base-finnish-cased-v1"
lines = [
    #(f'random_none_indirect_{model}.feather', 'black', 'Simple (en, random init.)'),
    #(f'semantic_short_en_indirect_{model}.feather', 'black', 'Semantic (en, short)', '-'),
    #(f'semantic_long_en_indirect_{model}.feather', 'grey', 'Semantic (en, long)', '-'),
    #(f'none_en_indirect_{model}_natural.feather',         'green',    'Simple (en)', '-'),
    #(f'none_indirect_{model}.feather',         'green',    'Simple (en)', '-'),
    #(f'none_indirect_{model}_short.feather', 'green', 'Simple (en, short)', '--'),
    #(f'none_en_indirect_{model}_controlled.feather', 'green', 'Simple (en, CIE)', '--'),
    #(f'none_indirect_{model}_short.feather',         'green',    'Simple (en)', '--'),
    #(f'none_fr_indirect_{model_fr}_controlled.feather', 'green', 'Simple (fr, CIE)', '--'),
    #(f'none_fr_indirect_{model_fr}.feather', 'green', 'Simple (fr)', '--'),
    #(f'none_nl_short_indirect_{model_nl}_natural.feather', 'green', 'Simple (nl)', '-.'),
    #(f'none_de_short_indirect_{model_de}_natural.feather', 'green', 'Simple (de)', ':'),
    #(f'none_nl_indirect_{model}.feather', 'green', 'Simple (nl)', '-.'),
    #(f'none_de_indirect_{model}.feather', 'green', 'Simple (de)', ':'),
    (f'none_fi_short_indirect_{model_fi}_natural.feather', 'green', 'Simple (fi)', ':'),
    #(f'rc_singular_en_indirect_{model}_natural.feather', 'orange', 'Singular RC (en)', '-'),
    #(f'rc_singular_indirect_{model}.feather', 'orange', 'Singular RC (en)', '-'),
    #(f'rc_singular_en_indirect_{model_en}_controlled.feather', 'orange', 'Singular RC (en, CIE)', '--'),
    #(f'rc_singular_indirect_{model}_short.feather', 'orange', 'Singular RC (en, short)', '--'),
    #(f'rc_singular_fr_indirect_{model_fr}_controlled.feather', 'orange', 'Singular RC (fr, CIE)', '--'),
    #(f'rc_singular_fr_short_indirect_{model_fr}_natural.feather', 'orange', 'Singular RC (fr)', '--'),
    #(f'rc_singular_nl_short_indirect_{model_nl}_natural.feather', 'orange', 'Singular RC (nl)', '-.'),
    #(f'rc_singular_de_short_indirect_{model_de}_natural.feather', 'orange', 'Singular RC (de)', ':'),
    #(f'rc_singular_nl_short_indirect_{model_nl}_natural.feather', 'orange', 'Singular RC (nl)', '-.'),
    #(f'rc_singular_de_indirect_{model}.feather', 'orange', 'Singular RC (de)', ':'),
    (f'rc_singular_fi_short_indirect_{model_fi}_natural.feather', 'orange', 'Singular RC (fi)', ':'),
    #(f'rc_plural_en_indirect_{model}_natural.feather', 'red', 'Plural RC (en)', '-'),
    #(f'rc_plural_indirect_{model}.feather', 'red', 'Plural RC (en)', '-'),
    #(f'rc_plural_indirect_{model_en}_controlled.feather', 'red', 'Plural RC (en, CIE)', '--'),
    #(f'rc_plural_indirect_{model}_short.feather', 'red', 'Plural RC (en, short)', '--'),
    #(f'rc_plural_fr_indirect_{model_fr}_controlled.feather', 'red', 'Plural RC (fr, CIE)', '--'),
    #(f'rc_plural_fr_short_indirect_{model_fr}_natural.feather', 'red', 'Plural RC (fr)', '--'),
    #(f'rc_plural_nl_short_indirect_{model_nl}_natural.feather', 'red', 'Plural RC (nl)', '-.'),
    #(f'rc_plural_de_short_indirect_{model_de}_natural.feather', 'red', 'Plural RC (de)', ':'),
    #(f'rc_plural_nl_indirect_{model}.feather', 'red', 'Plural RC (nl)', '-.'),
    #(f'rc_plural_de_indirect_{model}.feather', 'red', 'Plural RC (de)', ':'),
    (f'rc_plural_fi_short_indirect_{model_fi}_natural.feather', 'red', 'Plural RC (fi)', ':'),
    #(f'prep_singular_indirect_{model}.feather', 'purple', 'Across singular prep (en)', '-'),
    #(f'prep_singular_indirect_{model}_short.feather', 'purple', 'Across singular prep (en, short)', '--'),
    #(f'prep_plural_indirect_{model}.feather', 'blue', 'Across plural prep (en)', '-'),
    #(f'prep_plural_indirect_{model}_short.feather', 'blue', 'Across plural prep (en, short)', '--'),
    #(f'prep_singular_en_indirect_{model}_controlled.feather', 'purple', 'Across singular prep (en, CIE)', '--'),
    #(f'prep_singular_indirect_{model}_short.feather', 'purple', 'Across singular prep (en, short)', '--'),
    #(f'prep_singular_fr_indirect_{model_fr}_controlled.feather', 'purple', 'Across singular prep (fr, CIE)', '--'),
    #(f'prep_singular_fr_short_indirect_{model_fr}_natural.feather', 'purple', 'Across singular prep (fr)', '--'),
    #(f'prep_singular_nl_short_indirect_{model_nl}_natural.feather', 'purple', 'Across singular prep (nl)', '-.'),
    #(f'prep_singular_de_short_indirect_{model_de}_natural.feather', 'purple', 'Across singular prep (de)', ':'),
    (f'prep_singular_fi_short_indirect_{model_fi}_natural.feather', 'purple', 'Across singular prep (fi)', ':'),
    #(f'prep_plural_en_indirect_{model}_natural.feather', 'blue', 'Across plural prep (en, NIE)', '-'),
    #(f'prep_plural_en_indirect_{model}_controlled.feather', 'blue', 'Across plural prep (en, CIE)', '--'),
    #(f'prep_plural_indirect_{model}_short.feather', 'blue', 'Across plural prep (en, short)', '--'),
    #(f'prep_plural_fr_indirect_{model_fr}_controlled.feather', 'blue', 'Across plural prep (fr, CIE)', '--'),
    #(f'prep_plural_fr_short_indirect_{model_fr}_natural.feather', 'blue', 'Across plural prep (fr)', '--'),
    #(f'prep_plural_nl_short_indirect_{model_nl}_natural.feather', 'blue', 'Across plural prep (nl)', '-.'),
    #(f'prep_plural_de_short_indirect_{model_de}_natural.feather', 'blue', 'Across plural prep (de)', ':'),
    #(f'prep_plural_fi_indirect_{model_fi}.feather', 'blue', )
    (f'prep_plural_fi_short_indirect_{model_fi}_natural.feather', 'blue', 'Across plural prep (fi)', ':'),
]
    
plt.figure(figsize=(8,6))
for f, color, label, linestyle in lines:
    df = data_agg[f].loc[data_top5[f]].sort_values('layer').reset_index()
    # df = data_agg[f].sort_values('layer').reset_index()
    effect_mean = df.groupby('layer').mean().effect_mean
    effect_std = df.groupby('layer').mean().effect_sem
    loosify_dashes = {'--': (7, 8), '-.': (8, 7, 1, 7), ':': (2, 5)}
    if linestyle in loosify_dashes.keys():
        plt.plot(
            df.layer.unique(), 
            effect_mean,
            label=label, color=color, linestyle=linestyle, alpha=0.5,
            dashes=loosify_dashes[linestyle]
        )
    else:
        plt.plot(
            df.layer.unique(), 
            effect_mean,
            label=label, color=color, linestyle=linestyle, alpha=0.5,
        )
    plt.fill_between(
        df.layer.unique(),
        effect_mean + effect_std,
        effect_mean - effect_std,
        alpha=0.1, color=color
    )
plt.legend(bbox_to_anchor=(0.99,1.025), loc="upper left")
# plt.title('Indirect effects of top 5% of neurons by layer')
plt.xlabel('Layer')
plt.ylabel('Indirect effect')
plt.title('Finnish indirect effects')
# plt.ylim([-0.0005, 0.016])
plt.hlines(0, 0, 24, color='black', alpha=0.5, linestyle='dotted')
plt.xlim([0,12])
plt.xticks([0, 2, 4, 6, 8, 10, 12], [0, 2, 4, 6, 8, 10, 12])
plt.savefig(f'../fi_short_vs_orig-indirect_effects.pdf', format='pdf', bbox_inches='tight')
plt.show()

In [None]:
# Original vs. short plots

sns.set(font_scale=2.3)
sns.set_style('whitegrid')

model = "bert-base-cased"
model_en = "bert-base-multilingual-cased"
model_fr = "camembert-base"
model_nl = "bert-base-dutch-cased"
model_de = "bert-base-german-cased"
model_fi = "bert-base-finnish-cased-v1"
lines = [
    #(f'random_none_indirect_{model}.feather', 'black', 'Simple (en, random init.)'),
    #(f'semantic_short_en_indirect_{model}.feather', 'black', 'Semantic (en, short)', '-'),
    #(f'semantic_long_en_indirect_{model}.feather', 'grey', 'Semantic (en, long)', '-'),
    #(f'none_en_indirect_{model}_natural.feather',         'green',    'Simple (en)', '-'),
    #(f'none_indirect_{model}.feather',         'green',    'Simple (en)', '-'),
    #(f'none_indirect_{model}_short.feather', 'green', 'Simple (en, short)', '--'),
    #(f'none_en_indirect_{model}_controlled.feather', 'green', 'Simple (en, CIE)', '--'),
    #(f'none_indirect_{model}_short.feather',         'green',    'Simple (en)', '--'),
    #(f'none_fr_indirect_{model_fr}_controlled.feather', 'green', 'Simple (fr, CIE)', '--'),
    #(f'none_fr_indirect_{model_fr}.feather', 'green', 'Simple (fr, orig.)', '-'),
    #(f'none_fr_short_indirect_{model_fr}_natural.feather', 'green', 'Simple (fr, short)', '--'),
    (f'none_nl_indirect_{model_nl}.feather', 'green', 'Simple (nl, orig.)', '-'),
    (f'none_nl_short_indirect_{model_nl}_natural.feather', 'green', 'Simple (nl, short)', '--'),
    #(f'none_de_indirect_{model_de}.feather', 'green', 'Simple (de, orig.)', '-'),
    #(f'none_de_short_indirect_{model_de}_natural.feather', 'green', 'Simple (de, short)', '--'),
    #(f'none_nl_indirect_{model}.feather', 'green', 'Simple (nl)', '-.'),
    #(f'none_de_indirect_{model}.feather', 'green', 'Simple (de)', ':'),
    #(f'none_fi_indirect_{model_fi}.feather', 'green', 'Simple (fi, orig.)', '-'),
    #(f'none_fi_short_indirect_{model_fi}_natural.feather', 'green', 'Simple (fi, short)', '--'),
    #(f'rc_singular_en_indirect_{model}_natural.feather', 'orange', 'Singular RC (en)', '-'),
    #(f'rc_singular_indirect_{model}.feather', 'orange', 'Singular RC (en)', '-'),
    #(f'rc_singular_en_indirect_{model_en}_controlled.feather', 'orange', 'Singular RC (en, CIE)', '--'),
    #(f'rc_singular_indirect_{model}_short.feather', 'orange', 'Singular RC (en, short)', '--'),
    #(f'rc_singular_fr_indirect_{model_fr}_controlled.feather', 'orange', 'Singular RC (fr, CIE)', '--'),
    #(f'rc_singular_fr_indirect_{model_fr}.feather', 'orange', 'Singular RC (fr, orig.)', '-'),
    #(f'rc_singular_fr_short_indirect_{model_fr}_natural.feather', 'orange', 'Singular RC (fr, short)', '--'),
    (f'rc_singular_nl_indirect_{model_nl}.feather', 'orange', 'Singular RC (nl, orig.)', '-'),
    (f'rc_singular_nl_short_indirect_{model_nl}_natural.feather', 'orange', 'Singular RC (nl, short)', '--'),
    #(f'rc_singular_de_indirect_{model_de}.feather', 'orange', 'Singular RC (de, orig.)', '-'),
    #(f'rc_singular_de_short_indirect_{model_de}_natural.feather', 'orange', 'Singular RC (de, short)', '--'),
    #(f'rc_singular_nl_short_indirect_{model_nl}_natural.feather', 'orange', 'Singular RC (nl)', '-.'),
    #(f'rc_singular_de_indirect_{model}.feather', 'orange', 'Singular RC (de)', ':'),
    #(f'rc_singular_fi_indirect_{model_fi}.feather', 'orange', 'Singular RC (fi, orig.)', '-'),
    #(f'rc_singular_fi_short_indirect_{model_fi}_natural.feather', 'orange', 'Singular RC (fi, short)', '--'),
    #(f'rc_plural_en_indirect_{model}_natural.feather', 'red', 'Plural RC (en)', '-'),
    #(f'rc_plural_indirect_{model}.feather', 'red', 'Plural RC (en)', '-'),
    #(f'rc_plural_indirect_{model_en}_controlled.feather', 'red', 'Plural RC (en, CIE)', '--'),
    #(f'rc_plural_indirect_{model}_short.feather', 'red', 'Plural RC (en, short)', '--'),
    #(f'rc_plural_fr_indirect_{model_fr}_controlled.feather', 'red', 'Plural RC (fr, CIE)', '--'),
    #(f'rc_plural_fr_indirect_{model_fr}.feather', 'red', 'Plural RC (fr, orig.)', '-'),
    #(f'rc_plural_fr_short_indirect_{model_fr}_natural.feather', 'red', 'Plural RC (fr, short)', '--'),
    (f'rc_plural_nl_indirect_{model_nl}.feather', 'red', 'Plural RC (nl, orig.)', '-'),
    (f'rc_plural_nl_short_indirect_{model_nl}_natural.feather', 'red', 'Plural RC (nl, short)', '--'),
    #(f'rc_plural_de_indirect_{model_de}.feather', 'red', 'Plural RC (de, orig.)', '-'),
    #(f'rc_plural_de_short_indirect_{model_de}_natural.feather', 'red', 'Plural RC (de, short)', '--'),
    #(f'rc_plural_nl_indirect_{model}.feather', 'red', 'Plural RC (nl)', '-.'),
    #(f'rc_plural_de_indirect_{model}.feather', 'red', 'Plural RC (de)', ':'),
    #(f'rc_plural_fi_indirect_{model_fi}.feather', 'red', 'Plural RC (fi, orig.)', '-'),
    #(f'rc_plural_fi_short_indirect_{model_fi}_natural.feather', 'red', 'Plural RC (fi, short)', '--'),
    #(f'prep_singular_indirect_{model}.feather', 'purple', 'Across singular prep (en)', '-'),
    #(f'prep_singular_indirect_{model}_short.feather', 'purple', 'Across singular prep (en, short)', '--'),
    #(f'prep_plural_indirect_{model}.feather', 'blue', 'Across plural prep (en)', '-'),
    #(f'prep_plural_indirect_{model}_short.feather', 'blue', 'Across plural prep (en, short)', '--'),
    #(f'prep_singular_en_indirect_{model}_controlled.feather', 'purple', 'Across singular prep (en, CIE)', '--'),
    #(f'prep_singular_indirect_{model}_short.feather', 'purple', 'Across singular prep (en, short)', '--'),
    #(f'prep_singular_fr_indirect_{model_fr}_controlled.feather', 'purple', 'Across singular prep (fr, CIE)', '--'),
    #(f'prep_singular_fr_indirect_{model_fr}.feather', 'purple', 'Across singular prep (fr, orig.)', '-'),
    #(f'prep_singular_fr_short_indirect_{model_fr}_natural.feather', 'purple', 'Across singular prep (fr, short)', '--'),
    (f'prep_singular_nl_indirect_{model_nl}.feather', 'purple', 'Across singular prep (nl, orig.)', '-'),
    (f'prep_singular_nl_short_indirect_{model_nl}_natural.feather', 'purple', 'Across singular prep (nl, short)', '--'),
    #(f'prep_singular_de_indirect_{model_de}.feather', 'purple', 'Across singular prep (de, orig.)', '-'),
    #(f'prep_singular_de_short_indirect_{model_de}_natural.feather', 'purple', 'Across singular prep (de, short)', '--'),
    #(f'prep_singular_fi_indirect_{model_fi}.feather', 'purple', 'Across singular prep (fi, orig.)', '-'),
    #(f'prep_singular_fi_short_indirect_{model_fi}_natural.feather', 'purple', 'Across singular prep (fi, short)', '--'),
    #(f'prep_plural_en_indirect_{model}_natural.feather', 'blue', 'Across plural prep (en, NIE)', '-'),
    #(f'prep_plural_en_indirect_{model}_controlled.feather', 'blue', 'Across plural prep (en, CIE)', '--'),
    #(f'prep_plural_indirect_{model}_short.feather', 'blue', 'Across plural prep (en, short)', '--'),
    #(f'prep_plural_fr_indirect_{model_fr}_controlled.feather', 'blue', 'Across plural prep (fr, CIE)', '--'),
    #(f'prep_plural_fr_indirect_{model_fr}.feather', 'blue', 'Across plural prep (fr, orig.)', '-'),
    #(f'prep_plural_fr_short_indirect_{model_fr}_natural.feather', 'blue', 'Across plural prep (fr, short)', '--'),
    #(f'prep_plural_nl_short_indirect_{model_nl}_natural.feather', 'blue', 'Across plural prep (nl)', '-.'),
    (f'prep_plural_de_indirect_{model_de}.feather', 'blue', 'Across plural prep (de, orig.)', '-'),
    (f'prep_plural_de_short_indirect_{model_de}_natural.feather', 'blue', 'Across plural prep (de, short)', '--'),
    #(f'prep_plural_fi_indirect_{model_fi}.feather', 'blue', )
    #(f'prep_plural_fi_indirect_{model_fi}.feather', 'blue', 'Across plural prep (fi, orig.)', '-'),
    #(f'prep_plural_fi_short_indirect_{model_fi}_natural.feather', 'blue', 'Across plural prep (fi, short)', '--'),
]
    
plt.figure(figsize=(8,6))
for f, color, label, linestyle in lines:
    df = data_agg[f].loc[data_top5[f]].sort_values('layer').reset_index()
    # df = data_agg[f].sort_values('layer').reset_index()
    effect_mean = df.groupby('layer').mean().effect_mean
    effect_std = df.groupby('layer').mean().effect_sem
    loosify_dashes = {'--': (7, 8), '-.': (8, 7, 1, 7), ':': (2, 5)}
    if linestyle in loosify_dashes.keys():
        plt.plot(
            df.layer.unique(), 
            effect_mean,
            label=label, color=color, linestyle=linestyle, alpha=0.5,
            dashes=loosify_dashes[linestyle]
        )
    else:
        plt.plot(
            df.layer.unique(), 
            effect_mean,
            label=label, color=color, linestyle=linestyle, alpha=0.5,
        )
    plt.fill_between(
        df.layer.unique(),
        effect_mean + effect_std,
        effect_mean - effect_std,
        alpha=0.1, color=color
    )
plt.legend(bbox_to_anchor=(0.99,1.025), loc="upper left")
# plt.title('Indirect effects of top 5% of neurons by layer')
plt.xlabel('Layer')
plt.ylabel('Indirect effect')
plt.title('German indirect effects')
# plt.ylim([-0.0005, 0.016])
plt.hlines(0, 0, 24, color='black', alpha=0.5, linestyle='dotted')
plt.xlim([0,12])
plt.xticks([0, 2, 4, 6, 8, 10, 12], [0, 2, 4, 6, 8, 10, 12])
plt.savefig(f'../de_short_vs_orig-indirect_effects.pdf', format='pdf', bbox_inches='tight')
plt.show()

In [None]:
# plot total effects for multiple languages given one multilingual model

import itertools
from collections import defaultdict

model_en = "bert-base-cased"
model_en_m = "bert-base-multilingual-cased"
model_fr = "camembert-base"
model_fr_m = "bert-base-multilingual-cased"
model_de = "bert-base-german-cased"
model_de_m = "bert-base-multilingual-cased"
model_nl = "bert-base-dutch-cased"
model_nl_m = "bert-base-multilingual-cased"
model_fi = "bert-base-finnish-cased-v1"
model_list = (model_en, model_fr, model_de, model_nl, model_fi)
lines = [
    #(f'bigram_indirect_{model_en}.feather',         'green',    'Bigram (en)', '-'),
    (f'none_indirect_{model_en}.feather', 'green', 'Simple (en)', '-'),
    (f'none_indirect_{model_en_m}.feather')
    #(f'none_indirect_{model_en}_controlled.feather', 'green', 'Simple (en, CIE)', '--'),
    #(f'none_indirect_{model}_short.feather',         'green',    'Simple (en, short)', '--'),
    #(f'none_fr_indirect_{model_fr}.feather', 'green', 'Simple (fr)', '--'),
    #(f'none_de_indirect_{model_de}.feather', 'green', 'Simple (de)', '--'),
    #(f'none_nl_indirect_{model_nl}.feather', 'green', 'Simple (nl)', '-.'),
    #(f'none_fi_indirect_{model_fi}.feather', 'green', 'Simple (fi)', ':'),
    #(f'rc_singular_indirect_{model}.feather', 'orange', 'Singular RC (en)', '-'),
    (f'rc_singular_indirect_{model_en}.feather', 'orange', 'Singular RC (en)', '-'),
    (f'rc_singular_indirect_{model_en_m}.feather', 'orange', 'Singular RC (en)', '-'),
    #(f'rc_singular_indirect_{model_en}_controlled.feather', 'orange', 'Singular RC (en, CIE)', '--'),
    #(f'rc_singular_indirect_{model_en}_short.feather', 'orange', 'Singular RC (en, short)', '--'),
    #(f'rc_singular_fr_indirect_{model_fr}.feather', 'orange', 'Singular RC (fr)', '--'),
    #(f'rc_singular_de_indirect_{model_de}.feather', 'orange', 'Singular RC (de)', '-'),
    #(f'rc_singular_nl_indirect_{model_nl}.feather', 'orange', 'Singular RC (nl)', '-.'),
    #(f'rc_singular_fi_indirect_{model_fi}.feather', 'orange', 'Singular RC (fi)', ':'),
    (f'rc_plural_indirect_{model_en}.feather', 'red', 'Plural RC (en)', '-'),
    (f'rc_plural_indirect_{model_en_m}.feather', 'red', 'Plural RC (en)', '-'),
    #(f'rc_plural_indirect_{model_en}_controlled.feather', 'red', 'Plural RC (en, CIE)', '--'),
    #(f'rc_plural_indirect_{model_en}_short.feather', 'red', 'Plural RC (en, short)', '--'),
    #(f'rc_plural_fr_indirect_{model_fr}.feather', 'red', 'Plural RC (fr)', '--'),
    #(f'rc_plural_de_indirect_{model_de}.feather', 'red', 'Plural RC (de)', '--'),
    #(f'rc_plural_nl_indirect_{model_nl}.feather', 'red', 'Plural RC (nl)', '-.'),
    #(f'rc_plural_fi_indirect_{model_fi}.feather', 'red', 'Plural RC (fi)', ':'),
    (f'prep_singular_en_indirect_{model_en}.feather', 'purple', 'Across singular prep (en)', '-'),
    (f'prep_singular_en_indirect_{model_en_m}.feather', 'purple', 'Across singular prep (en)', '-'),
    #(f'prep_singular_indirect_{model_en}_controlled.feather', 'purple', 'Across singular prep (en, CIE)', '--'),
    #(f'prep_singular_indirect_{model_en}_short.feather', 'purple', 'Across singular prep (en, short)', '--'),
    #(f'prep_singular_fr_indirect_{model_fr}.feather', 'purple', 'Across singular prep (fr)', '--'),
    #(f'prep_singular_de_indirect_{model_de}.feather', 'purple', 'Across singular prep (de)', '--'),
    #(f'prep_singular_nl_indirect_{model_nl}.feather', 'purple', 'Across singular prep (nl)', '-.'),
    #(f'prep_singular_fi_indirect_{model_fi}.feather', 'purple', 'Across singular prep (fi)', ':'),
    (f'prep_plural_en_indirect_{model_en}.feather', 'blue', 'Across plural prep (en)', '-'),
    (f'prep_plural_en_indirect_{model_en_m}.feather', 'blue', 'Across plural prep (en)', '-'),
    #(f'prep_plural_indirect_{model_en}_controlled.feather', 'blue', 'Across plural prep (en, CIE)', '--'),
    #(f'prep_plural_indirect_{model}_short.feather', 'blue', 'Across plural prep (en, short)', '--'),
    #(f'prep_plural_fr_indirect_{model_fr}.feather', 'blue', 'Across plural prep (fr)', '--'),
    #(f'prep_plural_de_indirect_{model_de}.feather', 'blue', 'Across plural prep (de)', '--'),
    #(f'prep_plural_nl_indirect_{model_nl}.feather', 'blue', 'Across plural prep (nl)', '-.'),
    #(f'prep_plural_fi_indirect_{model_fi}.feather', 'blue', 'Across plural prep (fi)', ':'),
]

fig = plt.figure(figsize=(10,6))
ax = fig.add_subplot(111)
xlabels = []



labels_lang = defaultdict(list)
bars_lang   = defaultdict(list)
errors_lang = defaultdict(list)
first_lang = None
for f, color, label, linestyle in tqdm(lines):
    language = f.split("_")[2]
    if language not in ("fr", "de", "nl", "fi", "en"):
        language = f.split("_")[1]
        if language not in ("fr", "de", "nl", "fi", "en"):
            language = "en"
    if not first_lang:
        first_lang = language
    if language == first_lang:
        xlabels.append(label.split("(")[0])
    # TODO: debugging hack
    # if language != "fi":
    #    continue
    labels_lang[language].append(label)
    bars_lang[language].append(data_agg[f].iloc[0].total_effect_mean)
    errors_lang[language].append(data_agg[f].iloc[0].total_effect_sem)

# grouped bar chart
x = np.arange(len(xlabels))
width = 0.2

cmap = sns.color_palette("Set2", as_cmap=True).colors
colors_lang = {
    'en': sns.color_palette("husl", as_cmap=True).colors[0],
    'fr': cmap[2],
    'de': cmap[7],
    'nl': cmap[1],
    'fi': cmap[6]
}

num_langs = len(labels_lang.keys())
i = 0
for language in ("en", "fr", "de", "nl", "fi"):
    if language not in labels_lang:
        continue
    pos_shift = width * (-(num_langs-1)/2 + i)
    print(language)
    print(bars_lang[language])
    print(errors_lang[language])

    ax.bar(x + pos_shift, bars_lang[language], width, label=language, \
        yerr=errors_lang[language], color=colors_lang[language])#, edgecolor='black')
    i += 1

ax.legend(bbox_to_anchor=(0.99,1.025), loc="upper left")
plt.xticks(x, xlabels, rotation=45, ha='right', rotation_mode='anchor')
# plt.bar(labels, bars, yerr=errors)
plt.xlabel('Structure')
plt.ylabel('Total Effect')
#plt.ylim([0, 65])
plt.yscale('log')
plt.tight_layout()
plt.savefig("../total_effect_bert-base-multilingual-cased.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
# plot total effects of monolingual vs. multilingual models in one language

import itertools
from collections import defaultdict

model_en = "bert-base-cased"
model_en_m = "bert-base-multilingual-cased"
model_fr = "camembert-base"
model_fr_m = "bert-base-multilingual-cased"
model_de = "bert-base-german-cased"
model_de_m = "bert-base-multilingual-cased"
model_nl = "bert-base-dutch-cased"
model_nl_m = "bert-base-multilingual-cased"
model_fi = "bert-base-finnish-cased-v1"
model_list = (model_en, model_fr, model_de, model_nl, model_fi)
lines = [
    #(f'bigram_indirect_{model_en}.feather',         'green',    'Bigram (en)', '-'),
    #(f'none_en_indirect_{model_en}_natural.feather', 'green', 'Simple (en)', '-'),
    #(f'none_en_indirect_{model_en_m}_natural.feather', 'green', 'Simple (en)', '-'),
    #(f'none_indirect_{model_en}_controlled.feather', 'green', 'Simple (en, CIE)', '--'),
    #(f'none_indirect_{model}_short.feather',         'green',    'Simple (en, short)', '--'),
    (f'none_fr_short_indirect_{model_fr}_natural.feather', 'green', 'Simple (fr)', '--'),
    #(f'none_de_indirect_{model_de}.feather', 'green', 'Simple (de)', '--'),
    #(f'none_nl_indirect_{model_nl}.feather', 'green', 'Simple (nl)', '-.'),
    #(f'none_fi_indirect_{model_fi}.feather', 'green', 'Simple (fi)', ':'),
    #(f'rc_singular_indirect_{model}.feather', 'orange', 'Singular RC (en)', '-'),
    #(f'rc_singular_en_indirect_{model_en}_natural.feather', 'orange', 'Singular RC (en)', '-'),
    #(f'rc_singular_en_indirect_{model_en_m}_natural.feather', 'orange', 'Singular RC (en)', '-'),
    #(f'rc_singular_indirect_{model_en}_controlled.feather', 'orange', 'Singular RC (en, CIE)', '--'),
    #(f'rc_singular_indirect_{model_en}_short.feather', 'orange', 'Singular RC (en, short)', '--'),
    (f'rc_singular_fr_short_indirect_{model_fr}_natural.feather', 'orange', 'Singular RC (fr)', '--'),
    #(f'rc_singular_de_indirect_{model_de}.feather', 'orange', 'Singular RC (de)', '-'),
    #(f'rc_singular_nl_indirect_{model_nl}.feather', 'orange', 'Singular RC (nl)', '-.'),
    #(f'rc_singular_fi_indirect_{model_fi}.feather', 'orange', 'Singular RC (fi)', ':'),
    #(f'rc_plural_en_indirect_{model_en}_natural.feather', 'red', 'Plural RC (en)', '-'),
    #(f'rc_plural_en_indirect_{model_en_m}_natural.feather', 'red', 'Plural RC (en)', '-'),
    #(f'rc_plural_indirect_{model_en}_controlled.feather', 'red', 'Plural RC (en, CIE)', '--'),
    #(f'rc_plural_indirect_{model_en}_short.feather', 'red', 'Plural RC (en, short)', '--'),
    (f'rc_plural_fr_short_indirect_{model_fr}_natural.feather', 'red', 'Plural RC (fr)', '--'),
    #(f'rc_plural_de_indirect_{model_de}.feather', 'red', 'Plural RC (de)', '--'),
    #(f'rc_plural_nl_indirect_{model_nl}.feather', 'red', 'Plural RC (nl)', '-.'),
    #(f'rc_plural_fi_indirect_{model_fi}.feather', 'red', 'Plural RC (fi)', ':'),
    #(f'prep_singular_en_indirect_{model_en}_natural.feather', 'purple', 'Across singular prep (en)', '-'),
    #(f'prep_singular_en_indirect_{model_en_m}_natural.feather', 'purple', 'Across singular prep (en)', '-'),
    #(f'prep_singular_indirect_{model_en}_controlled.feather', 'purple', 'Across singular prep (en, CIE)', '--'),
    #(f'prep_singular_indirect_{model_en}_short.feather', 'purple', 'Across singular prep (en, short)', '--'),
    (f'prep_singular_fr_short_indirect_{model_fr}_natural.feather', 'purple', 'Across singular prep (fr)', '--'),
    #(f'prep_singular_de_indirect_{model_de}.feather', 'purple', 'Across singular prep (de)', '--'),
    #(f'prep_singular_nl_indirect_{model_nl}.feather', 'purple', 'Across singular prep (nl)', '-.'),
    #(f'prep_singular_fi_indirect_{model_fi}.feather', 'purple', 'Across singular prep (fi)', ':'),
    #(f'prep_plural_en_indirect_{model_en}_natural.feather', 'blue', 'Across plural prep (en)', '-'),
    #(f'prep_plural_en_indirect_{model_en_m}_natural.feather', 'blue', 'Across plural prep (en)', '-'),
    #(f'prep_plural_indirect_{model_en}_controlled.feather', 'blue', 'Across plural prep (en, CIE)', '--'),
    #(f'prep_plural_indirect_{model}_short.feather', 'blue', 'Across plural prep (en, short)', '--'),
    (f'prep_plural_fr_short_indirect_{model_fr}_natural.feather', 'blue', 'Across plural prep (fr)', '--'),
    #(f'prep_plural_de_indirect_{model_de}.feather', 'blue', 'Across plural prep (de)', '--'),
    #(f'prep_plural_nl_indirect_{model_nl}.feather', 'blue', 'Across plural prep (nl)', '-.'),
    #(f'prep_plural_fi_indirect_{model_fi}.feather', 'blue', 'Across plural prep (fi)', ':'),
]

fig = plt.figure(figsize=(14,6))
ax = fig.add_subplot(111)
xlabels = []


labels_model = defaultdict(list)
bars_model   = defaultdict(list)
errors_model = defaultdict(list)
first_model = None
for f, color, label, linestyle in tqdm(lines):
    model = f.split("_")[-1].split(".")[0]
    if not first_model:
        first_model = model
    if model == first_model:
        xlabels.append(label.split("(")[0])
    # TODO: debugging hack
    # if language != "fi":
    #    continue
    labels_model[model].append(label)
    bars_model[model].append(data_agg[f].iloc[0].total_effect_mean)
    errors_model[model].append(data_agg[f].iloc[0].total_effect_sem)

# grouped bar chart
x = np.arange(len(xlabels))
width = 0.2

cmap = sns.color_palette("Set2", as_cmap=True).colors
color_idx = 0

num_models = len(labels_model.keys())
i = 0
for model in labels_model.keys():
    if model not in labels_model:
        continue
    pos_shift = width * (-(num_models-1)/2 + i)
    print(model)
    print(bars_model[model])
    print(errors_model[model])

    ax.bar(x + pos_shift, bars_model[model], width, label=model, \
        yerr=errors_model[model], color=colors_model[color_idx])#, edgecolor='black')
    
    color_idx += 1
    i += 1

ax.legend(bbox_to_anchor=(0.99,1.025), loc="upper left")
plt.xticks(x, xlabels, rotation=45, ha='right', rotation_mode='anchor')
# plt.bar(labels, bars, yerr=errors)
plt.xlabel('Structure')
plt.ylabel('Total Effect')
#plt.ylim([0, 65])
plt.yscale('log')
plt.tight_layout()
plt.title("English Total Effects")
plt.savefig("../total_effect_fr_mono_vs_multi.pdf", format="pdf", bbox_inches="tight")
# plt.savefig("../indirect_effect")
plt.show()

In [None]:
# fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(15,6))

labels = []
dfs = []
for f, color, label, linestyle in tqdm(lines):
    labels.append(label)
    df = data_agg[f]
    df['label'] = label
    dfs.append(df)

In [None]:
import matplotlib

matplotlib.rcParams['xtick.minor.visible'] = True
sns.set(font_scale=1.75)
plt.figure(figsize=(10,6))
sns.set_style('whitegrid')
df = (
    pd.concat(dfs)
    .sort_index()
    .loc[0,0]
    .set_index('label')
    [[
        'inv_singular_grammaticality_mean','inv_plural_grammaticality_mean',
        'inv_singular_grammaticality_sem','inv_plural_grammaticality_sem'
    ]]
)
x = np.arange(len(labels))
width = 0.3
# print(list(df['inv_singular_grammaticality_mean']))
plt.bar(x - width/2, list(df['inv_singular_grammaticality_mean']), width=width, yerr=list(df['inv_singular_grammaticality_sem']))
plt.bar(x + width/2, list(df['inv_plural_grammaticality_mean']), width=width, yerr=list(df['inv_plural_grammaticality_sem']))
plt.xticks(x, labels, rotation=45, ha='right', rotation_mode='anchor')
#ax.xaxis.set_major_locator(matplotlib.ticker.LinearLocator())
plt.legend(['Singular subject', 'Plural subject'], loc='upper left')
plt.xlabel('Structure')
plt.ylabel('Grammaticality')
plt.tight_layout()
# plt.title('Grammaticality by example type')
plt.savefig('../grammaticality_small_en.pdf', format='pdf', bbox_inches='tight')
plt.show()