In [1]:
import pandas as pd
import numpy as np
from sklearn.linear_model import RidgeCV, LogisticRegressionCV
from sklearn.model_selection import StratifiedKFold
from tqdm.notebook import tqdm
import json
from rca import process_categorical, best_logistic_solver, k_fold_cross_val, make_binary_scoring, make_multiclass_scoring, checker
import pickle



## Loading Data

In [2]:
rca = pd.read_csv('../../data/results/rca.csv').dropna()
meta = pd.read_csv('../../data/psychNorms/psychNorms_metadata.csv', index_col=0)
norms = pd.read_csv('../../data/psychNorms/psychNorms.zip', index_col=0, compression='zip', low_memory=False)

# Adding norm_cat to rca
rca['norm_cat'] = (
    rca['norm'].apply(lambda norm: meta.loc[norm]['category'])
    .replace({'_': ' '}, regex=True)
)

rca

Unnamed: 0,embed,embed_type,norm,train_n,test_n,p,r2_mean,r2_sd,mse_mean,mse_sd,check,norm_cat
0,CBOW_GoogleNews,text,Freq_HAL,28012,7003,300,0.522118,0.008398,2.715453,0.072527,pass,frequency
1,CBOW_GoogleNews,text,Freq_KF,19285,4822,300,0.500425,0.009710,0.156666,0.004775,pass,frequency
2,CBOW_GoogleNews,text,Freq_SUBTLEXUS,28636,7159,300,0.537265,0.009814,0.361345,0.007803,pass,frequency
3,CBOW_GoogleNews,text,Freq_SUBTLEXUK,29316,7330,300,0.545643,0.008416,0.446042,0.009075,pass,frequency
4,CBOW_GoogleNews,text,Freq_Blog,31876,7969,300,0.523700,0.008448,0.400166,0.009345,pass,frequency
...,...,...,...,...,...,...,...,...,...,...,...,...
7003,THINGS,behavior,familiarity_vanarsdall,376,95,49,0.091050,0.077868,3216.050320,580.598768,pass,familiarity
7004,THINGS,behavior,imageability_vanarsdall,376,95,49,0.074112,0.090266,1317.411401,292.505564,pass,imageability
7005,THINGS,behavior,familiarity_fear,173,44,49,0.151637,0.158767,0.782185,0.220500,pass,familiarity
7006,THINGS,behavior,aoa_fear,173,44,49,0.015012,0.129597,0.514519,0.043528,pass,age of acquisition


In [3]:
embed_avgs = (
    rca[['embed', 'norm_cat', 'r2_mean']]
    .groupby(['embed', 'norm_cat']).median(numeric_only=True) # median is used to mitigate outliers within norm_cats
    .groupby('embed').mean()
    .rename(columns={'r2_mean': 'r2_avg'})
)
embed_avgs

Unnamed: 0_level_0,r2_avg
embed,Unnamed: 1_level_1
CBOW_GoogleNews,0.451043
EEG_speech,-0.06551
EEG_text,-0.048021
GloVe_CommonCrawl,0.315571
GloVe_Twitter,0.298494
GloVe_Wikipedia,0.301758
LexVec_CommonCrawl,0.293098
PPMI_SVD_EAT,0.296026
PPMI_SVD_SWOW,0.403018
PPMI_SVD_SouthFlorida,0.253147


In [4]:
# Adding embed types
with open('../../data/embed_to_dtype.json', 'r') as f:
    embed_to_type = json.load(f)
embed_avgs['type'] = embed_avgs.index.map(embed_to_type)

# Finding top 2 text 
text_name_1, text_name_2 = (
    embed_avgs.query('type == "text"').sort_values('r2_avg', ascending=False).head(2).index.tolist()
)
text_name_1, text_name_2

('CBOW_GoogleNews', 'fastText_CommonCrawl')

In [5]:
# Finding top behavior
behavior_name = (
    embed_avgs.query('type == "behavior"').sort_values('r2_avg', ascending=False).head(1).index[0]
)
behavior_name

'PPMI_SVD_SWOW'

