In [None]:
import pandas as pd
import numpy as np
import statsmodels.api as sm
from statsmodels.sandbox.regression.predstd import wls_prediction_std
from matplotlib import pyplot as plt
import sys
sys.path.append('..')
import figure
import re

def first_word_before_nonalpha(s):
    match = re.search(r'\b([A-Za-z]+)(?=[^A-Za-z])', s)
    if match:
        return match.group(1)
    else:
        full_match = re.match(r'[A-Za-z]+', s)
        return full_match.group(0) if full_match else None

In [None]:
imagenet = pd.read_csv('data/ImageClassification.csv')
imagenet['group'] = imagenet['model'].map(first_word_before_nonalpha)
imagenet['KE_per_param'] = imagenet['KQI'] / imagenet['num_params']
imagenet['KE_per_FLOP'] = imagenet['KQI'] / imagenet['GFLOPS']

imagenet

In [None]:
figure.initialize(width=3.5, height=3.5, left=True, bottom=True, left_tick=True, bottom_tick=True)
color_map = {group: color for group, color in zip(imagenet['group'].unique(), plt.cm.Dark2.colors + plt.cm.tab10.colors)}
for group, data in imagenet.groupby('group'):
    plt.scatter(data['acc@1']/data['num_params'], data['KQI']/data['num_params'], color=color_map[group], label=f"{group} ({', '.join(data['model'])})")
    
plt.legend()
plt.xscale('log')
plt.yscale('log')
plt.xlabel('Acc@1 / Params')
plt.ylabel('$KE_p$')

plt.savefig('imagenet.svg')

In [None]:
t_acc1 = sm.add_constant(imagenet['acc@1'])
model = sm.OLS(imagenet['KQI'], t_acc1)
results = model.fit()
y_predict, (std, ci_l, ci_u) = results.predict(t_acc1), wls_prediction_std(results, t_acc1)


figure.initialize(width=3.5, height=3.5, left=True, bottom=True, left_tick=True, bottom_tick=True)

# x = np.linspace(55, 95, 100)
# y = results.predict(sm.add_constant(x))
# for alpha in [1e-20] + list(np.linspace(.05, 1, 20)):
#     _, ci_l, ci_u = wls_prediction_std(results, sm.add_constant(x), alpha=alpha)
#     plt.fill_between(x, y, ci_u, color='#FCE2CB', alpha=.2, linewidth=0)
#     plt.fill_between(x, ci_l, y, color='#E8F0C5', alpha=.2, linewidth=0)

color_map = {group: color for group, color in zip(imagenet['group'].unique(), plt.cm.Dark2.colors + plt.cm.tab10.colors)}
for _, (name, kqi, acc1, param, group) in imagenet[['model', 'KQI', 'acc@1', 'num_params', 'group']].iterrows():
    plt.scatter(acc1/param, kqi/param, color=color_map[group])
    plt.annotate(name, (acc1/param, kqi/param), (acc1/param, kqi/param), arrowprops=dict(
                    arrowstyle='-', color=color_map[group], shrinkA=0, shrinkB=4, linewidth=.25
                ), horizontalalignment='left', verticalalignment='center', color=color_map[group], fontsize=5)

# plt.xlim(55, 95)
# plt.ylim(20, 55)
# plt.xticks([60, 70, 80, 90], ['60%', '70%', '80%', '90%'])
# plt.yticks([20, 30, 40, 50])
plt.xscale('log')
plt.yscale('log')
plt.xlabel('Accuracy@1')
plt.ylabel('KQI')

plt.savefig('imagenet.svg')

In [None]:
figure.initialize(width=3, height=3, left=True, bottom=True, left_tick=True, bottom_tick=True)

for (_, (name, kqi, acc1, acc5, param, gflop)), color in zip(imagenet.iterrows(), plt.cm.Dark2.colors*10):
    plt.scatter(acc1, param, color=color)
    # plt.annotate(name, (acc1, np.log(param)), (acc1+.6, np.log(param)), arrowprops=dict(
    #                 arrowstyle='-', color=color, shrinkA=0, shrinkB=4, linewidth=.25
    #             ), horizontalalignment='left', verticalalignment='center', color=color)

plt.yscale('log')
plt.xlim(55, 95)
plt.ylim(1e6, 1e9)
plt.xticks([60, 70, 80, 90], ['60%', '70%', '80%', '90%'])
plt.yticks([1e6, 1e7, 1e8, 1e9], ['1M', '10M', '100M', '1G'])
plt.xlabel('Accuracy@1')
plt.ylabel('#Parameter')

plt.savefig('imagenet_parameter.svg')

In [None]:
figure.initialize(width=3, height=3, left=True, bottom=True, left_tick=True, bottom_tick=True)

for (_, (name, kqi, acc1, acc5, param, gflop)), color in zip(imagenet.iterrows(), plt.cm.Dark2.colors*10):
    plt.scatter(acc1, gflop, color=color)
    # plt.annotate(name, (acc1, np.log(gflop)), (acc1+.6, np.log(gflop)), arrowprops=dict(
    #                 arrowstyle='-', color=color, shrinkA=0, shrinkB=4, linewidth=.25
    #             ), horizontalalignment='left', verticalalignment='center', color=color)

plt.yscale('log')
plt.xlim(55, 95)
plt.ylim(2e-2, 2e3)
plt.xticks([60, 70, 80, 90], ['60%', '70%', '80%', '90%'])
# plt.yticks([20, 30, 40, 50])
plt.xlabel('Accuracy@1')
plt.ylabel('GFLOPS')

plt.savefig('imagenet_gflops.svg')