In [None]:
import wandb
from tqdm.auto import tqdm
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

sns.set(font_scale=2, style='whitegrid')

In [None]:
def get_summary_metrics(sweep_id, config_keys=None, filter_func=None):
  api = wandb.Api(timeout=60)
  sweep = api.sweep(sweep_id)

  data = []
  for run in tqdm(sweep.runs, desc='Runs', leave=False):
    if callable(filter_func) and not filter_func(run):
      continue
    cfg = {k: run.config[k] for k in config_keys or []}
    data.append(dict(run_id=run.id, **cfg, **run.summary))

  return sweep, pd.DataFrame(data)

In [None]:
_, metrics = get_summary_metrics('deeplearn/pactl/vn8zm3zs', config_keys=['intrinsic_dim'])

In [None]:
fig, ax = plt.subplots(figsize=(9,7))
sns.lineplot(ax=ax, data=metrics, x='intrinsic_dim', y='sgd/test/best_acc', legend=False,
             alpha=.5)
sns.scatterplot(ax=ax, data=metrics, x='intrinsic_dim', y='sgd/test/best_acc', legend=False,
                marker='o', s=400, hue='sgd/test/best_acc',
                palette=sns.color_palette('crest_r', as_cmap=True), zorder=10)
ax.set(xlabel='Intrinsic Dimension', ylabel='Test Accuracy', title='MNIST')

fig.tight_layout()
# fig.savefig('id.pdf', bbox_inches='tight')