In [1]:
import sys
import os
import glob
import h5py
import numpy as np
import matplotlib.pyplot as plt
from numpy.fft import rfft, irfft
from matplotlib.animation import FuncAnimation

current_dir = os.getcwd()
sys.path.append(os.path.abspath(os.path.join(current_dir, '..')))
h5_files = glob.glob(os.path.join("..\\output", "*.h5"))

eta_hat, phi_hat, Hs, Tp, modes, time, length, x = None, None, None, None, None, None, None, None

with h5py.File("Z:\\files\\simulation.h5", "r") as data:
    eta_hat = data["eta_hat"][:]
    phi_hat = data["phi_hat"][:]
    Hs = data["Hs"][:]
    Tp = data["Tp"][:]
    time = data["time"][:]

    modes = data.attrs["modes"]
    length = data.attrs["length"]
    Ta = data.attrs["Ta"]
    x = np.linspace(0, length, 2*modes)

index = np.argmin(np.abs(time - 2*Ta))

eta_hat = eta_hat[:, index:, :]
phi_hat = phi_hat[:, index:, :]
time = time[index:] - time[index]

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import EarlyStopping
from numpy.fft import irfft, rfft

cut_index = np.argmin(np.abs(x-2000))
cut_index_2 = np.argmin(np.abs(x-2700))+1

In [None]:
# Creating dataset

prediction_time = 280 # 140 sec
measure_time = 120 # one minute
num_measurements = 6
step = int(measure_time / num_measurements)

train_percentage = 0.7
batch_size = 32

X = eta_hat[0, :-prediction_time, :]
y = eta_hat[0, prediction_time+measure_time:-1, :]

X = irfft(X)
y = irfft(y)

X[:, cut_index:] = 0

# Create dataset using 6 measurements with 10 sec inbetween

X = np.stack([
    X[0*step:-measure_time+0*step-1],
    X[1*step:-measure_time+1*step-1],
    X[2*step:-measure_time+2*step-1],
    X[3*step:-measure_time+3*step-1],
    X[4*step:-measure_time+4*step-1],
    X[5*step:-measure_time+5*step-1],
    X[6*step:-measure_time+6*step-1],
], axis=1)

# Split into train and test

X_train = X[:int(X.shape[0]*train_percentage), :, :]
X_test = X[int(X.shape[0]*train_percentage):, :, :]
y_train = y[:int(y.shape[0]*train_percentage), :]
y_test = y[int(y.shape[0]*train_percentage):, :]

# Only need std to normalize 

std = np.std(X_train)
X_train = X_train
y_train = y_train

In [None]:
import jax
import jax.numpy as jnp

sys.path.append(os.path.abspath(os.path.join(current_dir, '..')))

from HOSim import solver

