In [1]:
# Load
%load_ext autoreload


In [2]:
# Load

%autoreload
import os, sys
import numpy as np
import torch
import torch.nn as nn

import matplotlib.pyplot as plt

ckconv_source = os.path.join(os.getcwd(), '../')

if ckconv_source not in sys.path:
    sys.path.append(ckconv_source)

from ckconv.nn import ScaleFlexConv, ScaleCKConv, FlexConv
import disco.ses_conv_learnable as SESN

import numpy as np
import torch
from torch.nn.utils import weight_norm
from omegaconf import OmegaConf

import ckconv.nn as cknn

import disco.ses_conv_learnable as SESN
import utils.loaders as loaders


from matplotlib import pyplot as plt
from PIL import Image


# Load all runs

import pandas as pd 
import wandb
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from IPython.display import display
api = wandb.Api()


  from .autonotebook import tqdm as notebook_tqdm


In [3]:



def prepare_api_runs(runs, values = ['test/acc'], rounding=True):
    combined = []
    for run in runs: 
        temp_info_dict = run.summary._json_dict
        # .summary contains the output keys/values for metrics like accuracy.
        #  We call ._json_dict to omit large files 
        # .config contains the hyperparameters.
        #  We remove special values that start with _.
        temp_info_dict.update(run.config)
        if isinstance(temp_info_dict['init_scales'], int) or isinstance(temp_info_dict['init_scales'], float):
            temp_info_dict['init_scales'] = [temp_info_dict['init_scales']]
        if rounding:
            if temp_info_dict['sample_scales'][0] != 'DISCRETE':
                temp_info_dict['sample_scales'] = str([round(sample_scale,3) for sample_scale in temp_info_dict['sample_scales'][1:]])
            else:
                temp_info_dict['sample_scales'] = str([round(sample_scale,3) for sample_scale in temp_info_dict['sample_scales'][1:]])
            temp_info_dict['init_scales'] = str([round(init_scale,3) for init_scale in temp_info_dict['init_scales']])

        else:
            temp_info_dict['sample_scales'] = str(temp_info_dict['sample_scales'])
            temp_info_dict['init_scales'] = str(temp_info_dict['init_scales'])
        if 'test/acc' in temp_info_dict.keys():
            temp_info_dict['Test Error'] = (1 - temp_info_dict['test/acc'])*100
            temp_info_dict['val/acc.max'] = temp_info_dict['val/acc']['max']
            combined.append(temp_info_dict)

    return combined

def plot_table(runs_in, name, indexes = ['sample_scales','init_scales'], values = ['Test Error'], rounding = True, Save=False):
    # Prepare and load into dataframe
    all_info = prepare_api_runs(runs_in, values)
    df = pd.DataFrame.from_dict(all_info)
    # display(df)

    df = df.pivot_table(index=indexes, values=values, aggfunc=(np.mean, np.std))
    # Save Df
    if rounding:
        df = df.astype(float).round(3)
    df[r'Learned $\sigma_{basis}$'] = df["Final Basis Min Scale"]["mean"].astype('str') + " ("  + df["Final Basis Min Scale"]["std"].astype('str') +")" 	
    df['Learned ISR'] = df["Final ISR"]["mean"].astype('str') + " ("  + df["Final ISR"]["std"].astype('str') +")" 

    display(df)
    if Save:
        df.to_csv(f'results/{name}')


# How does parameterization of learnable scales affect learnability of internal scales?

## Hypothesis
We think that the way our internal scales are learned is a big reason why they are not collapsing, since the spacing logarithmically stays the same. We will compare a decoupled version of our learning strategy and directly learning the internal scales. 

We expect that especially directly learning the Internal Scales quickly leads to collapse. Why? This was also shown in a recent paper but we think that this parameterization makes the internal scales less dependent on each other and does not take into account Scale-Equivariance as much. 

## Network/Data
Again, it remains unclear whether we want to use the full setup or want to compare the settings in a way more controlled setting like our toy network!



