In [None]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import matplotlib as mpl

mpl.rc('image', cmap='Set1')

if os.path.basename(os.getcwd()) == 'analysis':
    os.chdir('..')
from models.custom_resnet import custom_resnet_50, custom_resnet_56

In [None]:
path = './results/net_time_comparison.csv'
d = pd.read_csv(path)

In [None]:
original_flops, original_memory, original_batch_8, original_batch_32 = tuple(d[d.type=='original'][['flops','memory','batch 8 time', 'batch 32 time']].values[0].tolist())

d['flop_reduction'] = 100*(original_flops - d.flops)/original_flops
d['theoretical_speedup'] = original_flops/ d.flops
d['memory_reduction'] = 100*(original_memory - d.memory)/original_memory
d['speedup 8'] = original_batch_8 / d['batch 8 time']
d['speedup 32'] = original_batch_32 / d['batch 32 time']

In [None]:
d

### Latex form conversion

In [None]:
relevant_columns = ['net', 'top1', 'top5', 'flop_reduction', 'memory_reduction', 'speedup 32']
d_table = d[relevant_columns]
for col in ['top1', 'top5', 'flop_reduction', 'memory_reduction']:
    d_table[col] = d_table[col].map('{:.2f}'.format)
d_table['speedup 32'] = d_table['speedup 32'].map('{:.3f}'.format)
d_table = d_table.replace('nan','--')

result_path = '/home/eli/Eli/gator/table_text_results/imagenet_table.txt'
with open(result_path, 'w+') as f:
# if True:
    for _,r in d_table.iterrows():
        row = ' & '.join([str(a) for a in r.tolist()]) + ' \\\\'
        print(row)
        f.write(row + '\n')

In [None]:

alpha = [0.25,0.5,1,2]

flops_long = d[(d.type=='flops') | (d.type=="original")]
timing_long = d[(d.type=='b8')| (d.type=="original")]
flops_short = flops_long.iloc[:-1]
timing_short = timing_long.iloc[:-4]

dcp = d[(d.type=='dcp') | (d.type=="original")]
geo = d[(d.type=='geo') | (d.type=="original")]
chan = d[(d.type=='chan')] # | (d.type=="original")]
pcas = d[(d.type=='pcas')] # | (d.type=="original")]

prune_short = [flops_short, timing_short, dcp, geo, chan, pcas]
prune_short_names = ['GATOR FLOPs', 'GATOR latency', 'DCP', 'Geometric median', 'Channel Pruning', 'PCAS']
prune_short_line_formats = ['-y','-r','--g','-b','-c','-m']
prune_short = [flops_short, dcp, geo, chan, pcas]
prune_short_names = ['GATOR FLOPs', 'DCP', 'Geometric median', 'Channel Pruning', 'PCAS']
prune_short_line_formats = ['-r','--g','-b','-c','--m']

prune_long = [timing_long]
prune_long_names =  ['GATOR latency']
prune_long_line_formats = ['-r']

mobilenet = d[d.net=='Mobilenet V2']
squezenet0 = d[d.net=='squeezenet 1_0']
squezenet1 = d[d.net=='squeezenet 1_1']

small_nets = [mobilenet, squezenet0, squezenet1]
small_nets_names= ['MobileNet V2', 'SqueezeNet 1_0', 'SqueezeNet 1_1'] 

# small_nets = [mobilenet]
# small_nets_names= ['MobileNet V2'] 


In [None]:
def plot_comparison(x_name, y_name, prune_series, prune_names, prune_line_formats, net_points=None, net_names=None):
    plt.figure(figsize=(10,5))
    for i in range(len(prune_series)):
        data, name, line_format = prune_series[i], prune_names[i], prune_line_formats[i]
#         if len(data[x_name].dropna())>1:
#             line_format = '--' if name=='DCP' else ('-y' if len(prune_series)==1  else '-')
        plt.plot(data[x_name], data[y_name],line_format, label = name )
    if net_points is not None:
        for i in range(len(net_points)):
            s, s_name = small_nets[i], net_names[i]
            plt.scatter(s[x_name], s[y_name], label=s_name)
    plt.xlabel(x_name.replace('_', ' ').replace('reduction', 'reduction %').replace('speedup 32', 'speedup multiplier'), fontsize=14)
    plt.ylabel(y_name.replace('_', ' ') if 'top' not in y_name else y_name + ' accuracy %', fontsize=14)
    plt.legend(loc='upper right', fontsize=12)
    plt.show()

In [None]:
plot_comparison('flop_reduction', 'top1', False)

In [None]:
plot_comparison('flop_reduction', 'top5', prune_short, prune_short_names, prune_short_line_formats)

In [None]:
plot_comparison('theoretical_speedup', 'top1', False)

In [None]:
plot_comparison('theoretical_speedup', 'top5', False)

