In [None]:
import os
from pathlib import Path
import wandb
from tqdm.auto import tqdm
import pandas as pd
import numpy as np

def get_summary_metrics(sweep_id, config_keys=None, filter_func=None):
  api = wandb.Api(timeout=60)
  sweep = api.sweep(sweep_id)

  data = []
  for run in tqdm(sweep.runs, desc='Runs', leave=False):
    cfg = {k: run.config[k] for k in config_keys or []}
    if callable(filter_func) and not filter_func(run, cfg):
      continue
    data.append(dict(run_id=run.id, **cfg, **run.summary))

  return sweep, pd.DataFrame(data)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import ticker

sns.set(font_scale=2, style='whitegrid')

In [None]:
def f(run, _):
    return run.state == 'finished'

## Scratch
sweep_id = 'deeplearn/pactl/ofqs9zqv'
_, scratch_metrics = get_summary_metrics(sweep_id=sweep_id,
                                 config_keys=['levels', 'use_kmeans', 'intrinsic_dim', 'scale_posterior'],
                                 filter_func=f)
scratch_metrics['mode'] = 'scratch'

## Transfer
sweep_id = 'deeplearn/pactl/bjhfcr9n'
_, tran_metrics = get_summary_metrics(sweep_id=sweep_id,
                                 config_keys=['levels', 'use_kmeans', 'intrinsic_dim', 'scale_posterior'],
                                 filter_func=f)
tran_metrics['mode'] = 'transfer'

all_metrics = pd.concat([scratch_metrics, tran_metrics]).reset_index().drop(columns=['index'])
all_metrics['err_train_100'] = (1 - all_metrics['train_acc']) * 100
all_metrics['quantized_err_train_100'] = (1 - all_metrics['quantized_train_acc']) * 100
all_metrics['err_test_100'] = (1 - all_metrics['test_acc']) * 100
all_metrics['quantized_err_test_100'] = (1 - all_metrics['quantized_test_acc']) * 100
all_metrics['err_bound_100'] = all_metrics['err_bound_100'].astype(float)

# all_metrics.to_csv('mnist_bounds.csv', index=False)

## Train/Test Accuracy Comparison

In [None]:
fig, axes = plt.subplots(figsize=(11,4), ncols=2)

ax = axes[0]
sns.lineplot(ax=ax, data=all_metrics, x='intrinsic_dim', y='err_test_100',
             style='mode', hue='mode', marker='o', markersize=11, linewidth=3)
ax.set(xlabel='Intrinsic Dimension', ylabel='Test Error', yscale='log')

handles, labels = ax.get_legend_handles_labels()
for idx, (h, l) in enumerate(zip(handles, labels)):
    h.set(marker='o', markersize=11, linewidth=3)
    labels[idx] = l.capitalize()
ax.legend(handles=handles, labels=labels)

# formatter = ticker.ScalarFormatter(useMathText=True)
# formatter.set_scientific(True)
# formatter.set_powerlimits((-1,1))
# ax.yaxis.set_major_formatter(formatter)

# fig, ax = plt.subplots(figsize=(6,3))
ax = axes[1]
sns.lineplot(ax=ax, data=all_metrics, x='intrinsic_dim', y='err_train_100',
             style='mode', hue='mode', marker='o', markersize=11, linewidth=3)
ax.set(xlabel='Intrinsic Dimension', ylabel='Train Error', yscale='log')

handles, labels = ax.get_legend_handles_labels()
for idx, (h, l) in enumerate(zip(handles, labels)):
    h.set(marker='o', markersize=11, linewidth=3)
    labels[idx] = l.capitalize()
ax.legend(handles=handles, labels=labels)

# ax.yaxis.set_major_formatter(formatter)

fig.show()
fig.tight_layout()

## Error Bound v/s Test Error

In [None]:
best_bound_metrics = all_metrics.iloc[all_metrics.groupby(['mode', 'intrinsic_dim'])['err_bound_100'].idxmin()]
best_bound_metrics = best_bound_metrics[best_bound_metrics.err_bound_100 < 50]
best_bound_metrics[['run_id', 'intrinsic_dim', 'mode', 'quantized_err_train_100', 'quantized_err_test_100', 'err_test_100', 'err_bound_100']]

In [None]:
from palettable.cartocolors.qualitative import Bold_5_r as _palette

fig, ax = plt.subplots(figsize=(6,6))

# ax.plot(np.arange(0, 15, .1), np.arange(0, 15, .1), '--', c='gray', alpha=.5, zorder=1)

sns.scatterplot(ax=ax, data=best_bound_metrics, x='err_test_100', y='err_bound_100',
                style='mode', s=400, hue='intrinsic_dim',
                palette=sns.color_palette(_palette.mpl_colors, best_bound_metrics.intrinsic_dim.nunique()))
ax.set(xlabel='Test Error', ylabel='Test Error Bound')

x_all = [-.18, .03, .03, .03, -.05, 0, 0, 0]
y_all = [-.3, -.25, -.25, -.25, -1.25, 0, 0, 0]
for (_, row), _dx, _dy in zip(best_bound_metrics.iterrows(), x_all, y_all):
    if row['mode'] != 'scratch':
        continue
    ax.text(row['err_test_100'] + _dx , row['err_bound_100'] + _dy, r"$d$=" + f"{row['intrinsic_dim']}",
            fontsize=18)

handles, labels = ax.get_legend_handles_labels()
for idx, (h, l) in enumerate(zip(handles, labels)):
    h.set_sizes([100])
    labels[idx] = l.capitalize()
handles, labels = handles[-2:], labels[-2:]
ax.legend(handles=handles, labels=labels)

fig.show()
fig.savefig('bound_vs_test_err.pdf', bbox_inches='tight')

## Train v/s Test Error

In [None]:
fig, ax = plt.subplots(figsize=(5,5))

sns.scatterplot(ax=ax, data=best_bound_metrics, x='err_test_100', y='err_train_100',
                style='mode', s=100, hue='intrinsic_dim',
                palette=sns.color_palette('tab20', best_bound_metrics.intrinsic_dim.nunique()))
ax.set(xlabel='Test Error', ylabel='Train Error', xlim=[0, 100], ylim=[0,105])
ax.plot(np.arange(0, 100, .1), np.arange(0, 100, .1), '--', c='gray', alpha=.5)

handles, labels = ax.get_legend_handles_labels()
for idx, (h, l) in enumerate(zip(handles, labels)):
    labels[idx] = l.capitalize()
handles, labels = handles[-2:], labels[-2:]
ax.legend(handles=handles, labels=labels)
ax.axis('equal')

fig.show()