In [None]:
import wandb
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np

wandb.init(project="d3pm")
api = wandb.Api()

def split_df(df, x_axis, y_axis, cond=None, median=False):
    if cond is not None:
        df = df[cond]
    y = lambda s: (np.nanmean(np.float64(s[1][y_axis])) if not median else
         np.nanmedian(np.float64(s[1][y_axis])))
    dat = np.array([(s[0], y(s), np.nanstd(np.float64(s[1][y_axis])))
                    for s in df.groupby(x_axis)])
    return dat.T

def load_from_dict(key):
    def f(x):
        try:
            if isinstance(key, list):
                for k in key:
                    x = x[k]
                return x
            else:
                return x[key]
        except KeyError:
            return np.nan
    return f

def geq(x, y):
    if x is None:
        return False
    else:
        return x >= y

In [None]:
def load_data(runs, average_n_ep=4):
    data = []
    for run in tqdm(runs):
        history = run.scan_history()
        epoch_data = [row for row in history
                      if (geq(row.get('epoch'), run.config['n_epoch']-average_n_ep)
                          and row.get('val_l01') is not None)]
        if epoch_data:
            average_dict = {key: np.mean([d[key] for d in epoch_data if key in d if isinstance(d[key], float)])
                            for key in set().union(*epoch_data)}
            data.append({
                'config': run.config,
                'summary': run.summary,
                'metrics': average_dict
            })
    return data

<!-- # CIFAR N=4 $\gamma$ sweeps -->

In [None]:
sweep_1 = api.sweep("alanamin/d3pm/rqz6fqqg")
runs_1 = sweep_1.runs

# use logparse, larger sdim, better schedule
sweep_2 = api.sweep("alanamin/d3pm/6hzm5l3w")
runs_2 = sweep_2.runs

# use film, u_inject
sweep_3 = api.sweep("alanamin/d3pm/jltfpmv3")
runs_3 = sweep_3.runs

# use film, u_inject
sweep_3_xtd = api.sweep("alanamin/d3pm/7ki50x5s")
runs_3_xtd = sweep_3_xtd.runs

# try 1:bas2ns4k, try 3:wjs4rx72
sweep_baseline = api.sweep("alanamin/d3pm/wjs4rx72") # try 1
runs_baseline = sweep_baseline.runs

In [None]:
data_1 = load_data(runs_1, 4)
data_2 = load_data(runs_2, 4)
data_3 = load_data(runs_3, 4)
data_3_xtd = load_data(runs_3_xtd, 4)
data_baseline = load_data(runs_baseline, 4)

df = pd.DataFrame(data_1, columns=['config', 'summary', 'metrics'])
df_1 = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   'gamma': df['config'].apply(lambda x: x['model.gamma']),
                   'x_t_param': df['config'].apply(lambda x: x['model.fix_x_t_bias']),
                   'nll': df['metrics'].apply(load_from_dict('val_l01')),})

df = pd.DataFrame(data_2, columns=['config', 'summary', 'metrics'])
df_2 = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   'gamma': df['config'].apply(lambda x: x['model.gamma']),
                   'nll': df['metrics'].apply(load_from_dict('val_l01')),})

df = pd.DataFrame(data_3, columns=['config', 'summary', 'metrics'])
df_3 = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   'gamma': df['config'].apply(lambda x: x['model.gamma']),
                   'nll': df['metrics'].apply(load_from_dict('val_l01')),})

df = pd.DataFrame(data_3_xtd, columns=['config', 'summary', 'metrics'])
df_3_xtd = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   'gamma': df['config'].apply(lambda x: x['model.gamma']),
                   'nll': df['metrics'].apply(load_from_dict('val_l01')),})

df_bl = pd.DataFrame(data_baseline, columns=['config', 'summary', 'metrics'])
df_bl = pd.DataFrame({'seed': df_bl['config'].apply(lambda x: x['model.seed']),
                            'model': df_bl['config'].apply(lambda x: x['model.model']),
                            'nll': df_bl['metrics'].apply(load_from_dict('val_l01')),})