f_jit = jax.jit(solver.f, static_argnums=(2, 3, 4, 5, 6))
rk4_step_jit = jax.jit(solver.rk4_step, static_argnums=(2, 3, 4, 5, 6, 7, 8))
k = np.arange(0, X_train.shape[-1]//2+1) * 2 * np.pi / length
g = 9.81

L1 = np.ones(X_train.shape[-1]//2)*2
L1 = np.insert(L1, 0, 0)

L2 = 1.0 / (np.arange(1, X_train.shape[-1]//2+1) * 2 * np.pi / length) - g
L2 = np.insert(L2, 0, 0)

# for index in range(X_train.shape[0]):

index = 0

eta_0 = X_train[index, 0, :]
eta_hat_0 = rfft(eta_0)
phi_hat_0 = eta_hat_0[1:] * np.exp(-1.j * np.pi / 2) * np.sqrt(g / k[1:])
phi_hat_0 = np.insert(phi_hat_0, 0, 0)

y_hat = jnp.asarray(np.concatenate((eta_hat_0, phi_hat_0)), dtype=jnp.complex128)

eta_saved = np.zeros((6 * 20 + 1, eta_0.shape[0]))
eta_obs = X_train[index, 0, :]

for i in range(6+14):
    for j in range(400):
        y_hat = rk4_step_jit(200, y_hat, 0.025, X_train.shape[-1]//2, g, k[1], 8, 0.001, f_jit)

    if i < 6:
        eta = irfft(y_hat[:y_hat.shape[0]//2])
        
        eta_obs = X_train[index, i+1, :]
        eta_obs[cut_index:] = eta[cut_index:]
        eta_hat_obs = rfft(eta_obs)

        y_hat = y_hat.at[:y_hat.shape[0]//2].add(0.5 * L1 * (eta_hat_obs - y_hat[:y_hat.shape[0]//2]))
        y_hat = y_hat.at[y_hat.shape[0]//2:].add(0.5 * L2 * (eta_hat_obs - y_hat[:y_hat.shape[0]//2]))

In [None]:
errors = []

for bounces in range(20):
    X_train_local = X_train[index, :, :].copy()

    eta_0 = X_train_local[0, :]
    eta_hat_0 = rfft(eta_0)
    phi_hat_0 = eta_hat_0[1:] * np.exp(-1.j * np.pi / 2) * np.sqrt(g / k[1:])
    phi_hat_0 = np.insert(phi_hat_0, 0, 0)

    y_hat = jnp.asarray(np.concatenate((eta_hat_0, phi_hat_0)), dtype=jnp.complex128)

    eta_obs = X_train_local[0, :]

    for it in range(bounces):
        for i in range(6):
            for j in range(400):
                y_hat = rk4_step_jit(200, y_hat, 0.025, X_train.shape[-1]//2, g, k[1], 8, 0.001, f_jit)

            eta = irfft(y_hat[:y_hat.shape[0]//2])
            
            eta_obs = X_train_local[i+1, :]
            eta_obs[cut_index:] = eta[cut_index:]
            eta_hat_obs = rfft(eta_obs)

            X_train_local[i+1, :] = eta_obs # new

            y_hat = y_hat.at[:y_hat.shape[0]//2].add(0.5 * L1 * (eta_hat_obs - y_hat[:y_hat.shape[0]//2]))
            y_hat = y_hat.at[y_hat.shape[0]//2:].add(0.5 * L2 * (eta_hat_obs - y_hat[:y_hat.shape[0]//2]))

        y_hat = y_hat.at[y_hat.shape[0]//2:].multiply(jnp.exp(-1.j*jnp.pi))
        X_train_local[:, :] = X_train_local[::-1, :]

        for i in range(6):
            for j in range(400):
                y_hat = rk4_step_jit(200, y_hat, 0.025, X_train.shape[-1]//2, g, k[1], 8, 0.001, f_jit)

            if i < 6:
                eta = irfft(y_hat[:y_hat.shape[0]//2])
                
                eta_obs = X_train_local[i+1, :]
                eta_obs[cut_index:] = eta[cut_index:]
                eta_hat_obs = rfft(eta_obs)

                X_train_local[i+1, :] = eta_obs # new

                y_hat = y_hat.at[:y_hat.shape[0]//2].add(0.5 * L1 * (eta_hat_obs - y_hat[:y_hat.shape[0]//2]))
                y_hat = y_hat.at[y_hat.shape[0]//2:].add(0.5 * L2 * (eta_hat_obs - y_hat[:y_hat.shape[0]//2]))

        y_hat = y_hat.at[y_hat.shape[0]//2:].multiply(jnp.exp(1.j * np.pi))
        X_train_local[:, :] = X_train_local[::-1, :]

    eta_pred = irfft(y_hat[:y_hat.shape[0]//2])
    error = eta_pred[cut_index:cut_index_2] - y_train[index, cut_index:cut_index_2]
    errors.append(error)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np 

fig, ax = plt.subplots(figsize=(12, 6))
line, = ax.plot([], [], lw=2)

ax.set_xlim(x[0], x[-1])
ax.set_ylim(np.min(saved_data[0]), np.max(saved_data[0]))
ax.set_xlabel("x")
ax.set_ylabel("Amplitude")
ax.set_title("Wave evolution over time")
ax.grid(True)

def update(frame):
    ax.set_title(f"Time: {saved_time[frame]:.1f}s")

    y = saved_data[0, frame, :]
    line.set_data(x, y)
    return line,

ani = animation.FuncAnimation(fig, update, frames=saved_data.shape[1], blit=True)
ani.save("wave_animation.gif", writer='pillow', fps=15)

In [None]:
for i in range(6+14):
    for j in range(400):
        y_hat = rk4_step_jit(200, y_hat, 0.025, X_train.shape[-1]//2, g, k[1], 8, 0.001, f_jit)

    if i < 6:
        eta = irfft(y_hat[:y_hat.shape[0]//2])
        
        eta_obs = X_train_local[i+1, :]
        eta_obs[cut_index:] = eta[cut_index:]
        eta_hat_obs = rfft(eta_obs)

        X_train_local[i+1, :] = eta_obs # new

        y_hat = y_hat.at[:y_hat.shape[0]//2].add(0.5 * L1 * (eta_hat_obs - y_hat[:y_hat.shape[0]//2]))
        y_hat = y_hat.at[y_hat.shape[0]//2:].add(0.5 * L2 * (eta_hat_obs - y_hat[:y_hat.shape[0]//2]))

In [None]:
eta_pred = irfft(y_hat[:y_hat.shape[0]//2])
plt.plot(x, eta_pred, label="Prediction")
plt.plot(x, y_train[index, :], label="Actual")
plt.xlim(x[cut_index], x[cut_index_2])
plt.legend()
plt.grid()
plt.show()

error = eta_pred[cut_index:cut_index_2] - y_train[index, cut_index:cut_index_2]

print(f"MAE: {np.mean(np.abs(error)):.5e}")
print(f"RMSE: {np.sqrt(np.mean(error**2)):.5e}")

In [None]:
mae = [0.091, 0.074, 0.071, 0.068, 0.065, 0.063, 0.061, 0.059, 0.057, 0.055, 0.054]
rmse = [0.111, 0.093, 0.086, 0.081, 0.077, 0.074, 0.072, 0.069, 0.068, 0.066, 0.065]

plt.plot(mae, label="MAE", color="k")
plt.plot(rmse, label="RMSE", color="b")
plt.ylim(0, rmse[0])
plt.xlim(0, len(mae)-1)
plt.grid()
plt.xlabel("Bounces [-]")
plt.ylabel("Error [m]")
plt.axhline(y=0.062, label="ML MAE", color="k", linestyle="--")
plt.axhline(y=0.080, label="ML RMSE", color="b", linestyle="--")
plt.legend()
plt.show()