In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import functools
import itertools
import pprint

import orbax.checkpoint
import numpy as np
import jax
import jax.numpy as jnp
import torch.utils.data.dataloader
import tensorflow as tf
import sqlalchemy as sa
import seaborn as sns
sns.set_theme(style='whitegrid', font_scale=1.3, palette=sns.color_palette('husl'),)
import pandas as pd
import matplotlib.pyplot as plt

from userdiffusion import samplers, unet
from userfm import cs, datasets, diffusion, sde_diffusion, flow_matching, utils, main as main_module, plots

2025-01-29 00:19:27.893851: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1738109967.916173   22323 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1738109967.922031   22323 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# somehow, this line of code prevents a segmentation fault in nn.Dense
# when calling model.init
tf.config.experimental.set_visible_devices([], 'GPU')

In [4]:
engine = cs.get_engine()
cs.create_all(engine)
session = cs.orm.Session(engine)
session.begin()

<sqlalchemy.orm.session.SessionTransaction at 0x7fb4bb7b1f00>

In [5]:
config_alt_ids = {
    # Lorenz
    ('0y35hp7d', 'DM'): {},
    # ('fba4g7bp', 'FMOT'): {'sample': {'use_score': False}},
    # ('1g2n8baa', 'FMOT+Reg'): {'sample': {'use_score': False}},
    # ('eug367ja', 'Flow Matching (VE)'): {'sample': {'use_score': False}},
    ('3bjjfgwa', 'FM (no score)'): {'sample': {'use_score': False}},
    ('c0ijllm1', 'FM+Reg (no score)'): {'sample': {'use_score': False}},
    ('3bjjfgwa', 'FM'): {'sample': {'use_score': True}},
    ('c0ijllm1', 'FM+Reg'): {'sample': {'use_score': True}},
    # FitzHughNagumo
    # ('wyrwide1', 'Diffusion (VE SDE)'): {},
    # ('gcior3bc', 'Flow Matching (OT)'): {'sample': {'use_score': False}},
    # ('tybh75p1', 'Flow Matching (VE)'): {'sample': {'use_score': False}},
    # ('tybh75p1', 'Flow Matching (VE Score)'): {'sample': {'use_score': True}},
}

In [6]:
cfgs = session.execute(sa.select(cs.Config).where(cs.Config.alt_id.in_([c[0] for c in config_alt_ids])))
cfgs = {c.alt_id: c for (c,) in cfgs}
reference_cfg = cfgs[next(iter(cfgs.keys()))]

In [7]:
key = jax.random.key(reference_cfg.rng_seed)

In [8]:
key, key_dataset = jax.random.split(key)
ds = datasets.get_dataset(reference_cfg.dataset, key=key_dataset)
splits = datasets.split_dataset(reference_cfg.dataset, ds)
dataloaders = {}
for n, s in splits.items():
    dataloaders[n] = torch.utils.data.dataloader.DataLoader(
        list(tf.data.Dataset.from_tensor_slices(s).batch(reference_cfg.dataset.batch_size).as_numpy_iterator()),
        batch_size=1,
        collate_fn=lambda x: x[0],
    )
data_std = splits['train'].std()

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3300/3300 [04:57<00:00, 11.08it/s]


In [9]:
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
x_sample = next(iter(dataloaders['train']))
ckpt_name = 'epoch_1999'

cfg_info = {}
for k in config_alt_ids:
    cfg = cfgs[k[0]]
    assert cfg.rng_seed == reference_cfg.rng_seed
    assert cfg.dataset == reference_cfg.dataset

    cfg_unet = unet.unet_64_config(
        splits['train'].shape[-1],
        base_channels=cfg.model.architecture.base_channel_count,
        attention=cfg.model.architecture.attention,
    )
    model = unet.UNet(cfg_unet)
    
    key, key_jaxlightning = jax.random.split(key)
    if isinstance(cfg.model, cs.ModelDiffusion):
        jax_lightning = diffusion.JaxLightning(cfg, key_jaxlightning, dataloaders, data_std, None, model)
    elif isinstance(cfg.model, cs.ModelFlowMatching):
        jax_lightning = flow_matching.JaxLightning(cfg, key_jaxlightning, dataloaders, data_std, None, model)
    else:
        raise ValueError(f'Unknown model: {cfg.model}')
        
    jax_lightning.params = orbax_checkpointer.restore(cfg.run_dir/ckpt_name)
    jax_lightning.params_ema = orbax_checkpointer.restore(cfg.run_dir/f'{ckpt_name}_ema')

    cfg_info[k] = dict(
        cfg=cfg,
        jax_lightning=jax_lightning,
    )