In [6]:
# Loading embeds
text_1 = pd.read_csv(f'../../data/embeds/{text_name_1}.csv', index_col=0)
text_2 = pd.read_csv(f'../../data/embeds/{text_name_2}.csv', index_col=0)
behavior = pd.read_csv(f'../../data/embeds/{behavior_name}.csv', index_col=0)

with open('../../data/brain_behav_union.pkl', 'rb') as f:
    brain_behav_union = pickle.load(f)

# Subsdetting to brain and behavior union and aligning vocabs
intersect = sorted(list(set.intersection(
    set(text_1.index), set(text_2.index), set(behavior.index), set(brain_behav_union)
)))
text_1, text_2, behavior = text_1.loc[intersect], text_2.loc[intersect], behavior.loc[intersect]

# Standardizing
standardize = lambda df: (df - df.mean()) / df.std()
text_1, text_2, behavior = standardize(text_1), standardize(text_2), standardize(behavior)

# Ensembling for comparison
embeds = {
    behavior_name: behavior,
    text_name_1: text_1, 
    text_name_2: text_2,
    text_name_1 + '&' + text_name_2: pd.concat([text_1, text_2], axis=1),
    text_name_1 + '&' + behavior_name: pd.concat([text_1, behavior], axis=1),
    text_name_2 + '&' + behavior_name: pd.concat([text_2, behavior], axis=1)
}

# Fixing column names
for embed_name, embed in embeds.items():
    embed.columns = list(range(embed.shape[1]))
    embeds[embed_name] = embed

{name: embed.shape for name, embed in embeds.items()}

{'PPMI_SVD_SWOW': (11723, 300),
 'CBOW_GoogleNews': (11723, 300),
 'fastText_CommonCrawl': (11723, 300),
 'CBOW_GoogleNews&fastText_CommonCrawl': (11723, 600),
 'CBOW_GoogleNews&PPMI_SVD_SWOW': (11723, 600),
 'fastText_CommonCrawl&PPMI_SVD_SWOW': (11723, 600)}

In [7]:
# Changing associated_embed to more usable format
meta['associated_embed'] = meta['associated_embed'].str.split(' ')

# Log transforming selected norms
norms_to_log = pd.read_csv('../../data/norms_to_log.csv')['norm']
norms[norms_to_log] = norms[norms_to_log].apply(np.log1p)
norms_to_log

0             Nsenses_WordNet
1           Nsenses_Wordsmyth
2         Nmeanings_Wordsmyth
3          Nmeanings_Websters
4                   NFeatures
5                       Sem_N
6         Assoc_Freq_Token123
7                 Cue_SetSize
8           LexicalD_RT_V_ELP
9           LexicalD_RT_V_ECP
10          LexicalD_RT_V_BLP
11         LexicalD_RT_A_MALD
12         LexicalD_RT_A_AELP
13              Naming_RT_ELP
14       SemanticD_RT_Calgary
15                  rt_khanna
16                     rt_ley
17               rt_chiarello
18                    rt_chen
19             aoa_rt_cortese
20    imageability_rt_cortese
21                  rt_schock
Name: norm, dtype: object

## Cross Validation

In [8]:
# Ridge
min_ord, max_ord = -5, 5
alphas = np.logspace(
    min_ord, max_ord, max_ord - min_ord + 1
)
ridge = RidgeCV(alphas=alphas)

# Logistic hyperparameters
Cs = 1 / alphas
inner_cv = 5
penalty = 'l2'

# Scorers
binary_scoring = make_binary_scoring()
multiclass_scoring = make_multiclass_scoring()
continuous_scoring = {'r2': 'r2', 'neg_mse': 'neg_mean_squared_error'}

# outer_cv setting 
outer_cv, n_jobs = 5, 6

solo_embed_names = [text_name_1, text_name_2, behavior_name] # For checking data leakage in checker

