In [None]:
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import json

## Loading data

In [None]:
rca_full = pd.read_csv('../../data/results/rca.csv')
meta = pd.read_csv('../../data/psychNorms/psychNorms_metadata.csv', index_col='norm')

print(f"# Norms: {len(meta.index.unique())}")
print(f"# Norm categories: {len(meta['category'].unique())}")
print(f"# Embeds: {len(rca_full['embed'].unique())}")

# Adding norm category
rca_full['norm_category'] = (
    rca_full['norm'].apply(lambda norm: meta.loc[norm]['category'])
    .replace({'_': ' '}, regex=True)
)
rca_full

## Evaluating how many embed-norm pairs didn't pass the checker

In [None]:
rca_full['check'].value_counts(dropna=False)

In [None]:
def groupby_pivot(df):
    return(
        df.groupby(['embed', 'norm_category'], as_index=False).count()
        .pivot(index='embed', columns='norm_category', values='norm')
    )

rca = rca_full.dropna()
rca_full_counts = groupby_pivot(rca_full[['embed', 'norm_category', 'norm']])
rca_counts =  groupby_pivot(rca[['embed', 'norm_category', 'norm']])

rca_full_counts, rca_counts = rca_full_counts.align(rca_counts, join='outer')
perc_retained = ((rca_counts / rca_full_counts) * 100)
perc_retained

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(25, 8))

sns.heatmap(rca_full_counts, ax=axs[0], cmap='viridis', annot=True, fmt='g', cbar=False)

sns.heatmap(rca_counts, ax=axs[1], cmap='viridis', annot=True, fmt='g', cbar=False)

sns.heatmap(-perc_retained, ax=axs[2], cmap='viridis', annot=perc_retained.round(0), fmt='g', cbar=False)

axs[0].set_title('Full RCA')
axs[1].set_title('RCA')
axs[2].set_title('RCA / Full RCA (%)')

# remove y tick labels for all but the first plot
for ax in axs[1:]:
    ax.set(ylabel='')
    ax.set_yticklabels([])
    
fig.tight_layout()

# Heatmap

In [None]:
with open('../../data/embed_to_dtype.json', 'r') as f:
    embed_to_type = json.load(f)
    
rca['embed_type'] = rca['embed'].map(embed_to_type)
rca

In [None]:
# rca average
rca_avg = (
    rca[['norm_category', 'embed', 'r2_mean']]
    .groupby(['norm_category', 'embed'], as_index=False).median()
)
rca_avg

In [None]:
rca_avg_piv = rca_avg.pivot(index='embed', columns='norm_category', values='r2_mean')
rca_avg_piv

In [None]:
winner_mask = rca_avg_piv.apply(lambda col: col == col.max(), axis=0)
winner_mask

In [None]:
# Function to create a lighter version of a colormap
def lighten_cmap(cmap_name, factor=0.3):
    cmap = plt.cm.get_cmap(cmap_name, 256)  # Get the original colormap
    colors = cmap(np.linspace(0, 1, 256))

    # Blend each color with white
    white = np.array([1, 1, 1, 1])  # RGBA for white
    new_colors = (1 - factor) * colors + factor * white

    return LinearSegmentedColormap.from_list(f'light_{cmap_name}', new_colors)

# Function to visualize a colormap
def plot_colormap(cmap):
    gradient = np.linspace(0, 1, 256)
    gradient = np.vstack((gradient, gradient))

    plt.imshow(gradient, aspect='auto', cmap=cmap)
    plt.axis('off')
    plt.show()

# Usage example:
# Generate a lighter viridis colormap
lighter_viridis = lighten_cmap('viridis', factor=0.6)

# Visualize it
plot_colormap(lighter_viridis)

In [None]:
def annotate(df, ax):
    for x, norm_cat in enumerate(df.columns):
        for y, embed in enumerate(df.index):
            annot = df.loc[embed, norm_cat]
            
            # Scientific notation
            if abs(annot) > 1e3:
                annot = f'{annot:.1e}'
            elif np.isnan(annot):
                annot = ''
            else:
                annot = f'{annot:.2f}'
            
            # Fontsize and fontweight
            if winner_mask.loc[embed, norm_cat]:
                fontsize, fontweight = 15, 'bold'
            else:
                fontsize, fontweight = 11, 'normal'
            
            
            ax.text(
                x + .5, y + .6, annot, fontsize=fontsize, fontweight=fontweight,
                ha='center', va='center', color='black'
            )

