In [12]:
import sys
import os
import glob
import h5py
import numpy as np
import matplotlib.pyplot as plt
from numpy.fft import rfft, irfft
from matplotlib.animation import FuncAnimation
import jax
import jax.numpy as jnp

current_dir = os.getcwd()
sys.path.append(os.path.abspath(os.path.join(current_dir, '..')))

from HOSim.solver import f
f_jit = jax.jit(f, static_argnums=(2, 3, 4, 5, 6))

h5_files = glob.glob(os.path.join("..\\output", "*.h5"))

eta_hat, phi_hat, Hs, Tp, modes, time, length, x = None, None, None, None, None, None, None, None

for i, file in enumerate(h5_files):
    with h5py.File(file, "r") as data:
        eta_hat = data["eta_hat"][:]
        phi_hat = data["phi_hat"][:]
        Hs = data["Hs"][:]
        Tp = data["Tp"][:]
        time = data["time"][:]

        modes = data.attrs["modes"]
        length = data.attrs["length"]
        Ta = data.attrs["Ta"]
        x = np.linspace(0, length, 2*modes)

    break

index = np.argmin(np.abs(time - 2*Ta))

eta_hat = eta_hat[:, index:, :]
phi_hat = phi_hat[:, index:, :]
time = time[index:] - time[index]

In [None]:
import torch
import torch.nn as nn
from torchdiffeq import odeint

class ODEFunc(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden_dim, 50),
            nn.Tanh(),
            nn.Linear(50, hidden_dim),
        )

    def forward(self, t, y):
        return self.net(y)

t = torch.from_numpy(time).float()
y0 = torch.from_numpy(irfft(eta_hat[0, 0, :])).float().unsqueeze(0)
y_true = torch.from_numpy(irfft(eta_hat[0, :, :])).float().unsqueeze(1)

func     = ODEFunc(hidden_dim=1024)
optimizer = torch.optim.Adam(func.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# 4) Training loop
for epoch in range(2000):
    optimizer.zero_grad()
    y_pred = odeint(func, y0, t)

    loss = criterion(y_pred, y_true)
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch:4d} loss = {loss.item():.6f}")


Epoch    0 loss = 69939.578125
Epoch    1 loss = 88591.718750
Epoch    2 loss = 91115.648438
Epoch    3 loss = 90060.375000
Epoch    4 loss = 91569.617188
Epoch    5 loss = 90367.218750
Epoch    6 loss = 89532.710938
Epoch    7 loss = 90590.343750
Epoch    8 loss = 89536.140625
Epoch    9 loss = 89774.835938
Epoch   10 loss = 89206.367188
Epoch   11 loss = 85601.859375
Epoch   12 loss = 83039.476562
Epoch   13 loss = 81603.359375
Epoch   14 loss = 77887.789062
Epoch   15 loss = 77579.640625
Epoch   16 loss = 42331.128906
Epoch   17 loss = 18156.230469
Epoch   18 loss = 35696.414062
Epoch   19 loss = 29407.091797
Epoch   20 loss = 27663.591797
Epoch   21 loss = 18646.773438
Epoch   22 loss = 15715.579102
Epoch   23 loss = 10628.051758
Epoch   24 loss = 9951.967773
Epoch   25 loss = 9289.284180
Epoch   26 loss = 8237.560547
Epoch   27 loss = 6592.616211
Epoch   28 loss = 6149.274902
Epoch   29 loss = 4980.509766
Epoch   30 loss = 5312.329102
Epoch   31 loss = 5109.109375
Epoch   32 loss 