In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Import

In [None]:
import sys
from logging import INFO, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

In [None]:
import os
import time
import math
import pathlib

import torch
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from torch.cuda.amp import GradScaler

from src.utils.random_seed_helper import set_seeds, seed_worker, get_torch_generator
from src.frameworks.ddpm import DDPMConfig
from src.frameworks.sda import ScoreBasedDA
from src.datasets.dataset_lorenz96 import DatasetLorenz96
from src.configs.lorenz96_config import ExperimentLorenz96Config
from src.neural_networks.unet_1d import Unet1D
from src.training.loss_maker import make_loss
from src.training.optim_helper import optimize_ddpm

plt.rcParams["font.family"] = "serif"
plt.style.use("tableau-colorblind10")
os.environ["CUBLAS_WORKSPACE_CONFIG"] = r":4096:8"  # to make calculations deterministic

# Define constants

In [None]:
DEVICE = "cuda:0"
ROOT_DIR = pathlib.Path(os.environ["PYTHONPATH"].split(":")[0]).resolve()
DL_DATA_FILE = str(ROOT_DIR / "data" / "DL_data" / "lorenz96" / "lorenz96_v00.nc")
DL_RESULT_DIR = str(ROOT_DIR / "data" / "DL_model" / "lorenz96_v01")
os.makedirs(DL_RESULT_DIR, exist_ok=True)

# Make config

- 分散保存型 (Variance-Preserving; VP) の拡散モデルを扱う．
- 順過程は下の様に書ける
$$
\begin{align}
dx_t &= -\frac{1}{2} \beta_t x \; dt + \sqrt{\beta_t} \; dW \quad (t \in [0,1]) \\
\beta_t &= \beta_{\rm start} + t \beta_{\rm end}
\end{align}
$$

In [None]:
config = ExperimentLorenz96Config(
    # Training settings
    batch_size=100,
    loss_name="L2",
    learning_rate=1e-3,
    total_epochs=40,
    save_interval=4,
    use_auto_mix_precision=False,
    # For U-Net
    n_features=32,
    list_channel=[1, 2, 4],
    # For DDPM
    ddpm=DDPMConfig(
        start_beta=1e1,  # start_beta == end_beta の設定により，beta を定数にする
        end_beta=1e1,
        n_timesteps=1_000,
        n_channels=32,  # データの時間ステップ数をチャネル数として指定
        n_spaces=32,  # 空間格子点数．dont change n_channels and n_spaces
    ),
)
config.save(f"{DL_RESULT_DIR}/config.yml")

In [None]:
# You can load your saved config.
# config = ExperimentLorenz96Config.load(f"{DL_RESULT_DIR}/config.yml")
# config

# Make dataloader

In [None]:
dataloader = torch.utils.data.DataLoader(
    dataset=DatasetLorenz96(DL_DATA_FILE),
    batch_size=config.batch_size,
    drop_last=True,
    shuffle=True,
    pin_memory=True,
    num_workers=2,
    worker_init_fn=seed_worker,
    generator=get_torch_generator(),
)

In [None]:
data = next(iter(dataloader))["y0"]
assert data.shape == (
    config.batch_size,
    config.ddpm.n_channels,
    config.ddpm.n_spaces,
)  # batch, time (=channel), space dims

## Plot

In [None]:
plt.rcParams["font.size"] = 14
fig, axes = plt.subplots(1, 3, sharex=True, sharey=True, figsize=[10, 4])

for i, ax in enumerate(axes.flatten()):
    d = dataloader.dataset.standardize_inversely(data[i].numpy())
    ts = np.arange(d.shape[0]) * 0.2  # dt = 0.2
    xs = np.linspace(0, 2 * math.pi, 32, endpoint=False)
    X, T = np.meshgrid(xs, ts, indexing="ij")
    ret = ax.pcolormesh(
        X, T, d.transpose(), vmin=-9, vmax=9, cmap="coolwarm", shading="nearest"
    )
    cbar = fig.colorbar(ret, ax=ax)
    ax.set_xlabel(r"Space, $x$")
    ax.set_ylabel(r"Time, $t$")