In [None]:
plt.figure(figsize=[3, 3])
cond = df_1['x_t_param']==False
x, y, yerr = split_df(df_1, 'gamma', 'nll', cond)
plt.errorbar(x, y, yerr, label='x_t-d param, try=1', color='blue')
cond = df_1['x_t_param']==True
x, y, yerr = split_df(df_1, 'gamma', 'nll', cond)
plt.errorbar(x, y, yerr, label='x_t param, try=1', color='blue', ls='--')
x, y, yerr = split_df(df_2, 'gamma', 'nll')
plt.errorbar(x, y, yerr, label='try=2', color='green')
x, y, yerr = split_df(df_3, 'gamma', 'nll')
plt.errorbar(x, y, yerr, label='x_t-d param, try=3', color='red')
x, y, yerr = split_df(df_3_xtd, 'gamma', 'nll')
plt.errorbar(x, y, yerr, label='x_t param, try=3', color='red', ls='--')
x, y, yerr = split_df(df_bl, 'model', 'nll')
plt.errorbar([1], y[x=='SEDD'].astype(float), yerr[x=='SEDD'].astype(float), label='SEDD', color='black')
plt.errorbar([1/4], y[x=='MaskingDiffusion'].astype(float), yerr[x=='MaskingDiffusion'].astype(float), label='MD', color='black')
plt.ylabel(f"NLL per bit")
plt.xlabel("$\gamma$")
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))

Seems we are now Asymptoting to SEDD -- great!
Seems adding unet helped out further away from uniform.
I'm now rerunning SEDD and try 1 to finish them off.

- Worrying try 2 did worse than try 1! How did this happen? Did the schedule mess things up?

  
  It's because we optimized for large N
  
- Also why is $x_t$ param is doing better! probably because of adding one-hot! Just noticed we're not doing logistic parse! One more try!


# CIFAR N=8 $\gamma$ sweeps

In [None]:
sweep_1 = api.sweep("alanamin/d3pm/t44pmy1u")
runs_1 = sweep_1.runs
sweep_2 = api.sweep("alanamin/d3pm/l0giaov0")
runs_2 = sweep_2.runs
sweep_3 = api.sweep("alanamin/d3pm/b9kpcexa")
runs_3 = sweep_3.runs

# try 1:54s5x1dd, try 3:q3syz2ac
sweep_baseline = api.sweep("alanamin/d3pm/q3syz2ac") 
runs_baseline = sweep_baseline.runs

In [None]:
data_1 = load_data(runs_1, 4)
data_2 = load_data(runs_2, 4)
data_3 = load_data(runs_3, 4)
data_baseline = load_data(runs_baseline, 4)

df = pd.DataFrame(data_1, columns=['config', 'summary', 'metrics'])
df_1 = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   'gamma': df['config'].apply(lambda x: x['model.gamma']),
                   'nll': df['metrics'].apply(load_from_dict('val_l01')),})

df = pd.DataFrame(data_2, columns=['config', 'summary', 'metrics'])
df_2 = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   'gamma': df['config'].apply(lambda x: x['model.gamma']),
                   'nll': df['metrics'].apply(load_from_dict('val_l01')),})
df = pd.DataFrame(data_3, columns=['config', 'summary', 'metrics'])
df_3 = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   'gamma': df['config'].apply(lambda x: x['model.gamma']),
                   'nll': df['metrics'].apply(load_from_dict('val_l01')),})

df_bl = pd.DataFrame(data_baseline, columns=['config', 'summary', 'metrics'])
df_bl = pd.DataFrame({'seed': df_bl['config'].apply(lambda x: x['model.seed']),
                            'model': df_bl['config'].apply(lambda x: x['model.model']),
                            'nll': df_bl['metrics'].apply(load_from_dict('val_l01')),})

