In [None]:
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


In [None]:
def sem(data, sd=2):
  "Calculate standard error of mean"
  sem = np.std(data, ddof=1) / np.sqrt(np.size(data))
  return sd * sem

def get_statistic(results, stat, betas):
  "get mean and sem for a summary stat (test_acc_ens or test_ens_loss)"
  score_ens_mean_ls = []
  score_ens_sem_ls = []

  for beta in betas:
      score_vals = results[results['beta'] == beta][stat].values
      score_mean = score_vals.mean()
      score_sem = sem(score_vals)
      score_ens_mean_ls.append(score_mean)
      score_ens_sem_ls.append(score_sem)

  return score_ens_mean_ls, score_ens_sem_ls

In [None]:
# load an results file output from running experiments

results_34_100 = pd.read_csv('CIFAR100_resnet34_results_loop_20230512_20 33 08.csv', index_col=0)


In [None]:
# this cell produces the plot of training dynamics (e.g. Fig 6 RHS)

betas = [0, 0.2, 0.4, 0.6, 0.8, 1.0]

score_mean, score_sem =  get_statistic(results_18_10, 'test_acc_ens', betas)

green = plt.cm.Greens(0.2*(6/4) + 0.4)
greenlight = plt.cm.Greens(0.4)
blue = plt.cm.Blues(0.2*(6/4) + 0.4)
bluelight = plt.cm.Blues(0.4)
red = plt.cm.Reds(0.2*(6/4) + 0.4)


MARKERSIZE = 6
LINEWIDTH = 3
SUBTITLESIZE = 20
MEDIUM_SIZE = 18
BIGGER_SIZE = 25
LEGENDSIZE = 16

def add_subplot(ax, results, color, lwd, stl, title=None, label=None):
  ax.set_title('ImageNet', fontsize=SUBTITLESIZE)
  ax.tick_params(axis='x', labelsize=MEDIUM_SIZE)
  ax.tick_params(axis='y', labelsize=MEDIUM_SIZE)

  ax.plot(results, color=color, linewidth=lwd,linestyle=stl, label=label)

idx = 277

fig, ax = plt.subplots(1, 2, figsize=(8,3.5))
# figure level settings
fig.text(0.5, -0.05, r'Training iteration', ha='center', size=BIGGER_SIZE)
fig.text(-0.02, 0.5, 'Accuracy', va='center', rotation='vertical', size=BIGGER_SIZE)


add_subplot(ax[0], val_acc_ens_0[:275], blue, LINEWIDTH, stl='-',label='Ensemble')
for i in range(val_acc_ind_0.shape[1]):
  add_subplot(ax[0], val_acc_ind_0[:275,i], bluelight, LINEWIDTH, stl='--', label=f'Base learner {i}')

add_subplot(ax[1], val_acc_ens_1[:275], green, LINEWIDTH, stl='-',label= 'Ensemble')
for i in range(val_acc_ind_1.shape[1]):
  add_subplot(ax[1], val_acc_ind_1[:275,i], greenlight,  LINEWIDTH, stl='--', label=f'Base learner {i}')


custom_lines = [Line2D([0], [0], color=blue, lw=4),
                Line2D([0], [0], color=bluelight, linestyle=(0.1,(2,1)), lw=4),
                Line2D([0], [0], color=green, lw=4),
                Line2D([0], [0], color=greenlight, linestyle=(0.1,(2,1)), lw=4)]

ax[0].legend(custom_lines, [r'Ensemble $\beta=0$', r'Base learners $\beta=0$',
                            r'Ensemble $\beta=1$', r'Base learners $\beta=1$'], fontsize=LEGENDSIZE)
ax[0].set_ylim(0,0.67)
ax[1].set_ylim(0,0.67)

fig.tight_layout()
plt.savefig('imagenet_training.pdf', format='pdf', dpi=1200)
plt.show()

In [None]:
# this cell produces the plot of interpolating over beta (e.g. Fig 6 LHS)

betas = [0, 0.2, 0.4, 0.6, 0.8, 1.0]

green = plt.cm.Greens(0.2*(6/4) + 0.4)
blue = plt.cm.Blues(0.2*(6/4) + 0.4)
red = plt.cm.Reds(0.2*(6/4) + 0.4)

MARKERSIZE = 6
LINEWIDTH = 2.5
SUBTITLESIZE = 20
MEDIUM_SIZE = 18
BIGGER_SIZE = 25

def add_subplot(ax, results, color, title=None):
  score_mean, score_sem =  get_statistic(results, 'test_acc_ens', betas)
  if title is not None:
    ax.set_title(title, fontsize=SUBTITLESIZE)
  ax.tick_params(axis='x', labelsize=MEDIUM_SIZE)
  ax.tick_params(axis='y', labelsize=MEDIUM_SIZE)

  ax.plot(betas, score_mean, color=color, linewidth=LINEWIDTH, marker='o', ms=MARKERSIZE)
  upper = [m + s for m, s in zip(score_mean, score_sem)]
  lower = [m - s for m, s in zip(score_mean, score_sem)]

  ax.fill_between(betas, upper, lower, color=color,
                        alpha=0.25)


fig, ax = plt.subplots(1, 2, figsize=(8,3.5))
# figure level settings
fig.text(0.5, -0.05, r'Level of diversity $\beta$', ha='center', size=BIGGER_SIZE)
fig.text(-0.02, 0.5, 'Accuracy', va='center', rotation='vertical', size=BIGGER_SIZE)

# add the subplots
add_subplot(ax[0], results_18_10, blue, 'CIFAR-10')
add_subplot(ax[0], results_34_10, green)
add_subplot(ax[1], results_18_100, blue, 'CIFAR-100')
add_subplot(ax[1], results_34_100, green)

custom_lines = [Line2D([0], [0], color=blue, lw=4),
                Line2D([0], [0], color=green, lw=4)]

ax[0].legend(custom_lines, ['ResNet-18', 'ResNet-34'], fontsize=SUBTITLESIZE)

fig.tight_layout()
plt.savefig('interpolating_beta.pdf', format='pdf', dpi=1200)
plt.show()