In [1]:
import os
import sys
sys.path.append('../..')
from shrinkbench.plot import df_from_results, plot_df, reset_plt
from IPython.display import display
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import json
import pathlib

# automatically reload modules
%load_ext autoreload
%autoreload 2

In [2]:
def jsonfile(file):
    with open(file, 'r') as f:
        s = json.load(f)
    return s

In [None]:
# plot results
results_dir = "../scripts/results"
# dataset = "MNIST" 
# model =  "LeNet" # "LeNet", "vgg11_bn_small_mnist"
dataset = "CIFAR10"
model = "resnet56" # "vgg11_bn_small",  "resnet56"
onelayer = False 
job_id = "5864587" 
if not onelayer:
    results_path = f"{results_dir}/{dataset}-{model}-{job_id}"
else:
    results_path = f"{results_dir}/onelayer-pruning/{dataset}-{model}-{job_id}"
df = df_from_results(results_path, structured=True, icml=True)
df = df[(df['model'] == model) & (df['dataset'] == dataset)] 
lay_names = sorted(df['prune_layers'].unique())
print(lay_names)
cf_key = 'fraction' 
print("# of rows in df: ", len(df))
# create folder for figs and save config file in it
fig_path = pathlib.Path(f"./figs/{dataset}-{model}-{job_id}")
fig_path.mkdir(parents=True, exist_ok=True)
config_file = results_path+'/config.json'
if os.path.isfile(config_file): 
    config = jsonfile(config_file)
    fractions = config['fractions'] if 'fractions' in config.keys() else [i * 0.1 for i in range(0, 11)]
    #fractions = [i * 0.1 for i in range(0, 11)]
    print(fractions)
    structure = config['structure'] if 'structure' in config.keys() else 'neuron'
    print(structure)
    config_file = fig_path / 'config.json'
    if not config_file.is_file():
        json.dump(config, open(config_file, 'w'), indent=4)
else:
    structure = 'channel' # neuron

In [18]:
df = df[(df['fraction'] != 0.01)]
fraction_ticks = [0.05, 0.1, 0.25, 0.5, 1.0] #0.01
def format_plot(plt, xaxis, yaxis):
    if yaxis in ["acc1", "acc5"]:
        plt.ylabel('Top-1 accuracy' if yaxis=="acc1" else "Top-5 accuracy")
       
    if xaxis == 'fraction':
                plt.xlabel(f'Fraction of prunable {"neurons/channels" if structure is None else structure} kept')
                if not onelayer:
                    plt.xscale('log')
                    plt.xticks(fraction_ticks)
                    plt.gca().set_xticklabels(map(str, fraction_ticks))
                    plt.ylim(0, 1)
                else:
                    plt.xscale('linear')
                    plt.xlim(0, 1)
                    plt.ylim(0, 1)
               
    else: 
        plt.xscale('log')
        mean = df.groupby('fraction', as_index=False).mean()
        x_max = int(max(mean[xaxis])) 
        x_min = 1 
        ticks = np.round(np.geomspace(x_min, x_max, len(fraction_ticks), endpoint=True)).astype(int)
        plt.xlim(0, x_max)
        plt.xticks(ticks)
        plt.gca().set_xticklabels(map(str, ticks))

    plt.rc('font', size=60)
    plt.rc('legend', fontsize=35) 
    plt.tight_layout()

In [None]:
reset_plt()
save_fig = True

order = {
 'Random': 11,
 'ActGrad': 4,
 'WeightNorm': 6,
 'LayerRandom': 12, 
 'SeqInChange': 2,
 'AsymInChange': 1,
 'LayerActGrad': 5,
 'LayerInChange': 3,
 'LayerWeightNorm': 7,
 'LayerSampling': 8,
 'LayerGreedy': 9,
 'LayerGreedy-fd': 10
}
# choose which acc after fine tuning to plot (max or last)
post_acc = 'post_' #'last_'
include = 'all'
exclude = ['NotIncluded', 'WeightNorm']

idx = np.logical_and(df['strategy'].isin(include), ~df['strategy'].isin(exclude)) if include is not 'all' else ~df['strategy'].isin(exclude) 
for name in (lay_names if onelayer else ['all']):
    df_idx = df[idx & (df['prune_layers'] == name)]  if name is not 'all' else df[idx]
    df_ours_rw = df_idx[df_idx['strategy'].isin(['AsymInChange', 'SeqInChange','LayerInChange']) & (df_idx['reweight'] == True)]
    for yaxis in ['acc1','pruning_time']: #'acc1' ,'acc5' ,'pruning_time', 'finetuning_time'
        for xaxis in ([cf_key, 'real_compression', 'speedup'] if yaxis == 'acc1' else [cf_key]): #cf_key, 'real_compression', 'speedup']:
            for ft in ([False, True] if name is 'all' else [False]):
                fig, ax = plt.subplots(1,2, figsize=(40, 20), sharey=False) 
                for i, reweight in enumerate([True, False]):
                    plt.subplot(1, 2, i+1) 
                    sub_df = df_idx[df_idx['reweight'] == reweight]
                    suffix = '- rw' if reweight else ''
                    rw_title = ("with" if reweight else "no") + " reweighting"
                    if yaxis in ["acc1", "acc5"]:
                        if not ft:
                            plot_df(sub_df, xaxis, 'pre_' + yaxis, markers='strategy', groupby_col=cf_key, line='-', fig=False, colors='strategy',
                                    alpha=1, markersize=25, linewidth=6) #suffix=suffix + '- pre', fig=reweight
                            if not reweight: # include our reweighted methods with non-reweighted methods for reference
                                plot_df(df_ours_rw, xaxis, 'pre_' + yaxis, markers='strategy', groupby_col=cf_key, line='-', fig=False, colors='strategy', alpha=0.25, 
                                        markersize=25, linewidth=6, label='_nolegend_')
                            format_plot(plt, xaxis, yaxis)
                            # plt.title(f'{name}, {rw_title}') # remove when plotting for paper
                        else:
                            min_acc = df[post_acc + yaxis].min()
                            plot_df(sub_df, xaxis, post_acc + yaxis, markers='strategy', groupby_col=cf_key, fig=False, colors='strategy', 
                                   alpha=1, markersize=25, linewidth=6) 
                            if not reweight: # include our reweighted methods with non-reweighted methods for reference
                                plot_df(df_ours_rw, xaxis, post_acc + yaxis, markers='strategy', groupby_col=cf_key, line='-', fig=False, colors='strategy', alpha=0.25, 
                                        markersize=25, linewidth=6, label='_nolegend_')
                            format_plot(plt, xaxis, yaxis)
                            plt.ylim(min_acc, 1)
                            # plt.title(f'{name}, {rw_title}, ft')


                    else: 
                        min_time = df_idx[yaxis].min()
                        max_time = df_idx[yaxis].max()
                        plot_df(sub_df, xaxis, yaxis, markers='strategy', fig=False, colors='strategy',
                                alpha=alpha, markersize=10) # suffix=suffix + '- pre'
                        format_plot(plt, xaxis, yaxis)
                        plt.ylim(min_time, max_time)
                        #plt.title(f'{dataset}-{model}, {rw_title}')

                ax = plt.gca()
                handles, labels = ax.get_legend_handles_labels()
                labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: order.get(t[0], 10)))
                ax.legend(handles, labels)
                if save_fig:
                    plt.savefig(fig_path / f'{yaxis}-{xaxis}{"-finetuned" if ft else ""}-{job_id}-{name}.png') 

plt.show()