In [None]:
plt.figure(figsize=[3, 3])
x, y, yerr = split_df(df_1, 'gamma', 'nll')
plt.errorbar(x, y, yerr, label="try=1")
x, y, yerr = split_df(df_2, 'gamma', 'nll')
plt.errorbar(x, y, yerr, label="try=2", color='green')
x, y, yerr = split_df(df_3, 'gamma', 'nll')
plt.errorbar(x, y, yerr, label="try=3", color='red')
x, y, yerr = split_df(df_bl, 'model', 'nll')
plt.errorbar([1], y[x=='SEDD'].astype(float), yerr[x=='SEDD'].astype(float), label='SEDD', color='black')
plt.errorbar([1/8], y[x=='MaskingDiffusion'].astype(float), yerr[x=='MaskingDiffusion'].astype(float), label='MD', color='black')
plt.ylabel(f"NLL per bit")
plt.xlabel("$\gamma$")
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))

# CIFAR N=8 struct $\gamma$ sweeps

In [None]:
sweep_1 = api.sweep("alanamin/d3pm/a8ba8kbd")
runs_1 = sweep_1.runs
sweep_3 = api.sweep("alanamin/d3pm/bo0tuxvi")
runs_3 = sweep_3.runs

sweep_baseline = api.sweep("alanamin/d3pm/ibfzma6j") # try 1
runs_baseline = sweep_baseline.runs

In [None]:
data_1 = load_data(runs_1, 4)
data_3 = load_data(runs_3, 4)
data_baseline = load_data(runs_baseline, 4)

df = pd.DataFrame(data_1, columns=['config', 'summary', 'metrics'])
df_1 = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   'gamma': df['config'].apply(lambda x: x['model.gamma']),
                   'nll': df['metrics'].apply(load_from_dict('val_l01')),})
df = pd.DataFrame(data_3, columns=['config', 'summary', 'metrics'])
df_3 = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   'gamma': df['config'].apply(lambda x: x['model.gamma']),
                   'nll': df['metrics'].apply(load_from_dict('val_l01')),})

df_bl = pd.DataFrame(data_baseline, columns=['config', 'summary', 'metrics'])
df_bl = pd.DataFrame({'seed': df_bl['config'].apply(lambda x: x['model.seed']),
                            'model': df_bl['config'].apply(lambda x: x['model.model']),
                            'nll': df_bl['metrics'].apply(load_from_dict('val_l01')),})

In [None]:
plt.figure(figsize=[3, 3])
x, y, yerr = split_df(df_1, 'gamma', 'nll')
plt.errorbar(x, y, yerr, label="try=1")
x, y, yerr = split_df(df_3, 'gamma', 'nll')
plt.errorbar(x, y, yerr, label="try=3", color='red')
x, y, yerr = split_df(df_bl, 'model', 'nll')
plt.errorbar([1], y[x=='SEDD'].astype(float), yerr[x=='SEDD'].astype(float), label='SEDD', color='black')
plt.errorbar([1/8], y[x=='MaskingDiffusion'].astype(float), yerr[x=='MaskingDiffusion'].astype(float), label='MD', color='black')
plt.ylabel(f"NLL per bit")
plt.xlabel("$\gamma$")
plt.legend()

- Wait for try 3 to finish and rerun baselines!

# CIFAR N=128 $\gamma$ sweeps

In [None]:
sweep_1 = api.sweep("alanamin/d3pm/29h93dqv")
runs_1 = sweep_1.runs

sweep_2 = api.sweep("alanamin/d3pm/sv64n8bg")
runs_2 = sweep_2.runs

sweep_3 = api.sweep("alanamin/d3pm/n2ef9nsf")
runs_3 = sweep_3.runs

# try 2: 2sbnuj6j, try 3:
sweep_bl = api.sweep("alanamin/d3pm/2sbnuj6j")
runs_bl = sweep_bl.runs