In [9]:
# RCA
rca = []
for norm_name in tqdm(norms.columns):
    print(f'{norm_name}:')
    y = norms[norm_name].dropna()
    
    to_print = []
    for embed_name, embed in embeds.items():
        
        # Aligning embed with norm
        X, y = embed.align(y, axis='index', join='inner', copy=True)
        
        # Checking norm dtype 
        norm_dtype = meta.loc[norm_name, 'type']
        
        # Solvers, scoring, estimators ir categorical or continuous
        if norm_dtype in ['binary', 'multiclass']: # categorical
            X, y = process_categorical(outer_cv, inner_cv, X, y)
            
            # may have switched form multi to bin after processing
            norm_dtype = 'binary' if len(y.unique()) == 2 else 'multiclass'
            
            # Cross validation settings for logistic regression
            solver = best_logistic_solver(y, norm_dtype)
            
            # Defining logistic regression 
            estimator = LogisticRegressionCV(
                Cs=Cs, penalty=penalty, cv=StratifiedKFold(inner_cv), solver=solver
            )
            scoring = binary_scoring if norm_dtype == 'binary' else multiclass_scoring
        else: # continuous
            estimator, scoring = ridge, continuous_scoring
            
        # Cross validation
        associated_embed = meta.loc[norm_name, 'associated_embed']
        check = checker(solo_embed_names, y, norm_dtype, associated_embed, outer_cv)
        if check == 'pass':
            scores = k_fold_cross_val(estimator, X, y, outer_cv, scoring, n_jobs) # stratification is automatically used for classification
            r2s, mses = scores['test_r2'], - scores['test_neg_mse']
        else:
            r2s, mses = pd.Series([np.nan] * outer_cv), pd.Series([np.nan] * outer_cv)
            
        # Saving
        train_n = int(((outer_cv - 1) / outer_cv) * len(y))
        for i, (r2, mse) in enumerate(zip(r2s, mses)):
            rca.append([embed_name, norm_name, train_n, i + 1, r2, mse, check])
            
        # Printing
        to_print.append([embed_name, r2s.mean(), r2s.std(), check])
    to_print = pd.DataFrame(to_print, columns=['embed', 'r2_mean', 'r2_std', 'check'])
    print(to_print.sort_values('r2_mean', ascending=False).head(10).reset_index(drop=True))
    print('--------------------------------')
 
 
rca = pd.DataFrame(
    rca, columns=[
        'embed', 'norm', 'train_n', 'fold', 'r2', 'mse', 'check']
)
rca.to_csv('../../data/results/rca_ensemb.csv', index=False)
rca

  0%|          | 0/292 [00:00<?, ?it/s]

Freq_HAL:
                                  embed   r2_mean    r2_std check
0    fastText_CommonCrawl&PPMI_SVD_SWOW  0.829188  0.010604  pass
1  CBOW_GoogleNews&fastText_CommonCrawl  0.816586  0.013173  pass
2                  fastText_CommonCrawl  0.801812  0.012247  pass
3         CBOW_GoogleNews&PPMI_SVD_SWOW  0.723285  0.013075  pass
4                       CBOW_GoogleNews  0.665994  0.011748  pass
5                         PPMI_SVD_SWOW  0.515268  0.026448  pass
--------------------------------
Freq_KF:
                                  embed   r2_mean    r2_std check
0    fastText_CommonCrawl&PPMI_SVD_SWOW  0.727461  0.025132  pass
1  CBOW_GoogleNews&fastText_CommonCrawl  0.712266  0.028553  pass
2                  fastText_CommonCrawl  0.690077  0.029289  pass
3         CBOW_GoogleNews&PPMI_SVD_SWOW  0.659806  0.021880  pass
4                       CBOW_GoogleNews  0.602879  0.021317  pass
5                         PPMI_SVD_SWOW  0.474919  0.043541  pass
------------------------



                                  embed   r2_mean    r2_std check
0    fastText_CommonCrawl&PPMI_SVD_SWOW  0.602445  0.024413  pass
1  CBOW_GoogleNews&fastText_CommonCrawl  0.593844  0.019808  pass
2                  fastText_CommonCrawl  0.573393  0.025700  pass
3         CBOW_GoogleNews&PPMI_SVD_SWOW  0.550520  0.025434  pass
4                       CBOW_GoogleNews  0.507764  0.018872  pass
5                         PPMI_SVD_SWOW  0.393716  0.038790  pass
--------------------------------
Freq_CobS:
                                  embed   r2_mean    r2_std check
