In [None]:
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import json
from scipy.stats import wilcoxon

# Processing data

In [None]:
# Loading data
rca = pd.read_csv('../../data/results/rca_ensemb.csv').dropna()
meta = pd.read_csv('../../data/psychNorms/psychNorms_metadata.csv', index_col=0)

# Adding norm_cat to rca
rca['norm_cat'] = (
    rca['norm'].apply(lambda norm: meta.loc[norm]['category'])
    .replace({'_': ' '}, regex=True)
)


with open('../../data/embed_to_dtype.json', 'r') as f:
    embed_to_type = json.load(f)
    
def embed_to_group(embed_name):
    if '&' in embed_name:
        name_1, name_2 = embed_name.split('&')
        return embed_to_type[name_1] + '&' + embed_to_type[name_2]
    else:
        return embed_to_type[embed_name]

rca['embed_group'] = rca['embed'].apply(embed_to_group)
rca

In [None]:
rca_mean = (
    rca[['embed_group', 'norm', 'fold', 'r2']]
    .groupby(['embed_group', 'norm', 'fold'], as_index=False).mean(numeric_only=True)
    .groupby(['embed_group', 'norm'], as_index=False).mean(numeric_only=True)
    .rename(columns={'r2': 'r2_mean'})
    .drop(columns='fold')
)
rca_mean

In [None]:
rca_mean['norm_cat'] = (
    rca_mean['norm'].apply(lambda norm: meta.loc[norm]['category'])
)
rca_mean

**Grand avgs**

In [None]:
rca_grand_avg = (
    rca_mean[['embed_group', 'norm_cat', 'r2_mean']]
    .groupby(['embed_group', 'norm_cat'], as_index=False).median(numeric_only=True)
    .rename(columns={'r2_mean': 'r2_grand_avg'})
)
rca_grand_avg

In [None]:
# Top embed_group
sorted_embeds = (
    rca_grand_avg.groupby('embed_group')['r2_grand_avg']
    .mean()
    .sort_values(ascending=True)
    .reset_index()
)
sorted_embeds

**grand avg diffs**

In [None]:
# Pivoting to make it easier to compute pairwise differences 
rca_wide = ( 
    rca[['embed', 'norm', 'norm_cat', 'fold', 'r2']]
    .pivot(index=['norm', 'fold', 'norm_cat'], columns='embed', values='r2')
    .reset_index()
)
rca_wide

In [None]:
text_1, text_2 = 'CBOW_GoogleNews', 'fastText_CommonCrawl' # Text
text_text = text_1 + '&' + text_2 # Text & Text 
text_behav_1 = text_1 + '&' + 'PPMI_SVD_SWOW' # Text and Behavior 1
text_behav_2 = text_2 + '&' + 'PPMI_SVD_SWOW' # Text and Behavior 2


# Text & Behavior - Text & Text 
rca_wide[f'{text_behav_1} vs {text_text}'] = rca_wide[text_behav_1] - rca_wide[text_text]
rca_wide[f'{text_behav_2} vs {text_text}'] = rca_wide[text_behav_2] - rca_wide[text_text]
tb_vs_tt = (
    rca_wide[['norm', 'fold', 'norm_cat', f'{text_behav_1} vs {text_text}', f'{text_behav_2} vs {text_text}']]
    .melt(id_vars=['norm', 'norm_cat', 'fold'])
    .rename(columns={'embed': 'comparing'})
)
tb_vs_tt

In [None]:
# Computing tb_vs_tt_mean analagously to how we compute diffs for wilcoxon test below
tb_vs_tt_mean = (
    tb_vs_tt
    .groupby(['norm', 'fold'], as_index=False).mean(numeric_only=True)
    .groupby('norm', as_index=False).mean(numeric_only=True)
    .rename(columns={'value': 'r2_diff'})
    .drop(columns='fold')
)

tb_vs_tt_mean['norm_cat'] = tb_vs_tt_mean['norm'].apply(lambda norm: meta.loc[norm]['category'])
tb_vs_tt_mean

In [None]:
tb_vs_tt_grand_avg = (
    tb_vs_tt_mean.groupby('norm_cat', as_index=False)
    .median(numeric_only=True)
    .rename(columns={'r2_diff': 'r2_diff_grand_avg'})
)
tb_vs_tt_grand_avg

## Plotting 

In [None]:
# Pivot rca_grand_avg for plotting
heat_df_1 = (
    rca_grand_avg
    .pivot(index='embed_group', columns='norm_cat', values='r2_grand_avg')
    .loc[['text', 'behavior', 'text&text', 'text&behavior']]
)

# Ordering norm_cats by text&behavior performance
norm_cat_order = heat_df_1.loc['text&behavior'].sort_values(ascending=True).index
heat_df_1 = heat_df_1[norm_cat_order]
heat_df_1.index = heat_df_1.index.str.replace('&', ' & ').str.title()
heat_df_1.columns = heat_df_1.columns.str.replace('_', ' ')
heat_df_1

In [None]:
heat_df_1_winners = heat_df_1.apply(lambda col: col == col.max(), axis=0)
heat_df_1_winners

In [None]:
# Pivot diffs_grand_avg for plotting
tb_vs_tt_grand_avg['comparing'] = 'text&behavior - text&text'
heat_df_2 = (
    tb_vs_tt_grand_avg
    .pivot(index='comparing', columns='norm_cat', values='r2_diff_grand_avg')
)
heat_df_2 = heat_df_2[norm_cat_order]
heat_df_2.index = heat_df_2.index.str.replace('&', ' & ').str.title()
heat_df_2.columns = heat_df_2.columns.str.replace('_', ' ')
heat_df_2