In [None]:
data_1 = load_data(runs_1, 4)
data_2 = load_data(runs_2, 4)
data_3 = load_data(runs_3, 4)
data_baseline = load_data(runs_bl, 4)

df = pd.DataFrame(data_1, columns=['config', 'summary', 'metrics'])
df_1 = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   'gamma': df['config'].apply(lambda x: x['model.gamma']),
                   'nll': df['metrics'].apply(load_from_dict('val_l01')),})
df = pd.DataFrame(data_2, columns=['config', 'summary', 'metrics'])
df_2 = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   'gamma': df['config'].apply(lambda x: x['model.gamma']),
                   'nll': df['metrics'].apply(load_from_dict('val_l01')),})

df = pd.DataFrame(data_3, columns=['config', 'summary', 'metrics'])
df_3 = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   'gamma': df['config'].apply(lambda x: x['model.gamma']),
                   'nll': df['metrics'].apply(load_from_dict('val_l01')),})

df_bl = pd.DataFrame(data_baseline, columns=['config', 'summary', 'metrics'])
df_bl = pd.DataFrame({'seed': df_bl['config'].apply(lambda x: x['model.seed']),
                            'model': df_bl['config'].apply(lambda x: x['model.model']),
                            'nll': df_bl['metrics'].apply(load_from_dict('val_l01')),})

In [None]:
plt.figure(figsize=[3, 3])
x, y, yerr = split_df(df_1, 'gamma', 'nll')
plt.errorbar(x, y, yerr, label="try=2")
x, y, yerr = split_df(df_2, 'gamma', 'nll')
plt.errorbar(x, y, yerr, label="try=2", color='green')
x, y, yerr = split_df(df_3, 'gamma', 'nll')
plt.errorbar(x, y, yerr, label="try=3", color='red')
x, y, yerr = split_df(df_bl, 'model', 'nll')
plt.errorbar([1], y[x=='SEDD'].astype(float), yerr[x=='SEDD'].astype(float), label='SEDD', color='black')
plt.errorbar([1/128], y[x=='MaskingDiffusion'].astype(float), yerr[x=='MaskingDiffusion'].astype(float), label='MD', color='black')
plt.ylabel(f"NLL per bit")
plt.xlabel("$\gamma$")
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))

# CIFAR $\sigma$ sweep

In [None]:
sweep_1 = api.sweep("alanamin/d3pm/mguwxwxd")
runs_1 = sweep_1.runs

sweep_2 = api.sweep("alanamin/d3pm/n8fzi4h1")
runs_2 = sweep_2.runs

sweep_3 = api.sweep("alanamin/d3pm/q2q5qwv5")
runs_3 = sweep_3.runs

sweep_bl = api.sweep("alanamin/d3pm/h0tee18u")
runs_bl = sweep_bl.runs



In [None]:
data_1 = load_data(runs_1, 4)
data_2 = load_data(runs_2, 4)
data_3 = load_data(runs_3, 4)
data_bl = load_data(runs_bl, 4)

df = pd.DataFrame(data_1, columns=['config', 'summary', 'metrics'])
df_1 = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   'sigma': df['config'].apply(load_from_dict(['forward_kwargs', 'bandwidth'])),
                   'normalized': df['config'].apply(load_from_dict(['forward_kwargs', 'normalized'])),
                   'nll': df['metrics'].apply(load_from_dict('val_l01')),
                   'kl1': (df['metrics'].apply(load_from_dict('val_l01'))
                           +df['metrics'].apply(load_from_dict('val_l1'))),})
df = pd.DataFrame(data_2, columns=['config', 'summary', 'metrics'])
df_2 = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   'sigma': df['config'].apply(load_from_dict(['forward_kwargs', 'bandwidth'])),
                   'normalized': df['config'].apply(load_from_dict(['forward_kwargs', 'normalized'])),
                   'nll': df['metrics'].apply(load_from_dict('val_l01')),
                   'kl1': (df['metrics'].apply(load_from_dict('val_l01'))
                           +df['metrics'].apply(load_from_dict('val_l1'))),})