plt.tight_layout()
plt.show()

# Prepare for training

In [None]:
set_seeds(42)

unet = Unet1D(
    dim=config.n_features,
    in_channels=32,  # num of times
    out_channels=32,  # dont change in_channels and out_channels
    padding_mode="circular",
    dim_mults=config.list_channel,
).to(DEVICE)

sda = ScoreBasedDA(
    config=config.ddpm,
    neural_net=unet,
    device=torch.device(DEVICE),
)

loss_fn = make_loss(loss_name=config.loss_name)
optimizer = torch.optim.AdamW(sda.parameters(), lr=config.learning_rate)
scaler = GradScaler()

# Train

In [None]:
start_time = time.time()

all_scores: list[dict] = []
set_seeds(42)

with tqdm(total=config.total_epochs, desc="Training Progress", unit="step") as pbar:
    for _epoch in range(config.total_epochs):
        epoch = _epoch + 1  # 0 から始まるため，1 を足す

        loss = optimize_ddpm(
            dataloader=dataloader,
            ddpm=sda,
            loss_fn=loss_fn,
            optimizer=optimizer,
            epoch=epoch,
            mode="train",
            scaler=scaler,
            use_amp=config.use_auto_mix_precision,
        )
        all_scores.append({"epoch": epoch, "loss": loss})

        if epoch % config.save_interval == 0:
            p = f"{DL_RESULT_DIR}/model_weight_{epoch:06}.pth"
            torch.save(sda.net.state_dict(), p)

        if epoch % 10 == 0 or epoch == config.total_epochs:
            p = f"{DL_RESULT_DIR}/loss_history.csv"
            pd.DataFrame(all_scores).to_csv(p, index=False)

        pbar.set_postfix({"loss": loss})
        pbar.update(1)

end_time = time.time()
logger.info(f"Finished. Total elapsed time = {(end_time - start_time) / 60.} min")

# Test

In [None]:
epoch = config.total_epochs
p = f"{DL_RESULT_DIR}/model_weight_{epoch:06}.pth"
unet.load_state_dict(torch.load(p, map_location=DEVICE, weights_only=False))
_ = unet.eval()

## Make observations

In [None]:
ground_truth = dataloader.dataset[-1]["y0"]
obs = torch.full_like(ground_truth, torch.nan)
obs = (
    ground_truth + torch.randn_like(obs) * ground_truth.abs().mean() * 0.1
)  # add measurement noise

In [None]:
plt.rcParams["font.size"] = 14
fig, axes = plt.subplots(1, 2, sharex=True, sharey=True, figsize=[10, 4])

for ax, data, ttl in zip(
    axes.flatten(), [ground_truth, obs], ["ground truth", "observations"]
):
    d = data.cpu().numpy()
    d = dataloader.dataset.standardize_inversely(d)
    ts = np.arange(d.shape[0]) * 0.2  # dt = 0.2
    xs = np.linspace(0, 2 * math.pi, 32, endpoint=False)
    X, T = np.meshgrid(xs, ts, indexing="ij")
    ret = ax.pcolormesh(
        X, T, d.transpose(), vmin=-9, vmax=9, cmap="coolwarm", shading="nearest"
    )
    cbar = fig.colorbar(ret, ax=ax)
    ax.set_xlabel(r"Space, $x$")
    ax.set_ylabel(r"Time, $t$")
    ax.set_title(ttl)
plt.tight_layout()
plt.show()

## Perform assimilation with observations

In [None]:
set_seeds(42)
dict_samples = sda.assimilate(
    n_batches=10,
    derivative_score=0.0,
    observations=obs,
    std_for_observations=0.1,
    n_return_steps=1000,
)
# n_return_steps == 1000, つまり，1000 ステップを 1000 等分するように，中間状態を返す
# この場合，1 ステップごとに返ってくる

In [None]:
last_samples = dict_samples[0]  # at t == 0

plt.rcParams["font.size"] = 12
fig, axes = plt.subplots(2, 4, sharex=True, sharey=True, figsize=[15, 6])