top_behav = (
    rca_avg_piv[[embed_to_type[embed] == 'behavior' for embed in rca_avg_piv.index]] # Selects behavior embeds
    .mean(axis=1).idxmax() # Selects the behavior embed with the highest average r2
)

# Sorts norms by the average r2 of the top behavior embed
norm_ord = rca_avg_piv.loc[top_behav].sort_values(ascending=True).index

# Builds heatmap dfs
heat_dfs = {}
embed_types = ['text', 'brain', 'behavior']
for embed_type in embed_types:
    heat_df = rca_avg_piv[[embed_to_type[embed] == embed_type for embed in rca_avg_piv.index]]

    # Sorts index and columns
    embed_order = heat_df.mean(axis=1).sort_values(ascending=False).index
    heat_dfs[embed_type] = heat_df[norm_ord].loc[embed_order]

fig, axs = plt.subplots(3, 1, figsize=(18, 10))

vmax = rca_avg_piv.max().max()
for i, embed_type in enumerate(['text', 'behavior', 'brain']):
    heat_df = heat_dfs[embed_type]
    
    sns.heatmap(
        heat_df, ax=axs[i], vmin=0, cmap=lighter_viridis, 
        vmax=vmax, annot=False, fmt='', cbar=False,
        
    )
    
    
    axs[i].set(xlabel='', xticklabels=[])
    
    # sets ylabel on right-hand side and flips it
    axs[i].set_ylabel(
        embed_type.title(), fontsize=17, rotation=270,
        labelpad=20, va='center', ha='center'
    )
    axs[i].yaxis.set_label_position('right')
    
    # Annotates cells
    annotate(heat_df, axs[i])
    
    
    # Ensure y-axis labels match the number of ticks
    axs[i].set_yticks(pd.Series(range(len(heat_df.index))) + .5)
    heat_df.index = heat_df.index.str.replace('SVD_sim_rel', 'SVD_similarity_relatedness')
    heat_df.index = heat_df.index.str.replace('_', ' ', regex=True)
    axs[i].set_yticklabels(heat_df.index, fontsize=12)

# Adding xticklabels to last plot
norm_ord = norm_ord.str.title().str.replace(' Of ', ' of ', regex=True)
axs[-1].set_xticklabels(norm_ord, rotation=90, fontsize=13)

# Sets figure title
axs[0].set_title('Average Test ${R^2}$', fontsize=20)
    
fig.tight_layout()
plt.savefig('../../figures/rca.png', dpi=300, bbox_inches='tight')

## Descriptive statistics

In [None]:
rca_avg_piv

In [None]:
# Embedding-wise medians
embed_medians = pd.DataFrame(rca_avg_piv.median(axis=1))
embed_medians['dtype'] = embed_medians.index.map(embed_to_type)
embed_medians.columns = ['median', 'dtype']
embed_medians

In [None]:
# Grouping by dtype and computing the median, 1st and 3rd quartiles
dtype_medians = embed_medians.groupby('dtype').agg(
    median=('median', 'median'),
    first_quartile=('median', lambda x: x.quantile(0.25)),
    third_quartile=('median', lambda x: x.quantile(0.75))
)
dtype_medians.round(2).sort_values(by='median', ascending=False)

## Comparing performance of the best-performing embeds from each type

In [None]:
rca_avg['embed_type'] = rca_avg['embed'].map(embed_to_type)

# Taking the best-performing embed from each type
rca_grand_avg = (
    rca_avg.groupby(['embed_type', 'norm_category'], as_index=False).max()
    .rename(columns={'r2_mean': 'r2_max'})
)
rca_grand_avg

In [None]:
# Top and bottom norms reported in the paper
rca_grand_avg = rca_grand_avg.pivot(columns='norm_category', index='embed_type', values='r2_max').T
rca_grand_avg['behavior - text'] = rca_grand_avg['behavior'] - rca_grand_avg['text']
rca_grand_avg = rca_grand_avg.sort_values(by='behavior - text', ascending=True).round(2)
rca_grand_avg