In [None]:
import wandb
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np

wandb.init(project="d3pm")
api = wandb.Api()

# MNIST N=4 $\gamma$ sweeps

In [None]:
sweep = api.sweep("alanamin/d3pm/vx0qgfmh")
runs = sweep.runs

In [None]:
epoch = 6

data = []
for run in tqdm(runs):
    history = run.scan_history()
    epoch_data = [row for row in history if row.get('epoch') == epoch][-1]
    data.append({
        'config': run.config,
        'summary': run.summary,
        'metrics': epoch_data#.get('val_l01')
    })

In [None]:
df = pd.DataFrame(data, columns=['config', 'summary', 'metrics'])
df = pd.DataFrame({'gamma': df['config'].apply(lambda x: x['gamma']),
                   'x_t_param': df['config'].apply(lambda x: x['model.fix_x_t_bias']),
                   'nll': df['metrics'].apply(lambda x: x['val_l01']),
                   'final_nll': df['metrics'].apply(lambda x: x['val_l01']),})

In [None]:
plt.figure(figsize=[3, 3])
plt.plot(df[df['x_t_param']==True]['gamma'], df[df['x_t_param']==True]['nll'], label='x_t param')
plt.plot(df[df['x_t_param']==False]['gamma'], df[df['x_t_param']==False]['nll'], label='x_t-d param')
plt.ylabel(f"NLL per bit (epoch {epoch})")
plt.xlabel("$\gamma$")
plt.legend()

# MNIST N=10 $\gamma$ sweeps

In [None]:
sweep = api.sweep("alanamin/d3pm/q547s1vh")
runs = sweep.runs

In [None]:
epoch = 10

data = []
for run in tqdm(runs):
    history = run.scan_history()
    epoch_data = [row for row in history if row.get('epoch') == epoch][-1]
    data.append({
        'config': run.config,
        'summary': run.summary,
        'metrics': epoch_data#.get('val_l01')
    })

In [None]:
df = pd.DataFrame(data, columns=['config', 'summary', 'metrics'])
df = pd.DataFrame({'gamma': df['config'].apply(lambda x: x['gamma']),
                   'x_t_param': df['config'].apply(lambda x: x['model.fix_x_t_bias']),
                   'nll': df['metrics'].apply(lambda x: x['val_l01']),
                   'final_nll': df['metrics'].apply(lambda x: x['val_l01']),})

In [None]:
plt.figure(figsize=[3, 3])
plt.plot(df[df['x_t_param']==True]['gamma'], df[df['x_t_param']==True]['nll'], label='x_t param')
plt.plot(df[df['x_t_param']==False]['gamma'], df[df['x_t_param']==False]['nll'], label='x_t-d param')
plt.ylabel(f"NLL per bit (epoch {epoch})")
plt.xlabel("$\gamma$")
plt.legend()

plt.figure(figsize=[3, 3])
plt.plot(df[df['x_t_param']==True]['gamma'], df[df['x_t_param']==True]['nll'], label='x_t param')
plt.plot(df[df['x_t_param']==False]['gamma'], df[df['x_t_param']==False]['nll'], label='x_t-d param')
plt.ylabel(f"NLL per bit (epoch {epoch})")
plt.xlabel("$\gamma$")
plt.ylim(0.45, 0.5)
plt.legend()

MNIST seems to be a bad testground as we basically fit the data perfectly.

# CIFAR $\gamma$ sweeps

Here I use the x_t-d param

In [None]:
sweep = api.sweep("alanamin/d3pm/sqb25jtc")
runs = sweep.runs

In [None]:
epoch = 0

data = []
for run in tqdm(runs):
    history = run.scan_history()
    epoch_data = next((row for row in reversed(list(history)) if row.get('epoch') == epoch), None)
    if epoch_data:
        data.append({
            'config': run.config,
            'summary': run.summary,
            'metrics': epoch_data#.get('val_l01')
        })

In [None]:
df = pd.DataFrame(data, columns=['config', 'summary', 'metrics'])
df = pd.DataFrame({'gamma': df['config'].apply(lambda x: x['gamma']),
                   'train.lr': df['config'].apply(lambda x: x['train.lr']),
                   'nll': df['metrics'].apply(lambda x: x['val_l01']),
                   'final_nll': df['metrics'].apply(lambda x: x['val_l01']),})