In [10]:
if isinstance(reference_cfg.dataset, cs.DatasetLorenz):
    def constraint(x):
        fourier_magnitudes = jnp.abs(jnp.fft.rfft(x[..., 0], axis=-1))
        return -(fourier_magnitudes[..., 1:].mean(-1) - .6)
elif isinstance(reference_cfg.dataset, cs.DatasetFitzHughNagumo):
    def constraint(x):
        return jnp.max(x[..., :2].mean(-1), -1) - 2.5
else:
    raise ValueError(f'Unknown dataset: {referenc_cfg.dataset}')

In [11]:
evaluation_trajectories = splits['train']

In [12]:
cond = main_module.condition_on_initial_time_steps(evaluation_trajectories, reference_cfg.dataset.time_step_count_conditioning)
trajectory_count = reference_cfg.dataset.batch_size
keep_path = isinstance(reference_cfg.dataset, cs.DatasetGaussianMixture)
# use same sampling key for all models
key, key_samples = jax.random.split(key)
for k, info in cfg_info.items():
    cfg = info['cfg']
    if isinstance(info['cfg'].model, cs.ModelFlowMatching):
        info['samples'] = info['jax_lightning'].sample(key_samples, 1., cond, x_shape=evaluation_trajectories.shape, keep_path=keep_path, **config_alt_ids[k]['sample'])
        if (
            isinstance(info['cfg'].model.conditional_flow, cs.ConditionalSDE)
            and isinstance(info['cfg'].model.conditional_flow.sde_diffusion, cs.SDEVarianceExploding)
            and config_alt_ids[k]['sample']['use_score']
        ):
            def score(x, t):
                if not hasattr(t, 'shape') or not t.shape:
                    t = jnp.ones((evaluation_trajectories.shape[0], 1, 1)) * t
                return info['jax_lightning'].score(x, t, cond, info['jax_lightning'].params_ema)
            event_scores = samplers.event_scores(
                info['jax_lightning'].diffusion, score, constraint, reg=1e-3
            )
            info['event_samples'] = samplers.sde_sample(
                info['jax_lightning'].diffusion, event_scores, key_samples, x_shape=evaluation_trajectories.shape, nsteps=info['cfg'].model.time_step_count_sampling, traj=keep_path
            )
    elif isinstance(info['cfg'].model, cs.ModelDiffusion):
        info['samples'] = info['jax_lightning'].sample(key_samples, 1., cond, x_shape=evaluation_trajectories.shape, keep_path=keep_path)
        def score(x, t):
            if not hasattr(t, 'shape') or not t.shape:
                t = jnp.ones((evaluation_trajectories.shape[0], 1, 1)) * t
            return info['jax_lightning'].score(x, t, cond, info['jax_lightning'].params_ema)
        event_scores = samplers.event_scores(
            info['jax_lightning'].diffusion, score, constraint, reg=1e-3
        )
        info['event_samples'] = samplers.sde_sample(
            info['jax_lightning'].diffusion, event_scores, key_samples, x_shape=evaluation_trajectories.shape, nsteps=info['cfg'].model.time_step_count_sampling, traj=keep_path
        )
    else:
        raise ValueError(f"Unknown model: {info['cfg'].model}")

