In [2]:
import pandas as pd
from embeds import multi_inner_align, standardize
import sys
sys.path.append('..')
from rca.rca import run_rca

In [3]:
fastText_CommonCrawl = pd.read_csv('../../data/embeds/fastText_CommonCrawl.csv', index_col=0)

# Llama 3 8B
Llama_3_8B_0 = pd.read_csv('../../data/llms/Llama-3.1-8B_0.csv', index_col=0)
Llama_3_8B_1 = pd.read_csv('../../data/llms/Llama-3.1-8B_1.csv', index_col=0)
Llama_3_8B_2 = pd.read_csv('../../data/llms/Llama-3.1-8B_2.csv', index_col=0)
Llama_3_8B = sum(multi_inner_align([Llama_3_8B_0, Llama_3_8B_1, Llama_3_8B_2])) / 3 # ensemble approach

# BERT large
BERT_large_0 = pd.read_csv('../../data/llms/bert-large-uncased_0.csv', index_col=0)
BERT_large_1 = pd.read_csv('../../data/llms/bert-large-uncased_1.csv', index_col=0)
BERT_large_2 = pd.read_csv('../../data/llms/bert-large-uncased_2.csv', index_col=0)
BERT_large = sum(multi_inner_align([BERT_large_0, BERT_large_1, BERT_large_2])) / 3 # ensemble approach

# Comparing
to_compare = {
    'fastText_CommonCrawl': fastText_CommonCrawl,

    'Llama_3_8B_0': Llama_3_8B_0,
    'Llama_3_8B_1': Llama_3_8B_1,
    'Llama_3_8B_2': Llama_3_8B_2,
    'Llama_3_8B': Llama_3_8B,

    'BERT_large_0': BERT_large_0,
    'BERT_large_1': BERT_large_1,
    'BERT_large_2': BERT_large_2,
    'BERT_large': BERT_large,
}

# Aligning for fair comparison
to_compare = dict(zip(to_compare.keys(), multi_inner_align(to_compare.values())))

# Standardizing
to_compare = {name: standardize(embed) for name, embed in to_compare.items()}
{name: embed.shape for name, embed in to_compare.items()}

{'fastText_CommonCrawl': (44450, 300),
 'Llama_3_8B_0': (44450, 4096),
 'Llama_3_8B_1': (44450, 4096),
 'Llama_3_8B_2': (44450, 4096),
 'Llama_3_8B': (44450, 4096),
 'BERT_large_0': (44450, 1024),
 'BERT_large_1': (44450, 1024),
 'BERT_large_2': (44450, 1024),
 'BERT_large': (44450, 1024)}

In [4]:
# Loading norm data
norms = pd.read_csv('../../data/psychNorms/psychNorms_processed.zip', index_col=0, low_memory=False, compression='zip')
norms_meta = pd.read_csv('../../data/psychNorms/psychNorms_metadata_processed.csv', index_col='norm')
norms_meta