In [None]:
def get_diffs(norm_cat):
    # Averaging diffs across folds
    diffs = (
        tb_vs_tt.query(f'norm_cat == "{norm_cat}"')
        .groupby(['norm', 'fold'])
        .mean(numeric_only=True)['value']
    )
    
    return diffs


def wilcoxon_test(diffs):
    """Does the same as above but with wilcoxon instead of t"""
    w, p = wilcoxon(diffs)
    return {
        'median': round(diffs.median(), 2), 
        'n': len(diffs), 'w': w, 'p': p
    }

# True if the difference is significant 
heat_df_2_sigs = pd.Series(dtype=bool, index=heat_df_2.columns)
for norm_cat in heat_df_2.columns:
    diffs = get_diffs(norm_cat)
    p = wilcoxon_test(diffs)['p']
    heat_df_2_sigs[norm_cat] = p < 0.05

# Reshapes heat_df_2_sigs to match heat_df_2
heat_df_2_sigs = heat_df_2_sigs.to_frame().T
heat_df_2_sigs.index = ['Text & Behavior - Text & Text']
heat_df_2_sigs

In [None]:
# Function to create a lighter version of a colormap
def lighten_cmap(cmap_name, factor=0.3):
    cmap = plt.cm.get_cmap(cmap_name, 256)  # Get the original colormap
    colors = cmap(np.linspace(0, 1, 256))

    # Blend each color with white
    white = np.array([1, 1, 1, 1])  # RGBA for white
    new_colors = (1 - factor) * colors + factor * white

    return LinearSegmentedColormap.from_list(f'light_{cmap_name}', new_colors)

# Function to visualize a colormap
def plot_colormap(cmap):
    gradient = np.linspace(0, 1, 256)
    gradient = np.vstack((gradient, gradient))

    plt.imshow(gradient, aspect='auto', cmap=cmap)
    plt.axis('off')
    plt.show()

# Usage example:
# Generate a lighter viridis colormap
lighter_viridis = lighten_cmap('viridis', factor=0.6)

# Visualize it
plot_colormap(lighter_viridis)

In [None]:
def annotate(heat_df, ax, mask):
    for x, norm_cat in enumerate(heat_df.columns):
        for y, embed in enumerate(heat_df.index):
            annot = heat_df.loc[embed, norm_cat]
            
            # Scientific notation
            if abs(annot) > 1e3:
                annot = f'{annot:.1e}'
            elif np.isnan(annot):
                annot = ''
            else:
                annot = f'{annot:.2f}'
            
            # Fontsize and fontweight
            fontsize, fontweight = 13, 'normal'
            if mask.loc[embed, norm_cat]:
                fontsize, fontweight = 16, 'bold'
                
            ax.text(
                x + .5, y + .5, annot, fontsize=fontsize, fontweight=fontweight,
                ha='center', va='center', color='black'
            )


heat_dfs = [heat_df_1, heat_df_2]
fig, axs = plt.subplots(2, figsize=(18, 6), height_ratios=[len(df) for df in heat_dfs])

# Plotting grand avg
vmax = heat_df_1.max().max()
sns.heatmap(
    heat_df_1, vmin=0, cmap=lighter_viridis, 
    vmax=vmax, annot=False, fmt='', cbar=False,
    ax=axs[0]
)

# Plotting text & behavior - text & text
vmax = heat_df_2.max().max()
sns.heatmap(
    heat_df_2.abs(), cmap=lighter_viridis,
    vmin=0, vmax=vmax, annot=False, fmt='', cbar=False,
    ax=axs[1]
)

for ax in axs:
    ax.set(xlabel='', ylabel='')
    ax.set_yticklabels(ax.get_yticklabels(), fontsize=13)
    ax.set_xticklabels(ax.get_xticklabels(), fontsize=13)
    
    # rotates y-tick labels to horizontal
    plt.setp(ax.get_yticklabels(), rotation=0)

# Remove x-tick labels for all but last plot
axs[0].set_xticklabels([])
x_tick_labels = heat_df_2.columns.str.title().str.replace('Of', 'of', regex=True)
axs[1].set_xticklabels(x_tick_labels, rotation=90, ha='right')

# Annotates cells
annotate(heat_df_1, axs[0], heat_df_1_winners)
annotate(heat_df_2, axs[1], heat_df_2_sigs)

# Sets axis titles
axs[0].set_title('Average Test $R^2$', fontsize=20)
  
fig.tight_layout()
plt.savefig('../../figures/rca_ensemb.png', dpi=300, bbox_inches='tight')

## Descriptive Stats

In [None]:
heat_df_2.loc['Text & Behavior - Text & Text'].sort_values().round(2)

In [None]:
# checking that ensembling always improves performance
print(f"Text & Text - Text: {(heat_df_1.loc['Text & Text'] - heat_df_1.loc['Text'] < 0).any()}")
print(f"Text & Behavior - Behavior: {(heat_df_1.loc['Text & Behavior'] - heat_df_1.loc['Behavior'] < 0).any()}")

In [None]:
print(f"# where Text & Behavior > Text & Text: {(heat_df_2.loc['Text & Behavior - Text & Text'] > 0).sum()}")