In [None]:
import os
import time
from functools import partial

import keras_reservoir_computing as krc

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import xarray as xr
from matplotlib import pyplot as plt

In [2]:
class KuramotoSivashinsky:
    def __init__(
        self,
        L,
        Q,
        dt,
        seed=None,
    ):
        self.L = L
        self.Q = Q
        self.dt = dt
        self.dx = L / Q

        if seed is None:
            seed = int(time.time()*1e6) % 2**32
        self.seed = seed

        wavenumbers = jnp.fft.rfftfreq(Q, d=L / (Q * 2 * jnp.pi))
        self.derivative_operator = 1j * wavenumbers

        linear_operator = -self.derivative_operator**2 - self.derivative_operator**4
        self.exp_term = jnp.exp(dt * linear_operator)
        self.coef_1 = jnp.where(
            linear_operator == 0.0,
            dt,
            (self.exp_term - 1.0) / linear_operator,
        )
        self.coef_2 = jnp.where(
            linear_operator == 0.0,
            dt / 2,
            (self.exp_term - 1.0 - linear_operator * dt) / (linear_operator**2 * dt),
        )

        self.alias_mask = wavenumbers < 2 / 3 * jnp.max(wavenumbers)

    @partial(jax.jit, static_argnums=0)
    def __call__(
        self,
        u,
    ):
        u_nonlin = -0.5 * u**2
        u_hat = jnp.fft.rfft(u)
        u_nonlin_hat = jnp.fft.rfft(u_nonlin)
        u_nonlin_hat = self.alias_mask * u_nonlin_hat
        u_nonlin_der_hat = self.derivative_operator * u_nonlin_hat

        u_stage_1_hat = self.exp_term * u_hat + self.coef_1 * u_nonlin_der_hat
        u_stage_1 = jnp.fft.irfft(u_stage_1_hat, n=self.Q)

        u_stage_1_nonlin = -0.5 * u_stage_1**2
        u_stage_1_nonlin_hat = jnp.fft.rfft(u_stage_1_nonlin)
        u_stage_1_nonlin_hat = self.alias_mask * u_stage_1_nonlin_hat
        u_stage_1_nonlin_der_hat = self.derivative_operator * u_stage_1_nonlin_hat

        u_next_hat = u_stage_1_hat + self.coef_2 * (
            u_stage_1_nonlin_der_hat - u_nonlin_der_hat
        )
        u_next = jnp.fft.irfft(u_next_hat, n=self.Q)

        return u_next

    @partial(jax.jit, static_argnums=(0, 2, 3))
    def integrate(self, u_init=None, steps=None, final_t=None):
        if u_init is None:

            rng_key = jax.random.key(self.seed)
            u_init = 1e-3 * jax.random.uniform(rng_key, shape=(self.Q,), minval=-1.0, maxval=1.0)

        if final_t is not None:
            # final_t takes precedence over steps
            steps = int(final_t / self.dt)

        elif steps is None:
            raise ValueError("Either steps or final_t must be provided")

        def step_fn(u, _):
            u_next = self(u)
            return u_next, u_next

        _, u_traj = jax.lax.scan(step_fn, u_init, None, length=steps-1) # first step is u_init
        return jnp.vstack((u_init, u_traj))

In [47]:
DOMAIN_SIZE = 36.0
N_DOF = 128
DT = 0.02

In [60]:
ks = KuramotoSivashinsky(
    L=DOMAIN_SIZE,
    Q=N_DOF,
    dt=DT,
)

trj_etdrk2 = ks.integrate(final_t=2000)

In [None]:
trj_etdrk2[10000:].shape

In [None]:
%matplotlib widget

# Adjust figure size and aspect ratio
plt.close("all")
plt.figure(figsize=(12, 6))  # Adjust width and height to match the data shape
plt.imshow(
    trj_etdrk2[:].T, cmap="viridis", interpolation="nearest", aspect="auto"
)  # Use 'auto' for proper aspect ratio
plt.colorbar(label="Value")  # Add a colorbar
plt.title("Colormap of 2D Array")
plt.xlabel("X-axis (8001)")
plt.ylabel("Y-axis (128)")
plt.tight_layout()  # Optimize spacing
plt.show()

In [57]:
data = np.load(
    "/media/elessar/Data/Pincha/TSDynamics/data/continuous/KS/Train/KS_dt-0.02_N-128_L-36.0_iteration-0.npy"
)

In [None]:
%matplotlib widget

# Adjust figure size and aspect ratio
plt.close("all")
plt.figure(figsize=(12, 6))  # Adjust width and height to match the data shape
plt.imshow(
    data[:].T, cmap="viridis", interpolation="nearest", aspect="auto"
)  # Use 'auto' for proper aspect ratio
plt.colorbar(label="Value")  # Add a colorbar
plt.title("Colormap of 2D Array")
plt.xlabel("X-axis (8001)")
plt.ylabel("Y-axis (128)")
plt.tight_layout()  # Optimize spacing
plt.show()

In [23]:
def jax2df(trj):
    return pd.DataFrame(trj)

In [24]:
def df2xr(df):
    ds = xr.DataArray(df.values, dims=("row", "col"))
    return ds


In [22]:
def save_to_netcdf(trj, path):
    df = jax2df(trj)
    ds = df2xr(df)
    ds.to_netcdf(
        path=path,
        engine="netcdf4",
        encoding={"__xarray_dataarray_variable__": {"zlib": True, "complevel": 5}},
    )

In [21]:
def save_data(trj, folder_path, iteration, method="nc"):
    os.makedirs(folder_path, exist_ok=True)
    if method == "nc":
        save_to_netcdf(
            trj, os.path.join(folder_path, f"KS_dt-{DT}_N-{N_DOF}_L-{DOMAIN_SIZE}_iteration-{iteration}.nc")
        )
    elif method == "npy":
        np.save(
            os.path.join(folder_path, f"KS_dt-{DT}_N-{N_DOF}_L-{DOMAIN_SIZE}_iteration-{iteration}.npy"),
            trj,
        )

In [None]:
for i in range(1,701):
    ks = KuramotoSivashinsky(
        L=DOMAIN_SIZE,
        Q=N_DOF,
        dt=DT,
    )

    if i % 10 == 0:
        print(f"Iteration {i}")

    trj_etdrk2 = ks.integrate(steps=70000)

    if i < 200:
        iteration = i
        savepath = "/media/elessar/Data/Pincha/TSDynamics/data/continuous/KS/Train"
    else:
        iteration = i - 200
        savepath = "/media/elessar/Data/Pincha/TSDynamics/data/continuous/KS/Test"


    save_data(trj_etdrk2, folder_path=savepath, iteration=iteration, method="npy")