# Model Plots



In [1]:
# Suppress the powerlaw package warnings
# "powerlaw.py:700: RuntimeWarning: divide by zero encountered in true_divide"
# "powerlaw.py:700: RuntimeWarning: invalid value encountered in true_divide"
import warnings
warnings.simplefilter(action='ignore', category=RuntimeWarning)

# 0.4x names

ALPHA_HAT = 'alpha_weighted'
ALPHA = 'alpha'
LOG_ALPHA_NORM = 'log_alpha_norm'
MP_SOFTRANK = 'mp_softrank'
SPECTRAL_NORM = 'spectral_norm'
STABLE_RANK = 'stable_rank'  
FROBENIUS_NORM = 'norm'

LOG_NORM = 'log_norm'
LOG_SPECTRAL_NORM = 'log_spectral_norm'
LOG_STABLE_RANK = 'log_stable_rank' 
LOG_FROBENIUS_NORM = LOG_NORM


WW4X_COMPAT_COLUMNS = {'lognorm':LOG_NORM,
                'logspectralnorm':LOG_SPECTRAL_NORM,
                'spectralnorm':SPECTRAL_NORM,
                'logpnorm':LOG_ALPHA_NORM,
                'softrank':STABLE_RANK,
                'softranklog':LOG_STABLE_RANK,
                'softrank_mp':MP_SOFTRANK}


LOG_NORM_EQN =  r"$\langle\log\Vert W\Vert^{2}_{F}\rangle$"
ALPHA_EQN = r"$\langle\alpha\rangle$"
ALPHA_HAT_EQN = r"$\hat{\alpha}$"
LOG_SPECTRAL_NORM_EQN = r"$\langle\log\;\Vert\mathbf{W}\Vert^{2}_{\infty}\rangle$"
LOG_STABLE_RANK_EQN = r"$\langle\log\;\mathcal{R}_{s}\rangle$"
MP_SOFTRANK_EQN = r"$\langle\log\;\mathcal{R}_{mp}\rangle$"
LOG_ALPHA_NORM_EQN = r"$\langle\log\;\Vert\mathbf{X}\Vert^{\alpha}_{\alpha}\rangle$"
    


In [2]:
import numpy as np
import pandas as pd
import scipy as sp

from sklearn.linear_model import LinearRegression
from sklearn import metrics
from scipy.stats import kendalltau
import itertools

    
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
plt.rcParams.update({'font.size': 20})
from pylab import rcParams
rcParams['figure.figsize'] = 10,10

MARKERS = ['o', '<', '>', 'v', '^', 's', 'D', 'X']
MARKER_SIZE = 250
TRANSPARENCY = 1.0

In [4]:
def plot_test_accuracy(metric, xlabel, title, series_name, \
                       all_names, all_summaries, top_errors, ylabel='Top1 Test Accuracy'):
    """Create plot of Metric vs Reported Test Accuracy, and run Linear Regression"""
    
    markers = itertools.cycle(MARKERS)
                         
    num = len(all_names)
    xs, ys = np.empty(num), np.empty(num)
    for im, modelname in enumerate(all_names):    

        summary = all_summaries[im]
        x = summary[metric]
        xs[im] = x

        error = top_errors[modelname]
        y = 100.0-error
        ys[im] = y

        label = modelname
        plt.scatter(x, y, label=label, s=MARKER_SIZE, marker=next(markers))


    xs = xs.reshape(-1,1)
    ys = ys.reshape(-1,1)
    regr = LinearRegression()
    regr.fit(xs, ys)
    y_pred = regr.predict(xs)
    plt.plot(xs, y_pred, color='red', linewidth=1)

    rmse = np.sqrt(metrics.mean_squared_error(ys, y_pred))
    r2 = metrics.r2_score(ys, y_pred)
    ktau, pval  = kendalltau(xs,ys)
    
    title2 = r"RMSE: {:0.2} R2: {:0.2} $\tau$: {:0.2}".format(rmse, r2, ktau)
    
    # legend moved for ResNet-1K
    plt.legend(bbox_to_anchor=(1.04,1), loc="upper left")
    plt.title(" Test Accuracy vs "+title+"\n"+title2)
    plt.ylabel(ylabel)
    plt.xlabel(xlabel);
    
    figname = "img/{}_{}_accs.png".format(series_name, metric)
    print("saving {}".format(figname))
    plt.savefig(figname, bbox_inches="tight")
    plt.show()

