In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.manifold import MDS
import itertools
from adjustText import adjust_text
import json

# Main plot

In [2]:
rsa = pd.read_csv('../../data/final/rsa.csv')
rsa

Unnamed: 0,name_i,name_j,spearman,n_words,dtype_i,dtype_j
0,CBOW_GoogleNews,fastText_CommonCrawl,0.633686,42374,text,text
1,CBOW_GoogleNews,fastText_Wiki_News,0.540727,41950,text,text
2,CBOW_GoogleNews,fastTextSub_OpenSub,0.461847,39759,text,text
3,CBOW_GoogleNews,GloVe_CommonCrawl,0.499414,42352,text,text
4,CBOW_GoogleNews,GloVe_Twitter,0.220378,32417,text,text
...,...,...,...,...,...,...
295,EEG_text,fMRI_text_hyper_align,0.194555,537,brain,brain
296,EEG_text,microarray,0.115807,266,brain,brain
297,fMRI_speech_hyper_align,fMRI_text_hyper_align,0.208698,315,brain,brain
298,fMRI_speech_hyper_align,microarray,0.048022,138,brain,brain


In [3]:
with open('../../data/raw/embed_to_dtype.json', 'r') as f:
    dtype_to_embed = json.load(f)
    
rsa['dtype_i'] = rsa['name_i'].map(dtype_to_embed)
rsa['dtype_j'] = rsa['name_j'].map(dtype_to_embed)

# Text clustering
text_text = rsa.query('dtype_i == "text" & dtype_j == "text"')['spearman'].mean().round(2)
print(f'Within-text mean correlation {text_text}')
text_other = rsa.query('(dtype_i == "text" & dtype_j != "text") | (dtype_i != "text" & dtype_j == "text")')['spearman'].mean().round(2)
print(f'Text-other mean correlation {text_other}')
print('---------------')

# Brain clustering
brain_brain = rsa.query('dtype_i == "brain" & dtype_j == "brain"')['spearman'].mean().round(2)
print(f'Within-brain mean correlation {brain_brain}')
brain_other = rsa.query('(dtype_i == "brain" & dtype_j != "brain") | (dtype_i != "brain" & dtype_j == "brain")')['spearman'].mean().round(2)
print(f'Brain-other mean correlation {brain_other}')
print('---------------')

# Behavior clustering
behavior_behavior = rsa.query('dtype_i == "behavior" & dtype_j == "behavior"')['spearman'].mean().round(2)
print(f'Within-behavior mean correlation {behavior_behavior}')
behavior_other = rsa.query('(dtype_i == "behavior" & dtype_j != "behavior") | (dtype_i != "behavior" & dtype_j == "behavior")')['spearman'].mean().round(2)
print(f'Behavior-other mean correlation {behavior_other}')
print('---------------')

Within-text mean correlation 0.41
Text-other mean correlation 0.16
---------------
Within-brain mean correlation 0.12
Brain-other mean correlation 0.06
---------------
Within-behavior mean correlation 0.23
Behavior-other mean correlation 0.14
---------------


In [4]:
def to_heat_df(results, col):
    # Heat df template
    names = list(pd.concat([results['name_i'], results['name_j']]).unique()) # needed because not all models guaranteed on one column
    heat_df = pd.DataFrame(index=names, columns=names)

    # Filling with correlations
    query = '(name_i == @name_i & name_j == @name_j) | (name_i == @name_j & name_j == @name_i)'
    for name_i, name_j in list(itertools.combinations(names, 2)):
        r, *_ = results.query(query)[col]
        heat_df.loc[name_i, name_j] = r
        heat_df.loc[name_j, name_i] = r
        
        order = text_names + brain_names + behavior_names

    return heat_df.loc[order, order].astype(float)


with open('../../data/raw/dtype_to_embed.json', 'r') as f:
    dtype_to_embed = json.load(f)

text_names = dtype_to_embed['text']
brain_names = dtype_to_embed['brain']
behavior_names = dtype_to_embed['behavior']

spearmans = to_heat_df(rsa, 'spearman')
spearmans

KeyError: "['PPMI_SVD_EAT'] not in index"

In [None]:
dissimilarity = 1 - spearmans
np.fill_diagonal(dissimilarity.values, 0.0)

# MDS
mds = MDS(n_components=2, dissimilarity='precomputed', random_state=2)
spearmans_2d = mds.fit_transform(dissimilarity)
spearmans_2d = pd.DataFrame(spearmans_2d, index=spearmans.index)

def data_type(mod_name):
    if mod_name in brain_names:
        return 'brain'
    elif mod_name in behavior_names:
        return 'behavior'
    else:
        return 'text'

# Adding data type
spearmans_2d['embed_type'] = [data_type(name) for name in spearmans_2d.index]
spearmans_2d

In [None]:
fig, (ax_1, ax_2) = plt.subplots(1, 2, figsize=(20, 8), width_ratios=(0.8, 1))

# MDS
cmap = plt.get_cmap('viridis', 4)
embed_type_to_color = {
    'brain': cmap(1),
    'behavior': cmap(0),
    'text': cmap(2)
}

sns.scatterplot(
    data=spearmans_2d, x=0, y=1, hue='embed_type',
    sizes=(500, 500), legend=False, s=110,
    marker='s', linewidth=0.1, edgecolor='black',
    palette=embed_type_to_color, ax=ax_1
)

ax_1.set(xticklabels='', yticklabels='', xlabel='', ylabel='')

texts = []
for model in spearmans.index:
    texts.append(
        ax_1.text(spearmans_2d[0][model], spearmans_2d[1][model], model.replace('_', ' '), fontsize=13)
    )

# Adjust text labels to avoid overlap
adjust_text(
    texts, arrowprops=dict(arrowstyle='-', color='black', lw=.5), ax=ax_1
)
ax_1.axis('off') # Turn off the axis

# Heatmap 
spearmans.index = spearmans.index.str.replace('_', ' ')
spearmans.columns = spearmans.columns.str.replace('_', ' ')
sns.heatmap(
    spearmans, square=True, annot=True, cmap='viridis',
    vmin=0, vmax=spearmans.max().max(), 
    fmt='.2f', annot_kws={"fontsize": 6}, cbar=False, ax=ax_2
)

# Adding bold panel labels
ax_1.text(-0.1, 1.05, 'A', transform=ax_1.transAxes, fontsize=20, fontweight='bold', va='top')
ax_2.text(-0.4, 1.05, 'B', transform=ax_2.transAxes, fontsize=20, fontweight='bold', va='top')

fig.tight_layout()
plt.savefig('../../figures/rsa.png', dpi=300, bbox_inches='tight')