In [None]:
import os
import pandas as pd
import json

import matplotlib.pyplot as plt
import plotly.graph_objects as go

In [None]:
repo_dir = '/share/u/can/OthelloUnderstanding/'
ae_group_name = 'mlp_out_sweep_all_layers_panneal_0628'
# ae_group_name = 'mlp_transcoder_all_layers_panneal_0628'


eval_path = os.path.join(repo_dir, 'autoencoders', ae_group_name, 'evaluations.csv')
eval_df = pd.read_csv(eval_path)

eval_df.columns

In [None]:
sparsity_penalties, learning_rates = [], []

for trainer_path in eval_df['path']:
    cfg_path = os.path.join(repo_dir, trainer_path, 'config.json')
    with open(cfg_path, 'r') as f:
        cfg = json.load(f)
    sparsity_penalties.append(cfg['trainer']['sparsity_penalty'])
    learning_rates.append(cfg['trainer']['lr'])

eval_df['cfg_sparsity_penalty'] = sparsity_penalties
eval_df['cfg_learning_rate'] = learning_rates

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize

plt.rcParams.update({'font.size': 14})  # Increase the default font size

fig, axs = plt.subplots(2, 4, figsize=(25, 12))
fig.suptitle('Fraction of Variance Explained vs L0 for Different Layers', fontsize=20)

# Create a colormap
cmap = plt.get_cmap('viridis')

# Get the global min and max of sparsity penalty for consistent color scaling
vmin = eval_df['cfg_sparsity_penalty'].min()
vmax = eval_df['cfg_sparsity_penalty'].max()
norm = Normalize(vmin=vmin, vmax=vmax)

for layer_idx in range(8):
    row = layer_idx // 4
    col = layer_idx % 4
    
    df_layer = eval_df[eval_df['layer_idx'] == layer_idx]
    
    # Get unique learning rates
    learning_rates = df_layer['cfg_learning_rate'].unique()
    
    for lr in learning_rates:
        df_lr = df_layer[df_layer['cfg_learning_rate'] == lr]
        marker = 'o' if lr == learning_rates[0] else 's'
        
        scatter = axs[row, col].scatter(df_lr['l0'], df_lr['frac_recovered'], 
                                        c=df_lr['cfg_sparsity_penalty'], 
                                        cmap=cmap, norm=norm,
                                        marker=marker)
    
    axs[row, col].set_ylabel('Fraction of Loss recovered', fontsize=16)
    axs[row, col].set_xlabel('L0', fontsize=16)
    axs[row, col].set_title(f'Layer {layer_idx}', fontsize=18)
    axs[row, col].tick_params(axis='both', which='major', labelsize=14)

# Adjust layout to make room for colorbar
plt.tight_layout(rect=[0, 0, 0.9, 0.95])

# Add a colorbar to the right of the subplots
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
cbar = fig.colorbar(scatter, cax=cbar_ax)
cbar.set_label('Sparsity Penalty', fontsize=16)
cbar.ax.tick_params(labelsize=14)

# Add a legend for learning rates
handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='gray', markersize=10, label=f'LR: {learning_rates[0]}'),
           plt.Line2D([0], [0], marker='s', color='w', markerfacecolor='gray', markersize=10, label=f'LR: {learning_rates[1]}')]
fig.legend(handles=handles, loc='upper right', bbox_to_anchor=(0.99, 0.99), fontsize=14)

plt.show()