In [5]:
def plot_metrics_histogram(metric, xlabel, title, series_name, \
    all_names, all_details, colors, log=False, valid_ids = []):
                                
    transparency = 1.0
    
    if len(valid_ids) == 0:
        valid_ids = range(0,len(all_details)-1)
        idname='all'
    else:
        idname='fnl'
        
    for im, details in enumerate(all_details):
        if im in valid_ids:
            vals = details[metric].to_numpy()
            if log:
                vals = np.log10(np.array(vals+0.000001, dtype=np.float))

            mu = np.mean(vals)
            med = np.median(vals)

            label = r"{} $\mu=${:0.2f}".format(all_names[im],mu)
            plt.hist(vals, bins=100, label=label, alpha=transparency, color=colors[im], density=True)
            transparency -= 0.15
            print(r"{} {} {} median = {:0.3}".format(title, metric, all_names[im], med))
            
    fulltitle = "Histogram: "+title+" "+xlabel
  
    plt.legend()
    plt.title(title)
    plt.title(fulltitle)
    plt.xlabel(xlabel)
    
    figname = "img/{}_{}_{}_hist.png".format(series_name, idname, metric)
    print("saving {}".format(figname))
    plt.savefig(figname)
    plt.show()

In [6]:
def plot_metrics_depth(metric, ylabel, title, series_name, \
    all_names, all_details, colors, log=False, valid_ids = []):
    
    markers = itertools.cycle(MARKERS)

    transparency = TRANSPARENCY
      
    if len(valid_ids) == 0:
        valid_ids = range(len(all_details)-1)
        idname='all'
    else:
        idname='fnl'
        
        
    for im, details in enumerate(all_details):
        if im in valid_ids:
            
            details = all_details[im]
            name = all_names[im]
            x = details.index.to_numpy()
            y = details[metric].to_numpy()
            if log:
                y = np.log10(np.array(y+0.000001, dtype=np.float))

            plt.scatter(x,y, label=name, color=colors[im], s=MARKER_SIZE/10.0, marker=next(markers))

    # moved legend to right for ResNet
    plt.legend()
    plt.title("Depth vs "+title+" "+ylabel)
    plt.xlabel("Layer id")
    plt.ylabel(ylabel)
    
    figname = "img/{}_{}_{}_depth.png".format(series_name, idname, metric)
    print("saving {}".format(figname))
    plt.savefig(figname)
    plt.show()

## Metrics vs Test Accuracy

