<a href="https://www.kaggle.com/code/a1795757/visualization-current?scriptVersionId=145653237" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

To use the demo press "Run all" and scroll to the bottom

In [1]:
import pickle
import numpy as np
import os
import pandas as pd
import sklearn.metrics
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib
import ipywidgets as widgets
from ipywidgets import interact
import functools
nl='\n'



In [2]:
import inspect
def pass_kwargs(function,**kwargs):
    """
    util for passing arguments into a function 
    Example
    g=lambda x:x+1
    h=lambda y:y+3
    def f(**kwargs)
      return pass_kwargs(g,kwargs)/pass_kwargs(h,kwargs)
    f(x=1,y=2)->2/5
    """
    return function(**{k:kwargs[k] for k in inspect.signature(function).parameters.keys()&kwargs})

def apply_metric(df,metric):
    """
    given a metric, passes columns as keyword arguments if an argument matches the name of the column
    """
    return df.apply(lambda row:pass_kwargs(metric,**row['results'].to_dict()),axis=1)
def add_metric_col(df,metric):
    """
    convenience function
    """
    df['metrics',metric.__name__]=apply_metric(df,metric)
def setattrkwargs(o:object,**kwargs):
    """
    convenience function providing alternative formatting ex.
    x.foo=1
    x.bar=2
    becomes
    setattrkwarg(x,
    foo=1
    bar=2
    )
    """
    [setattr(o,k,v) for k,v in kwargs.items()]

In [3]:
#Model Evaluation Metrics
def accuracy(truth,pred):
    return sklearn.metrics.accuracy_score(truth,pred)
def balanced_accuracy(truth,pred):
    return sklearn.metrics.balanced_accuracy_score(truth,pred)
def loss_curve_area_scaled(train_loss_track,epoch=2):
    train_loss_track=pd.Series(train_loss_track)
    return (s:=train_loss_track[epoch]).sum()/len(s)
def loss_curve_slope_abs_scaled(train_loss_track,epoch=2):
    train_loss_track=pd.Series(train_loss_track)
    
    return (s:=train_loss_track[epoch]).diff()[1:].abs().sum()/len(s)

In [4]:

def get_idx_from_values(df,values):
    idx=[None]*len(df.index.names)
    for param_name,value in values.items():
        if value is not None:
            idx[df.index.names.index(param_name)]=value
    return tuple(idx)
#Visualization Functions
def disp_slice_map(name,x,y,metric,**values):
    #idx at x and idx at y = slice(None) 
    df=df_dict[name]
    idx=list(get_idx_from_values(df,values))
    idx[df.index.names.index(x)]=slice(None)
    idx[df.index.names.index(y)]=slice(None)
    s=df['metrics'].loc[tuple(idx)].pivot_table(index=y,columns=x,values=metric)
    vmin=df['metrics',metric].min()
    vmax=df['metrics',metric].max()
    ax = plt.axes()
    g=sns.heatmap(s,cmap='viridis',vmin=vmin,vmax=vmax,ax=ax)
    ax.set_title(f'{name} with {metric=} at\n {nl.join([f"{k}={v}" for k,v in values.items() if k not in (x,y)])}')
    return g
def disp_loss_curve(name,**values):
    df=df_dict[name]
    idx=get_idx_from_values(df,values)
    ax = plt.axes()
    pd.Series(df.loc[idx]['results','train_loss_track']).plot(ax=ax)
    ax.set(xlabel='(epoch,batch)',
       ylabel='loss',
       title=f'{name}\'s loss curve at \n{nl.join(f"{k}={v}" for k,v in values.items())}'
       )
    return ax
def metric_scatter(keys,xaxis_metric='loss_curve_area_scaled',yaxis_metric='accuracy',cmap=matplotlib.cm.get_cmap('hsv'),fitter=functools.partial(np.polynomial.polynomial.Polynomial.fit,deg=3)):
       ax = plt.axes()
       ax.set(title=f'{xaxis_metric}  vs. {yaxis_metric} after 5 epochs\n poly deg 3 bestfit',
              xlabel=xaxis_metric,
              ylabel=yaxis_metric,
              )

       handles=[]
       for i,k in enumerate(keys):
              v=df_dict[k]
              color=cmap(i/len(df_dict))
              x=v['metrics',xaxis_metric]
              y=v['metrics',yaxis_metric]
              
              ax.scatter(x=x,y=y,color=color)
              #get a fitting curve
              f=fitter(x,y)
              ax.plot((linspace:=np.linspace(x.min(),x.max(),300)),f(linspace))
              handles.append(mpatches.Patch(color=color,label=k))
       ax.legend(handles=handles)
       return ax