df = pd.DataFrame(data_3, columns=['config', 'summary', 'metrics'])
df_3 = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   'sigma': df['config'].apply(load_from_dict(['forward_kwargs', 'bandwidth'])),
                   'normalized': df['config'].apply(load_from_dict(['forward_kwargs', 'normalized'])),
                   'nll': df['metrics'].apply(load_from_dict('val_l01')),
                   'kl1': (df['metrics'].apply(load_from_dict('val_l01'))
                           +df['metrics'].apply(load_from_dict('val_l1'))),})
df = pd.DataFrame(data_bl, columns=['config', 'summary', 'metrics'])
df_bl = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   'sigma': df['config'].apply(load_from_dict(['forward_kwargs', 'bandwidth'])),
                   'normalized': df['config'].apply(load_from_dict(['forward_kwargs', 'normalized'])),
                   'nll': df['metrics'].apply(load_from_dict('val_l01')),
                   'kl1': (df['metrics'].apply(load_from_dict('val_l01'))
                           +df['metrics'].apply(load_from_dict('val_l1'))),})

In [None]:
plt.figure(figsize=[3, 3])
cond = df_1['normalized']==True
x, y, yerr = split_df(df_1, 'sigma', 'nll', cond)
plt.errorbar(x, y, yerr, label='Normalized', color='blue')
cond = df_1['normalized']==False
x, y, yerr = split_df(df_1, 'sigma', 'nll', cond)
plt.errorbar(x, y, yerr, label='Not normalized', ls='--', color='blue')

cond = df_2['normalized']==True
x, y, yerr = split_df(df_2, 'sigma', 'nll', cond)
plt.errorbar(x, y, yerr, label='Normalized', color='green')
cond = df_2['normalized']==False
x, y, yerr = split_df(df_2, 'sigma', 'nll', cond)
plt.errorbar(x, y, yerr, label='Not normalized', ls='--', color='green')

cond = df_3['normalized']==True
x, y, yerr = split_df(df_3, 'sigma', 'nll', cond)
plt.errorbar(x, y, yerr, label='Normalized', color='red')
cond = df_3['normalized']==False
x, y, yerr = split_df(df_3, 'sigma', 'nll', cond)
plt.errorbar(x, y, yerr, label='Not normalized', ls='--', color='red')

cond = df_bl['normalized']==True
x, y, yerr = split_df(df_bl, 'sigma', 'nll', cond)
plt.errorbar(x, y, yerr, label='Normalized', color='black')
cond = df_bl['normalized']==False
x, y, yerr = split_df(df_bl, 'sigma', 'nll', cond)
plt.errorbar(x, y, yerr, label='Not normalized', ls='--', color='black')
plt.ylabel(f"NLL per bit")
plt.xlabel("$\sigma$")
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))

plt.figure(figsize=[3, 3])
cond = df_1['normalized']==True
x, y, yerr = split_df(df_1, 'sigma', 'kl1', cond)
plt.errorbar(x, y, yerr, label='Normalized', color='blue')
cond = df_1['normalized']==False
x, y, yerr = split_df(df_1, 'sigma', 'kl1', cond)
plt.errorbar(x, y, yerr, label='Not normalized', ls='--', color='blue')

cond = df_2['normalized']==True
x, y, yerr = split_df(df_2, 'sigma', 'kl1', cond)
plt.errorbar(x, y, yerr, label='Normalized', color='green')
cond = df_2['normalized']==False
x, y, yerr = split_df(df_2, 'sigma', 'kl1', cond)
plt.errorbar(x, y, yerr, label='Not normalized', ls='--', color='green')

cond = df_3['normalized']==True
x, y, yerr = split_df(df_3, 'sigma', 'kl1', cond)
plt.errorbar(x, y, yerr, label='Normalized', color='red')
cond = df_3['normalized']==False
x, y, yerr = split_df(df_3, 'sigma', 'kl1', cond)
plt.errorbar(x, y, yerr, label='Not normalized', ls='--', color='red')

