In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.manifold import MDS
import itertools
from adjustText import adjust_text
import json

## Loading data

In [None]:
rsa = pd.read_csv('../../data/results/rsa.csv')

with open('../../data/embed_to_dtype.json', 'r') as f:
    embed_to_dtype = json.load(f)
    
rsa['dtype_i'] = rsa['name_i'].map(embed_to_dtype)
rsa['dtype_j'] = rsa['name_j'].map(embed_to_dtype)

rsa

## Descriptive stats

**Within vs between dtype correlations**

In [None]:
def dtype_corr(dtype_i, dtype_j):
    corrs =  rsa.query(
        '(dtype_i == @dtype_i & dtype_j == @dtype_j) | (dtype_i == @dtype_j & dtype_j == @dtype_i)'
    )['spearman']
    return corrs.mean().round(2)

# Self-correlation
text_text, brain_brain, behavior_behavior = dtype_corr('text', 'text'), dtype_corr('brain', 'brain'), dtype_corr('behavior', 'behavior')
print(f'text-text: {text_text}')
print(f'brain-brain: {brain_brain}')
print(f'behavior-behavior: {behavior_behavior}')
print('---------------')

# Self-another
text_brain = dtype_corr('text', 'brain')
print(f'Text-brain: {text_brain}')
text_behavior = dtype_corr('text', 'behavior')
print(f'Text-behavior: {text_behavior}')
brain_behavior = dtype_corr('brain', 'behavior')
print(f'Brain-behavior: {brain_behavior}')

**Proportion of same dtype neighbors**

In [None]:
with open('../../data/dtype_to_embed.json', 'r') as f:
    dtype_to_embed = json.load(f)

k = 3
same_dtype_props = {}
for dtype, names in dtype_to_embed.items():
    
    same_dtype_bool = [] 
    for name in names:
        k_neighbors = (
            rsa.query(f'name_i == @name | name_j == @name').nlargest(k, 'spearman')
            [['name_i', 'name_j']].to_numpy().flatten()
        )
        same_dtype_bool += [embed_to_dtype[neighbor] == dtype for neighbor in k_neighbors if neighbor != name]

    same_dtype_props[dtype] = round(np.mean(same_dtype_bool), 2)
    
same_dtype_props

## MDS

In [None]:
text_names = dtype_to_embed['text']
brain_names = dtype_to_embed['brain']
behavior_names = dtype_to_embed['behavior']

print({dtype: len(names) for dtype, names in dtype_to_embed.items()})


def to_heat_df(results, col):
    # Heat df template
    names = list(pd.concat([results['name_i'], results['name_j']]).unique()) # needed because not all models guaranteed on one column
    heat_df = pd.DataFrame(index=names, columns=names)

    # Filling with correlations
    query = '(name_i == @name_i & name_j == @name_j) | (name_i == @name_j & name_j == @name_i)'
    for name_i, name_j in list(itertools.combinations(names, 2)):
        r, *_ = results.query(query)[col]
        heat_df.loc[name_i, name_j] = r
        heat_df.loc[name_j, name_i] = r
        
        order = text_names + behavior_names + brain_names

    return heat_df.loc[order, order].astype(float)


spearmans = to_heat_df(rsa, 'spearman')
spearmans

In [None]:
dissimilarity = 1 - spearmans
np.fill_diagonal(dissimilarity.values, 0.0)

# MDS
mds = MDS(n_components=2, dissimilarity='precomputed', random_state=0)
spearmans_2d = mds.fit_transform(dissimilarity)
spearmans_2d = pd.DataFrame(spearmans_2d, index=spearmans.index)

def data_type(mod_name):
    if mod_name in brain_names:
        return 'brain'
    elif mod_name in behavior_names:
        return 'behavior'
    else:
        return 'text'

# Adding data type
spearmans_2d['embed_type'] = [data_type(name) for name in spearmans_2d.index]
spearmans_2d

## Plotting

In [None]:
def fix_name(name: pd.Series) -> pd.Series:
    rename = {
    'compo_attribs': 'experiential_attributes',
    'SVD_sim_rel': 'SVD_similarity_relatedness'
    }
    return rename.get(name, name).replace('_', ' ')


# Rename some embeds and remove underscores
spearmans.index = spearmans.index.to_series().apply(fix_name)
spearmans.columns = spearmans.columns.to_series().apply(fix_name)
spearmans_2d.index = spearmans.index.to_series().apply(fix_name)

# Colors
cmap = plt.get_cmap('viridis', 4)
embed_type_to_color = {
    'brain': cmap(1),
    'behavior': cmap(0),
    'text': cmap(2)
}


In [None]:
def plot_scatter_panel(ax, spearmans_2d, embed_type_to_color, spearmans, add_label=False):
    if add_label:
        annot_fontsize = 13
    else:
        annot_fontsize = 12
    
    sns.scatterplot(
        data=spearmans_2d, x=0, y=1, hue='embed_type',
        sizes=(500, 500), legend=False, s=110,
        marker='s', linewidth=0.1, edgecolor='black',
        palette=embed_type_to_color, ax=ax
    )

    # Remove axis labels/ticks
    ax.set(xticklabels='', yticklabels='', xlabel='', ylabel='')

    # Add text labels for points
    texts = []
    for model in spearmans.index:
        texts.append(
            ax.text(
                spearmans_2d[0][model],
                spearmans_2d[1][model],
                model.replace('_', ' '),
                fontsize=annot_fontsize
            )
        )
    adjust_text(
        texts, arrowprops=dict(arrowstyle='-', color='black', lw=.5), ax=ax
    )

    ax.axis('off')
    
    # Optionally add the panel label
    if add_label:
        ax.text(
            -0.1, 1.05, 'A',
            transform=ax.transAxes,
            fontsize=20,
            fontweight='bold',
            va='top'
        )


def plot_heatmap_panel(ax, spearmans, add_label=False):
    """
    Draws the heatmap on the given Axes object `ax`.
    If add_label is True, draw the bold panel label 'B' on the panel.
    """
    if add_label:
        annot_fontsize = 7
        ticklabel_fontsize = 11
    else:
        annot_fontsize = 5
        ticklabel_fontsize = 9
    
    sns.heatmap(
        spearmans, square=True, annot=True, cmap='viridis',
        vmin=0, vmax=spearmans.max().max(), 
        fmt='.2f', annot_kws={"fontsize": annot_fontsize}, cbar=False, ax=ax
    )

    # Increase font size of x and y tick labels
    ax.tick_params(axis='both', which='major', labelsize=ticklabel_fontsize)

    # Optionally add the panel label
    if add_label:
        ax.text(
            -0.4, 1.05, 'B',
            transform=ax.transAxes,
            fontsize=20,
            fontweight='bold',
            va='top'
        )


# Create a 1-row, 2-column figure
fig, (ax_1, ax_2) = plt.subplots(1, 2, figsize=(20, 8), width_ratios=(0.6, 1))

# Plot each panel, adding the labels:
plot_scatter_panel(ax_1, spearmans_2d, embed_type_to_color, spearmans, add_label=True)
plot_heatmap_panel(ax_2, spearmans, add_label=True)

fig.tight_layout()
plt.savefig('../../figures/rsa.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# MDS plot
fig_a, ax_a = plt.subplots(figsize=(8, 6))
plot_scatter_panel(ax_a, spearmans_2d, embed_type_to_color, spearmans, add_label=False)
fig_a.tight_layout()
plt.show()

In [None]:
# rsa matrix
fig_b, ax_b = plt.subplots(figsize=(8, 6))
plot_heatmap_panel(ax_b, spearmans, add_label=False)
fig_b.tight_layout()
plt.show()