2025-01-29 00:25:01.649118: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3020] Can't reduce memory use below -3.52GiB (-3785230798 bytes) by rematerialization; only reduced to 48.85GiB (52455755224 bytes), down from 49.05GiB (52664581612 bytes) originally
2025-01-29 00:25:15.396441: W external/xla/xla/tsl/framework/bfc_allocator.cc:497] Allocator (GPU_0_bfc) ran out of memory trying to allocate 21.44GiB (rounded to 23016960000)requested by op 
2025-01-29 00:25:15.403324: W external/xla/xla/tsl/framework/bfc_allocator.cc:508] **__________________________________________________________________________________________________
E0129 00:25:15.403666   22323 pjrt_stream_executor_client.cc:3086] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 23016960000 bytes.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 23016960000 bytes.

In [None]:
trajectory_count = 10
df = pd.concat([
    *itertools.chain.from_iterable([
        [
            pd.DataFrame(dict(
                Source=source,
                Values=trajectory[:, 0],
            ))
            for i, trajectory in zip(range(trajectory_count), info['samples'][constraint(info['samples']) > 0])
        ]
        for (_, source), info in cfg_info.items()
    ])
], axis=0, keys=len(cfg_info) * list(map(str, range(trajectory_count)))).reset_index(names=['Trajectory', 'Time Step'])
sns.relplot(
    kind='line',
    data=df,
    x='Time Step', y='Values',
    hue='Trajectory',
    col='Source',
    col_order=[c[1] for c in cfg_info],
)
print('Model-sampled events')

In [None]:
trajectory_count = 5
df = pd.concat([
    *[
        pd.DataFrame(dict(
            IsEvent=False,
            Values=trajectory[:, 0]
        )) for i, trajectory in zip(
            range(trajectory_count),
            evaluation_trajectories[constraint(evaluation_trajectories) <= 0]
        )
    ],
    *[
        pd.DataFrame(dict(
            IsEvent=True,
            Values=trajectory[:, 0]
        )) for i, trajectory in zip(
            range(trajectory_count),
            evaluation_trajectories[constraint(evaluation_trajectories) > 0]
        )
    ],
], axis=0, keys=map(str, range(2 * trajectory_count))).reset_index(names=['Trajectory', 'Time Step'])
sns.relplot(
    kind='line',
    data=df,
    x='Time Step', y='Values',
    hue='Trajectory',
    col='IsEvent',
)
print('Data')

In [None]:
df = pd.concat([
    pd.Series(constraint(info['samples']), name=source)
    for (_, source), info in cfg_info.items()
], axis=1).melt(var_name='Source', value_name='Constraint Value')
df_data = pd.DataFrame({'Source': 'Data', 'Constraint Value': constraint(splits['train'])})
bins = np.histogram(np.zeros(2), bins=128, range=pd.concat((df, df_data))['Constraint Value'].agg(['min', 'max']))[1]
plot = (
    sns.displot(
        data=df,
        stat='density',
        x='Constraint Value',
        col='Source',
        col_order=[c[1] for c in cfg_info],
        hue='Source',
        hue_order=[c[1] for c in cfg_info],
        common_norm=False,
        bins=bins,
        facet_kws=dict(
            # sharey=False,
        )
    )
    .set(yscale='log' if isinstance(reference_cfg.dataset, cs.DatasetFitzHughNagumo) else 'linear')
    .set_titles('')
)
plot.map(
    sns.histplot,
    data=df_data,
    bins=bins,
    stat='density',
    color='tab:grey',
    x='Constraint Value',
    zorder=-1,
).set_xlabels('').set_ylabels('')
for (row, col, hue), data in plot.facet_data():
    ax = plot.axes[row][col]
    ax.axvline(x=0, c='r', ls=':')
    ax.xaxis.set_tick_params(labelbottom=True)
    ax.yaxis.set_tick_params(labelleft=True)
plot.tight_layout()
sns.move_legend(
    plot,
    loc='upper center',
    ncol=len(cfg_info) + 1,
    title='',
    bbox_to_anchor=(.455, 1.06),
    frameon=True,
    fancybox=True,
)

