In [1]:
import pandas as pd
from embeds import multi_inner_align, standardize
from sklearn.decomposition import PCA
import sys
sys.path.append('..')
from rca.rca import run_rca

## Load LLM embeddings

In [2]:
fastText_CommonCrawl = pd.read_csv('../../data/embeds/fastText_CommonCrawl.csv', index_col=0)

# Llama 3 8B
Llama_3_8B_0 = pd.read_csv('../../data/llms/Llama-3.1-8B_0.csv', index_col=0)
Llama_3_8B_1 = pd.read_csv('../../data/llms/Llama-3.1-8B_1.csv', index_col=0)
Llama_3_8B_2 = pd.read_csv('../../data/llms/Llama-3.1-8B_2.csv', index_col=0)
Llama_3_8B_012 = sum(multi_inner_align([Llama_3_8B_0, Llama_3_8B_1, Llama_3_8B_2])) / 3 # ensemble approach
Llama_3_8B_12 = sum(multi_inner_align([Llama_3_8B_1, Llama_3_8B_2])) / 2

# BERT large
BERT_large_0 = pd.read_csv('../../data/llms/bert-large-uncased_0.csv', index_col=0)
BERT_large_1 = pd.read_csv('../../data/llms/bert-large-uncased_1.csv', index_col=0)
BERT_large_2 = pd.read_csv('../../data/llms/bert-large-uncased_2.csv', index_col=0)
BERT_large_012 = sum(multi_inner_align([BERT_large_0, BERT_large_1, BERT_large_2])) / 3 # ensemble approach
BERT_large_12 = sum(multi_inner_align([BERT_large_1, BERT_large_2])) / 2

# Comparing
to_compare = {
    'fastText_CommonCrawl': fastText_CommonCrawl,

    'Llama_3_8B_0': Llama_3_8B_0,
    'Llama_3_8B_1': Llama_3_8B_1,
    'Llama_3_8B_2': Llama_3_8B_2,
    'Llama_3_8B_012': Llama_3_8B_012,
    'Llama_3_8B_12': Llama_3_8B_12,

    'BERT_large_0': BERT_large_0,
    'BERT_large_1': BERT_large_1,
    'BERT_large_2': BERT_large_2,
    'BERT_large_012': BERT_large_012,
    'BERT_large_12': BERT_large_12
}

In [3]:
# Aligning for fair comparison
to_compare = dict(zip(to_compare.keys(), multi_inner_align(to_compare.values())))

# Standardizing
to_compare = {name: standardize(embed) for name, embed in to_compare.items()}
{name: embed.shape for name, embed in to_compare.items()}

{'fastText_CommonCrawl': (44450, 300),
 'Llama_3_8B_0': (44450, 4096),
 'Llama_3_8B_1': (44450, 4096),
 'Llama_3_8B_2': (44450, 4096),
 'Llama_3_8B_012': (44450, 4096),
 'Llama_3_8B_12': (44450, 4096),
 'BERT_large_0': (44450, 1024),
 'BERT_large_1': (44450, 1024),
 'BERT_large_2': (44450, 1024),
 'BERT_large_012': (44450, 1024),
 'BERT_large_12': (44450, 1024)}

In [17]:
# Adding PCAed version
pca = PCA(n_components=.9, random_state=42)  # reduce to capture 90% of original variance
Llama_3_8B_12_pca = pd.DataFrame(pca.fit_transform(Llama_3_8B_12), index=Llama_3_8B_12.index)
BERT_large_12_pca = pd.DataFrame(pca.fit_transform(X=BERT_large_12), index=BERT_large_12.index)

# Adding to comparison dict
to_compare['Llama_3_8B_12_pca'] = Llama_3_8B_12_pca
to_compare['BERT_large_12_pca'] = BERT_large_12_pca

# Aligning for fair comparison
to_compare = dict(zip(to_compare.keys(), multi_inner_align(to_compare.values())))

# Standardizing
to_compare = {name: standardize(embed) for name, embed in to_compare.items()}
{name: embed.shape for name, embed in to_compare.items()}

KeyboardInterrupt: 

## Compare

In [3]:
# Loading norm data
norms = pd.read_csv('../../data/psychNorms/psychNorms_processed.zip', index_col=0, low_memory=False, compression='zip')
norms_meta = pd.read_csv('../../data/psychNorms/psychNorms_metadata_processed.csv', index_col='norm')