Unnamed: 0_level_0,description,citation,category,source,associated_embed,type
norm,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
frequency_lund,Log10 version of frequency norms based on the ...,"Lund, K. and C. Burgess (1996). ""Producing hig...",frequency,SCOPE,,numeric
frequency_kucera,Log10 version of frequency norms based on the ...,"Kučera, H. and W. N. Francis (1967). Computati...",frequency,SCOPE,,numeric
frequency_subtlexus,Log10 version of frequency norms based on the ...,"Brysbaert, M. and B. New (2009). ""Moving beyon...",frequency,SCOPE,,numeric
frequency_subtlexuk,Log10 version of the frequency norms based on ...,"Van Heuven, W. J., et al. (2014). ""SUBTLEX-UK:...",frequency,SCOPE,,numeric
frequency_blog_gimenes,Log10 version of the frequency norms based on ...,"Gimenes, M. and B. New (2016). ""Worldlex: Twit...",frequency,SCOPE,,numeric
...,...,...,...,...,...,...
imageability_vanarsdall,Ratings from 1 (low imagery) to 7 (high imagery),"VanArsdall, J. E., & Blunt, J. R. (2022). Anal...",imageability,lit_search,,numeric
familiarity_fear,Ratings from 1 (extremely infrequent) to 7 (ex...,"Fear, W. J. (1997). Ratings for Welsh words an...",familiarity,lit_search,,numeric
aoa_fear,Ratings from 1 (learned early) to 7 (learned l...,"Fear, W. J. (1997). Ratings for Welsh words an...",age_of_acquisition,lit_search,,numeric
imageability_fear,Ratings from 1 (low imageability) to 7 (high i...,"Fear, W. J. (1997). Ratings for Welsh words an...",imageability,lit_search,,numeric


In [5]:
norms = norms[[norm for norm in norms if 'glasgow' in norm]]
results = run_rca(to_compare, norms, norms_meta, n_jobs=10)
results

  0%|          | 0/9 [00:00<?, ?it/s]

fastText_CommonCrawl:   0%|          | 0/9 [00:00<?, ?it/s]

                         norm  train_n   r2_mean     r2_sd check
2        concreteness_glasgow     3705  0.806406  0.010490  pass
3        imageability_glasgow     3705  0.751500  0.011450  pass
1                 aoa_glasgow     3705  0.744070  0.018202  pass
6             valence_glasgow     3705  0.729247  0.025791  pass
5  gender_association_glasgow     3705  0.666486  0.035355  pass
0         familiarity_glasgow     3705  0.664903  0.019769  pass
4       semantic_size_glasgow     3705  0.661660  0.011675  pass
7             arousal_glasgow     3705  0.569323  0.013773  pass
8           dominance_glasgow     3705  0.532936  0.019119  pass


Llama_3_8B_0:   0%|          | 0/9 [00:00<?, ?it/s]

                         norm  train_n   r2_mean     r2_sd check
2        concreteness_glasgow     3705  0.728785  0.022854  pass
3        imageability_glasgow     3705  0.674631  0.024971  pass
4       semantic_size_glasgow     3705  0.596764  0.010804  pass
6             valence_glasgow     3705  0.593209  0.017427  pass
1                 aoa_glasgow     3705  0.592659  0.008804  pass
0         familiarity_glasgow     3705  0.501898  0.021259  pass
5  gender_association_glasgow     3705  0.482646  0.025521  pass
7             arousal_glasgow     3705  0.468764  0.011144  pass
8           dominance_glasgow     3705  0.431336  0.016009  pass


Llama_3_8B_1:   0%|          | 0/9 [00:00<?, ?it/s]

                         norm  train_n   r2_mean     r2_sd check
2        concreteness_glasgow     3705  0.809845  0.011987  pass
3        imageability_glasgow     3705  0.749760  0.014348  pass
6             valence_glasgow     3705  0.711922  0.020845  pass
1                 aoa_glasgow     3705  0.692026  0.011807  pass
4       semantic_size_glasgow     3705  0.686057  0.006163  pass
0         familiarity_glasgow     3705  0.587842  0.012411  pass
5  gender_association_glasgow     3705  0.573724  0.023965  pass
7             arousal_glasgow     3705  0.566248  0.009085  pass
8           dominance_glasgow     3705  0.532983  0.013185  pass


Llama_3_8B_2:   0%|          | 0/9 [00:00<?, ?it/s]

                         norm  train_n   r2_mean     r2_sd check
2        concreteness_glasgow     3705  0.806623  0.014664  pass
3        imageability_glasgow     3705  0.750107  0.019755  pass
6             valence_glasgow     3705  0.733508  0.015421  pass
4       semantic_size_glasgow     3705  0.689008  0.017015  pass
1                 aoa_glasgow     3705  0.660116  0.015414  pass
5  gender_association_glasgow     3705  0.598233  0.018889  pass
7             arousal_glasgow     3705  0.574990  0.014155  pass
0         familiarity_glasgow     3705  0.564877  0.009206  pass
8           dominance_glasgow     3705  0.534448  0.014723  pass


Llama_3_8B:   0%|          | 0/9 [00:00<?, ?it/s]

                         norm  train_n   r2_mean     r2_sd check
2        concreteness_glasgow     3705  0.817338  0.014477  pass
3        imageability_glasgow     3705  0.759873  0.017396  pass
6             valence_glasgow     3705  0.736481  0.019971  pass
4       semantic_size_glasgow     3705  0.698210  0.016626  pass
1                 aoa_glasgow     3705  0.691481  0.014425  pass
5  gender_association_glasgow     3705  0.605429  0.021191  pass
0         familiarity_glasgow     3705  0.591983  0.013094  pass
7             arousal_glasgow     3705  0.580573  0.012038  pass
8           dominance_glasgow     3705  0.543795  0.015847  pass


BERT_large_0:   0%|          | 0/9 [00:00<?, ?it/s]

                         norm  train_n   r2_mean     r2_sd check
2        concreteness_glasgow     3705  0.693819  0.013818  pass
3        imageability_glasgow     3705  0.634103  0.015326  pass
4       semantic_size_glasgow     3705  0.554692  0.024741  pass
6             valence_glasgow     3705  0.550414  0.019576  pass
1                 aoa_glasgow     3705  0.544505  0.018981  pass
0         familiarity_glasgow     3705  0.462700  0.027213  pass
5  gender_association_glasgow     3705  0.462572  0.026487  pass
7             arousal_glasgow     3705  0.444214  0.030547  pass
8           dominance_glasgow     3705  0.397344  0.014824  pass


BERT_large_1:   0%|          | 0/9 [00:00<?, ?it/s]

                         norm  train_n   r2_mean     r2_sd check
2        concreteness_glasgow     3705  0.818624  0.011994  pass
3        imageability_glasgow     3705  0.766894  0.015934  pass
6             valence_glasgow     3705  0.761208  0.019699  pass
4       semantic_size_glasgow     3705  0.692673  0.009141  pass
1                 aoa_glasgow     3705  0.673347  0.008410  pass
5  gender_association_glasgow     3705  0.638930  0.030228  pass
7             arousal_glasgow     3705  0.585532  0.023245  pass
8           dominance_glasgow     3705  0.572923  0.022465  pass
0         familiarity_glasgow     3705  0.551531  0.028296  pass


BERT_large_2:   0%|          | 0/9 [00:00<?, ?it/s]

                         norm  train_n   r2_mean     r2_sd check
2        concreteness_glasgow     3705  0.765055  0.010530  pass
3        imageability_glasgow     3705  0.708047  0.012424  pass
6             valence_glasgow     3705  0.694481  0.025886  pass
4       semantic_size_glasgow     3705  0.630229  0.009214  pass
1                 aoa_glasgow     3705  0.608934  0.013159  pass
5  gender_association_glasgow     3705  0.530669  0.037692  pass
7             arousal_glasgow     3705  0.527152  0.043292  pass
8           dominance_glasgow     3705  0.515479  0.028075  pass
0         familiarity_glasgow     3705  0.503782  0.024680  pass


BERT_large:   0%|          | 0/9 [00:00<?, ?it/s]

                         norm  train_n   r2_mean     r2_sd check
2        concreteness_glasgow     3705  0.812260  0.013623  pass
3        imageability_glasgow     3705  0.761864  0.012816  pass
6             valence_glasgow     3705  0.749975  0.014269  pass
4       semantic_size_glasgow     3705  0.687041  0.010590  pass
1                 aoa_glasgow     3705  0.674546  0.011384  pass
5  gender_association_glasgow     3705  0.618029  0.036581  pass
7             arousal_glasgow     3705  0.588630  0.031029  pass
8           dominance_glasgow     3705  0.572958  0.020643  pass
0         familiarity_glasgow     3705  0.554534  0.022708  pass


Unnamed: 0,embed,embed_type,norm,train_n,test_n,p,r2_mean,r2_sd,check
0,fastText_CommonCrawl,,familiarity_glasgow,3705,927,300,0.664903,0.019769,pass
1,fastText_CommonCrawl,,aoa_glasgow,3705,927,300,0.744070,0.018202,pass
2,fastText_CommonCrawl,,concreteness_glasgow,3705,927,300,0.806406,0.010490,pass
3,fastText_CommonCrawl,,imageability_glasgow,3705,927,300,0.751500,0.011450,pass
4,fastText_CommonCrawl,,semantic_size_glasgow,3705,927,300,0.661660,0.011675,pass
...,...,...,...,...,...,...,...,...,...
76,BERT_large,,semantic_size_glasgow,3705,927,1024,0.687041,0.010590,pass
77,BERT_large,,gender_association_glasgow,3705,927,1024,0.618029,0.036581,pass
78,BERT_large,,valence_glasgow,3705,927,1024,0.749975,0.014269,pass
79,BERT_large,,arousal_glasgow,3705,927,1024,0.588630,0.031029,pass


In [6]:
# Adding norm category
results['norm_category'] = (
    results['norm']
    .apply(lambda norm: norms_meta.loc[norm]['category'])
    .replace({'_': ' '}, regex=True)
)

results_avg = (
    results[['norm_category', 'embed', 'r2_mean']]
    .groupby(['norm_category', 'embed'], as_index=False).median()
    .dropna()
)

results_avg_piv = results_avg.pivot(columns='embed', index='norm_category', values='r2_mean')
results_avg_piv.round(2)

embed,BERT_large,BERT_large_0,BERT_large_1,BERT_large_2,Llama_3_8B,Llama_3_8B_0,Llama_3_8B_1,Llama_3_8B_2,fastText_CommonCrawl
norm_category,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
age of acquisition,0.67,0.54,0.67,0.61,0.69,0.59,0.69,0.66,0.74
arousal,0.59,0.44,0.59,0.53,0.58,0.47,0.57,0.57,0.57
concreteness,0.81,0.69,0.82,0.77,0.82,0.73,0.81,0.81,0.81
dominance,0.57,0.4,0.57,0.52,0.54,0.43,0.53,0.53,0.53
familiarity,0.55,0.46,0.55,0.5,0.59,0.5,0.59,0.56,0.66
imageability,0.76,0.63,0.77,0.71,0.76,0.67,0.75,0.75,0.75
social/moral,0.62,0.46,0.64,0.53,0.61,0.48,0.57,0.6,0.67
space/time/quantity,0.69,0.55,0.69,0.63,0.7,0.6,0.69,0.69,0.66
valence,0.75,0.55,0.76,0.69,0.74,0.59,0.71,0.73,0.73


In [7]:
# Finding the top-performing fmri_text_denoise
sorted_overall = results_avg_piv.mean().sort_values(ascending=False)
sorted_overall

embed
fastText_CommonCrawl    0.680726
BERT_large_1            0.673518
Llama_3_8B              0.669463
BERT_large              0.668871
Llama_3_8B_2            0.656879
Llama_3_8B_1            0.656712
BERT_large_2            0.609314
Llama_3_8B_0            0.563410
BERT_large_0            0.527151
dtype: float64

In [None]:
## Saving
top_performer = None
top_performer.to_csv('../../data/embeds/Llama_X_XB.csv')