In [7]:
def plot_all_metrics_vs_test_accuracies( \
    series_name, all_names, colors, all_summaries, all_details, \
    top_errors):

    metric = LOG_NORM
    xlabel = LOG_NORM_EQN 
    title = "Avg. log Frobenius Norm "
    plot_test_accuracy(metric, xlabel, title, series_name, \
                        all_names, all_summaries, top_errors)
    
    metric = ALPHA
    xlabel = ALPHA_EQN 
    title = "Avg. Alpha  "
    plot_test_accuracy(metric, xlabel, title, series_name, \
                        all_names, all_summaries, top_errors)
        
    metric = ALPHA_HAT
    xlabel = ALPHA_HAT_EQN 
    title = "Avg. Weighted Alpha"
    plot_test_accuracy(metric, xlabel, title, series_name, \
                        all_names, all_summaries, top_errors)
        
    metric = LOG_SPECTRAL_NORM
    xlabel = LOG_SPECTRAL_NORM_EQN
    title = "Avg. log Spectral Norm"
    plot_test_accuracy(metric, xlabel, title, series_name, \
                        all_names, all_summaries, top_errors)
        
    metric = LOG_STABLE_RANK
    xlabel = LOG_STABLE_RANK_EQN 
    title = "Avg. log Stable Rank"
    plot_test_accuracy(metric, xlabel, title, series_name, \
                        all_names, all_summaries, top_errors)
        

    metric = MP_SOFTRANK
    xlabel = MP_SOFTRANK_EQN 
    title = "Avg. log MP Soft Rank"
    plot_test_accuracy(metric, xlabel, title, series_name, \
                        all_names, all_summaries, top_errors)
    
    metric = LOG_ALPHA_NORM
    xlabel = LOG_ALPHA_NORM_EQN 
    title = r"Avg. log $\alpha$-Norm"
    plot_test_accuracy(metric, xlabel, title, series_name, \
                        all_names, all_summaries, top_errors)

## Histogram of metrics for all layers

In [8]:
#first_n_last_ids = [0, len(all_details)-1]

In [9]:
def plot_all_metric_histograms(\
    series_name, all_names, colors, all_summaries, all_details,  first_n_last_ids):
    
    metric = LOG_NORM
    xlabel = LOG_NORM_EQN
    title = series_name
    plot_metrics_histogram(metric, xlabel, title, series_name, \
            all_names, all_details, colors)                        
    plot_metrics_histogram(metric, xlabel, title, series_name, \
            all_names, all_details, colors, valid_ids = first_n_last_ids)

    metric = ALPHA
    xlabel = ALPHA_EQN
    title = series_name
    plot_metrics_histogram(metric, xlabel, title, series_name, \
            all_names, all_details, colors)                           
    plot_metrics_histogram(metric, xlabel, title, series_name, \
            all_names, all_details, colors, valid_ids = first_n_last_ids)


    metric = ALPHA_HAT
    xlabel = ALPHA_HAT_EQN
    title = series_name
    plot_metrics_histogram(metric, xlabel, title, series_name, \
            all_names, all_details, colors)                         
    plot_metrics_histogram(metric, xlabel, title, series_name, \
            all_names, all_details, colors, valid_ids = first_n_last_ids)

    
    metric = LOG_STABLE_RANK
    xlabel = LOG_STABLE_RANK_EQN
    title = series_name
    plot_metrics_histogram(metric, xlabel, title, series_name, \
            all_names, all_details, colors)   
    plot_metrics_histogram(metric, xlabel, title, series_name, \
            all_names, all_details, colors, valid_ids = first_n_last_ids)

    metric = LOG_SPECTRAL_NORM
    xlabel = LOG_SPECTRAL_NORM_EQN
    title = series_name
    plot_metrics_histogram(metric, xlabel, title,  series_name, \
            all_names, all_details, colors)                          
    plot_metrics_histogram(metric, xlabel, title, series_name, \
            all_names, all_details, colors, valid_ids = first_n_last_ids)


    metric = MP_SOFTRANK
    xlabel = MP_SOFTRANK_EQN
    title = series_name
    plot_metrics_histogram(metric, xlabel, title,  series_name, \
            all_names, all_details, colors)                         
    plot_metrics_histogram(metric, xlabel, title, series_name, \
            all_names, all_details, colors, valid_ids = first_n_last_ids)

    
    metric = LOG_ALPHA_NORM
    xlabel = LOG_ALPHA_NORM_EQN
    title = series_name
    plot_metrics_histogram(metric, xlabel, title, series_name, \
            all_names, all_details, colors)                          
    plot_metrics_histogram(metric, xlabel, title, series_name, \
            all_names, all_details, colors, \
            valid_ids = first_n_last_ids)


## Metrics as a function of depth