0    fastText_CommonCrawl&PPMI_SVD_SWOW  0.599674  0.020229  pass
1  CBOW_GoogleNews&fastText_CommonCrawl  0.583162  0.018801  pass
2         CBOW_GoogleNews&PPMI_SVD_SWOW  0.565387  0.024336  pass
3                  fastText_CommonCrawl  0.565240  0.023171  pass
4                       CBOW_GoogleNews  0.513011  0.018845  pass
5                         PPMI_SVD_SWOW  0.444104  0.037114  pass
--------------------------------











                                  embed   r2_mean    r2_std check
0                  fastText_CommonCrawl  0.771918  0.006751  pass
1    fastText_CommonCrawl&PPMI_SVD_SWOW  0.771181  0.011474  pass
2  CBOW_GoogleNews&fastText_CommonCrawl  0.765757  0.007572  pass
3         CBOW_GoogleNews&PPMI_SVD_SWOW  0.709303  0.011728  pass
4                       CBOW_GoogleNews  0.699887  0.009189  pass
5                         PPMI_SVD_SWOW  0.479969  0.009043  pass
--------------------------------
DPoS_VanH:












                                  embed   r2_mean    r2_std check
0  CBOW_GoogleNews&fastText_CommonCrawl  0.683204  0.007444  pass
1                  fastText_CommonCrawl  0.681129  0.006797  pass
2    fastText_CommonCrawl&PPMI_SVD_SWOW  0.674923  0.008685  pass
3         CBOW_GoogleNews&PPMI_SVD_SWOW  0.635099  0.005329  pass
4                       CBOW_GoogleNews  0.626675  0.004448  pass
5                         PPMI_SVD_SWOW  0.417472  0.011492  pass
--------------------------------
Conc_Brys:




                                  embed   r2_mean    r2_std check
0    fastText_CommonCrawl&PPMI_SVD_SWOW  0.845518  0.003899  pass
1         CBOW_GoogleNews&PPMI_SVD_SWOW  0.833835  0.004971  pass
2  CBOW_GoogleNews&fastText_CommonCrawl  0.818273  0.001787  pass
3                  fastText_CommonCrawl  0.797529  0.005259  pass
4                       CBOW_GoogleNews  0.767745  0.011398  pass
5                         PPMI_SVD_SWOW  0.751258  0.007806  pass
--------------------------------
Conc_Glasgow:
                                  embed   r2_mean    r2_std check
0    fastText_CommonCrawl&PPMI_SVD_SWOW  0.850106  0.001971  pass
1         CBOW_GoogleNews&PPMI_SVD_SWOW  0.833601  0.003123  pass
2  CBOW_GoogleNews&fastText_CommonCrawl  0.829314  0.006607  pass
3                  fastText_CommonCrawl  0.810710  0.009411  pass
4                       CBOW_GoogleNews  0.786143  0.008561  pass
5                         PPMI_SVD_SWOW  0.747822  0.009598  pass
-----------------------------



                                  embed   r2_mean    r2_std check
0                         PPMI_SVD_SWOW  0.418172  0.032813  pass
1    fastText_CommonCrawl&PPMI_SVD_SWOW  0.403793  0.033730  pass
2         CBOW_GoogleNews&PPMI_SVD_SWOW  0.402412  0.036872  pass
3                  fastText_CommonCrawl  0.375660  0.026135  pass
4  CBOW_GoogleNews&fastText_CommonCrawl  0.367141  0.026476  pass
5                       CBOW_GoogleNews  0.349945  0.023419  pass
--------------------------------
Emot_Assoc_Anticipation:




                                  embed   r2_mean    r2_std check
0         CBOW_GoogleNews&PPMI_SVD_SWOW  0.173487  0.029919  pass
1                         PPMI_SVD_SWOW  0.170380  0.033975  pass
2    fastText_CommonCrawl&PPMI_SVD_SWOW  0.159745  0.036124  pass
3                  fastText_CommonCrawl  0.158850  0.027981  pass
4                       CBOW_GoogleNews  0.156699  0.015277  pass
5  CBOW_GoogleNews&fastText_CommonCrawl  0.149275  0.036816  pass
--------------------------------
Emot_Assoc_Disgust:




                                  embed   r2_mean    r2_std check