cond = df_bl['normalized']==True
x, y, yerr = split_df(df_bl, 'sigma', 'kl1', cond)
plt.errorbar(x, y, yerr, label='Normalized', color='black')
cond = df_bl['normalized']==False
x, y, yerr = split_df(df_bl, 'sigma', 'kl1', cond)
plt.errorbar(x, y, yerr, label='Not normalized', ls='--', color='black')
plt.ylabel(f"NLL per bit")
plt.xlabel("$\sigma$")
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))

Looks like schedule fixed it -- we see benefit!
- Wait for try 3 to finish
- run baselines!
- try smaller sigma?

# CIFAR parameter runs try 1

In [None]:
sweep = api.sweep("alanamin/d3pm/83mp27zv")
runs = sweep.runs

In [None]:
data = load_data(runs, 4)

df = pd.DataFrame(data, columns=['config', 'summary', 'metrics'])
df = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   't_emb': df['config'].apply(load_from_dict(['nn_params', 'time_embed_dim'])),
                   's_dim': df['config'].apply(load_from_dict(['nn_params', 's_dim'])),
                   'hybrid': df['config'].apply(load_from_dict('hybrid_loss_coeff')),
                   'logistic_pars': df['config'].apply(load_from_dict('logistic_pars')),
                   'nll': df['metrics'].apply(load_from_dict('val_l01')),})

In [None]:
for hlc in [0.00, 0.01, 0.1]:
    # plt.figure(figsize=[3, 3])
    for temb in [0, 128]:
        for lp in [True, False]:
            cond = np.logical_and(np.logical_and(
                df['t_emb']==temb, df['hybrid']==hlc), df['logistic_pars']==lp)
            label = f'temb={temb},lp={lp},hlc={hlc}'
            alpha = 0.1 if hlc == 0.1 else ( 0.5 if hlc==0.01 else 1)
            color = 'blue' if lp else 'orange'
            ls = '-' if temb==0 else '--'
            x, y, yerr = split_df(df, 's_dim', 'nll', cond)
            plt.errorbar(x, y, yerr, alpha=alpha,
                         label=label, color=color, ls=ls)
            plt.semilogx()
    plt.ylabel(f"NLL per bit")
    plt.xlabel("s_dim")
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.ylim(2.04, 2.17)

Overall: CIFAR n=4 doesn't look good but that's with the wrong param. CIFAR N=8 is still not done, **I'll need to accelerate both of these and get the comparison to SEDD and masking asap.** ✔️
CIFAR n=4 with big sdim also failed, **will need to debug.** ✔️
*then we can say we do better than masking and interpolate between SEDD and masking.*

N=128 and structure sweeps worked and look as expected, **but I need to get SEDD comparison working**. ✔️ *Then we can say adding transition info helps*

sigma sweep looks very strange we're doing bad. Is this because I changed schedule? **Will look at D3PM and change schedule if needed.** *then we can say structure helps*

Say when not using lp, nothing makes a big difference, but when using lp and with small hybrid loss, a bigger s_dim helps. Will continue running and run anything new at sdim 128. **I'll run a smaller scan dropping the hybrid, and testing 64 and 128 sdim.** ✔️

# CIFAR parameter runs right before try 2

In [None]:
sweep = api.sweep("alanamin/d3pm/1emyvi0q")
runs = sweep.runs

In [None]:
data = load_data(runs, 3)

df = pd.DataFrame(data, columns=['config', 'summary', 'metrics'])
df = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   't_emb': df['config'].apply(load_from_dict(['nn_params', 'time_embed_dim'])),
                   's_dim': df['config'].apply(load_from_dict(['nn_params', 's_dim'])),
                   'logistic_pars': df['config'].apply(load_from_dict('logistic_pars')),
                   'nll': df['metrics'].apply(load_from_dict('val_l01')),})

