In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Import

In [None]:
import sys
from logging import INFO, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

In [None]:
import os
import time
import math
import pathlib

import torch
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from torch.cuda.amp import GradScaler

from src.utils.random_seed_helper import set_seeds, seed_worker, get_torch_generator
from src.utils.psd import compute_2d_psd_for_lorenz96, compute_1d_psds_for_lorenz96
from src.frameworks.ddpm import DDPM, DDPMConfig
from src.datasets.dataset_lorenz96 import DatasetLorenz96
from src.configs.lorenz96_config import ExperimentLorenz96Config
from src.neural_networks.unet_1d import Unet1D
from src.training.loss_maker import make_loss
from src.training.optim_helper import optimize_ddpm

plt.rcParams["font.family"] = "serif"
plt.style.use("tableau-colorblind10")
os.environ["CUBLAS_WORKSPACE_CONFIG"] = r":4096:8"  # to make calculations deterministic

# Define constants

In [None]:
DEVICE = "cuda:0"
ROOT_DIR = pathlib.Path(os.environ["PYTHONPATH"].split(":")[0]).resolve()
DL_DATA_FILE = str(ROOT_DIR / "data" / "DL_data" / "lorenz96" / "lorenz96_v00.nc")
DL_RESULT_DIR = str(ROOT_DIR / "data" / "DL_model" / "lorenz96_v00")
os.makedirs(DL_RESULT_DIR, exist_ok=True)

# Make config

- 分散保存型 (Variance-Preserving; VP) の拡散モデルを扱う．
- 順過程は下の様に書ける
$$
\begin{align}
dx_t &= -\frac{1}{2} \beta_t x \; dt + \sqrt{\beta_t} \; dW \quad (t \in [0,1]) \\
\beta_t &= \beta_{\rm start} + t \beta_{\rm end}
\end{align}
$$

In [None]:
config = ExperimentLorenz96Config(
    # Training settings
    batch_size=100,
    loss_name="L2",
    learning_rate=1e-3,
    total_epochs=40,
    save_interval=4,
    use_auto_mix_precision=False,
    # For U-Net
    n_features=32,
    list_channel=[1, 2, 4],
    # For DDPM
    ddpm=DDPMConfig(
        start_beta=1e1,  # start_beta == end_beta の設定により，beta を定数にする
        end_beta=1e1,
        n_timesteps=1_000,
        n_channels=32,  # データの時間ステップ数をチャネル数として指定
        n_spaces=32,  # 空間格子点数．dont change n_channels and n_spaces
    ),
)
config.save(f"{DL_RESULT_DIR}/config.yml")

In [None]:
# You can load your saved config.
# config = ExperimentLorenz96Config.load(f"{DL_RESULT_DIR}/config.yml")
# config

# Make dataloader

In [None]:
dataloader = torch.utils.data.DataLoader(
    dataset=DatasetLorenz96(DL_DATA_FILE),
    batch_size=config.batch_size,
    drop_last=True,
    shuffle=True,
    pin_memory=True,
    num_workers=2,
    worker_init_fn=seed_worker,
    generator=get_torch_generator(),
)

In [None]:
data = next(iter(dataloader))["y0"]
assert data.shape == (
    config.batch_size,
    config.ddpm.n_channels,
    config.ddpm.n_spaces,
)  # batch, time (=channel), space dims

## Plot

In [None]:
plt.rcParams["font.size"] = 14
fig, axes = plt.subplots(1, 3, sharex=True, sharey=True, figsize=[10, 4])

for i, ax in enumerate(axes.flatten()):
    d = dataloader.dataset.standardize_inversely(data[i].numpy())
    ts = np.arange(d.shape[0]) * 0.2  # dt = 0.2
    xs = np.linspace(0, 2 * math.pi, 32, endpoint=False)
    X, T = np.meshgrid(xs, ts, indexing="ij")
    ret = ax.pcolormesh(
        X, T, d.transpose(), vmin=-9, vmax=9, cmap="coolwarm", shading="nearest"
    )
    cbar = fig.colorbar(ret, ax=ax)
    ax.set_xlabel(r"Space, $x$")
    ax.set_ylabel(r"Time, $t$")
plt.tight_layout()
plt.show()

# Prepare for training

In [None]:
set_seeds(42)

unet = Unet1D(
    dim=config.n_features,
    in_channels=32,  # num of times
    out_channels=32,  # dont change in_channels and out_channels
    padding_mode="circular",
    dim_mults=config.list_channel,
).to(DEVICE)

ddpm = DDPM(config=config.ddpm, neural_net=unet, device=torch.device(DEVICE))

loss_fn = make_loss(loss_name=config.loss_name)
optimizer = torch.optim.AdamW(ddpm.parameters(), lr=config.learning_rate)
scaler = GradScaler()