In [10]:
%autoreload
def visualize(runs, filter_scales = True):
    # Prepare and load into dataframe
    all_info = prepare_api_runs(runs, ['Test Error','Final ISR', 'Final Basis Min Scale'])
    df = pd.DataFrame.from_dict(all_info)
    # display(df)
    df[r'Init \sigma_{basis}'] = df['basis_min_scale']
    df['Init ISR'] = df['ISR_start']
    df['Data Range'] = df['sample_scales']
    for i in range(3):
        df[f'Scale {i+1}'] = df['Final Conv Scales'].apply(lambda x: x[i])
    df = df.pivot_table(index=['Data Range', 'learn_mode'] , values=['Test Error','Scale 1','Scale 2','Scale 3'], aggfunc=(np.mean, np.std))
    # Save Df
    df = df.astype(float).round(3)
    for i in range(3):
        df[f'Scale {i+1} - 1'] = df[f'Scale {i+1}']["mean"].astype('str') + r" $\pm$ "  + df[f'Scale {i+1}']["std"].astype('str')
    # df[r'Learned $\sigma_{basis}$'] = df["Final Basis Min Scale"]["mean"].astype('str') + r" $\pm$ "   + df["Final Basis Min Scale"]["std"].astype('str') 	
    # df['Learned ISR'] = df["Final ISR"]["mean"].astype('str') + r" $\pm$ "  + df["Final ISR"]["std"].astype('str')
    df['Test Error 1'] = df["Test Error"]["mean"].astype('str') + r" $\pm$ "  + df["Test Error"]["std"].astype('str')
    df.drop(columns=['Scale 1','Scale 2','Scale 3', 'Test Error'], inplace=True)
    for i in range(3):
        df[f'Scale {i+1}'] = df[f'Scale {i+1} - 1']
        df.drop(columns=[f'Scale {i+1} - 1'], inplace=True)
    df['Test Error'] = df['Test Error 1']
    df.drop(columns=['Test Error 1',], inplace=True)
    display(df)
    print(df.to_latex(escape=False))


api = wandb.Api()
exp_name = 'Compare_Parameterization_Methods'
# Load all runs

runs = api.runs(f"mbasting/scale_learning", {
        "$and": [{"tags": exp_name}]
})
visualize(runs, False)    



  df.drop(columns=[f'Scale {i+1} - 1'], inplace=True)
  df.drop(columns=['Test Error 1',], inplace=True)


Unnamed: 0_level_0,Unnamed: 1_level_0,Scale 1,Scale 2,Scale 3,Test Error
Data Range,learn_mode,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
"[1, 2.83]",2,1.965 $\pm$ 0.047,3.5 $\pm$ 0.193,6.235 $\pm$ 0.497,2.321 $\pm$ 0.095
"[1, 2.83]",4,1.967 $\pm$ 0.079,3.591 $\pm$ 0.329,6.93 $\pm$ 1.374,2.285 $\pm$ 0.038
"[1, 2.83]",6,1.96 $\pm$ 0.081,3.608 $\pm$ 0.435,6.672 $\pm$ 1.311,2.291 $\pm$ 0.067
"[1, 4.76]",2,1.865 $\pm$ 0.046,3.357 $\pm$ 0.105,6.45 $\pm$ 0.049,2.554 $\pm$ 0.093
"[1, 4.76]",4,1.996 $\pm$ 0.013,3.626 $\pm$ 0.158,6.83 $\pm$ 0.167,2.565 $\pm$ 0.061
"[1, 4.76]",6,2.001 $\pm$ 0.063,3.647 $\pm$ 0.127,6.647 $\pm$ 0.255,2.51 $\pm$ 0.084
"[1, 8]",2,1.689 $\pm$ 0.109,3.262 $\pm$ 0.107,6.997 $\pm$ 0.282,3.057 $\pm$ 0.015
"[1, 8]",4,1.902 $\pm$ 0.085,3.648 $\pm$ 0.165,8.093 $\pm$ 0.229,3.007 $\pm$ 0.049
"[1, 8]",6,1.943 $\pm$ 0.063,3.977 $\pm$ 0.053,8.145 $\pm$ 0.057,2.872 $\pm$ 0.07


\begin{tabular}{llllll}
\toprule
       &   &            Scale 1 &            Scale 2 &            Scale 3 &         Test Error \\
       &   \\
Data Range & learn_mode &                    &                    &                    &                    \\
\midrule
[1, 2.83] & 2 &  1.965 $\pm$ 0.047 &    3.5 $\pm$ 0.193 &  6.235 $\pm$ 0.497 &  2.321 $\pm$ 0.095 \\
       & 4 &  1.967 $\pm$ 0.079 &  3.591 $\pm$ 0.329 &   6.93 $\pm$ 1.374 &  2.285 $\pm$ 0.038 \\
       & 6 &   1.96 $\pm$ 0.081 &  3.608 $\pm$ 0.435 &  6.672 $\pm$ 1.311 &  2.291 $\pm$ 0.067 \\
[1, 4.76] & 2 &  1.865 $\pm$ 0.046 &  3.357 $\pm$ 0.105 &   6.45 $\pm$ 0.049 &  2.554 $\pm$ 0.093 \\
       & 4 &  1.996 $\pm$ 0.013 &  3.626 $\pm$ 0.158 &   6.83 $\pm$ 0.167 &  2.565 $\pm$ 0.061 \\
       & 6 &  2.001 $\pm$ 0.063 &  3.647 $\pm$ 0.127 &  6.647 $\pm$ 0.255 &   2.51 $\pm$ 0.084 \\
[1, 8] & 2 &  1.689 $\pm$ 0.109 &  3.262 $\pm$ 0.107 &  6.997 $\pm$ 0.282 &  3.057 $\pm$ 0.015 \\
       & 4 &  1.902 $\pm$ 0.085 &  3.648 $\

  print(df.to_latex(escape=False))