In [10]:
def plot_all_metric_vs_depth(\
    series_name, all_names, colors, all_summaries, all_details, first_n_last_ids):

    metric = LOG_NORM
    xlabel = LOG_NORM_EQN
    title = series_name
    plot_metrics_depth(metric, xlabel, title, series_name, \
            all_names, all_details, colors, log=False, valid_ids = [])
    plot_metrics_depth(metric, xlabel, title,series_name, \
            all_names, all_details, colors, log=False, valid_ids = first_n_last_ids)

    metric = ALPHA
    xlabel = ALPHA_EQN
    title = series_name
    plot_metrics_depth(metric, xlabel, title, series_name, \
            all_names, all_details, colors, log=False, valid_ids = [])
    plot_metrics_depth(metric, xlabel, title, series_name, \
            all_names, all_details, colors, log=False, valid_ids = first_n_last_ids)

    metric = ALPHA_HAT
    xlabel = ALPHA_HAT_EQN
    title = series_name
    plot_metrics_depth(metric, xlabel, title, series_name, \
            all_names, all_details, colors, log=False, valid_ids = [])
    plot_metrics_depth(metric, xlabel, title, series_name, \
            all_names, all_details, colors, log=False, valid_ids = first_n_last_ids)



    metric = LOG_STABLE_RANK
    xlabel = LOG_STABLE_RANK_EQN
    title = series_name
    plot_metrics_depth(metric, xlabel, title, series_name, \
            all_names, all_details, colors, log=False, valid_ids = [])
    plot_metrics_depth(metric, xlabel, title, series_name, \
            all_names, all_details, colors, log=False, valid_ids = first_n_last_ids)

    metric = LOG_SPECTRAL_NORM
    xlabel = LOG_SPECTRAL_NORM_EQN
    title = series_name
    plot_metrics_depth(metric, xlabel, title, series_name, \
            all_names, all_details, colors, log=False, valid_ids = [])
    plot_metrics_depth(metric, xlabel, title, series_name, \
            all_names, all_details, colors, log=False, valid_ids = first_n_last_ids)



    metric = MP_SOFTRANK
    xlabel = MP_SOFTRANK_EQN
    title = series_name
    plot_metrics_depth(metric, xlabel, title, series_name, \
            all_names, all_details, colors, log=False, valid_ids = [])
    plot_metrics_depth(metric, xlabel, title, series_name, \
            all_names, all_details, colors, log=False, valid_ids = first_n_last_ids)

    metric = LOG_ALPHA_NORM
    xlabel = LOG_ALPHA_NORM_EQN
    title = series_name
    plot_metrics_depth(metric, xlabel, title, series_name, \
            all_names, all_details, colors, log=False, valid_ids = [])
    plot_metrics_depth(metric, xlabel, title, series_name, \
            all_names, all_details, colors, log=False, valid_ids = first_n_last_ids)


In [11]:
def save_ww2x_data(all_names,all_details, all_summaries, datadir='data'):
    """Save ww2x dataframes and summaries"""
    for name, df in zip(all_names,all_details):
        filename = "{}/{}.csv".format(datadir,name)
        df.to_csv(filename)
    
    for name, summary in zip(all_names,all_summaries):
        filename = "{}/{}.txt".format(datadir,name)
        with open(filename, 'w') as f:
            f.write( str(summary) )
        

In [12]:
def read_ww2x_data(all_names, datadir='data'):
    """read ww2x dataframes and summaries"""
    all_details,all_summaries = [],[]
    

    for name in all_names:
        filename = "{}/{}.csv".format(datadir,name)
        print("loading details in", filename)
        details = pd.read_csv(filename)
        all_details.append(details)
    
    for name in all_names:
        filename = "{}/{}.txt".format(datadir,name)
        print("loading summary in ", filename)
        summary = eval(open(filename, 'r').read())
        all_summaries.append(summary)
        
    return all_details,all_summaries