# Train

In [None]:
start_time = time.time()

all_scores: list[dict] = []
set_seeds(42)

with tqdm(total=config.total_epochs, desc="Training Progress", unit="step") as pbar:
    for _epoch in range(config.total_epochs):
        epoch = _epoch + 1  # 0 から始まるため，1 を足す

        loss = optimize_ddpm(
            dataloader=dataloader,
            ddpm=ddpm,
            loss_fn=loss_fn,
            optimizer=optimizer,
            epoch=epoch,
            mode="train",
            scaler=scaler,
            use_amp=config.use_auto_mix_precision,
        )
        all_scores.append({"epoch": epoch, "loss": loss})

        if epoch % config.save_interval == 0:
            p = f"{DL_RESULT_DIR}/model_weight_{epoch:06}.pth"
            torch.save(ddpm.net.state_dict(), p)

        if epoch % 10 == 0 or epoch == config.total_epochs:
            p = f"{DL_RESULT_DIR}/loss_history.csv"
            pd.DataFrame(all_scores).to_csv(p, index=False)

        pbar.set_postfix({"loss": loss})
        pbar.update(1)

end_time = time.time()
logger.info(f"Finished. Total elapsed time = {(end_time - start_time) / 60.} min")

# Test

In [None]:
epoch = config.total_epochs
p = f"{DL_RESULT_DIR}/model_weight_{epoch:06}.pth"
unet.load_state_dict(torch.load(p, map_location=DEVICE, weights_only=False))
_ = unet.eval()

## Run sampling

In [None]:
set_seeds(42)
dict_samples = ddpm.backward_sample_y(n_batches=50, n_return_steps=1000)
# n_return_steps == 1000, つまり，1000 ステップを 1000 等分するように，中間状態を返す
# この場合，1 ステップごとに返ってくる

## Plot generated samples

In [None]:
last_samples = dict_samples[0]  # at t == 0

plt.rcParams["font.size"] = 14
fig, axes = plt.subplots(1, 3, sharex=True, sharey=True, figsize=[10, 4])

for i, ax in enumerate(axes.flatten()):
    d = last_samples[i].cpu().numpy()
    d = dataloader.dataset.standardize_inversely(d)
    ts = np.arange(d.shape[0]) * 0.2  # dt = 0.2
    xs = np.linspace(0, 2 * math.pi, 32, endpoint=False)
    X, T = np.meshgrid(xs, ts, indexing="ij")
    ret = ax.pcolormesh(
        X, T, d.transpose(), vmin=-9, vmax=9, cmap="coolwarm", shading="nearest"
    )
    cbar = fig.colorbar(ret, ax=ax)
    ax.set_xlabel(r"Space, $x$")
    ax.set_ylabel(r"Time, $t$")
plt.tight_layout()
plt.show()

## Plot intermediate states during generation

In [None]:
plt.rcParams["font.size"] = 12
fig, axes = plt.subplots(2, 6, sharex=True, sharey=True, figsize=(15, 6))

for t, ax in zip([100, 500, 200, 100, 80, 60, 50, 40, 30, 20, 10, 0], axes.flatten()):
    d = dict_samples[t][0].cpu().numpy()
    d = dataloader.dataset.standardize_inversely(d)

    ts = np.arange(d.shape[0]) * 0.2  # dt = 0.2
    xs = np.linspace(0, 2 * math.pi, 32, endpoint=False)
    X, T = np.meshgrid(xs, ts, indexing="ij")

    ret = ax.pcolormesh(
        X, T, d.transpose(), vmin=-9, vmax=9, cmap="coolwarm", shading="nearest"
    )

    cbar = fig.colorbar(ret, ax=ax)
    ax.set_xlabel(r"Space, $x$")
    ax.set_ylabel(r"Time, $t$")
    ax.set_title(f"Diffusion Step = {t}")
plt.tight_layout()
plt.show()

## Spectral analysis

In [None]:
gt_samples = torch.stack(
    [dataloader.dataset[i]["y0"] for i in range(50)], dim=0
).numpy()
my_samples = dict_samples[0].cpu().numpy()  # at t == 0

assert gt_samples.shape == my_samples.shape
assert isinstance(gt_samples, np.ndarray) and isinstance(my_samples, np.ndarray)

### make coordinates

In [None]:
xs = dataloader.dataset.data["space"].values
ts = dataloader.dataset.data["time"].values
dt = float(np.mean(np.diff(ts.astype(np.float64))))
dx = float(np.mean(np.diff(xs.astype(np.float64))))

_, gt_psd_time_mean, _, gt_psd_space_mean = compute_1d_psds_for_lorenz96(
    gt_samples, dt=dt, dx=dx
)

