# Import

In [None]:
import sys
from logging import INFO, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

In [None]:
import os
import math
import random
import pathlib

import torch
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
from torch import Tensor
from tqdm.notebook import tqdm

plt.rcParams["font.family"] = "serif"
plt.style.use("tableau-colorblind10")

# Make Lorenz96 data

## Preferences

In [None]:
ROOT_DIR = pathlib.Path(".").resolve()
DL_DATA_DIR = str(ROOT_DIR / "data" / "DL_data" / "lorenz96")
os.makedirs(DL_DATA_DIR, exist_ok=True)

In [None]:
N_BATCHES = 10_000  # 作成するデータセット数
N_SPACES = 32  # 空間の格子点数
N_TIMES = 10_000  # 時間ステップ数

FORCING = 8.0
AMP_PERTURBATION = 0.01
DT = 0.005
SEED = 42

## Methods

In [None]:
def integrate_lorenz96(x0: Tensor, forcing: float, n_steps: int, dt: float) -> Tensor:

    assert isinstance(x0, Tensor) and x0.ndim == 2  # batch and space
    assert isinstance(forcing, float)
    assert isinstance(n_steps, int) and n_steps > 0
    assert isinstance(dt, float) and dt > 0.0

    current = x0.clone().detach()
    states = [current.clone().detach()]

    for _ in tqdm(range(n_steps)):
        rhs = _lorenz96_rhs(x=current, forcing=forcing)
        current = current + dt * rhs
        states.append(current.clone().detach())

    return torch.stack(states, dim=1).cpu()  # stack along time dim


def _lorenz96_rhs(x: Tensor, forcing: float) -> Tensor:

    a = x.roll(shifts=-1, dims=1)
    b = x.roll(shifts=2, dims=1)
    c = x.roll(shifts=1, dims=1)
    dxdt = (a - b) * c - x + forcing

    return dxdt


def set_seeds(seed: int = 42, use_deterministic: bool = True) -> None:
    try:
        os.environ["PYTHONHASHSEED"] = str(seed)
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)

        if use_deterministic:
            torch.use_deterministic_algorithms(True, warn_only=True)

        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    except Exception as e:
        logger.error(e)

## Integrate Lorenz96

In [None]:
dtype = torch.float32
device = torch.device("cpu")

set_seeds(SEED)
x0 = FORCING * torch.ones(size=(N_BATCHES, N_SPACES), dtype=dtype, device=device)
x0 += torch.randn_like(x0) * AMP_PERTURBATION

In [None]:
states = integrate_lorenz96(x0=x0, forcing=FORCING, n_steps=N_TIMES, dt=DT)

if torch.any(torch.isnan(states)):
    logger.warning("NaNs appear.")
elif torch.any(~torch.isfinite(states)):
    logger.warning("Infs appear")
else:
    logger.info("Integration was successfully finished.")

## Plot results

In [None]:
plt.rcParams["font.size"] = 14
fig, axes = plt.subplots(1, 3, sharex=True, sharey=True, figsize=[10, 4])

for i, ax in enumerate(axes.flatten()):
    d = states[i].numpy()[::40][
        -32:
    ]  ## 40 時間ステップごとに抽出し，最後の 32 要素を時間に沿って抽出
    ts = np.arange(d.shape[0]) * 20 * DT
    xs = np.linspace(0, 2 * math.pi, N_SPACES, endpoint=False)
    X, T = np.meshgrid(xs, ts, indexing="ij")
    ret = ax.pcolormesh(
        X, T, d.transpose(), vmin=-9, vmax=9, cmap="coolwarm", shading="nearest"
    )
    cbar = fig.colorbar(ret, ax=ax)
    ax.set_xlabel(r"Space, $x$")
    ax.set_ylabel(r"Time, $t$")
plt.tight_layout()
plt.show()

## Write out

In [None]:
outs = states[:, ::40][:, -32:]
## 40 時間ステップごとに抽出し，最後の 32 要素を時間に沿って抽出
assert outs.shape == (N_BATCHES, 32, N_SPACES), f"{outs.shape=}"

ts = np.arange(0.0, outs.shape[1]) * (DT * 40)
xs = np.linspace(0, 2 * math.pi, N_SPACES, endpoint=False)

da = xr.DataArray(
    outs.numpy().astype(np.float32),
    dims=["batch", "time", "space"],
    coords={
        "batch": np.arange(N_BATCHES, dtype=np.int32),
        "time": ts.astype(np.float32),
        "space": xs.astype(np.float32),
    },
    name="lorenz96_trajectory",
    attrs={
        "forcing": FORCING,
        "dt": DT,
        "seed": SEED,
        "amp_perturb": AMP_PERTURBATION,
    },
)

p = f"{DL_DATA_DIR}/lorenz96_v00.nc"
da.to_netcdf(path=p)

# Tran diffusion model