In [None]:
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import cm
import seaborn as sns
#sns.set(style='whitegrid')
sns.set(style='whitegrid')

PARAMETERIZATIONS = [[
    ('uscion', r'Scion (Sign $\rightarrow$ Spectral $\rightarrow$ Sign)', 'Sign'),
    ('uscion', r'Scion  (Spectral $\rightarrow$ Spectral $\rightarrow$ Sign)', 'Spectral'),
    ('uscion', r'Scion (ColNorm $\rightarrow$ Spectral $\rightarrow$ Sign)', 'ColNorm'),
],[
    ('sp', r'SP (AdamW)', None),
    ('scion', r'Unconstrained Scion (Sign $\rightarrow$ Spectral $\rightarrow$ Sign)', 'Sign'),
    ('scion', r'Unconstrained Scion (Spectral $\rightarrow$ Spectral $\rightarrow$ Sign)', 'Spectral'),
    ('scion', r'Unconstrained Scion (ColNorm $\rightarrow$ Spectral $\rightarrow$ Sign)', 'ColNorm'),
],[
    ('scion_full', r'Signum', 'Sign-naive'),
    ('scion_full', r'Unconstrained Scion (Sign throughout)', 'Sign'),
],[
    ('scion_full', r'Unconstrained Scion (RowNorm throughout)', 'RowNorm'),
    ('scion_full', r'Unconstrained Scion (ColNorm throughout)', 'ColNorm'),
]]
for j, parameterizations in enumerate(PARAMETERIZATIONS):
    seeds = [1,2,3]
    widths = [
        256,
        512,
        1024,
        2048,
    ]
    lrs = [
        0.5,
        0.25,
        0.125,
        0.0625,
        0.03125,
        0.015625,
        0.0078125,
        0.00390625,
        0.001953125,
        0.0009765625,
        0.00048828125,
        0.000244140625,
        0.0001220703125,
        0.00006103515625,
        0.00003051757812,
        0.00001525878906,
        0.000007629394531,
        0.000003814697266,
    ]
    class MplColorHelper:

        def __init__(self, cmap_name, start_val, stop_val):
            self.cmap_name = cmap_name
            self.cmap = plt.get_cmap(cmap_name)
            self.norm = mpl.colors.Normalize(vmin=start_val, vmax=stop_val)
            self.scalarMap = cm.ScalarMappable(norm=self.norm, cmap=self.cmap)

        def get_rgb(self, val):
            return self.scalarMap.to_rgba(val)


    color_helper = MplColorHelper('viridis', 0, len(widths)-1)
    n_cols = len(parameterizations)
    n_rows = 1
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 3.33*n_rows))
    plt.subplots_adjust(wspace=5.0)  # Adjust the width space between the axes

    for parameterization_idx, (parameterization, parameterization_str, extra) in enumerate(parameterizations):
        ax = axes[parameterization_idx]
        optimal_lrs = []
        optimal_losses = []
        for width_idx, width in enumerate(widths):
            mean_losses = []
            sem_losses = []
            lrs_to_plot = []
            for lr in lrs:
                losses = []
                for seed in seeds:
                    if parameterization == 'scion_full' and extra is not None:
                        job_name = f'mode{extra}_width{width}_depth2_seed{seed}_lr{lr:.20f}'.rstrip('0') 
                    elif extra is not None:
                        job_name = f'first_layer{extra}_width{width}_depth2_seed{seed}_lr{lr:.20f}'.rstrip('0')
                    else:
                        job_name = f'width{width}_depth2_seed{seed}_lr{lr:.20f}'.rstrip('0')
                    csv_path = os.path.join(parameterization, 'out', job_name, 'log.csv')
                    if os.path.exists(csv_path):
                        ckpt_df = pd.read_csv(csv_path)
                        #losses.append(ckpt_df['train/loss'].mean())
                        #losses.append(ckpt_df['train/loss'].min())
                        losses.append(ckpt_df['val/loss'].values[-1])
                        #losses.append(ckpt_df['train/loss'].ewm(alpha=0.9).mean().values[-1])
                    # else:
                    #     print(f'Missing {csv_path}')
                if len(losses):
                    mean_losses.append(np.mean(losses))
                    sem_losses.append(np.std(losses, ddof=1) / np.sqrt(len(losses)))
                    lrs_to_plot.append(lr)
            
            mean_losses = np.array(mean_losses)
            sem_losses = np.array(sem_losses)
            #ax.plot(lrs_to_plot, mean_losses, label=width, marker='o', color=color_helper.get_rgb(width_idx))
            #ax.fill_between(lrs_to_plot, mean_losses-sem_losses, mean_losses+sem_losses, color=color_helper.get_rgb(width_idx), alpha=0.33)
            palette = sns.color_palette("mako", n_colors=len(widths))
            ax.plot(lrs_to_plot, mean_losses, label=width, color=palette[len(widths)-width_idx-1])
            ax.fill_between(lrs_to_plot, mean_losses-sem_losses, mean_losses+sem_losses, color=palette[len(widths)-width_idx-1], alpha=0.33)

            if len(mean_losses):
                optimum_idx = np.argmin(mean_losses)
                optimal_lrs.append(lrs_to_plot[optimum_idx])
                optimal_losses.append(mean_losses[optimum_idx])
            
        ax.plot(optimal_lrs, optimal_losses, color='red', linestyle='none', marker='o')
        ax.set_xscale('log', base=2)
        ax.set_xlabel('Learning rate')
        ax.set_title(parameterization_str)
        #ax.set_ylim(2.57, 3.15)
        #ax.set_ylim(2.0, 3.0)
        if j >= 2:
            ax.set_ylim(1.75, 3.0)
        else:    
            ax.set_ylim(1.3, 3.0)
        #ax.set_ylim(2.3, 8.0)
        # ax.set_ylim(2.3, 2.7)
        # ax.set_ylim(2.4, 2.8)

    axes[0].legend(title='Width')
    # axes[0].set_ylabel('Train loss on\nshakespeare_char')
    axes[0].set_ylabel('Validation loss')
    for i in range(len(axes))[1:]:
        axes[i].yaxis.set_ticklabels([])
        axes[i].tick_params(axis='y', length=0, width=0)

    plt.tight_layout()
    plt.savefig(f"GPT_shakespeare_transfer_{j}.pdf")
    plt.show()
    plt.close()
