In [28]:
# Load
%load_ext autoreload

from matplotlib.ticker import FuncFormatter
from scipy.stats import truncnorm, uniform
from enum import Enum
from IPython.display import display

import pandas as pd 
import wandb
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from IPython.display import display
api = wandb.Api()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# SOTA comparison between SESN and Scale Learning

In [35]:
import wandb
api = wandb.Api()

def create_table(runs):
    all_info = []
    ISR_epochs_list = None
    for run in runs: 
        temp_dict = {}

        # Only fill with useful info
        # dataset.extra_scaling
        # model.scale_learn_mode
        # learnable_basis_min
        # model.decoupled_basis_min
        # Scales
        # Test Error
        

        
        temp_dict['Learnable Basis Min'] = '\\cmark' if run.config['model']['learnable_basis_min'] else '\\xmark'
        temp_dict['Decoupled Basis Min'] = '\\cmark' if run.config['model']['decoupled_basis_min'] else '\\xmark'

        temp_dict['Augmentation'] = '\\cmark' if run.config['dataset']['extra_scaling'] == 0.5 else '\\xmark'
        if run.config['model']['scale_learn_mode'] == 6:
            temp_dict['Parameterisation'] = 'Learn ISR'
        elif run.config['model']['scale_learn_mode'] == 2:
            temp_dict['Parameterisation'] = 'Direct'
            temp_dict['Decoupled Basis Min'] = '\\xmark'
            temp_dict['Learnable Basis Min'] = '\\xmark'
        else:
            temp_dict['Parameterisation'] = 'Learn Spacing'
        if 'Final Conv Scales' not in run.config:
            continue
        Final_Conv_scales = run.config['Final Conv Scales']
        for i, scale in enumerate(Final_Conv_scales):
            temp_dict[f'Scale {i+1}'] = scale
        temp_dict['Test Error'] = run.summary['Test/Error']
        # temp_dict['Val Acc'] = run.history(keys=['Val/Accuracy']).max().values[1]
        all_info.append(temp_dict)

    data = pd.DataFrame(all_info)
    # Drop all tables except some

    # Only keep 
    # Group based on Augmentation, then Parameterisation and then Learnable Basis Min, Decoupled Basis Min
    # Then plot the Test Error
    # Then plot the scales
    keys_to_adapt = ['Test Error','Scale 1','Scale 2','Scale 3','Scale 4'] # ,'Val Acc',
    df = data.pivot_table(index=['Augmentation', 'Parameterisation', 'Learnable Basis Min'], values=keys_to_adapt, aggfunc=(np.mean, np.std))
    df = df.astype(float).round(3)
    for key in keys_to_adapt:
            df[key + '-1'] = df[key]["mean"].astype('str') + r" $\pm$ "   + df[key]["std"].astype('str') 
    df = df.drop(columns=keys_to_adapt)
    for key in keys_to_adapt:
        df[key] = df[key + '-1']
        df = df.drop(columns=[key + '-1'])

    display(df)
    print(df.to_latex(escape=False))

# Project is specified by <entity/project-name>
runs = api.runs("mbasting/scale_learning", {
                "$and": [{"tags" :  'MNIST_SCALE_SOTA_2'}
                         ,{'config.train.scale_lr' : 0.005},
                          {'$not' : {"tags": 'duplicate'}},
                          {'config.seed': {'$in' : [0,1,2,3,4]}}]}) 
create_table(runs)



  df = df.drop(columns=[key + '-1'])


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Test Error,Scale 1,Scale 2,Scale 3,Scale 4
Augmentation,Parameterisation,Learnable Basis Min,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
\cmark,Direct,\xmark,1.499 $\pm$ 0.082,1.36 $\pm$ 0.014,1.741 $\pm$ 0.062,2.485 $\pm$ 0.05,3.775 $\pm$ 0.083
\cmark,Learn ISR,\cmark,1.442 $\pm$ 0.086,1.381 $\pm$ 0.013,2.066 $\pm$ 0.012,3.092 $\pm$ 0.038,4.627 $\pm$ 0.095
\cmark,Learn ISR,\xmark,1.496 $\pm$ 0.081,1.5 $\pm$ 0.0,2.202 $\pm$ 0.048,3.235 $\pm$ 0.142,4.753 $\pm$ 0.317
\cmark,Learn Spacing,\cmark,1.501 $\pm$ 0.077,1.373 $\pm$ 0.013,1.859 $\pm$ 0.069,2.847 $\pm$ 0.091,4.646 $\pm$ 0.306
\cmark,Learn Spacing,\xmark,1.501 $\pm$ 0.115,1.5 $\pm$ 0.0,2.026 $\pm$ 0.082,3.017 $\pm$ 0.17,4.755 $\pm$ 0.292
\xmark,Direct,\xmark,1.735 $\pm$ 0.06,1.375 $\pm$ 0.008,1.889 $\pm$ 0.054,2.383 $\pm$ 0.044,3.297 $\pm$ 0.086
\xmark,Learn ISR,\cmark,1.718 $\pm$ 0.045,1.39 $\pm$ 0.016,1.89 $\pm$ 0.061,2.572 $\pm$ 0.164,3.503 $\pm$ 0.338
\xmark,Learn ISR,\xmark,1.727 $\pm$ 0.086,1.5 $\pm$ 0.0,2.007 $\pm$ 0.044,2.686 $\pm$ 0.117,3.596 $\pm$ 0.232
\xmark,Learn Spacing,\cmark,1.7 $\pm$ 0.098,1.39 $\pm$ 0.011,1.93 $\pm$ 0.099,2.614 $\pm$ 0.3,3.776 $\pm$ 0.6
\xmark,Learn Spacing,\xmark,1.71 $\pm$ 0.045,1.5 $\pm$ 0.0,1.969 $\pm$ 0.071,2.688 $\pm$ 0.194,3.66 $\pm$ 0.32


\begin{tabular}{llllllll}
\toprule
       &               &        &         Test Error &            Scale 1 &            Scale 2 &            Scale 3 &            Scale 4 \\
       &               &        \\
Augmentation & Parameterisation & Learnable Basis Min &                    &                    &                    &                    &                    \\
\midrule
\cmark & Direct & \xmark &  1.499 $\pm$ 0.082 &   1.36 $\pm$ 0.014 &  1.741 $\pm$ 0.062 &   2.485 $\pm$ 0.05 &  3.775 $\pm$ 0.083 \\
       & Learn ISR & \cmark &  1.442 $\pm$ 0.086 &  1.381 $\pm$ 0.013 &  2.066 $\pm$ 0.012 &  3.092 $\pm$ 0.038 &  4.627 $\pm$ 0.095 \\
       &               & \xmark &  1.496 $\pm$ 0.081 &      1.5 $\pm$ 0.0 &  2.202 $\pm$ 0.048 &  3.235 $\pm$ 0.142 &  4.753 $\pm$ 0.317 \\
       & Learn Spacing & \cmark &  1.501 $\pm$ 0.077 &  1.373 $\pm$ 0.013 &  1.859 $\pm$ 0.069 &  2.847 $\pm$ 0.091 &  4.646 $\pm$ 0.306 \\
       &               & \xmark &  1.501 $\pm$ 0.115 &      1.5 $\pm$ 

  print(df.to_latex(escape=False))
