In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Import

In [None]:
import sys
from logging import INFO, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

In [None]:
import os
import math
import pathlib

import torch
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt

from src.frameworks.ddpm import DDPMConfig, DDPM

plt.rcParams["font.family"] = "serif"
plt.style.use("tableau-colorblind10")

# Define constants

In [None]:
ROOT_DIR = pathlib.Path(os.environ["PYTHONPATH"].split(":")[0]).resolve()
DL_DATA_DIR = str(ROOT_DIR / "data" / "DL_data" / "lorenz96")
os.makedirs(DL_DATA_DIR, exist_ok=True)

# Make a DDPM instance

In [None]:
mu = 3.0
sigma = np.sqrt(2.0 * mu)

std_init = 2.0
decay_init = 3.0

config = DDPMConfig(
    start_beta=1e-3,
    end_beta=5e1,
    n_timesteps=1_000,
    n_channels=1,
    n_spaces=16,
)


class ExactScoreFunc(torch.nn.Module):

    def __init__(
        self,
        mu: float,
        sigma: float,
        n_t: int,
        decay_init: float,
        std_init: float,
    ):
        super().__init__()

        times = torch.linspace(0, 1, n_t + 1, dtype=torch.float32)
        self.m_t = torch.exp(-mu * times)[1:]
        self.v_t = (sigma**2 / (2 * mu) * (1.0 - torch.exp(-2 * mu * times)))[1:]
        self.s_t = torch.sqrt(self.v_t)

        self.decay_init = decay_init
        self.std_init = std_init
        self.out_channel = 1

    def forward(self, yt: torch.Tensor, t_index: int, **kwargs) -> torch.Tensor:
        # score (\nabla log p)
        m_t = torch.index_select(self.m_t, dim=0, index=t_index)
        s_t = torch.index_select(self.s_t, dim=0, index=t_index)
        v_t = torch.index_select(self.v_t, dim=0, index=t_index)
        tot_v_t = self.std_init**2 * m_t**2 + v_t
        score = -(yt - self.decay_init * m_t[:, None, None]) / tot_v_t[:, None, None]
        return (-s_t[:, None, None] * score).to(torch.float32)


score_network = ExactScoreFunc(
    mu=mu, sigma=sigma, n_t=config.n_timesteps, std_init=std_init, decay_init=decay_init
)

ddpm = DDPM(config=config, neural_net=score_network)

# Check variance preserving

In [None]:
plt.plot(ddpm.decays, label="decays")
plt.plot(ddpm.stds, label="stds")
vars = ddpm.decays**2 + ddpm.stds**2
plt.plot(vars, label="vars")
plt.axhline(0.0, ls="--")
plt.xlabel("Num. of Time Steps")
plt.legend()
plt.show()

# Forward process

In [None]:
b = 10_000

y0 = score_network.decay_init + score_network.std_init * torch.randn(
    b, config.n_channels, config.n_spaces
)


fig, axes = plt.subplots(2, 6, sharex=True, sharey=True, figsize=(15, 6))
for t, ax in zip(torch.arange(-1, config.n_timesteps + 1, 100), axes.flatten()):
    if t == -1:
        yt = y0.detach().clone()
    else:
        noise = torch.randn_like(y0)
        yt = ddpm._forward_sample_y(
            y0=y0,
            t_index=torch.full((b,), fill_value=t, dtype=torch.long),
            noise=noise,
        )
    data = yt[:, 0, 0].numpy().flatten()
    ax.hist(data, range=(-5, 5), bins=50, density=True)
    ax.set_xlim(-5, 5)
    ax.set_title(f"t={t+1}\nmean={np.mean(data):.2f},std={np.std(data):.2f}")
    ax.set_ylabel("PDF")
    ax.axvline(0, color="k")
plt.tight_layout()
plt.show()

# Backward process

In [None]:
ddpm.net.eval()
lst_img = ddpm.backward_sample_y(n_batches=10_000, n_return_steps=10)

In [None]:
fig, axes = plt.subplots(2, 6, sharex=True, sharey=True, figsize=(15, 6))

for t, ax in zip(np.arange(0, config.n_timesteps + 1, 100), axes.flatten()):
    data = lst_img[int(t)][:, 0, 0].numpy().flatten()
    ax.hist(data, range=(-5, 5), bins=50, density=True)
    ax.set_xlim(-5, 5)
    ax.set_title(f"t={t}\nmean={np.mean(data):.2f},std={np.std(data):.2f}")
    ax.set_ylabel("PDF")
    ax.axvline(0, color="k")
plt.tight_layout()
plt.show()