data_hist = np.histogram(df_data['Constraint Value'], bins=bins)[0] / len(df_data)
for (row, col, hue), data in plot.facet_data():
    print(plot.col_names[col])
    model_hist = np.histogram(data['Constraint Value'], bins=bins)[0] / len(data)
    kl_divergence = np.where(data_hist == 0., 0., data_hist * np.log(data_hist / (model_hist + 1e-12)))
    print(kl_divergence.sum())

In [None]:
# plots.save_all_subfigures(plot, f'event_histogram.unconditional.{reference_cfg.dataset.__class__.__name__}')

In [None]:
splits['train'].shape

In [None]:
df = pd.concat([
    pd.Series(constraint(info['event_samples']), name=source)
    for (_, source), info in cfg_info.items()
    if 'event_samples' in info
], axis=1).melt(var_name='Source', value_name='Constraint Value')
# reuse bins from previous plot
data_color = 'tab:gray'
plot = (
    sns.displot(
        data=df,
        stat='density',
        x='Constraint Value',
        row='Source',
        row_order=[c[1] for c, info in cfg_info.items() if 'event_samples' in info],
        # row_order=['Data', *(c[1] for c, info in cfg_info.items() if 'event_samples' in info)],
        hue='Source',
        hue_order=[*(c[1] for c, info in cfg_info.items() if 'event_samples' in info), 'Data'],
        palette=[*sns.color_palette()[:3], data_color],
        common_norm=False,
        bins=bins,
        facet_kws=dict(
            # sharex=True
        ),
        height=1.8,
        aspect=2.2,
    )
    .set(yscale='log' if isinstance(reference_cfg.dataset, cs.DatasetFitzHughNagumo) else 'linear')
    .set_titles('')
)
df_data = pd.DataFrame({'Source': 'Data', 'Constraint Value': constraint(splits['train'][constraint(splits['train']) > 0])})
plot.map(
    sns.histplot,
    data=df_data,
    bins=bins,
    stat='density',
    color=data_color,
    x='Constraint Value',
    zorder=-1,
).set_xlabels('').set_ylabels('')
for (row, col, hue), data in plot.facet_data():
    ax = plot.axes[row][col]
    ax.axvline(x=0, c='r', ls=':')
    ax.xaxis.set_tick_params(labelbottom=True)
    ax.yaxis.set_tick_params(labelleft=True)
    if row != len(plot.row_names) - 1:
        ax.xaxis.set_visible(False)
plot.tight_layout()
sns.move_legend(
    plot,
    loc='upper center',
    ncol=len(cfg_info) + 1,
    title='',
    bbox_to_anchor=(.455, 1.06),
    frameon=True,
    fancybox=True,
)

data_hist = np.histogram(df_data['Constraint Value'], bins=bins, density=True)[0]
for (row, col, hue), data in plot.facet_data():
    print(plot.row_names[row])
    model_hist = np.histogram(data['Constraint Value'], bins=bins, density=True)[0]
    kl_divergence = np.where(data_hist == 0., 0., data_hist * np.log(data_hist / model_hist))
    print(kl_divergence.sum())

In [None]:
plots.save_all_subfigures(plot, f'event_histogram.conditional.{reference_cfg.dataset.__class__.__name__}')

In [None]:
print('Event Likelihood: Direct Monte-Carlo')
for (_, source), info in ((('', 'Data'), {'samples': splits['train']}), *cfg_info.items()):
    is_event = constraint(info['samples']) > 0
    print(f'{source}: P(E) = {is_event.mean():.3f}+-{is_event.std()/jnp.sqrt(len(is_event)):.3f}')

In [None]:
key, key_nll = jax.random.split(key)
for (_, source), info in cfg_info.items():
    x_noise, nll_no_div, nll = info['jax_lightning'].compute_nll(key_nll, 1., evaluation_trajectories[:10])
    print(f'{source=}, {nll_no_div.mean()=}, {nll.mean()=}, {x_noise.mean()=}, {x_noise.std()=}')