# Import

In [None]:
import sys
from logging import INFO, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

In [None]:
import abc
import copy
import dataclasses
import math
import os
import pathlib
import random
import time
import typing
from functools import partial
from typing import Callable, Iterable, List, Literal, Optional, Sequence, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import xarray as xr
from einops import rearrange
from torch import Tensor, nn
from torch.amp import GradScaler
from torch.distributions import Categorical, Normal
from torch.optim import Optimizer
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm


plt.rcParams["font.family"] = "serif"
plt.style.use("tableau-colorblind10")

# Examine SDE and diffusion model

- 理論ノートで扱ってきた線形 SDE はしばしば Ornstein-Uhlenbeck 過程 (OU 過程) と呼ばれる [wikipedia](https://en.wikipedia.org/wiki/Ornstein%E2%80%93Uhlenbeck_process)
$$
dx = -\mu x dt + \sigma dW
$$
- 初期分布を混合正規分布に取る (二峰分布)
- そして時間発展を SDE を解くことで求める

## Run SDE

In [None]:
# Mixed Gaussian
init_means = [4.0, -3.0]
init_stds = [2.0, 1.0]
gauss_weights = [0.6, 0.4]

b = 2_000  # n_batches, number of paths
n_steps = 1_000
dt = 1.0 / n_steps

# OU process: dx = -mu x dt + sigma dW
mu = 5.0
sigma = np.sqrt(2.0 * mu)

idxs = Categorical(torch.tensor(gauss_weights)).sample((b,))
y0 = torch.normal(
    mean=torch.tensor(init_means)[idxs], std=torch.tensor(init_stds)[idxs]
)
dW = math.sqrt(dt) * torch.randn(size=(b, n_steps))

yt = torch.zeros((b, n_steps + 1))
current = y0
yt[:, 0] = current.detach().clone()

for i in range(n_steps):
    current = current - mu * current * dt + sigma * dW[:, i]
    yt[:, i + 1] = current.detach().clone()

ts = torch.linspace(0, 1, n_steps + 1)

plt.rcParams["font.size"] = 14
fig = plt.plot()
ax = plt.subplot()

interval = 4
for i in range(100):
    ax.plot(ts[::interval], yt[i][::interval], lw=0.5, alpha=0.5)

ax.set_xlabel(r"Time, $t$")
ax.set_ylabel(r"State, $x$")
plt.show()

plt.rcParams["font.size"] = 14
fig, axes = plt.subplots(2, 6, sharex=True, sharey=True, figsize=(20, 6))

for t, ax in zip([0, 1, 2, 5, 10, 20, 50, 100, 200, 500, 750, 1000], axes.flatten()):
    data = yt[:, t].numpy().flatten()
    ax.hist(data, range=(-10, 10), bins=100, density=True, alpha=1.0)
    ax.set_xlim(-10, 10)
    ax.set_title(f"Timestep={t:04}/{n_steps}")
    ax.set_ylabel("PDF")
    ax.axvline(0, color="gray")

plt.tight_layout()
plt.show()

del y0, yt, ts, mu, sigma, idxs
del init_means, init_stds, gauss_weights
del n_steps, dt, current, b, interval

- 二峰を持つ確率分布が，単峰の標準正規分布に発展する様子が確認される

## Define diffusion framework

In [None]:
@dataclasses.dataclass()
class DDPMConfig:
    start_beta: float
    end_beta: float
    n_timesteps: int
    n_channels: int
    n_spaces: int

In [None]:
class DDPM(nn.Module):

    def __init__(
        self,
        config: DDPMConfig,
        neural_net: nn.Module,
        device: torch.device = torch.device("cpu"),
    ):
        super().__init__()

        self.dtype = torch.float32
        self.device = device
        self.c = copy.deepcopy(config)
        self.net = neural_net
        self._set_noise_schedule()

    def _set_noise_schedule(self):
        to_torch = partial(torch.tensor, dtype=self.dtype, device=self.device)

        betas = _make_beta_schedule(
            schedule="linear",
            start=self.c.start_beta,
            end=self.c.end_beta,
            n_timesteps=self.c.n_timesteps,
        )
        times = np.linspace(
            0.0, 1.0, num=len(betas) + 1, endpoint=True, dtype=np.float64
        )
        times = times[1:]  # skip the initial value
        assert len(times) == len(betas) == self.c.n_timesteps

        self.dt = 1.0 / float(self.c.n_timesteps)
        self.sqrt_dt = math.sqrt(self.dt)

        # variance-preserving SDE
        frictions = 0.5 * betas
        sigmas = np.sqrt(betas)

        decays, vars = _precompute_ou(mu=frictions, sigma=sigmas, dt=self.dt)
        stds = np.sqrt(vars)
        # the OU solution is expressed as x_t = decay * x_0 + std * epsilon (epsilon ~ N(0,1))

        # the number of elements in each param is equal to self.c.n_timesteps
        self.register_buffer("frictions", to_torch(frictions))
        self.register_buffer("sigmas", to_torch(sigmas))
        self.register_buffer("times", to_torch(times))

        # Register params except for the initial values because std is initially zero
        # Later, std is used as denominator to convert noise into the score function.
        self.register_buffer("decays", to_torch(decays[1:]))
        self.register_buffer("stds", to_torch(stds[1:]))

        assert (
            self.frictions.shape
            == self.sigmas.shape
            == self.times.shape
            == self.decays.shape
            == self.stds.shape
            == (self.c.n_timesteps,)
        )
        assert torch.all(self.sigmas > 0.0) and torch.all(self.stds > 0.0)

    def _extract_params(
        self, params: torch.Tensor, t_indices: torch.Tensor, for_broadcast: bool = True
    ) -> torch.Tensor:

        def select(array):
            return torch.index_select(array, dim=0, index=t_indices)
            # Select diffusion times along batch dim

        (n_batches,) = t_indices.shape

        selected = select(params)
        assert selected.shape == (n_batches,)

        # add channel and space dims
        if for_broadcast:
            return selected.requires_grad_(False)[:, None, None]
        else:
            return selected.requires_grad_(False)

    def _forward_sample_y(
        self, y0: torch.Tensor, t_index: torch.Tensor, noise: torch.Tensor
    ) -> torch.Tensor:
        #
        a = self._extract_params(self.decays, t_index)
        b = self._extract_params(self.stds, t_index)
        return a * y0 + b * noise

    @torch.no_grad()
    def _backward_sample_y(
        self,
        yt: torch.Tensor,
        t_index: torch.Tensor,
        y_cond: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:

        friction = self._extract_params(self.frictions, t_index)
        sigma = self._extract_params(self.sigmas, t_index)
        std = self._extract_params(self.stds, t_index)
        t = self._extract_params(self.times, t_index, for_broadcast=False)
        t = t[:, None]  # add channel dim

        est_noise = self.net(yt=yt, y_cond=y_cond, t=t, t_index=t_index)
        score = -est_noise / std

        mean = yt + self.dt * (friction * yt + (sigma**2) * score)
        dW = self.sqrt_dt * torch.randn_like(yt)

        n_batches = yt.shape[0]
        mask = (1 - (t_index == 0).float()).reshape(n_batches, *((1,) * (yt.ndim - 1)))
        mask = mask.to(dtype=self.dtype, device=self.device)
        # no noise at t_index == 0

        return mean + mask * sigma * dW

    # public methods

    @torch.no_grad()
    def backward_sample_y(
        self,
        n_batches: int,
        y_cond: Optional[torch.Tensor] = None,
        n_return_steps: Optional[int] = None,
        tqdm_disable: bool = False,
    ) -> dict[int, torch.Tensor]:
        assert not self.net.training

        size = (n_batches, self.c.n_channels, self.c.n_spaces)
        yt = torch.randn(size=size, device=self.device)
        yt = self.stds[-1] * yt

        if n_return_steps is not None:
            interval = self.c.n_timesteps // n_return_steps

        intermidiates: dict[int, torch.Tensor] = {}

        for i in tqdm(
            reversed(range(0, self.c.n_timesteps)),
            total=self.c.n_timesteps,
            disable=tqdm_disable,
        ):
            if interval is not None and (i + 1) % interval == 0:
                intermidiates[i + 1] = yt.detach().clone().cpu()

            index = torch.full((n_batches,), i, device=self.device, dtype=torch.long)
            yt = self._backward_sample_y(yt=yt, y_cond=y_cond, t_index=index)

        intermidiates[0] = yt.detach().clone().cpu()

        return intermidiates

    def forward(
        self, y0: torch.Tensor, y_cond: Optional[torch.Tensor] = None, **kwargs
    ) -> tuple[torch.Tensor, torch.Tensor]:
        assert y0.ndim == 3  # batch, channel, space
        assert y0.shape[1] == self.c.n_channels
        assert y0.shape[2] == self.c.n_spaces

        b = y0.shape[0]
        t_index = torch.randint(0, self.c.n_timesteps, (b,), device=self.device).long()

        noise = torch.randn_like(y0)

        yt = self._forward_sample_y(y0=y0, t_index=t_index, noise=noise)
        t = self._extract_params(self.times, t_index, for_broadcast=False)
        t = t[:, None]  # add channel dim
        noise_hat = self.net(yt=yt, y_cond=y_cond, t=t, t_index=t_index)

        return noise, noise_hat


def _make_beta_schedule(
    schedule: str,
    start: float,
    end: float,
    n_timesteps: int,
) -> np.ndarray:
    if schedule == "linear":
        betas = np.linspace(start, end, n_timesteps, dtype=np.float64, endpoint=True)
    else:
        raise NotImplementedError(f"Not supported: {schedule=}")
    return betas


def _precompute_ou(
    mu: np.ndarray,
    sigma: np.ndarray,
    dt: float | np.ndarray,
    init_variance: float = 0.0,
) -> tuple[np.ndarray, np.ndarray]:
    """Method to compute the mean and variance for OU process.
    OU process: dx = -mu x dt + sigma dW
    """
    mu = np.array(mu, dtype=np.float64)
    assert np.all(mu >= 0.0)

    sigma = np.array(sigma, dtype=np.float64)
    assert np.all(sigma >= 0.0)

    if isinstance(dt, float):
        dt = np.full_like(mu, dt, dtype=np.float64)
    else:
        dt = np.array(dt, dtype=np.float64)
    assert mu.shape == sigma.shape == dt.shape
    assert init_variance >= 0.0

    N = mu.size
    m = np.empty(N + 1, dtype=np.float64)  # mean
    v = np.empty(N + 1, dtype=np.float64)  # variance
    m[0] = 1.0
    v[0] = init_variance

    for n in range(N):
        decay = np.exp(-mu[n] * dt[n])
        m[n + 1] = decay * m[n]
        if mu[n] == 0.0:
            q = sigma[n] ** 2 * dt[n]
        else:
            q = sigma[n] ** 2 * (1.0 - decay**2) / (2.0 * mu[n])
        v[n + 1] = decay**2 * v[n] + q

    return np.array(m), np.array(v)

## Make a diffusion model instance

- 分散保存型 (Variance-Preserving; VP) の拡散モデルを扱う．
- 順過程は下の様に書ける
$$
\begin{align}
dx_t &= - \mu_t x \; dt + \sqrt{2\mu_t} \; dW \quad (t \in [0,1]) \\
\end{align}
$$
- 混合正規分布の FPE の解は解析的に書ける
$$
\begin{align}
  p(x,0) &= \sum w_i {\cal N}(x; m_i(0), s_i^2(0)) \\
  p(x,t) &= \sum w_i {\cal N}(x; m_i(t), s_i^2(t)) \\
  m_i(t) &= m_i(0) e^{-\mu t} \\
  s_i^2(t) &= s_i^2(0) e^{-2\mu t} + (1 - e^{-2\mu t})
\end{align}
$$
- この結果を利用して，厳密にスコア関数を与える
$$
\begin{align}
  \text{score function} &= \frac{\partial}{\partial x} \ln p(x,t) \\
  &= -\frac{1}{\sum w_i {\cal N}(x; m_i(t), s_i^2(t))} \sum_i \left[ w_i {\cal N}(x; m_i(t), s_i^2(t)) \frac{x-m_i(t)}{s_i^2(t)} \right]
\end{align}
$$
- スコア関数が厳密に分かれば，逆 SDE の積分が可能になる
- 拡散モデルのフレームワークに厳密なスコア関数を与えるインスタンスを代入している

In [None]:
# OU process: dx = -mu x dt + sigma dW
mu = 5.0
sigma = np.sqrt(2.0 * mu)

# Mixed Gaussian
init_means = [4.0, -3.0]
init_stds = [2.0, 1.0]
gauss_weights = [0.6, 0.4]

config = DDPMConfig(
    start_beta=2 * mu,  # 始点と終点の beta の設定値を一定にすると，上の mu も一定となる
    end_beta=2 * mu,  # mu = beta / 2 (i.e., beta = 2 mu) の関係あり
    n_timesteps=1_000,
    n_channels=1,
    n_spaces=1,
)


class ExactScoreFunc(torch.nn.Module):

    def __init__(
        self,
        mu: float,
        sigma: float,
        n_t: int,
        init_means: list[float],
        init_stds: list[float],
        gauss_weights: list[float],
    ):
        super().__init__()

        times = torch.linspace(0, 1, n_t + 1, dtype=torch.float32)

        # 線形 SDE (OU 過程) の解より得られる分散などの解析形
        # 分散保存系の設定だと sigma**2 / (2 * mu) == 1 になる
        self.m_t = torch.exp(-mu * times)[1:]
        self.v_t = (sigma**2 / (2 * mu) * (1.0 - torch.exp(-2 * mu * times)))[1:]
        self.s_t = torch.sqrt(self.v_t)

        assert len(init_means) == len(init_stds) == len(gauss_weights)
        self.init_means = torch.tensor(init_means)
        self.init_stds = torch.tensor(init_stds)
        self.gauss_weights = torch.tensor(gauss_weights)

    def potential(self, yt: torch.Tensor, t_index: int):
        assert yt.ndim == 3  # batch, channel, and space
        m_t = torch.index_select(self.m_t, dim=0, index=t_index)[:, None, None]
        v_t = torch.index_select(self.v_t, dim=0, index=t_index)[:, None, None]
        # add channel and space dims

        # Calculate prob for the mixed Gaussian
        probs = torch.zeros_like(yt)
        for i in range(self.init_means.shape[0]):
            mean = self.init_means[i] * m_t
            var = self.init_stds[i] ** 2 * m_t**2 + v_t
            p = torch.exp(-0.5 * (yt - mean) ** 2 / var) / torch.sqrt(2 * math.pi * var)
            probs += self.gauss_weights[i] * p

        return -torch.log(probs)

    def score(self, yt: torch.Tensor, t_index: int) -> torch.Tensor:
        assert yt.ndim == 3  # batch, channel, and space
        m_t = torch.index_select(self.m_t, dim=0, index=t_index)[:, None, None]
        v_t = torch.index_select(self.v_t, dim=0, index=t_index)[:, None, None]
        # add channel and space dims

        # Calculate prob for the mixed Gaussian
        probs = torch.zeros_like(yt)
        for i in range(self.init_means.shape[0]):
            mean = self.init_means[i] * m_t
            var = self.init_stds[i] ** 2 * m_t**2 + v_t
            p = torch.exp(-0.5 * (yt - mean) ** 2 / var) / torch.sqrt(2 * math.pi * var)
            probs += self.gauss_weights[i] * p

        scores = torch.zeros_like(yt)
        for i in range(self.init_means.shape[0]):
            mean = self.init_means[i] * m_t
            var = self.init_stds[i] ** 2 * m_t**2 + v_t
            p = torch.exp(-0.5 * (yt - mean) ** 2 / var) / torch.sqrt(2 * math.pi * var)
            scores = scores - (self.gauss_weights[i] * p / probs) * (yt - mean) / var

        return scores

    def forward(self, yt: torch.Tensor, t_index: int, **kwargs) -> torch.Tensor:
        assert yt.ndim == 3  # batch, channel, and space
        scores = self.score(yt, t_index)
        s_t = torch.index_select(self.s_t, dim=0, index=t_index)[:, None, None]
        # add channel and space dims
        return (-s_t * scores).to(torch.float32)  # estimate noise


score_network = ExactScoreFunc(
    mu=mu,
    sigma=sigma,
    n_t=config.n_timesteps,
    init_means=init_means,
    init_stds=init_stds,
    gauss_weights=gauss_weights,
)

ddpm = DDPM(config=config, neural_net=score_network)

## Check variance preserving

- 分散保存型であることの確認
$$
{\rm Var}\left[x_t^2\right]= {\rm Var}\left[x_0^2\right]  + \left(1 - e^{-2 \mu t}\right)
$$
- 第二項が拡散モデル内部で用いるノイズの標準偏差を表す
- 初期データ $x_0$ の分散が 1 に規格化されていると仮定すれば，全分散は常に 1 となる

In [None]:
plt.rcParams["font.size"] = 14
plt.plot(ddpm.decays, label=r"${\rm decay} = e^{-\mu t}$")
plt.plot(ddpm.stds, label=r"${\rm std} = \sqrt{1-e^{-2\mu t}}$")
vars = ddpm.decays**2 + ddpm.stds**2
plt.plot(vars, label="total vars")
plt.axhline(0.0, ls="--")
plt.xlabel("Num. of Time Steps")
plt.legend()
plt.show()
del vars

## Forward process

- 順過程では，二峰の分布が単峰の標準ガウス分布へ緩和する
- 緩和が 300 ステップ程度で完了し，残りの 300 - 1000 ステップではほとんど確率分布が変化していない
- これは SDE の摩擦項により時間に関して指数関数的な緩和を施しているため

In [None]:
b = 10_000
idxs = Categorical(torch.tensor(gauss_weights)).sample((b,))
y0 = torch.normal(
    mean=torch.tensor(init_means)[idxs], std=torch.tensor(init_stds)[idxs]
)[:, None, None]

plt.rcParams["font.size"] = 14
fig, axes = plt.subplots(2, 6, sharex=True, sharey=True, figsize=(20, 6))
for t, ax in zip(torch.arange(-1, config.n_timesteps + 1, 100), axes.flatten()):
    if t == -1:
        yt = y0.detach().clone()
    else:
        noise = torch.randn_like(y0)
        yt = ddpm._forward_sample_y(
            y0=y0,
            t_index=torch.full((b,), fill_value=t, dtype=torch.long),
            noise=noise,
        )
    data = yt[:, 0, 0].numpy().flatten()
    ax.hist(data, range=(-10, 10), bins=100, density=True)
    ax.set_xlim(-10, 10)
    ax.set_title(f"t={t+1}\nmean={np.mean(data):.2f},std={np.std(data):.2f}")
    ax.set_ylabel("PDF")
    ax.axvline(0, color="k")
plt.tight_layout()
plt.show()

del b, y0, yt, data, idxs

## Backward process

- 逆過程では，正規分布から始まり，混合正規分布へと発展させる
- 混合正規分布の FPE の解は解析的に書ける
- この結果を利用して，厳密にスコア関数を与えて，逆 SDE を解いている
- 逆過程は，厳密に順過程の逆回しになる
- そのため，二峰性が現れるのが $t$ が 0 に十分近づいてからになる

In [None]:
ddpm.net.eval()
lst_img = ddpm.backward_sample_y(n_batches=10_000, n_return_steps=config.n_timesteps)

In [None]:
plt.rcParams["font.size"] = 14
fig, axes = plt.subplots(2, 6, sharex=True, sharey=True, figsize=(20, 6))

for t, ax in zip([0, 1, 2, 5, 10, 20, 50, 100, 200, 500, 750, 1000], axes.flatten()):
    data = lst_img[int(t)][:, 0, 0].numpy().flatten()
    ax.hist(data, range=(-10, 10), bins=100, density=True, alpha=0.3)
    ax.set_xlim(-10, 10)
    ax.set_title(f"Timestep={t:04}/{config.n_timesteps}")
    ax.set_ylabel("PDF")
    ax.axvline(0, color="gray")

    for x in init_means:
        ax.axvline(x, ls="--", lw=0.5, alpha=0.5, color="gray")

    yt = torch.linspace(-10, 10, 101)[:, None, None]
    if t < config.n_timesteps:
        # add channel and space dim
        dummy_t = torch.full(size=(101,), fill_value=t).long()
        potential = score_network.potential(yt, dummy_t)
    else:
        potential = 0.5 * yt**2

    ax2 = ax.twinx()
    ax2.plot(yt.squeeze(), potential.squeeze(), ls="-", lw=2.0, color="coral")
    ax2.set_ylim(0, 20)

plt.tight_layout()
plt.show()


ts = sorted(lst_img.keys())
yt = torch.stack([lst_img[t][:100].squeeze() for t in ts], dim=1)

plt.rcParams["font.size"] = 14
fig = plt.plot()
ax = plt.subplot()

interval = 4
for i in range(100):
    ax.plot(ts[::interval] / np.max(ts), yt[i][::interval], lw=0.5, alpha=0.5)

ax.set_xlabel(r"Time, $t$")
ax.set_ylabel(r"State, $x$")

plt.show

del ts, yt, t, potential, dummy_t

- 上の図では確率分布のポテンシャル $U(x,t) = -\ln p(x,t)$ を線で描いている
- 各パスを調べてみると，単峰の正規分布から二峰の混合正規分布に分かれている
- 逆 SDE を解いているため，時刻は $t=1.0$ から $t=0.0$ に向かっている
- SDE は Euler-Maruyama 法で解いている

# Make Lorenz96 data

- Lorenz96 モデルは地球大気の東西運動を表す理想モデル [wikipedia](https://en.wikipedia.org/wiki/Lorenz_96_model)
$$
\frac{dx_i}{dt} = (x_{i+1}-x_{i-2})x_{i-1} - x_i + F
$$
- $F$ は上の設定で $F=8$ としている．これはカオスレジームとして知られる典型的な設定
- $i$ が一次元上の格子点位置を示す
- 右辺の第一項が移流項，第二項が摩擦，第三項 $F$ が強制となる
- 平衡解として，すべての $i$ に対して $x_i = F$ が存在し，これは $F$ が十分大きいと不安定

## Preferences

In [None]:
ROOT_DIR = pathlib.Path(".").resolve()
DL_DATA_DIR = str(ROOT_DIR / "data" / "DL_data" / "lorenz96")
os.makedirs(DL_DATA_DIR, exist_ok=True)

In [None]:
N_BATCHES = 5_000  # 作成するデータセット数
N_SPACES = 32  # 空間の格子点数
N_TIMES = 5_000  # 時間ステップ数

FORCING = 8.0
AMP_PERTURBATION = 0.01
DT = 0.005
SEED = 42

## Methods

In [None]:
def integrate_lorenz96(x0: Tensor, forcing: float, n_steps: int, dt: float) -> Tensor:

    assert isinstance(x0, Tensor) and x0.ndim == 2  # batch and space
    assert isinstance(forcing, float)
    assert isinstance(n_steps, int) and n_steps > 0
    assert isinstance(dt, float) and dt > 0.0

    current = x0.clone().detach()
    states = [current.clone().detach()]

    for _ in tqdm(range(n_steps)):
        rhs = _lorenz96_rhs(x=current, forcing=forcing)
        current = current + dt * rhs
        states.append(current.clone().detach())

    return torch.stack(states, dim=1).cpu()  # stack along time dim


def _lorenz96_rhs(x: Tensor, forcing: float) -> Tensor:

    a = x.roll(shifts=-1, dims=1)
    b = x.roll(shifts=2, dims=1)
    c = x.roll(shifts=1, dims=1)
    dxdt = (a - b) * c - x + forcing

    return dxdt


def set_seeds(seed: int = 42, use_deterministic: bool = True) -> None:
    try:
        os.environ["PYTHONHASHSEED"] = str(seed)
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)

        if use_deterministic:
            torch.use_deterministic_algorithms(True, warn_only=True)

        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    except Exception as e:
        logger.error(e)

## Integrate Lorenz96

In [None]:
dtype = torch.float32
device = torch.device("cpu")

set_seeds(SEED)
x0 = FORCING * torch.ones(size=(N_BATCHES, N_SPACES), dtype=dtype, device=device)
x0 += torch.randn_like(x0) * AMP_PERTURBATION

states = integrate_lorenz96(x0=x0, forcing=FORCING, n_steps=N_TIMES, dt=DT)

if torch.any(torch.isnan(states)):
    logger.warning("NaNs appear.")
elif torch.any(~torch.isfinite(states)):
    logger.warning("Infs appear")
else:
    logger.info("Integration was successfully finished.")

del x0, dtype, device

## Plot results

In [None]:
plt.rcParams["font.size"] = 14
fig, axes = plt.subplots(1, 3, sharex=True, sharey=True, figsize=[10, 4])

for i, ax in enumerate(axes.flatten()):
    d = states[i].numpy()[::40][
        -32:
    ]  ## 40 時間ステップごとに抽出し，最後の 32 要素を時間に沿って抽出
    ts = np.arange(d.shape[0]) * 20 * DT
    xs = np.linspace(0, 2 * math.pi, N_SPACES, endpoint=False)
    X, T = np.meshgrid(xs, ts, indexing="ij")
    ret = ax.pcolormesh(
        X, T, d.transpose(), vmin=-9, vmax=9, cmap="coolwarm", shading="nearest"
    )
    cbar = fig.colorbar(ret, ax=ax)
    ax.set_xlabel(r"Space, $x$")
    ax.set_ylabel(r"Time, $t$")
plt.tight_layout()
plt.show()

del d, ts, xs, X, T, ret

## Write out

In [None]:
outs = states[:, ::40][:, -32:]
## 40 時間ステップごとに抽出し，最後の 32 要素を時間に沿って抽出
assert outs.shape == (N_BATCHES, 32, N_SPACES), f"{outs.shape=}"

ts = np.arange(0.0, outs.shape[1]) * (DT * 40)
xs = np.linspace(0, 2 * math.pi, N_SPACES, endpoint=False)

da = xr.DataArray(
    outs.numpy().astype(np.float32),
    dims=["batch", "time", "space"],
    coords={
        "batch": np.arange(N_BATCHES, dtype=np.int32),
        "time": ts.astype(np.float32),
        "space": xs.astype(np.float32),
    },
    name="lorenz96_trajectory",
    attrs={
        "forcing": FORCING,
        "dt": DT,
        "seed": SEED,
        "amp_perturb": AMP_PERTURBATION,
        "output_time_interval": 40,
    },
)

p = f"{DL_DATA_DIR}/lorenz96_v00.nc"
da.to_netcdf(path=p)

del states, outs, ts, xs, da

# Tran diffusion model

- 分散保存型 (Variance-Preserving; VP) の拡散モデルを扱う．
- 順過程は下の様に書ける
$$
\begin{align}
dx_t &= -\frac{1}{2} \beta_t x \; dt + \sqrt{\beta_t} \; dW \quad (t \in [0,1]) \\
\beta_t &= \beta_{\rm start} + t \beta_{\rm end}
\end{align}
$$
- 理論ノートに合わせるため，$\beta_{\rm start}=\beta_{\rm end}$ に設定

## Preferences

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"{DEVICE=}")
ROOT_DIR = pathlib.Path(".").resolve()
DL_DATA_DIR = str(ROOT_DIR / "data" / "DL_data" / "lorenz96")
DL_DATA_FILE = str(ROOT_DIR / "data" / "DL_data" / "lorenz96" / "lorenz96_v00.nc")
DL_RESULT_DIR = str(ROOT_DIR / "data" / "DL_model" / "lorenz96_v00")
os.makedirs(DL_RESULT_DIR, exist_ok=True)

## Make config

In [None]:
@dataclasses.dataclass()
class ExperimentLorenz96Config:
    batch_size: int
    loss_name: str
    learning_rate: float
    #
    n_features: int
    list_channel: list[int]
    #
    total_epochs: int
    save_interval: int
    use_auto_mix_precision: bool
    ddpm: DDPMConfig

In [None]:
config = ExperimentLorenz96Config(
    # Training settings
    batch_size=100,
    loss_name="L2",
    learning_rate=1e-3,
    total_epochs=40,
    save_interval=4,
    use_auto_mix_precision=False,
    # For U-Net
    n_features=32,
    list_channel=[1, 2, 4],
    # For DDPM
    ddpm=DDPMConfig(
        start_beta=1e1,  # start_beta == end_beta の設定により，beta を定数にする
        end_beta=1e1,
        n_timesteps=1_000,
        n_channels=32,  # データの時間ステップ数をチャネル数として指定
        n_spaces=32,  # 空間格子点数．dont change n_channels and n_spaces
    ),
)

## Make dataloader

In [None]:
class DatasetLorenz96(Dataset):
    def __init__(self, path_to_dataarray: str):

        self.data = xr.load_dataarray(path_to_dataarray)
        assert self.data.dims == ("batch", "time", "space")

        self.n_batch, self.n_times, self.n_spaces = self.data.shape
        self.mean = self.data.mean().item()
        self.std = self.data.std().item()

        self.dtype = torch.float32

    def __len__(self):
        return self.data.shape[0]  # batch dimension

    def standardize(self, data):
        return (data - self.mean) / self.std

    def standardize_inversely(self, data):
        return data * self.std + self.mean

    def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
        data = self.data[idx].values  # time x space
        standardized = self.standardize(data)
        ret = torch.tensor(standardized, dtype=self.dtype)
        assert ret.shape == (self.n_times, self.n_spaces)
        return {"y0": ret}

In [None]:
dataloader = torch.utils.data.DataLoader(
    dataset=DatasetLorenz96(DL_DATA_FILE),
    batch_size=config.batch_size,
    drop_last=True,
    shuffle=True,
    pin_memory=True,
    num_workers=2,
)

### Check data

In [None]:
data = next(iter(dataloader))["y0"]
assert data.shape == (
    config.batch_size,
    config.ddpm.n_channels,
    config.ddpm.n_spaces,
)  # batch, time (=channel), space dims

In [None]:
plt.rcParams["font.size"] = 14
fig, axes = plt.subplots(1, 3, sharex=True, sharey=True, figsize=[10, 4])

for i, ax in enumerate(axes.flatten()):
    d = dataloader.dataset.standardize_inversely(data[i].numpy())
    ts = np.arange(d.shape[0]) * 0.2  # dt = 0.2
    xs = np.linspace(0, 2 * math.pi, 32, endpoint=False)
    X, T = np.meshgrid(xs, ts, indexing="ij")
    ret = ax.pcolormesh(
        X, T, d.transpose(), vmin=-9, vmax=9, cmap="coolwarm", shading="nearest"
    )
    cbar = fig.colorbar(ret, ax=ax)
    ax.set_xlabel(r"Space, $x$")
    ax.set_ylabel(r"Time, $t$")
plt.tight_layout()
plt.show()
del d, ts, xs, X, T, ret, data

## Define unet

### Define blocks

In [None]:
def Downsample1D(
    dim: int, kernel_size: int = 4, padding_mode: str = "zeros"
) -> nn.Module:
    return nn.Conv1d(
        dim,
        dim,
        kernel_size=(kernel_size,),
        stride=(kernel_size // 2,),
        padding=(1,),
        padding_mode=padding_mode,
    )


def Upsample1D(dim: int, kernel_size: int = 4) -> nn.Module:
    return nn.ConvTranspose1d(
        dim, dim, kernel_size=(kernel_size,), stride=(kernel_size // 2,), padding=(1,)
    )


def PeriodicDownsample1D(dim: int, kernel_size: int) -> nn.Module:
    assert kernel_size % 2 == 1, "kernel_size should be odd."
    return nn.Conv1d(
        dim,
        dim,
        kernel_size=(kernel_size,),
        stride=(2,),
        padding=((kernel_size - 1) // 2,),
        padding_mode="circular",
    )


class PeriodicUpsampleConv1d(nn.Module):
    def __init__(
        self,
        in_ch: int,
        kernel_size: int,
        out_ch: Optional[int] = None,
        scale: int = 2,
    ):
        assert kernel_size % 2 == 1, "kernel_size should be odd."
        super().__init__()
        self.scale = scale
        self.pad = (kernel_size - 1) // 2

        out_ch = in_ch if out_ch is None else out_ch

        self.upsample = nn.Upsample(scale_factor=scale, mode="nearest-exact")
        self.conv = nn.Conv1d(
            in_ch,
            out_ch,
            kernel_size=kernel_size,
            padding=self.pad,
            padding_mode="circular",
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.upsample(x)
        x = self.conv(x)
        return x

In [None]:
class RMSNorm1D(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.scale = dim**0.5
        self.gamma = nn.Parameter(torch.ones(dim, 1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Normalize along the channel dimension
        return F.normalize(x, dim=1) * self.scale * self.gamma


class FiLMBlock1D(nn.Module):

    def __init__(self, dim: int, dim_out: int, padding_mode: str):
        super().__init__()
        self.proj = nn.Conv1d(
            dim, dim_out, kernel_size=3, padding=1, padding_mode=padding_mode
        )
        self.norm = RMSNorm1D(dim_out)
        self.act = nn.SiLU()

    def forward(
        self,
        x: torch.Tensor,
        scale_shift: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> torch.Tensor:
        x = self.proj(x)
        x = self.norm(x)

        if scale_shift is not None:
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        return self.act(x)


class ResnetBlock1D(nn.Module):

    def __init__(
        self,
        dim: int,
        dim_out: int,
        padding_mode: str,
        *,
        time_emb_dim: Optional[int] = None,
    ):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2))
            if time_emb_dim is not None
            else None
        )

        self.block1 = FiLMBlock1D(dim, dim_out, padding_mode=padding_mode)
        self.block2 = FiLMBlock1D(dim_out, dim_out, padding_mode=padding_mode)
        self.res_conv = (
            nn.Conv1d(dim, dim_out, kernel_size=1) if dim != dim_out else nn.Identity()
        )

    def forward(
        self, x: torch.Tensor, time_emb: Optional[torch.Tensor] = None
    ) -> torch.Tensor:

        scale_shift = None
        if self.mlp is not None:
            assert time_emb is not None
            emb: torch.Tensor = self.mlp(time_emb)
            scale_shift = rearrange(emb, "b c -> b c 1").chunk(2, dim=1)

        h = self.block1(x, scale_shift=scale_shift)
        h = self.block2(h)
        return h + self.res_conv(x)

In [None]:
class SinusoidalTimeEmbedding(nn.Module):

    def __init__(self, dim: int, time_base: float):
        super().__init__()
        self.dim = dim
        self.time_base = time_base
        logger.info(f"SinusoidalTimeEmbedding: {self.time_base=}")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        device = x.device
        half_dim = self.dim // 2
        emb = torch.log(torch.tensor(self.time_base, device=device)) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

### Define network

In [None]:
class Unet1D(torch.nn.Module):

    def __init__(
        self,
        dim: int,
        in_channels: int,
        out_channels: int,
        padding_mode: Literal["zeros", "circular"] = "zeros",
        dim_mults: Sequence[int] = (1, 2, 4, 8),
        init_dim: Optional[int] = None,
        init_kernel_size: int = 5,
        time_base: float = 1000.0,
    ):
        super().__init__()

        init_dim = dim if init_dim is None else init_dim
        assert isinstance(init_dim, int)
        assert init_kernel_size % 2 == 1, "init kernel size must be odd"

        init_padding = init_kernel_size // 2
        self.init_conv = nn.Conv1d(
            in_channels,
            init_dim,
            kernel_size=init_kernel_size,
            padding=init_padding,
            padding_mode=padding_mode,
        )

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        time_dim = dim * 4
        self.time_mlp = nn.Sequential(
            SinusoidalTimeEmbedding(dim, time_base),
            nn.Linear(dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim),
        )

        self.downs: Iterable[nn.Module] = nn.ModuleList([])
        self.ups: Iterable[nn.Module] = nn.ModuleList([])

        num_resolutions = len(in_out)
        block_class = ResnetBlock1D
        block_class_cond = partial(
            block_class, time_emb_dim=time_dim, padding_mode=padding_mode
        )

        Downsample: Callable[[int], nn.Module] = Downsample1D
        if padding_mode == "circular":
            logger.info("PeriodicDownsample1D is used.")
            Downsample = partial(PeriodicDownsample1D, kernel_size=5)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)
            self.downs.append(
                nn.ModuleList(
                    [
                        block_class_cond(dim_in, dim_out),
                        block_class_cond(dim_out, dim_out),
                        (Downsample(dim_out) if not is_last else nn.Identity()),
                    ]
                )
            )

        mid_dim = dims[-1]
        self.mid_block1 = block_class_cond(mid_dim, mid_dim)
        self.mid_block2 = block_class_cond(mid_dim, mid_dim)

        Upsample: nn.Module | Callable[[int], nn.Module] = Upsample1D
        if padding_mode == "circular":
            logger.info("PeriodicUpsampleConv1d is used.")
            Upsample = partial(PeriodicUpsampleConv1d, kernel_size=5)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind >= (num_resolutions - 1)
            self.ups.append(
                nn.ModuleList(
                    [
                        block_class_cond(dim_out * 2, dim_in),
                        block_class_cond(dim_in, dim_in),
                        Upsample(dim_in) if not is_last else nn.Identity(),
                    ]
                )
            )

        self.final_conv = nn.Sequential(
            block_class(dim * 2, dim, padding_mode=padding_mode),
            nn.Conv1d(dim, out_channels, kernel_size=1),
        )

    def forward(
        self,
        yt: torch.Tensor,
        y_cond: torch.Tensor,  # not used
        t_index: torch.Tensor,
        **kwargs,
    ) -> torch.Tensor:
        # x shape = b c h
        # time shape = b

        yt = self.init_conv(yt)
        r = yt.clone()
        t_index = self.time_mlp(t_index)

        h: List[torch.Tensor] = []

        for downs in self.downs:
            assert isinstance(downs, nn.ModuleList)
            block1, block2, downsample = downs
            yt = block1(yt, t_index)
            yt = block2(yt, t_index)
            h.append(yt)
            yt = downsample(yt)

        yt = self.mid_block1(yt, t_index)
        yt = self.mid_block2(yt, t_index)

        for ups in self.ups:
            assert isinstance(ups, nn.ModuleList)
            block1, block2, upsample = ups
            yt = torch.cat((yt, h.pop()), dim=1)
            yt = block1(yt, t_index)
            yt = block2(yt, t_index)
            yt = upsample(yt)

        yt = torch.cat((yt, r), dim=1)

        return self.final_conv(yt)

## Prepare for training

### Define loss funcs

In [None]:
class CustomLoss(torch.nn.Module, metaclass=abc.ABCMeta):
    def __init__(self):
        super().__init__()

    @abc.abstractmethod
    def forward(
        self, predicts: torch.Tensor, targets: torch.Tensor, masks: torch.Tensor
    ):
        raise NotImplementedError()


def make_loss(loss_name: str) -> CustomLoss:
    if loss_name == "L2":
        logger.info("L2 loss is created.")
        return L2Loss()
    elif loss_name == "L1":
        logger.info("L1 loss is created.")
        return L1Loss()
    else:
        raise ValueError(f"{loss_name} is not supported.")


class L2Loss(CustomLoss):
    def __init__(self):
        super().__init__()
        self.loss = nn.MSELoss()

    def forward(
        self, predicts: torch.Tensor, targets: torch.Tensor, masks: torch.Tensor
    ):
        return self.loss(predicts, targets)


class L1Loss(CustomLoss):
    def __init__(self, **kwargs):
        super().__init__()
        self.loss = nn.L1Loss()

    def forward(
        self, predicts: torch.Tensor, targets: torch.Tensor, masks: torch.Tensor
    ):
        return self.loss(predicts, targets)

### Construct diffusion model

In [None]:
set_seeds(42)

unet = Unet1D(
    dim=config.n_features,
    in_channels=32,  # num of times
    out_channels=32,  # dont change in_channels and out_channels
    padding_mode="circular",
    dim_mults=config.list_channel,
).to(DEVICE)

ddpm = DDPM(config=config.ddpm, neural_net=unet, device=torch.device(DEVICE))

loss_fn = make_loss(loss_name=config.loss_name)
optimizer = torch.optim.AdamW(ddpm.parameters(), lr=config.learning_rate)
scaler = GradScaler()

## Train

### Define training method

In [None]:
class AverageMeter(object):
    def __init__(self):
        super().__init__()
        self.reset()

    def reset(self):
        self.val = 0.0
        self.avg = 0.0
        self.sum = 0.0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def optimize_ddpm(
    *,
    dataloader: DataLoader,
    ddpm: DDPM,
    loss_fn: CustomLoss,
    optimizer: Optimizer,
    epoch: int,
    mode: typing.Literal["train", "valid", "test"],
    scaler: GradScaler,
    use_amp: bool,
) -> float:
    #
    loss_meter = AverageMeter()

    d = next(ddpm.net.parameters()).device
    device = str(d)

    if mode == "train":
        ddpm.net.train()
    elif mode in ["valid", "test"]:
        ddpm.net.eval()
    else:
        raise ValueError(f"{mode} is not supported.")

    random.seed(epoch)
    np.random.seed(epoch)

    device_type = "cuda" if "cuda" in device else "cpu"

    for batch in dataloader:

        for k in batch.keys():
            batch[k] = batch[k].to(device, non_blocking=True)

        if mode == "train":
            optimizer.zero_grad()

            with torch.autocast(
                device_type=device_type, dtype=torch.float16, enabled=use_amp
            ):
                noise, noise_hat = ddpm(**batch)
                loss = loss_fn(predicts=noise_hat, targets=noise, masks=None)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        else:
            with torch.no_grad(), torch.autocast(
                device_type=device_type, dtype=torch.float16, enabled=use_amp
            ):
                noise, noise_hat = ddpm(**batch)
                loss = loss_fn(predicts=noise_hat, targets=noise, masks=None)

        loss_meter.update(loss.item(), n=noise.shape[0])

    return loss_meter.avg

### Run training

In [None]:
start_time = time.time()

all_scores: list[dict] = []
set_seeds(42)

with tqdm(total=config.total_epochs, desc="Training Progress", unit="step") as pbar:
    for _epoch in range(config.total_epochs):
        epoch = _epoch + 1  # 0 から始まるため，1 を足す

        loss = optimize_ddpm(
            dataloader=dataloader,
            ddpm=ddpm,
            loss_fn=loss_fn,
            optimizer=optimizer,
            epoch=epoch,
            mode="train",
            scaler=scaler,
            use_amp=config.use_auto_mix_precision,
        )
        all_scores.append({"epoch": epoch, "loss": loss})

        if epoch % config.save_interval == 0:
            p = f"{DL_RESULT_DIR}/model_weight_{epoch:06}.pth"
            torch.save(ddpm.net.state_dict(), p)

        if epoch % 10 == 0 or epoch == config.total_epochs:
            p = f"{DL_RESULT_DIR}/loss_history.csv"
            pd.DataFrame(all_scores).to_csv(p, index=False)

        pbar.set_postfix({"loss": loss})
        pbar.update(1)

end_time = time.time()
logger.info(f"Finished. Total elapsed time = {(end_time - start_time) / 60.} min")
del epoch, _epoch, loss, all_scores

# Test

In [None]:
epoch = config.total_epochs
p = f"{DL_RESULT_DIR}/model_weight_{epoch:06}.pth"
unet.load_state_dict(torch.load(p, map_location=DEVICE, weights_only=False))
_ = unet.eval()
del epoch, p

## Run sampling

In [None]:
set_seeds(42)
dict_samples = ddpm.backward_sample_y(n_batches=50, n_return_steps=1000)
# n_return_steps == 1000, つまり，1000 ステップを 1000 等分するように，中間状態を返す
# この場合，1 ステップごとに返ってくる

## Plot generated samples

In [None]:
last_samples = dict_samples[0]  # at t == 0

plt.rcParams["font.size"] = 14
fig, axes = plt.subplots(1, 3, sharex=True, sharey=True, figsize=[10, 4])

for i, ax in enumerate(axes.flatten()):
    d = last_samples[i].cpu().numpy()
    d = dataloader.dataset.standardize_inversely(d)
    ts = np.arange(d.shape[0]) * 0.2  # dt = 0.2
    xs = np.linspace(0, 2 * math.pi, 32, endpoint=False)
    X, T = np.meshgrid(xs, ts, indexing="ij")
    ret = ax.pcolormesh(
        X, T, d.transpose(), vmin=-9, vmax=9, cmap="coolwarm", shading="nearest"
    )
    cbar = fig.colorbar(ret, ax=ax)
    ax.set_xlabel(r"Space, $x$")
    ax.set_ylabel(r"Time, $t$")
plt.tight_layout()
plt.show()
del d, ts, xs, X, T, ret

## Plot intermediate states

- ノイズからデータの生成に成功している
- これは非ガウスのデータ分布の学習に成功していることを意味する
- 時間ステップが 0 に近づくと急にデータに空間構造が現れるのは，指数的に減衰する緩和の逆回しをしているため

In [None]:
plt.rcParams["font.size"] = 12
fig, axes = plt.subplots(2, 6, sharex=True, sharey=True, figsize=(15, 6))

for t, ax in zip([1000, 500, 200, 100, 80, 60, 50, 40, 30, 20, 10, 0], axes.flatten()):
    d = dict_samples[t][0].cpu().numpy()
    d = dataloader.dataset.standardize_inversely(d)

    ts = np.arange(d.shape[0]) * 0.2  # dt = 0.2
    xs = np.linspace(0, 2 * math.pi, 32, endpoint=False)
    X, T = np.meshgrid(xs, ts, indexing="ij")

    ret = ax.pcolormesh(
        X, T, d.transpose(), vmin=-9, vmax=9, cmap="coolwarm", shading="nearest"
    )

    cbar = fig.colorbar(ret, ax=ax)
    ax.set_xlabel(r"Space, $x$")
    ax.set_ylabel(r"Time, $t$")
    ax.set_title(f"Diffusion Step = {t}")
plt.tight_layout()
plt.show()
del d, ts, xs, X, T, ret

## Run spectral analysis

- Lorenz96 はカオス系なので，正解データに対して厳密な一致を確認するのは難しい
- そこで，時空間スペクトルを比較し，訓練データと同様の時空間構造が生成されたデータに表れることを確認する

In [None]:
import torch.fft
from scipy.signal import welch


def compute_1d_psds_for_lorenz96(samples: np.ndarray, dt: float, dx: float):
    assert isinstance(samples, np.ndarray)
    (n_batches, n_times, n_spaces) = samples.shape

    # Compute PSD in the time direction
    freqs_time, psd_time = welch(
        samples,
        fs=1 / dt,
        axis=1,
        nperseg=n_times,
        detrend="constant",
        scaling="density",
        window="hamming",
    )
    psd_time_mean = psd_time.mean(axis=(0, 2))

    # Compute PSD in the space direction
    freqs_space, psd_space = welch(
        samples,
        fs=1 / dx,
        axis=2,
        nperseg=n_spaces,
        detrend="constant",
        scaling="density",
        window="hamming",
    )
    psd_space_mean = psd_space.mean(axis=(0, 1))

    return freqs_time, psd_time_mean, freqs_space, psd_space_mean


def compute_2d_psd_for_lorenz96(data: torch.Tensor) -> torch.Tensor:
    assert isinstance(data, torch.Tensor) and data.ndim == 3
    (n_batches, n_times, n_spaces) = data.shape

    # Perform 2D FFT along time and space axes with norm='ortho'
    fft_result = torch.fft.fft2(data, dim=(-2, -1), norm="ortho")

    # Compute PSD and shift zero frequency to center
    psd = torch.abs(torch.fft.fftshift(fft_result, dim=(-2, -1))) ** 2

    # Average PSD over batches
    psd_mean = psd.mean(dim=0)

    return psd_mean

In [None]:
gt_samples = torch.stack(
    [dataloader.dataset[i]["y0"] for i in range(50)], dim=0
).numpy()
my_samples = dict_samples[0].cpu().numpy()  # at t == 0

assert gt_samples.shape == my_samples.shape
assert isinstance(gt_samples, np.ndarray) and isinstance(my_samples, np.ndarray)

In [None]:
xs = dataloader.dataset.data["space"].values
ts = dataloader.dataset.data["time"].values
dt = float(np.mean(np.diff(ts.astype(np.float64))))
dx = float(np.mean(np.diff(xs.astype(np.float64))))

_, gt_psd_time_mean, _, gt_psd_space_mean = compute_1d_psds_for_lorenz96(
    gt_samples, dt=dt, dx=dx
)

freqs_time, my_psd_time_mean, freqs_space, my_psd_space_mean = (
    compute_1d_psds_for_lorenz96(my_samples, dt=dt, dx=dx)
)

gt_psd_2d_mean = compute_2d_psd_for_lorenz96(torch.from_numpy(gt_samples)).numpy()
my_psd_2d_mean = compute_2d_psd_for_lorenz96(torch.from_numpy(my_samples)).numpy()

del xs, ts

In [None]:
xs = np.zeros(32)
xs[-17:] = freqs_space
xs[:15] = -np.flip(freqs_space[1:16])

ys = np.zeros(32)
ys[-17:] = freqs_time
ys[:15] = -np.flip(freqs_time[1:16])

xs, ys = np.meshgrid(xs, ys, indexing="ij")
vmin, vmax = None, None

plt.rcParams["font.size"] = 14
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

for ax, kind, data in zip(
    axes, ["Physics Simulation", "Langevin Sampling"], [gt_psd_2d_mean, my_psd_2d_mean]
):
    d = data.transpose()  # time, space --> space, time

    if vmin is None or vmax is None:
        vmin, vmax = np.quantile(np.log10(d), [0.01, 0.99])
    pm = ax.pcolormesh(xs, ys, np.log10(d), shading="nearest", vmin=vmin, vmax=vmax)
    ax.set_xlabel(r"Wavenumber, $k$")
    ax.set_ylabel(r"Frequency, $\omega$")

    ax.set_title(kind)
    ax.axvline(0, ls="--", color="k", alpha=0.5)
    ax.axhline(0, ls="--", color="k", alpha=0.5)
    fig.colorbar(pm, ax=ax, label=r"$\log_{10}(\mathrm{PSD})$", extend="both")
plt.tight_layout()
plt.show()

plt.rcParams["font.size"] = 16
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

for kind in ["Physics Simulation", "Langevin Sampling"]:

    for t, cmap in zip(
        reversed([1000] + list(np.arange(480, 0, -20)) + [0]),
        plt.cm.turbo(np.linspace(0, 1, 28))[1:-1],
    ):
        if kind == "Physics Simulation" and t != 0:
            continue

        ls, label, lw = "--", None, 1.5
        if kind == "Physics Simulation":
            lw = 4.0
            cmap = "k"
            ls = "-"
            label = r"Phys. Sim. ${\rm Diffusion Time}=0.0$"
        else:
            if t == 1000:
                lw = 4.0
                label = r"Langevin ${\rm Diffusion Time}=1.0$"
            elif t == 0:
                lw = 4.0
                label = r"Langevin ${\rm Diffusion Time}=0.0$"

            _freqs_time, psd_time_mean, _freqs_space, psd_space_mean = (
                compute_1d_psds_for_lorenz96(
                    dict_samples[t].squeeze().numpy(), dt=dt, dx=dx
                )
            )
            assert np.all(_freqs_space == freqs_space)
            assert np.all(_freqs_time == freqs_time)

        d = gt_psd_time_mean if kind == "Physics Simulation" else psd_time_mean
        axes[0].plot(freqs_time, d, lw=lw, color=cmap, ls=ls, label=label)
        axes[0].set_title("PSD in Time Direction")
        axes[0].set_xlabel(r"Frequency, $\omega$ (1/time)")
        axes[0].set_ylabel("Power Spectral Density")
        axes[0].set_xscale("log")
        axes[0].set_yscale("log")
        axes[0].legend(loc="lower left", fontsize=15)

        d = gt_psd_space_mean if kind == "Physics Simulation" else psd_space_mean
        axes[1].plot(freqs_space, d, lw=lw, color=cmap, ls=ls, label=label)
        axes[1].set_title("PSD in Space Direction")
        axes[1].set_xlabel(r"Wavenumber, $k$ (1/space)")
        axes[1].set_ylabel("Power Spectral Density")
        axes[1].set_xscale("log")
        axes[1].set_yscale("log")

plt.tight_layout()
plt.show()

del xs, ys, d
del _freqs_time, psd_time_mean, _freqs_space, psd_space_mean

- ノイズによりエネルギーが注入され，スペクトルが平坦に近づく
- 逆過程では，その平坦なスペクトルから初めて，段々と時空間構造を復元する
- ノイズは，全ての格子点で独立に作用し，データの時空間構造を一切加味することなく破壊する
- そのため，エネルギーが平均より大きいスケールでエネルギーの注入が起こり，逆にエネルギーが小さいスケールでエネルギーの減衰が起こる
- この注入と減衰は，スケールの大きさではなくて，エネルギーの大きさで決定される
- 例えば，空間スペクトルを見ると，$k$ の小さな領域と大きな領域でエネルギーが注入されている

# Run Score-based DA

- ここでは観測データを正解にノイズを加えて作る
- そしてこの観測データを条件とする生成を行う
  - この条件付き生成により観測データの同化が実現される
- なお簡単のため共分散行列は対角としている
  - そのため，欠損のない観測データを考えている

## Define sda class

In [None]:
class ScoreBasedDA(DDPM):
    def __init__(
        self,
        config: DDPMConfig,
        neural_net: torch.nn.Module,
        device: torch.device = torch.device("cpu"),
    ):
        super().__init__(config=config, neural_net=neural_net, device=device)

    @torch.no_grad()
    def _get_mean_for_likelihood(
        self, *, yt: torch.Tensor, t_index: torch.Tensor, score: torch.Tensor
    ) -> torch.Tensor:
        decay = self._extract_params(self.decays, t_index)
        std = self._extract_params(self.stds, t_index)

        mean = (yt + (std**2) * score) / decay

        return mean

    @torch.no_grad()
    def _get_var_for_likelihood(
        self, *, t_index: torch.Tensor, dsdx: float, std_for_obs: float
    ) -> torch.Tensor:
        decay = self._extract_params(self.decays, t_index)
        std = self._extract_params(self.stds, t_index)

        vars = std_for_obs**2 + (std**2 + (std**4) * dsdx) / (decay**2)

        return vars

    @torch.no_grad()
    def _get_derivative_of_likelihood(
        self,
        *,
        yt: torch.Tensor,
        t_index: torch.Tensor,
        dsdx: float,
        obs: torch.Tensor,
        std_for_obs: float,
        score: torch.Tensor,
    ):
        mean = self._get_mean_for_likelihood(yt=yt, t_index=t_index, score=score)
        var = self._get_var_for_likelihood(
            t_index=t_index, dsdx=dsdx, std_for_obs=std_for_obs
        )
        decay = self._extract_params(self.decays, t_index)

        o = obs.to(mean.device)
        derivatives = (o - mean) / var / decay
        masked = torch.where(torch.isnan(o), torch.zeros_like(o), derivatives)

        return masked

    @torch.no_grad()
    def _backward_sample_y_with_assimilation(
        self,
        *,
        yt: torch.Tensor,
        t_index: torch.Tensor,
        dsdx: float,
        obs: torch.Tensor,
        std_for_obs: float,
    ) -> torch.Tensor:

        friction = self._extract_params(self.frictions, t_index)
        sigma = self._extract_params(self.sigmas, t_index)
        std = self._extract_params(self.stds, t_index)
        t = self._extract_params(self.times, t_index, for_broadcast=False)
        t = t[:, None]  # add channel dim

        est_noise = self.net(yt=yt, t=t, t_index=t_index, y_cond=None)
        score = -est_noise / std
        dldx = self._get_derivative_of_likelihood(
            yt=yt,
            t_index=t_index,
            dsdx=dsdx,
            obs=obs,
            std_for_obs=std_for_obs,
            score=score,
        )

        mean = yt + self.dt * (friction * yt + (sigma**2) * (score + dldx))
        dW = self.sqrt_dt * torch.randn_like(yt)

        n_batches = yt.shape[0]
        mask = (1 - (t_index == 0).float()).reshape(n_batches, *((1,) * (yt.ndim - 1)))
        mask = mask.to(dtype=self.dtype, device=self.device)
        # no noise at t_index == 0

        return mean + mask * sigma * dW

    # public method

    @torch.no_grad()
    def assimilate(
        self,
        *,
        n_batches: int,
        derivative_score: float,
        observations: torch.Tensor,
        std_for_observations: float,
        n_return_steps: Optional[int] = None,
        tqdm_disable: bool = False,
    ):
        assert not self.net.training
        assert observations.shape == (self.c.n_channels, self.c.n_spaces)

        size = (n_batches, self.c.n_channels, self.c.n_spaces)
        yt = torch.randn(size=size, device=self.device)
        yt = self.stds[-1] * yt

        if n_return_steps is not None:
            interval = self.c.n_timesteps // n_return_steps

        intermidiates: dict[int, torch.Tensor] = {}

        for i in tqdm(
            reversed(range(0, self.c.n_timesteps)),
            total=self.c.n_timesteps,
            disable=tqdm_disable,
        ):
            if interval is not None and (i + 1) % interval == 0:
                intermidiates[i + 1] = yt.detach().clone().cpu()

            index = torch.full((n_batches,), i, device=self.device, dtype=torch.long)
            yt = self._backward_sample_y_with_assimilation(
                yt=yt,
                t_index=index,
                dsdx=derivative_score,
                obs=observations,
                std_for_obs=std_for_observations,
            )

        intermidiates[0] = yt.detach().clone().cpu()

        return intermidiates

## Make a sda instance and load weights

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

sda = ScoreBasedDA(
    config=config.ddpm,
    neural_net=unet,
    device=torch.device(DEVICE),
)

epoch = config.total_epochs
p = f"{DL_RESULT_DIR}/model_weight_{epoch:06}.pth"
unet.load_state_dict(torch.load(p, map_location=DEVICE, weights_only=False))
_ = unet.eval()
del epoch, p

## Make observations

- ノイズは測定誤差を模倣したもの

In [None]:
ground_truth = dataloader.dataset[-1]["y0"]
obs = torch.full_like(ground_truth, torch.nan)
obs = (
    ground_truth + torch.randn_like(obs) * ground_truth.abs().mean() * 0.1
)  # add measurement noise

plt.rcParams["font.size"] = 14
fig, axes = plt.subplots(1, 2, sharex=True, sharey=True, figsize=[10, 4])

for ax, data, ttl in zip(
    axes.flatten(), [ground_truth, obs], ["ground truth", "observations"]
):
    d = data.cpu().numpy()
    d = dataloader.dataset.standardize_inversely(d)
    ts = np.arange(d.shape[0]) * 0.2  # dt = 0.2
    xs = np.linspace(0, 2 * math.pi, 32, endpoint=False)
    X, T = np.meshgrid(xs, ts, indexing="ij")
    ret = ax.pcolormesh(
        X, T, d.transpose(), vmin=-9, vmax=9, cmap="coolwarm", shading="nearest"
    )
    cbar = fig.colorbar(ret, ax=ax)
    ax.set_xlabel(r"Space, $x$")
    ax.set_ylabel(r"Time, $t$")
    ax.set_title(ttl)
plt.tight_layout()
plt.show()
del d, ts, xs, X, T, ret

## Perform assimilation with observations

In [None]:
set_seeds(42)
dict_samples_with_obs = sda.assimilate(
    n_batches=10,
    derivative_score=0.0,
    observations=obs,
    std_for_observations=0.1,
    n_return_steps=1000,
)
# n_return_steps == 1000, つまり，1000 ステップを 1000 等分するように，中間状態を返す
# この場合，1 ステップごとに返ってくる

In [None]:
last_samples = dict_samples_with_obs[0]  # at t == 0

plt.rcParams["font.size"] = 12
fig, axes = plt.subplots(2, 4, sharex=True, sharey=True, figsize=[15, 6])

gt = None
for i in range(axes.shape[1]):
    ax = axes[0, i]

    if i == 0:
        d = ground_truth.cpu().numpy()
        ttl = "Ground Truth (GT)"
    else:
        d = last_samples[i].cpu().numpy()
        ttl = f"Sample {i}"
    d = (dataloader.dataset.standardize_inversely(d)).transpose()

    ts = np.arange(d.shape[0]) * 0.2  # dt = 0.2
    xs = np.linspace(0, 2 * math.pi, 32, endpoint=False)
    X, T = np.meshgrid(xs, ts, indexing="ij")

    ret = ax.pcolormesh(X, T, d, vmin=-10, vmax=10, cmap="coolwarm", shading="nearest")
    cbar = fig.colorbar(ret, ax=ax)
    ax.set_xlabel(r"Space, $x$")
    ax.set_ylabel(r"Time, $t$")
    ax.set_title(ttl)

    if i == 0:
        gt = d.copy()
        ax = axes[1, i]
        d = obs.numpy().transpose()
        d = dataloader.dataset.standardize_inversely(d)
        cmap = "coolwarm"
        vmin, vmax = -10, 10
        ttl = "Observations"
    else:
        d = gt - d
        cmap = "PiYG"
        vmin, vmax = -5, 5
        ttl = f"GT - Sample {i}"

    ax = axes[1, i]
    ret = ax.pcolormesh(X, T, d, vmin=vmin, vmax=vmax, cmap=cmap, shading="nearest")
    cbar = fig.colorbar(ret, ax=ax)
    ax.set_xlabel(r"Space, $x$")
    ax.set_ylabel(r"Time, $t$")
    ax.set_title(ttl)

plt.tight_layout()
plt.show()
del gt, d, ts, xs, X, T, ret, last_samples

- 条件付き生成により，正解と同様のデータが生成できている

## Case without observations

In [None]:
set_seeds(42)
dict_samples_without_obs = sda.assimilate(
    n_batches=10,
    derivative_score=0.0,
    observations=torch.full_like(obs, fill_value=torch.nan),  # 観測値は全て欠損
    std_for_observations=0.1,
    n_return_steps=1000,
)
# n_return_steps == 1000, つまり，1000 ステップを 1000 等分するように，中間状態を返す
# この場合，1 ステップごとに返ってくる

In [None]:
last_samples = dict_samples_without_obs[0]  # at t == 0

plt.rcParams["font.size"] = 12
fig, axes = plt.subplots(2, 4, sharex=True, sharey=True, figsize=[15, 6])

gt = None
for i in range(axes.shape[1]):
    ax = axes[0, i]

    if i == 0:
        d = ground_truth.cpu().numpy()
        ttl = "Ground Truth (GT)"
    else:
        d = last_samples[i].cpu().numpy()
        ttl = f"Sample {i}"
    d = (dataloader.dataset.standardize_inversely(d)).transpose()

    ts = np.arange(d.shape[0]) * 0.2  # dt = 0.2
    xs = np.linspace(0, 2 * math.pi, 32, endpoint=False)
    X, T = np.meshgrid(xs, ts, indexing="ij")

    ret = ax.pcolormesh(X, T, d, vmin=-10, vmax=10, cmap="coolwarm", shading="nearest")
    cbar = fig.colorbar(ret, ax=ax)
    ax.set_xlabel(r"Space, $x$")
    ax.set_ylabel(r"Time, $t$")
    ax.set_title(ttl)

    if i == 0:
        gt = d.copy()
        ax = axes[1, i]
        d = obs.numpy().transpose()
        d = dataloader.dataset.standardize_inversely(d)
        cmap = "coolwarm"
        vmin, vmax = -10, 10
        ttl = "Observations"
    else:
        d = gt - d
        cmap = "PiYG"
        vmin, vmax = -5, 5
        ttl = f"GT - Sample {i}"

    ax = axes[1, i]
    ret = ax.pcolormesh(X, T, d, vmin=vmin, vmax=vmax, cmap=cmap, shading="nearest")
    cbar = fig.colorbar(ret, ax=ax)
    ax.set_xlabel(r"Space, $x$")
    ax.set_ylabel(r"Time, $t$")
    ax.set_title(ttl)

plt.tight_layout()
plt.show()
del gt, d, ts, xs, X, T, ret, last_samples

- 観測データがない場合 (つまり無条件の生成の場合)，得られるサンプルは正解とは大きく異なる