In [None]:
for temb in [0, 128]:
    for lp in [True, False]:
        cond = np.logical_and(
            df['t_emb']==temb, df['logistic_pars']==lp)
        label = f'temb={temb},lp={lp}'
        alpha = 1
        color = 'blue' if lp else 'orange'
        ls = '-' if temb==0 else '--'
        x, y, yerr = split_df(df, 's_dim', 'nll', cond)
        plt.errorbar(x, y, yerr, alpha=alpha,
                     label=label, color=color, ls=ls)
        plt.semilogx()
plt.ylabel(f"NLL per bit")
plt.xlabel("s_dim")
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
# plt.ylim(2.04, 2.17)

Seems the best is low t_dim and using large s_dim, using lp.

# CIFAR check s_dim integration method

In [None]:
sweep_film = api.sweep("alanamin/d3pm/osr9fbrb")
runs_film = sweep_film.runs

sweep_s_method = api.sweep("alanamin/d3pm/acm31eq6")
runs_s_method = sweep_s_method.runs


sweep_baseline = api.sweep("alanamin/d3pm/k17s5qkt")
runs_baseline = sweep_baseline.runs

In [None]:
data_film = load_data(runs_film, 3)
data_s_method = load_data(runs_s_method, 3)
data_baseline = load_data(runs_baseline, 3)

df = pd.DataFrame(data_film, columns=['config', 'summary', 'metrics'])
df_film = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   't_emb': df['config'].apply(load_from_dict(['nn_params', 'time_embed_dim'])),
                   'film': df['config'].apply(load_from_dict(['nn_params', 'film'])),
                   'nll': df['metrics'].apply(load_from_dict('val_l01')),})

df = pd.DataFrame(data_s_method, columns=['config', 'summary', 'metrics'])
df_s_method = pd.DataFrame({'seed': df['config'].apply(lambda x: x['model.seed']),
                   't_emb': df['config'].apply(load_from_dict(['nn_params', 'time_embed_dim'])),
                   'semb_style': df['config'].apply(load_from_dict(['nn_params', 'semb_style'])),
                   'nll': df['metrics'].apply(load_from_dict('val_l01')),})

df_bl = pd.DataFrame(data_baseline, columns=['config', 'summary', 'metrics'])
df_bl = pd.DataFrame({'seed': df_bl['config'].apply(lambda x: x['model.seed']),
                      'model': df_bl['config'].apply(lambda x: x['model.model']),
                      'film': df_bl['config'].apply(load_from_dict(['nn_params', 'film'])),
                      'nll': df_bl['metrics'].apply(load_from_dict('val_l01')),})

In [None]:
w = 0.25

c = 0
for i, style in enumerate(['learn_embed', 'learn_nn', 'u_inject']):
    cond = df_s_method['semb_style']==style
    x, y, yerr = split_df(df_s_method, 't_emb', 'nll', cond)
    bars = plt.bar([c-w/2, c+w/2], y, yerr=yerr, width=w, label=style)
    bars[1].set_hatch('//')
    c += 1
for j, film in enumerate([False, True]):
    cond = df_film['film'] == film
    x, y, yerr = split_df(df_film, 't_emb', 'nll', cond)
    bars = plt.bar([c-w/2, c+w/2], y, yerr=yerr, width=w, label=f'u_inject_film={film}')
    bars[1].set_hatch('//')
    c += 1
for j, model in enumerate(['MaskingDiffusion', 'SEDD']):
    cond = df_bl['model'] == model
    x, y, yerr = split_df(df_bl, 'film', 'nll', cond)
    print(y, x)
    bars = plt.bar([c-w/2, c+w/2], y[::-1], yerr=yerr, width=w, label=f'{model}_(bars=no film)', color='grey')
    bars[1].set_hatch('-')
    c += 1
plt.ylim(2, 2.2)
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.ylabel(f"NLL per bit")

- Great! Now let's run the full gamma sweep!