In [None]:
import os
import wandb
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sns.set(font_scale=1.5, style='whitegrid', palette=sns.color_palette('Spectral'))

In [None]:
from typing import List

def get_metrics(sweep_id, keys=None, config_keys=None):
  api = wandb.Api()
  sweep = api.sweep(sweep_id)

  if isinstance(keys, list):
    keys.extend(['_runtime', '_step', '_timestamp'])
    keys = list(set(keys))

  data = []
  for run in sweep.runs:
    cfg = {k: run.config[k] for k in config_keys}
    for row in run.scan_history(keys=keys):
      data.append(dict(run_id=run.id, **cfg, **row))

  return sweep, pd.DataFrame(data)

In [None]:
keys = None ## get everything

## KeOps
_, metrics1 = get_metrics('gausspr/simplex-gp/xt1i60t7', keys=keys, config_keys=['method', 'dataset'])

## Simplex-GP
_, metrics2 = get_metrics('gausspr/simplex-gp/wz0yzdqq', keys=keys, config_keys=['method', 'dataset'])

metrics = pd.concat([metrics1, metrics2])
metrics['train/total_cu_ts'] = metrics.groupby(by=['run_id'])['train/total_ts'].cumsum()
metrics

## Runtime, RMSE, MLL

In [None]:
fig, axes = plt.subplots(figsize=(10, 10), nrows=2, ncols=2)

dataset = 'keggdirected'
plt_metrics = metrics[(metrics.dataset == dataset) & (metrics._step <= 100)]
plt_metrics = plt_metrics.sort_values(by=['method'], ascending=False)
sns.lineplot(data=plt_metrics, x='_step', y='train/mll', hue='method', ci=None, ax=axes[0,0])
sns.lineplot(data=plt_metrics, x='_step', y='train/total_cu_ts', hue='method', ci=None, ax=axes[0,1])
sns.lineplot(data=plt_metrics, x='_step', y='val/rmse', hue='method', ci=None, ax=axes[1,0])
sns.lineplot(data=plt_metrics, x='_step', y='test/rmse', hue='method', ci=None, ax=axes[1,1])

# axes[1,0].set_ylim([0.0, 2.0])
# axes[1,1].set_ylim([0.0, 2.0])

fig.tight_layout()

# ax.set_ylim([0., 2.])
# fig.savefig(f'{dataset}-training.png', bbox_inches='tight')

## Lengthscales and Noise

In [None]:
def raw2label(v):
  l = v.split('/')[-1]
  if l == 'outputscale':
    return r'$\alpha$'
  elif l == 'noise':
    return r'$\sigma^2$'
  else:
    return fr'$\ell_{{{l}}}$'

dataset = 'houseelectric'
# plt_metrics = metrics[(metrics.dataset == dataset) & (metrics._step == step)].dropna(axis=1)
plt_metrics = metrics[(metrics.dataset == dataset)].dropna(axis=1)
param_columns = list(filter(lambda x: 'param/lengthscale' in x, plt_metrics.columns))
plt_metrics = plt_metrics[['run_id', 'method', 'dataset', '_step'] + param_columns]
plt_metrics = plt_metrics.melt(id_vars=['run_id', 'method', 'dataset', '_step'], var_name='param', value_name='param_value')
plt_metrics = plt_metrics.sort_values(by=['method', 'param_value'], ascending=False)

fig, ax = plt.subplots(figsize=(20, 5))
# fig, ax = plt.subplots()
sns.barplot(data=plt_metrics, x='param', y='param_value', hue='method', ax=ax)
ax.set_xticklabels([raw2label(t.get_text()) for t in ax.get_xticklabels()])
ax.set_xlabel('')
ax.set_ylabel('')
fig.savefig(f'{dataset}-ls.png', bbox_inches='tight')

In [None]:
def raw2label(v):
  l = v.split('/')[-1]
  if l == 'outputscale' or l == 'noise':
    return l
  else:
    return fr'$\ell_{{{l}}}$'

dataset = 'houseelectric'
plt_metrics = metrics[(metrics.dataset == dataset)].dropna(axis=1)
param_columns = ['param/outputscale', 'param/noise']
plt_metrics = plt_metrics[['run_id', 'method', 'dataset', '_step'] + param_columns]
plt_metrics = plt_metrics.melt(id_vars=['run_id', 'method', 'dataset', '_step'], var_name='param', value_name='param_value')
plt_metrics = plt_metrics.sort_values(by=['method', 'param_value'], ascending=False)

fig, ax = plt.subplots()
sns.barplot(data=plt_metrics, x='param', y='param_value', hue='method', ax=ax)
ax.set_xticklabels([raw2label(t.get_text()) for t in ax.get_xticklabels()])
ax.set_xlabel('')
ax.set_ylabel('')
fig.savefig(f'{dataset}-scale_noise.png', bbox_inches='tight')

## CG Truncation

In [None]:
## Simplex-GP CG Truncations
sweep, metrics = get_metrics('gausspr/simplex-gp/hn3wy998',
                         keys=['train/total_ts', 'train/mll', 'val/rmse', 'test/rmse'],
                         config_keys=['dataset', 'cg_iter'])

metrics['train/total_cu_ts'] = metrics.groupby(by=['run_id'])['train/total_ts'].cumsum()
metrics

In [None]:
rmse_data = []
for run in sweep.runs:
    rmse_data.append({ 'dataset': run.config['dataset'], 'cg_iter': run.config['cg_iter'], 'best_rmse': run.summary['test/best_rmse'] })
rmse_data = pd.DataFrame(rmse_data)
rmse_data[rmse_data.dataset == 'protein']

In [None]:
fig, axes = plt.subplots(figsize=(10, 10), nrows=2, ncols=2)

dataset = 'protein'
plt_metrics = metrics[(metrics.dataset == dataset) & (metrics._step <= 100)]
plt_metrics = plt_metrics.sort_values(by=['cg_iter'])
# plt_metrics = plt_metrics[plt_metrics['train/mll'] != 'NaN']
plt_metrics.loc[:, 'train/mll'] = pd.to_numeric(plt_metrics['train/mll'])

sns.lineplot(data=plt_metrics, x='_step', y='train/mll', hue='cg_iter', ax=axes[0,0])
sns.lineplot(data=plt_metrics, x='_step', y='train/total_cu_ts', hue='cg_iter', ax=axes[0,1])
sns.lineplot(data=plt_metrics, x='_step', y='val/rmse', hue='cg_iter', ax=axes[1,0])
sns.lineplot(data=plt_metrics, x='_step', y='test/rmse', hue='cg_iter', ax=axes[1,1])
fig.tight_layout()

fig.savefig(f'{dataset}-cg-iter.png', bbox_inches='tight')