0                         PPMI_SVD_SWOW  0.397922  0.031886  pass
1         CBOW_GoogleNews&PPMI_SVD_SWOW  0.393010  0.042778  pass
2    fastText_CommonCrawl&PPMI_SVD_SWOW  0.390471  0.048489  pass
3                  fastText_CommonCrawl  0.389882  0.054060  pass
4  CBOW_GoogleNews&fastText_CommonCrawl  0.381198  0.046997  pass
5                       CBOW_GoogleNews  0.372124  0.043416  pass
--------------------------------
Emot_Assoc_Fear:




                                  embed   r2_mean    r2_std check
0    fastText_CommonCrawl&PPMI_SVD_SWOW  0.379433  0.020732  pass
1                  fastText_CommonCrawl  0.374663  0.013846  pass
2                         PPMI_SVD_SWOW  0.371339  0.037375  pass
3         CBOW_GoogleNews&PPMI_SVD_SWOW  0.367637  0.027068  pass
4  CBOW_GoogleNews&fastText_CommonCrawl  0.361888  0.011947  pass
5                       CBOW_GoogleNews  0.352284  0.016047  pass
--------------------------------
Emot_Assoc_Joy:




                                  embed   r2_mean    r2_std check
0    fastText_CommonCrawl&PPMI_SVD_SWOW  0.418624  0.055374  pass
1         CBOW_GoogleNews&PPMI_SVD_SWOW  0.417561  0.045097  pass
2                         PPMI_SVD_SWOW  0.404239  0.045958  pass
3  CBOW_GoogleNews&fastText_CommonCrawl  0.373743  0.030934  pass
4                  fastText_CommonCrawl  0.367377  0.035434  pass
5                       CBOW_GoogleNews  0.359613  0.027599  pass
--------------------------------
Emot_Assoc_Negative:




                                  embed   r2_mean    r2_std check
0         CBOW_GoogleNews&PPMI_SVD_SWOW  0.513867  0.038848  pass
1    fastText_CommonCrawl&PPMI_SVD_SWOW  0.510466  0.040016  pass
2                         PPMI_SVD_SWOW  0.510382  0.031182  pass
3  CBOW_GoogleNews&fastText_CommonCrawl  0.492611  0.024712  pass
4                  fastText_CommonCrawl  0.476773  0.033086  pass
5                       CBOW_GoogleNews  0.472697  0.032912  pass
--------------------------------
Emot_Assoc_Positive:




                                  embed   r2_mean    r2_std check
0                         PPMI_SVD_SWOW  0.340597  0.049519  pass
1    fastText_CommonCrawl&PPMI_SVD_SWOW  0.338542  0.054684  pass
2         CBOW_GoogleNews&PPMI_SVD_SWOW  0.322858  0.064485  pass
3                  fastText_CommonCrawl  0.314134  0.049435  pass
4  CBOW_GoogleNews&fastText_CommonCrawl  0.304535  0.059290  pass
5                       CBOW_GoogleNews  0.287440  0.057090  pass
--------------------------------
Emot_Assoc_Sadness:




                                  embed   r2_mean    r2_std check
0                         PPMI_SVD_SWOW  0.388995  0.040903  pass
1    fastText_CommonCrawl&PPMI_SVD_SWOW  0.382987  0.053363  pass
2         CBOW_GoogleNews&PPMI_SVD_SWOW  0.378218  0.051497  pass
3                  fastText_CommonCrawl  0.356922  0.052973  pass
4  CBOW_GoogleNews&fastText_CommonCrawl  0.352175  0.045708  pass
5                       CBOW_GoogleNews  0.339112  0.045555  pass
--------------------------------
Emot_Assoc_Surprise:




                                  embed   r2_mean    r2_std check