gt = None
for i in range(axes.shape[1]):
    ax = axes[0, i]

    if i == 0:
        d = ground_truth.cpu().numpy()
        ttl = "Ground Truth (GT)"
    else:
        d = last_samples[i].cpu().numpy()
        ttl = f"Sample {i}"
    d = (dataloader.dataset.standardize_inversely(d)).transpose()

    ts = np.arange(d.shape[0]) * 0.2  # dt = 0.2
    xs = np.linspace(0, 2 * math.pi, 32, endpoint=False)
    X, T = np.meshgrid(xs, ts, indexing="ij")

    ret = ax.pcolormesh(X, T, d, vmin=-10, vmax=10, cmap="coolwarm", shading="nearest")
    cbar = fig.colorbar(ret, ax=ax)
    ax.set_xlabel(r"Space, $x$")
    ax.set_ylabel(r"Time, $t$")
    ax.set_title(ttl)

    if i == 0:
        gt = d.copy()
        ax = axes[1, i]
        d = obs.numpy().transpose()
        d = dataloader.dataset.standardize_inversely(d)
        cmap = "coolwarm"
        vmin, vmax = -10, 10
        ttl = "Observations"
    else:
        d = gt - d
        cmap = "PiYG"
        vmin, vmax = -5, 5
        ttl = f"GT - Sample {i}"

    ax = axes[1, i]
    ret = ax.pcolormesh(X, T, d, vmin=vmin, vmax=vmax, cmap=cmap, shading="nearest")
    cbar = fig.colorbar(ret, ax=ax)
    ax.set_xlabel(r"Space, $x$")
    ax.set_ylabel(r"Time, $t$")
    ax.set_title(ttl)

plt.tight_layout()
plt.show()

# Without observations

In [None]:
set_seeds(42)
dict_samples = sda.assimilate(
    n_batches=10,
    derivative_score=0.0,
    observations=torch.full_like(obs, fill_value=torch.nan),  # 観測値は全て欠損
    std_for_observations=0.1,
    n_return_steps=1000,
)
# n_return_steps == 1000, つまり，1000 ステップを 1000 等分するように，中間状態を返す
# この場合，1 ステップごとに返ってくる

In [None]:
last_samples = dict_samples[0]  # at t == 0

plt.rcParams["font.size"] = 12
fig, axes = plt.subplots(2, 4, sharex=True, sharey=True, figsize=[15, 6])

gt = None
for i in range(axes.shape[1]):
    ax = axes[0, i]

    if i == 0:
        d = ground_truth.cpu().numpy()
        ttl = "Ground Truth (GT)"
    else:
        d = last_samples[i].cpu().numpy()
        ttl = f"Sample {i}"
    d = (dataloader.dataset.standardize_inversely(d)).transpose()

    ts = np.arange(d.shape[0]) * 0.2  # dt = 0.2
    xs = np.linspace(0, 2 * math.pi, 32, endpoint=False)
    X, T = np.meshgrid(xs, ts, indexing="ij")

    ret = ax.pcolormesh(X, T, d, vmin=-10, vmax=10, cmap="coolwarm", shading="nearest")
    cbar = fig.colorbar(ret, ax=ax)
    ax.set_xlabel(r"Space, $x$")
    ax.set_ylabel(r"Time, $t$")
    ax.set_title(ttl)

    if i == 0:
        gt = d.copy()
        ax = axes[1, i]
        d = obs.numpy().transpose()
        d = dataloader.dataset.standardize_inversely(d)
        cmap = "coolwarm"
        vmin, vmax = -10, 10
        ttl = "Observations"
    else:
        d = gt - d
        cmap = "PiYG"
        vmin, vmax = -5, 5
        ttl = f"GT - Sample {i}"

    ax = axes[1, i]
    ret = ax.pcolormesh(X, T, d, vmin=vmin, vmax=vmax, cmap=cmap, shading="nearest")
    cbar = fig.colorbar(ret, ax=ax)
    ax.set_xlabel(r"Space, $x$")
    ax.set_ylabel(r"Time, $t$")
    ax.set_title(ttl)

plt.tight_layout()
plt.show()