freqs_time, my_psd_time_mean, freqs_space, my_psd_space_mean = (
    compute_1d_psds_for_lorenz96(my_samples, dt=dt, dx=dx)
)

gt_psd_2d_mean = compute_2d_psd_for_lorenz96(torch.from_numpy(gt_samples)).numpy()
my_psd_2d_mean = compute_2d_psd_for_lorenz96(torch.from_numpy(my_samples)).numpy()

### plot

In [None]:
xs = np.zeros(32)
xs[-17:] = freqs_space
xs[:15] = -np.flip(freqs_space[1:16])

ys = np.zeros(32)
ys[-17:] = freqs_time
ys[:15] = -np.flip(freqs_time[1:16])

xs, ys = np.meshgrid(xs, ys, indexing="ij")
vmin, vmax = None, None

plt.rcParams["font.size"] = 14
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

for ax, kind, data in zip(
    axes, ["Physics Simulation", "Langevin Sampling"], [gt_psd_2d_mean, my_psd_2d_mean]
):
    d = data.transpose()  # time, space --> space, time

    if vmin is None or vmax is None:
        vmin, vmax = np.quantile(np.log10(d), [0.01, 0.99])
    pm = ax.pcolormesh(xs, ys, np.log10(d), shading="nearest", vmin=vmin, vmax=vmax)
    ax.set_xlabel(r"Wavenumber, $k$")
    ax.set_ylabel(r"Frequency, $\omega$")

    ax.set_title(kind)
    ax.axvline(0, ls="--", color="k", alpha=0.5)
    ax.axhline(0, ls="--", color="k", alpha=0.5)
    fig.colorbar(pm, ax=ax, label=r"$\log_{10}(\mathrm{PSD})$", extend="both")
plt.tight_layout()
plt.show()

In [None]:
plt.rcParams["font.size"] = 16
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

for kind in ["Physics Simulation", "Langevin Sampling"]:

    for t, cmap in zip(
        reversed([1000] + list(np.arange(480, 0, -20)) + [0]),
        plt.cm.turbo(np.linspace(0, 1, 28))[1:-1],
    ):
        if kind == "Physics Simulation" and t != 0:
            continue

        ls, label, lw = "--", None, 1.5
        if kind == "Physics Simulation":
            lw = 4.0
            cmap = "k"
            ls = "-"
            label = r"Phys. Sim. ${\rm Diffusion Time}=0.0$"
        else:
            if t == 1000:
                lw = 4.0
                label = r"Langevin ${\rm Diffusion Time}=1.0$"
            elif t == 0:
                lw = 4.0
                label = r"Langevin ${\rm Diffusion Time}=0.0$"

            _freqs_time, psd_time_mean, _freqs_space, psd_space_mean = (
                compute_1d_psds_for_lorenz96(
                    dict_samples[t].squeeze().numpy(), dt=dt, dx=dx
                )
            )
            assert np.all(_freqs_space == freqs_space)
            assert np.all(_freqs_time == freqs_time)

        d = gt_psd_time_mean if kind == "Physics Simulation" else psd_time_mean
        axes[0].plot(freqs_time, d, lw=lw, color=cmap, ls=ls, label=label)
        axes[0].set_title("PSD in Time Direction")
        axes[0].set_xlabel(r"Frequency, $\omega$ (1/time)")
        axes[0].set_ylabel("Power Spectral Density")
        axes[0].set_xscale("log")
        axes[0].set_yscale("log")
        axes[0].legend(loc="lower left", fontsize=15)

        d = gt_psd_space_mean if kind == "Physics Simulation" else psd_space_mean
        axes[1].plot(freqs_space, d, lw=lw, color=cmap, ls=ls, label=label)
        axes[1].set_title("PSD in Space Direction")
        axes[1].set_xlabel(r"Wavenumber, $k$ (1/space)")
        axes[1].set_ylabel("Power Spectral Density")
        axes[1].set_xscale("log")
        axes[1].set_yscale("log")

plt.tight_layout()
plt.show()

- ノイズによりエネルギーが注入され，スペクトルが平坦に近づく
- 逆過程では，その平坦なスペクトルから初めて，段々と時空間構造を復元する
- ノイズは，全ての格子点で独立に作用し，データの時空間構造を一切加味することなく破壊する
- そのため，エネルギーが平均より大きいスケールでエネルギーの注入が起こり，逆にエネルギーが小さいスケールでエネルギーの減衰が起こる
- この注入と減衰は，スケールの大きさではなくて，エネルギーの大きさで決定される
- 例えば，空間スペクトルを見ると，$k$ の小さな領域と大きな領域でエネルギーが注入されている