0                  fastText_CommonCrawl  0.170152  0.031541  pass
1    fastText_CommonCrawl&PPMI_SVD_SWOW  0.166545  0.049851  pass
2         CBOW_GoogleNews&PPMI_SVD_SWOW  0.159957  0.050778  pass
3  CBOW_GoogleNews&fastText_CommonCrawl  0.156872  0.037283  pass
4                       CBOW_GoogleNews  0.144704  0.039108  pass
5                         PPMI_SVD_SWOW  0.142918  0.033372  pass
--------------------------------
Emot_Assoc_Trust:




                                  embed   r2_mean    r2_std check
0    fastText_CommonCrawl&PPMI_SVD_SWOW  0.238483  0.021816  pass
1                         PPMI_SVD_SWOW  0.238322  0.018760  pass
2         CBOW_GoogleNews&PPMI_SVD_SWOW  0.233109  0.021987  pass
3  CBOW_GoogleNews&fastText_CommonCrawl  0.223876  0.012682  pass
4                  fastText_CommonCrawl  0.210081  0.026418  pass
5                       CBOW_GoogleNews  0.168465  0.047997  pass
--------------------------------
Sem_Diversity:
                                  embed   r2_mean    r2_std check
0  CBOW_GoogleNews&fastText_CommonCrawl  0.736929  0.009970  pass
1    fastText_CommonCrawl&PPMI_SVD_SWOW  0.726702  0.009477  pass
2                  fastText_CommonCrawl  0.723172  0.009614  pass
3         CBOW_GoogleNews&PPMI_SVD_SWOW  0.700762  0.009484  pass
4                       CBOW_GoogleNews  0.686559  0.011982  pass
5                         PPMI_SVD_SWOW  0.492860  0.021252  pass
----------------------------

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

                                  embed   r2_mean    r2_std check
0  CBOW_GoogleNews&fastText_CommonCrawl  0.576822  0.070480  pass
1                  fastText_CommonCrawl  0.565732  0.137874  pass
2                       CBOW_GoogleNews  0.551297  0.074591  pass
3    fastText_CommonCrawl&PPMI_SVD_SWOW  0.535748  0.220551  pass
4         CBOW_GoogleNews&PPMI_SVD_SWOW  0.472984  0.178464  pass
5                         PPMI_SVD_SWOW  0.462997  0.290269  pass
--------------------------------
imagery_toronto:
                                  embed   r2_mean    r2_std check
0    fastText_CommonCrawl&PPMI_SVD_SWOW  0.669332  0.056579  pass
1  CBOW_GoogleNews&fastText_CommonCrawl  0.649006  0.039851  pass
2         CBOW_GoogleNews&PPMI_SVD_SWOW  0.640592  0.052086  pass
3                  fastText_CommonCrawl  0.620359  0.057510  pass
4                       CBOW_GoogleNews  0.579855  0.039124  pass
5                         PPMI_SVD_SWOW  0.578119  0.078404  pass
--------------------------

Unnamed: 0,embed,norm,train_n,fold,r2,mse,check
0,PPMI_SVD_SWOW,Freq_HAL,9040,1,0.478024,1.691392,pass
1,PPMI_SVD_SWOW,Freq_HAL,9040,2,0.507218,1.634118,pass
2,PPMI_SVD_SWOW,Freq_HAL,9040,3,0.549852,1.669096,pass
3,PPMI_SVD_SWOW,Freq_HAL,9040,4,0.500720,1.695633,pass
4,PPMI_SVD_SWOW,Freq_HAL,9040,5,0.540526,1.796841,pass
...,...,...,...,...,...,...,...
8755,fastText_CommonCrawl&PPMI_SVD_SWOW,imageability_fear,510,1,0.667308,0.565868,pass
8756,fastText_CommonCrawl&PPMI_SVD_SWOW,imageability_fear,510,2,0.640832,0.705835,pass
8757,fastText_CommonCrawl&PPMI_SVD_SWOW,imageability_fear,510,3,0.592154,0.793009,pass
8758,fastText_CommonCrawl&PPMI_SVD_SWOW,imageability_fear,510,4,0.539998,0.741900,pass