# Subsetting to glasgow
norms = norms[[norm for norm in norms if 'glasgow' in norm]]

# Running tests
results = run_rca(to_compare, norms, norms_meta, n_jobs=10)
results

  0%|          | 0/13 [00:00<?, ?it/s]

fastText_CommonCrawl:   0%|          | 0/9 [00:00<?, ?it/s]

                         norm  train_n   r2_mean     r2_sd check
2        concreteness_glasgow     3705  0.806406  0.010490  pass
3        imageability_glasgow     3705  0.751500  0.011450  pass
1                 aoa_glasgow     3705  0.744070  0.018202  pass
6             valence_glasgow     3705  0.729247  0.025791  pass
5  gender_association_glasgow     3705  0.666486  0.035355  pass
0         familiarity_glasgow     3705  0.664903  0.019769  pass
4       semantic_size_glasgow     3705  0.661660  0.011675  pass
7             arousal_glasgow     3705  0.569323  0.013773  pass
8           dominance_glasgow     3705  0.532936  0.019119  pass


Llama_3_8B_0:   0%|          | 0/9 [00:00<?, ?it/s]

                         norm  train_n   r2_mean     r2_sd check
2        concreteness_glasgow     3705  0.728785  0.022854  pass
3        imageability_glasgow     3705  0.674631  0.024971  pass
4       semantic_size_glasgow     3705  0.596764  0.010804  pass
6             valence_glasgow     3705  0.593209  0.017427  pass
1                 aoa_glasgow     3705  0.592659  0.008804  pass
0         familiarity_glasgow     3705  0.501898  0.021259  pass
5  gender_association_glasgow     3705  0.482646  0.025521  pass
7             arousal_glasgow     3705  0.468764  0.011144  pass
8           dominance_glasgow     3705  0.431336  0.016009  pass


Llama_3_8B_1:   0%|          | 0/9 [00:00<?, ?it/s]

                         norm  train_n   r2_mean     r2_sd check
2        concreteness_glasgow     3705  0.809845  0.011987  pass
3        imageability_glasgow     3705  0.749760  0.014348  pass
6             valence_glasgow     3705  0.711922  0.020845  pass
1                 aoa_glasgow     3705  0.692026  0.011807  pass
4       semantic_size_glasgow     3705  0.686057  0.006163  pass
0         familiarity_glasgow     3705  0.587842  0.012411  pass
5  gender_association_glasgow     3705  0.573724  0.023965  pass
7             arousal_glasgow     3705  0.566248  0.009085  pass
8           dominance_glasgow     3705  0.532983  0.013185  pass


Llama_3_8B_2:   0%|          | 0/9 [00:00<?, ?it/s]

                         norm  train_n   r2_mean     r2_sd check
2        concreteness_glasgow     3705  0.806623  0.014664  pass
3        imageability_glasgow     3705  0.750107  0.019755  pass
6             valence_glasgow     3705  0.733508  0.015421  pass
4       semantic_size_glasgow     3705  0.689008  0.017015  pass
1                 aoa_glasgow     3705  0.660116  0.015414  pass
5  gender_association_glasgow     3705  0.598233  0.018889  pass
7             arousal_glasgow     3705  0.574990  0.014155  pass
0         familiarity_glasgow     3705  0.564877  0.009206  pass
8           dominance_glasgow     3705  0.534448  0.014723  pass


Llama_3_8B_012:   0%|          | 0/9 [00:00<?, ?it/s]

                         norm  train_n   r2_mean     r2_sd check
2        concreteness_glasgow     3705  0.817338  0.014477  pass
3        imageability_glasgow     3705  0.759873  0.017396  pass
6             valence_glasgow     3705  0.736481  0.019971  pass
4       semantic_size_glasgow     3705  0.698210  0.016626  pass
1                 aoa_glasgow     3705  0.691481  0.014425  pass
5  gender_association_glasgow     3705  0.605429  0.021191  pass
0         familiarity_glasgow     3705  0.591983  0.013094  pass
7             arousal_glasgow     3705  0.580573  0.012038  pass
8           dominance_glasgow     3705  0.543795  0.015847  pass


Llama_3_8B_12:   0%|          | 0/9 [00:00<?, ?it/s]

                         norm  train_n   r2_mean     r2_sd check
