In [None]:
import wandb
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np

wandb.init(project="d3pm")
api = wandb.Api()

def split_df(df, x_axis, y_axis, cond=None):
    if cond is not None:
        df = df[cond]
    dat = np.array([(s[0], np.nanmean(np.float64(s[1][y_axis])), np.nanstd(np.float64(s[1][y_axis])))
                    for s in df.groupby(x_axis)])
    return dat.T

def load_from_dict(key):
    def f(x):
        try:
            if isinstance(key, list):
                for k in key:
                    x = x[k]
                return x
            else:
                return x[key]
        except KeyError:
            return np.nan
    return f

# CIFAR N=4 $\gamma$ sweeps

In [None]:
sweep = api.sweep("alanamin/d3pm/rqz6fqqg")
runs = sweep.runs

In [None]:
data = [{'config': run.config,
        'summary': run.summary,}
        for run in runs]

df = pd.DataFrame(data, columns=['config', 'summary'])
df = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   'gamma': df['config'].apply(lambda x: x['model.gamma']),
                   'x_t_param': df['config'].apply(lambda x: x['model.fix_x_t_bias']),
                   'nll': df['summary'].apply(load_from_dict('val_l01')),})

In [None]:
plt.figure(figsize=[3, 3])
cond = df['x_t_param']==False
x, y, yerr = split_df(df, 'gamma', 'nll', cond)
plt.errorbar(x, y, yerr, label='x_t-d param')
cond = df['x_t_param']==True
x, y, yerr = split_df(df, 'gamma', 'nll', cond)
plt.errorbar(x, y, yerr, label='x_t param')
plt.ylabel(f"NLL per bit (epoch {epoch})")
plt.xlabel("$\gamma$")
plt.legend()

# CIFAR N=8 $\gamma$ sweeps

In [None]:
sweep = api.sweep("alanamin/d3pm/t44pmy1u")
runs = sweep.runs

In [None]:
data = [{'config': run.config,
        'summary': run.summary,}
        for run in runs]

df = pd.DataFrame(data, columns=['config', 'summary'])
df = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   'gamma': df['config'].apply(lambda x: x['model.gamma']),
                   'nll': df['summary'].apply(load_from_dict('val_l01')),})

In [None]:
plt.figure(figsize=[3, 3])
x, y, yerr = split_df(df, 'gamma', 'nll')
plt.errorbar(x, y, yerr)
plt.ylabel(f"NLL per bit (epoch {epoch})")
plt.xlabel("$\gamma$")
plt.legend()

# CIFAR N=8 struct $\gamma$ sweeps

In [None]:
sweep = api.sweep("alanamin/d3pm/a8ba8kbd")
runs = sweep.runs

In [None]:
data = [{'config': run.config,
        'summary': run.summary,}
        for run in runs]

df = pd.DataFrame(data, columns=['config', 'summary'])
df = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   'gamma': df['config'].apply(lambda x: x['model.gamma']),
                   'nll': df['summary'].apply(load_from_dict('val_l01')),})

In [None]:
plt.figure(figsize=[3, 3])
x, y, yerr = split_df(df, 'gamma', 'nll')
plt.errorbar(x, y, yerr)
plt.ylabel(f"NLL per bit (epoch {epoch})")
plt.xlabel("$\gamma$")
plt.legend()

# CIFAR N=128 $\gamma$ sweeps

In [None]:
sweep = api.sweep("alanamin/d3pm/29h93dqv")
runs = sweep.runs

In [None]:
data = [{'config': run.config,
        'summary': run.summary,}
        for run in runs]

df = pd.DataFrame(data, columns=['config', 'summary'])
df = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   'gamma': df['config'].apply(lambda x: x['model.gamma']),
                   'nll': df['summary'].apply(load_from_dict('val_l01')),})

In [None]:
plt.figure(figsize=[3, 3])
x, y, yerr = split_df(df, 'gamma', 'nll')
plt.errorbar(x, y, yerr)
plt.ylabel(f"NLL per bit (epoch {epoch})")
plt.xlabel("$\gamma$")
plt.legend()

# CIFAR $\sigma$ sweep

In [None]:
sweep = api.sweep("alanamin/d3pm/mguwxwxd")
runs = sweep.runs