In [None]:
plt.figure(figsize=[3, 3])
plt.plot(df[df['train.lr']==0.001]['gamma'], df[df['train.lr']==0.001]['final_nll'], label='lr=0.001')
plt.plot(df[df['train.lr']==0.0001]['gamma'], df[df['train.lr']==0.0001]['final_nll'], label='lr=0.0001')
plt.ylabel(f"NLL per bit (final epoch)")
plt.xlabel("$\gamma$")
plt.legend()
plt.ylim(3.85, 4.55)

# CIFAR $\sigma$ sweep

In [None]:
sweep = api.sweep("alanamin/d3pm/3tfeu0us")
runs = sweep.runs

In [None]:
epoch = 0

data = []
for run in tqdm(runs):
    history = run.scan_history()
    epoch_data = next((row for row in reversed(list(history)) if row.get('epoch') == epoch), None)
    if epoch_data:
        data.append({
            'config': run.config,
            'summary': run.summary,
            'metrics': epoch_data#.get('val_l01')
        })

In [None]:
df = pd.DataFrame(data, columns=['config', 'summary', 'metrics'])
df = pd.DataFrame({'sigma': df['config'].apply(lambda x: x['model.forward_kwargs.bandwidth']),
                   'normalized': df['config'].apply(lambda x: x['model.forward_kwargs.normalized']),
                   'train.lr': df['config'].apply(lambda x: x['train.lr']),
                   'final_nll': df['metrics'].apply(lambda x: x['val_l01']),})

In [None]:
plt.figure(figsize=[3, 3])
cond = np.logical_and(df['train.lr']==0.001, df['normalized']==True)
plt.plot(df[cond]['sigma'], df[cond]['final_nll'], label='lr=0.001, normalized', color='blue')
cond = np.logical_and(df['train.lr']==0.0001, df['normalized']==True)
plt.plot(df[cond]['sigma'], df[cond]['final_nll'], label='lr=0.0001, normalized', color='blue', ls='--')
cond = np.logical_and(df['train.lr']==0.001, df['normalized']==False)
plt.plot(df[cond]['sigma'], df[cond]['final_nll'], label='lr=0.001', color='orange')
cond = np.logical_and(df['train.lr']==0.0001, df['normalized']==False)
plt.plot(df[cond]['sigma'], df[cond]['final_nll'], label='lr=0.0001', color='orange', ls='--')
plt.ylabel(f"NLL per bit (final epoch)")
plt.xlabel("$\sigma$")
plt.legend()

Clear benefit to smaller bandwidths, and looks like normalization helps as well

# CIFAR parameter runs

In [None]:
sweep = api.sweep("alanamin/d3pm/x38oy66z")
runs = sweep.runs

In [None]:
epoch = 0

data = []
for run in tqdm(runs):
    history = run.scan_history()
    epoch_data = next((row for row in reversed(list(history)) if row.get('epoch') == epoch), None)
    if epoch_data:
        data.append({
            'config': run.config,
            'summary': run.summary,
            'metrics': epoch_data#.get('val_l01')
        })

In [None]:
df = pd.DataFrame(data, columns=['config', 'summary', 'metrics'])
df = pd.DataFrame({'s_dim': df['config'].apply(lambda x: x['architecture.s_dim']),
                   'hybrid': df['config'].apply(lambda x: x['model.hybrid_loss_coeff']),
                   'logistic_pars': df['config'].apply(lambda x: x['model.logistic_pars']),
                   'train.lr': df['config'].apply(lambda x: x['train.lr']),
                   'final_nll': df['metrics'].apply(lambda x: x['val_l01']),})

In [None]:
for lr in [0.001, 0.0001]:
    plt.figure(figsize=[3, 3])
    for hlc in [0.00, 0.01, 0.1]:
        for lp in [True, False]:
            cond = np.logical_and(np.logical_and(
                df['train.lr']==lr, df['hybrid']==hlc), df['logistic_pars']==lp)
            label = f'hlc={hlc},lp={lp}'
            color = 'blue' if lp else 'orange'
            ls = '-' if hlc==0 else ('--' if hlc==0.01 else ':')
            plt.plot(df[cond]['s_dim'], df[cond]['final_nll'],
                     label=label, color=color, ls=ls)
    plt.ylabel(f"NLL per bit (final epoch)")
    plt.xlabel("s_dim")
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.title(f"lr={lr}")
    plt.ylim(3.85, 4.25)