2        concreteness_glasgow     3705  0.816883  0.012927  pass
3        imageability_glasgow     3705  0.760858  0.016616  pass
6             valence_glasgow     3705  0.740695  0.011595  pass
4       semantic_size_glasgow     3705  0.699572  0.015547  pass
1                 aoa_glasgow     3705  0.689229  0.016934  pass
5  gender_association_glasgow     3705  0.608411  0.020159  pass
0         familiarity_glasgow     3705  0.594458  0.010235  pass
7             arousal_glasgow     3705  0.586791  0.012008  pass
8           dominance_glasgow     3705  0.547532  0.013113  pass


Llama_3_8B_12_pca:   0%|          | 0/9 [00:00<?, ?it/s]

                         norm  train_n   r2_mean     r2_sd check
2        concreteness_glasgow     3705  0.677794  0.010213  pass
6             valence_glasgow     3705  0.663785  0.016939  pass
3        imageability_glasgow     3705  0.611887  0.014244  pass
4       semantic_size_glasgow     3705  0.594199  0.016618  pass
1                 aoa_glasgow     3705  0.538224  0.023903  pass
5  gender_association_glasgow     3705  0.487299  0.039133  pass
7             arousal_glasgow     3705  0.460180  0.033356  pass
8           dominance_glasgow     3705  0.420904  0.034180  pass
0         familiarity_glasgow     3705  0.417479  0.010367  pass


BERT_large_0:   0%|          | 0/9 [00:00<?, ?it/s]

                         norm  train_n   r2_mean     r2_sd check
2        concreteness_glasgow     3705  0.693819  0.013818  pass
3        imageability_glasgow     3705  0.634103  0.015326  pass
4       semantic_size_glasgow     3705  0.554692  0.024741  pass
6             valence_glasgow     3705  0.550414  0.019576  pass
1                 aoa_glasgow     3705  0.544505  0.018981  pass
0         familiarity_glasgow     3705  0.462700  0.027213  pass
5  gender_association_glasgow     3705  0.462572  0.026487  pass
7             arousal_glasgow     3705  0.444214  0.030547  pass
8           dominance_glasgow     3705  0.397344  0.014824  pass


BERT_large_1:   0%|          | 0/9 [00:00<?, ?it/s]

                         norm  train_n   r2_mean     r2_sd check
2        concreteness_glasgow     3705  0.818624  0.011994  pass
3        imageability_glasgow     3705  0.766894  0.015934  pass
6             valence_glasgow     3705  0.761208  0.019699  pass
4       semantic_size_glasgow     3705  0.692673  0.009141  pass
1                 aoa_glasgow     3705  0.673347  0.008410  pass
5  gender_association_glasgow     3705  0.638930  0.030228  pass
7             arousal_glasgow     3705  0.585532  0.023245  pass
8           dominance_glasgow     3705  0.572923  0.022465  pass
0         familiarity_glasgow     3705  0.551531  0.028296  pass


BERT_large_2:   0%|          | 0/9 [00:00<?, ?it/s]

                         norm  train_n   r2_mean     r2_sd check
2        concreteness_glasgow     3705  0.765055  0.010530  pass
3        imageability_glasgow     3705  0.708047  0.012424  pass
6             valence_glasgow     3705  0.694481  0.025886  pass
4       semantic_size_glasgow     3705  0.630229  0.009214  pass
1                 aoa_glasgow     3705  0.608934  0.013159  pass
5  gender_association_glasgow     3705  0.530669  0.037692  pass
7             arousal_glasgow     3705  0.527152  0.043292  pass
8           dominance_glasgow     3705  0.515479  0.028075  pass
0         familiarity_glasgow     3705  0.503782  0.024680  pass


BERT_large_012:   0%|          | 0/9 [00:00<?, ?it/s]

                         norm  train_n   r2_mean     r2_sd check
2        concreteness_glasgow     3705  0.812260  0.013623  pass
3        imageability_glasgow     3705  0.761864  0.012816  pass
6             valence_glasgow     3705  0.749975  0.014269  pass
4       semantic_size_glasgow     3705  0.687041  0.010590  pass
1                 aoa_glasgow     3705  0.674546  0.011384  pass
5  gender_association_glasgow     3705  0.618029  0.036581  pass
7             arousal_glasgow     3705  0.588630  0.031029  pass
8           dominance_glasgow     3705  0.572958  0.020643  pass
0         familiarity_glasgow     3705  0.554534  0.022708  pass