In [5]:
def ui_helper(df:pd.DataFrame,params:widgets.VBox,x,y,metric):
    params_ranges={i:df['hyperparameter',i].unique() for i in df['hyperparameter'].columns}
    i=0
    for i,(name,r) in enumerate(params_ranges.items()):
        setattrkwargs(params.children[i],
            options=r,
            description=name,
            disabled=False,
            rows=len(r),
            value=r[0]
            )
    for j in range(i+1,len(params.children)):
        setattrkwargs(params.children[i],
            disabled=True)
    setattrkwargs(x,
        options=df.index.names,
        rows=len(df.index.names)
    )
    setattrkwargs(y,
        options=df.index.names,
        rows=len(df.index.names)
    )
    setattrkwargs(metric,
        options=tuple(df['metrics'].columns),
        rows=len(df['metrics'].columns)
    )
def init_ui(ui,name,SelectWidget_display,SelectWidget_params):
    df=df_dict[name]
    x=SelectWidget_display(
        description='heatmap x axis',
        disabled=False,
    )
    y=SelectWidget_display(
        description='heatmap y axis',
        disabled=False,
    )
    metric=SelectWidget_display(
        description='metrics',
        disable=False,
    )
    params=widgets.VBox([SelectWidget_params(disabled=True,options=(0,1,2)) for i in range(max([len(v['hyperparameter'].columns) for k,v in df_dict.items()]))])
    ui_helper(df,params,x,y,metric)   
    options=widgets.VBox([x,y,metric])
    ui.children=[params,options]
    

    for k in x,y,metric:
        if k.value is None:
            k.value=k.options[0]
def update_ui(ui):
    def helper(name):
        df=df_dict[name]
        params=ui.children[0]
        x=ui.children[1].children[0]
        y=ui.children[1].children[1]
        metric=ui.children[1].children[2]
        ui_helper(df,params,x,y,metric)
    return helper

In [6]:
folder='/kaggle/input/snapshot-safari-results/AdamVariants/AdamVariants'
datasets=["-".join(i.split('-')[:-1]) for i in os.listdir(folder)]
df_dict=dict()
for dataset in datasets:
    with open(f'{folder}/{dataset}-performance.pkl','rb') as f:
        df=pickle.load(f)
        df.index=df.index.set_levels([(pd.Index([(j.__name__ if callable(j) else j) for j in x]) if x.dtype==np.dtype(object) else x) for x in (df.index.levels)])
        for hyperparameter in df['hyperparameter'].columns:
            if df['hyperparameter',hyperparameter].dtype == np.dtype(object):
                df['hyperparameter',hyperparameter]=df['hyperparameter',hyperparameter].apply(lambda j:j.__name__ if callable(j) else j )

        # Add Metrics by function name
        add_metric_col(df,loss_curve_area_scaled)
        add_metric_col(df,accuracy)
        add_metric_col(df,loss_curve_slope_abs_scaled)
        df_dict[dataset]=df

In [7]:
dataframe_select=widgets.Select(
    options=tuple(df_dict.keys()),
    description='select dataframe',
    disabled=False,
    rows=len(df_dict)
)
ui=widgets.HBox()
slider=functools.partial(widgets.SelectionSlider,options=[None],continuous_update=False)
init_ui(ui,name=tuple(df_dict.keys())[0],SelectWidget_display=widgets.Select,SelectWidget_params=slider)
widgets.interactive_output(update_ui(ui),dict(
    name=dataframe_select
))
heatmap_out=widgets.interactive_output(disp_slice_map,dict(
   name=dataframe_select ,
   x=ui.children[1].children[0],
   y=ui.children[1].children[1],
   metric=ui.children[1].children[2],
   **{w.description:w for w in ui.children[0].children}
))
loss_curve_out=widgets.interactive_output(disp_loss_curve,dict(
    name=dataframe_select,
    **{w.description:w for w in ui.children[0].children})
)

display(dataframe_select,ui,widgets.HBox([heatmap_out,loss_curve_out]))

Select(description='select dataframe', options=('enonkishu-1224', 'camdeboo-1224', 'kruger-1224', 'karoo-1224'…

HBox(children=(VBox(children=(SelectionSlider(continuous_update=False, description='lr', options=(0.01, 0.02, …

HBox(children=(Output(), Output()))

In [8]:
main_df=pd.concat(df_dict)
main_df.index.names=['dataset',*main_df.index.names[1:]]
main_df['dataset']=main_df.index.get_level_values(0)

dataframe_multi_select=widgets.SelectMultiple(
    options=tuple(df_dict.keys()),
    description='multiselect dataframes',
    disabled=False,
    rows=len(df_dict)
)
def metric_selector_factory(Selector,df):
        return Selector(
        description='metrics',
        disable=False,
        options=tuple(df['metrics'].columns),
        rows=len(df['metrics'].columns)
    )
xaxis_metric_select=metric_selector_factory(widgets.Select,main_df)
yaxis_metric_select=metric_selector_factory(widgets.Select,main_df)

metric_scatter_out=widgets.interactive_output(metric_scatter,dict(
    keys=dataframe_multi_select,
    xaxis_metric=xaxis_metric_select,
    yaxis_metric=yaxis_metric_select
)
)
display(widgets.HBox([dataframe_multi_select,xaxis_metric_select,yaxis_metric_select]),metric_scatter_out)

HBox(children=(SelectMultiple(description='multiselect dataframes', options=('enonkishu-1224', 'camdeboo-1224'…

Output()