In [None]:
data = [{'config': run.config,
        'summary': run.summary,}
        for run in runs]

df = pd.DataFrame(data, columns=['config', 'summary'])
df = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   'sigma': df['config'].apply(load_from_dict(['forward_kwargs', 'bandwidth'])),
                   'normalized': df['config'].apply(load_from_dict(['forward_kwargs', 'normalized'])),
                   'nll': df['summary'].apply(load_from_dict('val_l01')),
                   'kl1': (df['summary'].apply(load_from_dict('val_l01'))
                           +df['summary'].apply(load_from_dict('val_l1'))),})

In [None]:
plt.figure(figsize=[3, 3])
cond = df['normalized']==True
x, y, yerr = split_df(df, 'sigma', 'nll', cond)
plt.errorbar(x, y, yerr, label='Normalized')
cond = df['normalized']==False
x, y, yerr = split_df(df, 'sigma', 'nll', cond)
plt.errorbar(x, y, yerr, label='Not normalized')
plt.ylabel(f"NLL per bit (epoch {epoch})")
plt.xlabel("$\gamma$")
plt.legend()

plt.figure(figsize=[3, 3])
cond = df['normalized']==True
x, y, yerr = split_df(df, 'sigma', 'kl1', cond)
plt.errorbar(x, y, yerr, label='Normalized')
cond = df['normalized']==False
x, y, yerr = split_df(df, 'sigma', 'kl1', cond)
plt.errorbar(x, y, yerr, label='Not normalized')
plt.ylabel(f"NLL per bit (epoch {epoch})")
plt.xlabel("$\gamma$")
plt.legend()

# CIFAR parameter runs

In [None]:
sweep = api.sweep("alanamin/d3pm/83mp27zv")
runs = sweep.runs

In [None]:
data = [{'config': run.config,
        'summary': run.summary,}
        for run in runs]

df = pd.DataFrame(data, columns=['config', 'summary'])
df = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   't_emb': df['config'].apply(load_from_dict(['nn_params', 'time_embed_dim'])),
                   's_dim': df['config'].apply(load_from_dict(['nn_params', 's_dim'])),
                   'hybrid': df['config'].apply(load_from_dict('hybrid_loss_coeff')),
                   'logistic_pars': df['config'].apply(load_from_dict('logistic_pars')),
                   'nll': df['summary'].apply(load_from_dict('val_l01')),})

In [None]:
for hlc in [0.00, 0.01, 0.1]:
    # plt.figure(figsize=[3, 3])
    for temb in [0, 128]:
        for lp in [True, False]:
            cond = np.logical_and(np.logical_and(
                df['t_emb']==temb, df['hybrid']==hlc), df['logistic_pars']==lp)
            label = f'temb={temb},lp={lp},hlc={hlc}'
            alpha = 0.1 if hlc == 0.1 else ( 0.5 if hlc==0.01 else 1)
            color = 'blue' if lp else 'orange'
            ls = '-' if temb==0 else '--'
            x, y, yerr = split_df(df, 's_dim', 'nll', cond)
            plt.errorbar(x, y, yerr, alpha=alpha,
                         label=label, color=color, ls=ls)
            plt.semilogx()
    plt.ylabel(f"NLL per bit (final epoch)")
    plt.xlabel("s_dim")
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.title(f"hlc={hlc}")
    plt.ylim(2.04, 2.17)

Overall: CIFAR n=4 doesn't look good but that's with the wrong param. CIFAR N=8 is still not done, **I'll need to accelerate both of these and get the comparison to SEDD and masking asap**.
CIFAR n=4 with big sdim also failed, **will need to debug.**
*then we can say we do better than masking and interpolate between SEDD and masking.*

N=128 and structure sweeps worked and look as expected, **but I need to get SEDD comparison working**. *Then we can say adding transition info helps*

sigma sweep looks very strange we're doing bad. Is this because I changed schedule? **Will look at D3PM and change schedule if needed.** *then we can say structure helps*

Say when not using lp, nothing makes a big difference, but when using lp and with small hybrid loss, a bigger s_dim helps. Will continue running and run anything new at sdim 128. **I'll run a smaller scan dropping the hybrid, and testing 64 and 128 sdim.**