BERT_large_12:   0%|          | 0/9 [00:00<?, ?it/s]

                         norm  train_n   r2_mean     r2_sd check
2        concreteness_glasgow     3705  0.814668  0.012081  pass
3        imageability_glasgow     3705  0.764887  0.013041  pass
6             valence_glasgow     3705  0.762466  0.018970  pass
4       semantic_size_glasgow     3705  0.694265  0.007023  pass
1                 aoa_glasgow     3705  0.675675  0.009813  pass
5  gender_association_glasgow     3705  0.622435  0.037686  pass
7             arousal_glasgow     3705  0.587701  0.030402  pass
8           dominance_glasgow     3705  0.579053  0.025905  pass
0         familiarity_glasgow     3705  0.551068  0.026230  pass


BERT_large_12_pca:   0%|          | 0/9 [00:00<?, ?it/s]

                         norm  train_n   r2_mean     r2_sd check
2        concreteness_glasgow     3705  0.795866  0.015213  pass
6             valence_glasgow     3705  0.747387  0.024088  pass
3        imageability_glasgow     3705  0.743296  0.016243  pass
4       semantic_size_glasgow     3705  0.674465  0.010387  pass
1                 aoa_glasgow     3705  0.646638  0.007291  pass
5  gender_association_glasgow     3705  0.599596  0.042816  pass
7             arousal_glasgow     3705  0.557572  0.026752  pass
8           dominance_glasgow     3705  0.552498  0.033853  pass
0         familiarity_glasgow     3705  0.514401  0.018756  pass


Unnamed: 0,embed,embed_type,norm,train_n,test_n,p,r2_mean,r2_sd,check
0,fastText_CommonCrawl,,familiarity_glasgow,3705,927,300,0.664903,0.019769,pass
1,fastText_CommonCrawl,,aoa_glasgow,3705,927,300,0.744070,0.018202,pass
2,fastText_CommonCrawl,,concreteness_glasgow,3705,927,300,0.806406,0.010490,pass
3,fastText_CommonCrawl,,imageability_glasgow,3705,927,300,0.751500,0.011450,pass
4,fastText_CommonCrawl,,semantic_size_glasgow,3705,927,300,0.661660,0.011675,pass
...,...,...,...,...,...,...,...,...,...
112,BERT_large_12_pca,,semantic_size_glasgow,3705,927,510,0.674465,0.010387,pass
113,BERT_large_12_pca,,gender_association_glasgow,3705,927,510,0.599596,0.042816,pass
114,BERT_large_12_pca,,valence_glasgow,3705,927,510,0.747387,0.024088,pass
115,BERT_large_12_pca,,arousal_glasgow,3705,927,510,0.557572,0.026752,pass


In [9]:
resuls_piv = results[['embed', 'norm', 'r2_mean']].pivot(index='norm', columns='embed', values='r2_mean')
resuls_piv

embed,BERT_large_0,BERT_large_012,BERT_large_1,BERT_large_12,BERT_large_12_pca,BERT_large_2,Llama_3_8B_0,Llama_3_8B_012,Llama_3_8B_1,Llama_3_8B_12,Llama_3_8B_12_pca,Llama_3_8B_2,fastText_CommonCrawl
norm,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
aoa_glasgow,0.544505,0.674546,0.673347,0.675675,0.646638,0.608934,0.592659,0.691481,0.692026,0.689229,0.538224,0.660116,0.74407
arousal_glasgow,0.444214,0.58863,0.585532,0.587701,0.557572,0.527152,0.468764,0.580573,0.566248,0.586791,0.46018,0.57499,0.569323
concreteness_glasgow,0.693819,0.81226,0.818624,0.814668,0.795866,0.765055,0.728785,0.817338,0.809845,0.816883,0.677794,0.806623,0.806406
dominance_glasgow,0.397344,0.572958,0.572923,0.579053,0.552498,0.515479,0.431336,0.543795,0.532983,0.547532,0.420904,0.534448,0.532936
familiarity_glasgow,0.4627,0.554534,0.551531,0.551068,0.514401,0.503782,0.501898,0.591983,0.587842,0.594458,0.417479,0.564877,0.664903
gender_association_glasgow,0.462572,0.618029,0.63893,0.622435,0.599596,0.530669,0.482646,0.605429,0.573724,0.608411,0.487299,0.598233,0.666486
imageability_glasgow,0.634103,0.761864,0.766894,0.764887,0.743296,0.708047,0.674631,0.759873,0.74976,0.760858,0.611887,0.750107,0.7515
semantic_size_glasgow,0.554692,0.687041,0.692673,0.694265,0.674465,0.630229,0.596764,0.69821,0.686057,0.699572,0.594199,0.689008,0.66166
valence_glasgow,0.550414,0.749975,0.761208,0.762466,0.747387,0.694481,0.593209,0.736481,0.711922,0.740695,0.663785,0.733508,0.729247


