In [10]:
from scipy.stats import pearsonr, spearmanr
import pandas as pd

In [11]:
models = ['bloom-560m', 'bloom-1b1', 'bloom-1b7']

In [12]:
avg_similarities = []
avg_overlap_rates = []
prws_similarities = []
prws_overlap_rates = []

for model in models:
    df_similarity = pd.read_csv(f'csv_files/{model}_last-layer_cos-similarity.csv', index_col=0)
    df_ovlp = pd.read_csv(f'csv_files/{model}_inter-layer-17_ovlp-rate.csv', index_col=0)
    
    ckpts = df_similarity.columns
    langs = df_similarity.index
    for ckpt in ckpts:
        similarity_list = []
        ovlp_list = []
        for lang in langs:
            similarity_list.append(float(df_similarity[ckpt][lang]))
            ovlp_list.append(float(df_ovlp[ckpt][lang]))
        
        prws_similarities += similarity_list
        prws_overlap_rates += ovlp_list
        avg_similarities.append(sum(similarity_list)/len(similarity_list))
        avg_overlap_rates.append(sum(ovlp_list)/len(ovlp_list))


In [13]:
print("avg similarities: ", avg_similarities)
print("avg overlap rates: ", avg_overlap_rates)

pearson_corr, pearson_p_value = pearsonr(avg_similarities, avg_overlap_rates)
print('Pearsons correlation: %.3f' % pearson_corr)
print(f'Pearsons P value: {pearson_p_value}')


spearman_corr, spearman_p_value = spearmanr(avg_similarities, avg_overlap_rates)
print('Spearmans correlation: %.3f' % spearman_corr)
print(f'Spearmans P value: {spearman_p_value}')

print("pairwise similarities: ", prws_similarities)
print("pairwise overlap rates: ", prws_overlap_rates)

pearson_corr, pearson_p_value = pearsonr(prws_similarities, prws_overlap_rates)
print('Pearsons correlation: %.3f' % pearson_corr)
print(f'Pearsons P value: {pearson_p_value}')


spearman_corr, spearman_p_value = spearmanr(prws_similarities, prws_overlap_rates)
print('Spearmans correlation: %.3f' % spearman_corr)
print(f'Spearmans P value: {spearman_p_value}')



avg similarities:  [0.5231994234968831, 0.6382962401334082, 0.4915794398104692, 0.5759763711990495, 0.5131113574032138, 0.76970226732041, 0.5838259374796997, 0.4915794398104692, 0.6851515939130012, 0.5335985131136229, 0.6615855097824537, 0.5183748051732784, 0.5063469763103071, 0.5151671894006894, 0.5461859079851498, 0.6650049449164268, 0.6037521466142862, 0.5150285890918064, 0.5396082588928439, 0.6351852935181513, 0.6351852935181513, 0.5958732383722529, 0.490204296676581, 0.6109162454428957, 0.5818905056298274, 0.5331954327708485, 0.5331954327708485]
avg overlap rates:  [0.17555555555555555, 0.07777777777777778, 0.09555555555555556, 0.10444444444444445, 0.12666666666666668, 0.12222222222222223, 0.1511111111111111, 0.08444444444444445, 0.06888888888888889, 0.13555555555555557, 0.04000000000000001, 0.07111111111111111, 0.06222222222222223, 0.11111111111111113, 0.1111111111111111, 0.024444444444444446, 0.11111111111111113, 0.1666666666666667, 0.09999999999999999, 0.022222222222222223, 0.0