In [None]:
plot_comparison('speedup 8', 'top1', True)

In [None]:
plot_comparison('speedup 8', 'top5', True)

In [None]:
plot_comparison('speedup 32', 'top1', True)

In [None]:
prune_short[0] = timing_short
prune_short  = prune_short[:4]
prune_short_names[0] = 'Gator latency'
prune_short_names = prune_short_names[:4]
plot_comparison('speedup 32', 'top5', prune_short, prune_short_names, prune_short_line_formats)

In [None]:
plot_comparison('speedup 32', 'top5', prune_long, prune_long_names, prune_long_line_formats, small_nets, small_nets_names)

In [None]:
plot_comparison('speedup 32', 'top5', True)

In [None]:
# custom resnet 56
import os
import torch
if os.path.basename(os.getcwd()) == 'analysis':
    os.chdir('..')
    
from models.cifar_resnet import resnet56
from models.custom_resnet import custom_resnet_56

weight_path = \
'/home/eli/Eli/Training/Cifar10/ResNet56/resnet56_w_16/net_e_240'
full_dict = torch.load(weight_path)
state_dict = full_dict['state_dict']
channels_config = full_dict['channels_config']
net = custom_resnet_56(channels_config, 10)
res = net.compute_flops_memory(True)
print('{}, {}'.format(res[0],res[1]))

# print(resnet56(10).compute_flops_memory(True))

In [None]:
path = './results/cifar_results_new.csv'
d2 = pd.read_csv(path)
original_flops2, original_memory2 = tuple(d2[d2.type=='original'][['flops','memory']].values[0].tolist())

d2['flop_reduction'] = 100*(original_flops2 - d2.flops)/original_flops2
d2['memory_reduction'] = 100*(original_memory2 - d2.memory)/original_memory2
d2['theoretical_speedup'] = original_flops2/ d2.flops

In [None]:
d2

In [None]:
relevant_columns2 = ['net', 'accuracy', 'flop_reduction', 'memory_reduction']
d_table2 = d2[relevant_columns2]
for col in ['accuracy', 'flop_reduction', 'memory_reduction']:
    d_table2[col] = d_table2[col].map('{:.2f}'.format)
d_table2 = d_table2.replace('nan','--')
    
result_path2 = '/home/eli/Eli/gator/table_text_results/cifar_table.txt'
with open(result_path2, 'w+') as f:
# if True:
    for _,r in d_table2.iterrows():
        row = ' & '.join([str(a) for a in r.tolist()]) + ' \\\\'
        print(row)
        f.write(row + '\n')

In [None]:
# d2 = d2[~d2.net.str.contains('baseline')]

flops2 = d2[(d2.type=='flops') | (d2.type=="original")]
memory2 = d2[(d2.type=='memory') & ~(d2.net.str.contains('16')) & ~(d2.net.str.contains('8')) | (d2.type=="original")]
direct2 = d2[(d2.type=='direct') | (d2.net=="Gator flops 0.25") | (d2.type=="original")]

dcp1  = d2[(d2.net=='dcp baseline') | (d2.net == 'DCP')]
dcp2  = d2[(d2.net=='dcp baseline') | (d2.net == 'DCP Adapt')]
pcas = d2[d2.type=='pcas']
cp = d2[d2.type=='cp']
geo = d2[d2.type=='eeo']

prunes2 = [flops2, memory2, direct2, dcp1, dcp2, pcas, cp, geo]
prune_names2 = ['Gator FLOPs', 'Gator memory', 'Gator FLOPs direct', 'DCP', 'DCP Adapt', 'PCAS', 'Channel pruning', 'Geometric']

# prunes2 = [flops2, memory2, direct2]
# prune_names2 = ['Gator FLOPs', 'Gator memory', 'Gator FLOPs direct']

In [None]:
def plot_comparison2(x_name, y_name):
    plt.figure(figsize=(10,5))
    for i in range(len(prunes2)):
        data, name = prunes2[i], prune_names2[i]
        if len(data[x_name].dropna())>=1:
            if len(data[x_name]) > 1 :
                plt.plot(data[x_name], data[y_name], label = name)
            else:
                plt.scatter(data[x_name], data[y_name], label = name)
    plt.xlabel(x_name.replace('_', ' ').replace('reduction', 'reduction %'), fontsize=14)
    plt.ylabel(y_name.replace('_', ' ').replace('accuracy', 'accuracy %'), fontsize=14)
    plt.legend(loc='best', fontsize=14)
    plt.show()

In [None]:
plot_comparison2('flop_reduction', 'accuracy')

In [None]:
plot_comparison2('flop_reduction', 'accuracy')

In [None]:
plot_comparison2('theoretical_speedup', 'accuracy')

In [None]:
plot_comparison2('memory_reduction', 'accuracy')

In [None]:
plot_comparison2('memory_reduction', 'accuracy')