In [10]:
# Finding the top-performing
sorted_overall = resuls_piv.mean().sort_values(ascending=False)
sorted_overall

embed
fastText_CommonCrawl    0.680726
BERT_large_1            0.673518
BERT_large_12           0.672469
Llama_3_8B_12           0.671603
Llama_3_8B_012          0.669463
BERT_large_012          0.668871
Llama_3_8B_2            0.656879
Llama_3_8B_1            0.656712
BERT_large_12_pca       0.647969
BERT_large_2            0.609314
Llama_3_8B_0            0.563410
Llama_3_8B_12_pca       0.541306
BERT_large_0            0.527151
dtype: float64

In [12]:
Llama_3_8B_12

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,4086,4087,4088,4089,4090,4091,4092,4093,4094,4095
ABC,0.189453,-1.609375,2.171875,-1.582031,0.859375,-1.960938,-2.828125,-0.089844,-1.157227,3.281250,...,-0.403320,-2.351562,0.460938,-1.328125,-1.378906,1.578125,-2.523438,-4.070312,2.136719,-0.890625
AI,-1.566406,0.176636,0.636719,-4.281250,2.484375,1.099121,-0.573090,1.021484,0.211914,1.468750,...,2.707031,-2.585938,1.613281,1.421875,-1.289062,1.283203,-1.445312,-0.927734,-1.937500,0.412598
Aaron,-0.206055,-0.132812,-4.195312,1.617188,0.724609,0.179688,-1.687500,1.550781,-2.789062,0.610840,...,-3.195312,2.650391,1.636719,2.976562,-0.875000,1.457031,0.892578,-2.492188,-3.593750,-0.288086
Abe,-2.328125,1.726562,-1.523438,2.617188,-4.265625,0.288086,-2.246094,3.695312,0.273438,4.906250,...,-1.851562,1.902344,2.003906,1.552734,-4.406250,-3.367188,-3.289062,-2.562500,-3.140625,0.612793
Abel,-2.929688,-0.282227,2.609375,-0.161865,-3.335938,0.471146,-3.492188,-0.062500,-0.295410,-0.246094,...,-1.695312,-1.808594,2.914062,0.959717,0.157227,0.011719,-0.902344,1.271484,-4.734375,-1.621094
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
zoom,-0.549805,-0.824219,3.242188,1.152344,2.097656,1.523438,-0.791016,1.021484,0.636719,1.527344,...,1.292969,-2.351562,0.349609,2.664062,-3.312500,0.596680,-2.136719,-3.562500,0.193604,-2.037109
zooming,-0.153320,-0.391602,3.156250,0.383789,1.472656,1.503906,-1.406250,-0.122070,1.734375,0.954102,...,1.037109,-1.343750,-0.480469,1.832031,-2.929688,1.628906,-1.410156,-2.453125,0.828125,-2.503906
zoophobia,0.738281,-2.273438,1.933594,-1.084717,-0.391846,3.140625,-2.390625,1.486328,-0.574219,0.581055,...,0.328125,1.109375,-3.578125,-0.894531,-2.304688,0.183594,2.984375,-0.187500,0.610352,-1.136719
zucchini,0.341797,-1.017578,2.195312,0.048828,0.443359,3.062500,0.572266,-0.417969,-1.554688,-0.011719,...,-2.031250,-0.777344,1.898438,-0.189453,-1.316406,-0.308594,-0.129883,0.100586,0.318359,-1.921875


In [6]:
print(len(Llama_3_8B_12))
print(len(to_compare['Llama_3_8B_12']))

46246
44450


In [None]:
## Saving
Llama_3_8B_12.to_csv('../../data/embeds/LLama_3_8B.csv')
BERT_large_12.to_csv('../../data/embeds/BERT_large.csv')