In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Import

In [None]:
import sys
from logging import INFO, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

In [None]:
import os
import pathlib

import matplotlib.pyplot as plt
import numpy as np
import torch
from src.lorenz63_model.lorenz63_model import Lorenz63
from src.lorenz63_model.utils.lorenz63_config import Lorenz63Config
from src.utils.random_seed_helper import set_seeds
from tqdm.notebook import tqdm

plt.rcParams["font.family"] = "serif"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = r":4096:8"  # to make calculations deterministic

# Define constant

In [None]:
ROOT_DIR = pathlib.Path(os.environ["PYTHONPATH"]).parent.resolve()

fig_dir = f"{ROOT_DIR}/docs/lorenz63_model/fig"
# os.makedirs(fig_dir, exist_ok=False)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu").type

# Run simulation

In [None]:
set_seeds(seed=42, use_deterministic=True)

In [None]:
cfg = Lorenz63Config(
    n_batch=2,
    noise_amplitude=1.0,
    device=DEVICE,
    precision="double"
)

In [None]:
model = Lorenz63(cfg, show_input_cfg_info=True)

In [None]:
X0 = torch.tensor([11.2, 10.2, 33.2], dtype=model.real_dtype).to(model.device)

model.initialize(X=X0)

Xs, ts = [model.get_state()], [model.t]

dt = 0.001
output_dt = 0.01
end_time = 40

output_tsteps = torch.arange(output_dt, end_time + output_dt, output_dt)

for _ in tqdm(output_tsteps):
    model.integrate_n_steps(dt_per_step=dt, n_steps=int(output_dt / dt))
    Xs.append(model.get_state())
    ts.append(model.t)

# Stack arrays along time dim
Xs = torch.stack(Xs, dim=1).squeeze()

# shape = (batch, time, (x, y, z))
logger.info(f"Shape of the result: {Xs.shape}")

# Plot result

In [None]:
plt.rcParams["font.size"] = 15

fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(10, 10))
for (i, ax), ylabel in zip(enumerate(axes), ["x", "y", "z"]):
    ax.plot(ts, Xs[0, :, i], label="data1")
    ax.plot(ts, Xs[1, :, i], label="data2")
    ax.set_xlabel("t")
    ax.set_ylabel(ylabel)

    ax.legend(loc=3)

# fig.savefig(f"{fig_dir}/xyz_trajectory_plot.png")

plt.show()

In [None]:
plt.rcParams["font.size"] = 20

fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(projection="3d")
ax.plot(Xs[0, :, 0].numpy(), Xs[0, :, 1].numpy(), Xs[0, :, 2].numpy(), "-", label="data1")
ax.plot(Xs[1, :, 0].numpy(), Xs[1, :, 1].numpy(), Xs[1, :, 2].numpy(), "-", label="data2")

ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")

ax.legend()

# fig.savefig(f"{fig_dir}/3d_trajectory_plot.png")

plt.show()