In [2]:
import sys
import os
import glob
import h5py
import numpy as np
import matplotlib.pyplot as plt
from numpy.fft import rfft2, irfft2
from matplotlib.animation import FuncAnimation

current_dir = os.getcwd()
sys.path.append(os.path.abspath(os.path.join(current_dir, '..')))
h5_files = glob.glob(os.path.join("..\\output", "*.h5"))

y, Hs, Tp, modes, time, length, x = None, None, None, None, None, None, None

with h5py.File("Z:\\files\\simulation_2d_compressed.h5", "r") as data:
    y = data["y"][:]
    Hs = data.attrs["Hs"]
    Tp = data.attrs["Tp"]

    modes = data.attrs["modes"]
    length = data.attrs["length"]
    Ta = data.attrs["Ta"]
    x = np.linspace(0, length, 2*modes)

index = 100

In [3]:
mHOS = 4

eta_hat = y[index:, 0, : :].copy()
modes = eta_hat.shape[-2] // 2
x = np.linspace(0, length, 2*modes)

mes_index_1 = np.argmin(np.abs(x - 1300))
mes_index_2 = np.argmin(np.abs(x - 1700))

eta = irfft2(eta_hat)
eta[:, mes_index_1:mes_index_2, mes_index_1:mes_index_2] = 0

##
for i in range(mes_index_1, mes_index_2):
    eta[0, mes_index_1:mes_index_2, i] = (eta[0, mes_index_1-1, i] + eta[0, mes_index_2+1, i]) * 0.5 # juks?
##

eta_hat = rfft2(eta).astype(np.complex64)

# plt.plot(np.linspace(0, 1, 2*modes), np.fft.irfft2(eta_hat[0, :, :])[0, :])

modes = eta_hat.shape[-2] // 2
alias_mask = np.arange(modes+1) < (modes * 2 / (mHOS + 1) + 1) * 0.8
alias_mask_long = np.concatenate((alias_mask, alias_mask[2:][::-1]))
new_modes = np.sum(alias_mask)-1
eta_hat = eta_hat[:, alias_mask_long, :]
eta_hat = eta_hat[:, :, alias_mask]
eta_hat *= new_modes * new_modes / (modes * modes)
modes = new_modes

# plt.plot(np.linspace(0, 1, 2*modes), np.fft.irfft2(eta_hat[0, :, :])[0, :])
# plt.xlim(0, 0.1)
# plt.show()

In [4]:
prediction_time = 60 # 140 sec
measure_time = 60 # one minute
num_measurements = 6
step = int(measure_time / num_measurements)
train_percentage = 0.8

X = eta_hat[:-prediction_time]
y = eta_hat[prediction_time+measure_time:-1]

X = irfft2(X)
y = irfft2(y)

X = np.stack([
    X[0*step:-measure_time+0*step-1],
    X[1*step:-measure_time+1*step-1],
    X[2*step:-measure_time+2*step-1],
    X[3*step:-measure_time+3*step-1],
    X[4*step:-measure_time+4*step-1],
    X[5*step:-measure_time+5*step-1],
    X[6*step:-measure_time+6*step-1],
], axis=1)

X_train = X[:int(X.shape[0]*train_percentage), :, :]
X_test = X[int(X.shape[0]*train_percentage):, :, :]
y_train = y[:int(y.shape[0]*train_percentage), :]
y_test = y[int(y.shape[0]*train_percentage):, :]

# Only need std to normalize

std = np.std(X_train)
X_train = X_train / std
y_train = y_train / std

print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)

(624, 7, 328, 328) (624, 328, 328)
(156, 7, 328, 328) (156, 328, 328)


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from FNO import FNO2d

device = torch.device("cuda")

val_mask = np.zeros(X_train.shape[0], dtype=bool)
val_mask[:int(0.1 * X_train.shape[0])] = True
np.random.shuffle(val_mask)

# to PyTorch tensors
X_t = torch.from_numpy(X_train[~val_mask]).float()
y_t = torch.from_numpy(y_train[~val_mask]).float()

X_v = torch.from_numpy(X_train[val_mask]).float()
y_v = torch.from_numpy(y_train[val_mask]).float()

dataset = TensorDataset(X_t, y_t)
loader  = DataLoader(dataset, batch_size=25, shuffle=True)

dataset_val = TensorDataset(X_v, y_v)
loader_val  = DataLoader(dataset_val, batch_size=25, shuffle=True)

model = FNO2d(in_channels=7, out_channels=1, width=16, modes_height=30, modes_width=150, depth=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

early_stopping_rounds = 3

early_stopping_count = 0
best_loss = np.inf

loss_mask = np.ones(y_train.shape[1:])
loss_mask[mes_index_1:mes_index_2, mes_index_1:mes_index_2] = 5
loss_mask = torch.tensor(loss_mask, dtype=torch.float32).to(device)
loss_correction = 4 * modes * modes

model.train()
for epoch in range(1, 101):
    total_loss = 0.0
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device)
        optimizer.zero_grad()
        pred = model(xb)
        loss = criterion(pred * loss_mask, yb * loss_mask) * loss_correction / torch.sum(loss_mask)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg = total_loss / len(loader)

    total_loss_val = 0.0
    for xb, yb in loader_val:
        xb = xb.to(device)
        yb = yb.to(device)
        optimizer.zero_grad()
        pred = model(xb)
        loss = criterion(pred * loss_mask, yb * loss_mask) * loss_correction / torch.sum(loss_mask)
        total_loss_val += loss.item()
    avg_val = total_loss / len(loader_val)

    if avg_val < best_loss:
        best_loss = avg_val
        early_stopping_count = 0
        torch.save(model.state_dict(), f"results/best_FNO_2d.pt")
        print(f"Epoch: {epoch}, Loss train - {avg:.6f}, Loss val - {avg_val:.6f} - Saving")
    else:
        print(f"Epoch: {epoch}, Loss train - {avg:.6f}, Loss val - {avg_val:.6f}")
        early_stopping_count += 1
        if early_stopping_count > early_stopping_rounds:
            print("early_stopping")
            break

Epoch: 1, Loss train - 0.744830, Loss val - 5.710367 - Saving
Epoch: 2, Loss train - 0.218387, Loss val - 1.674300 - Saving
Epoch: 3, Loss train - 0.110089, Loss val - 0.844014 - Saving
Epoch: 4, Loss train - 0.091685, Loss val - 0.702916 - Saving
Epoch: 5, Loss train - 0.081920, Loss val - 0.628054 - Saving
Epoch: 6, Loss train - 0.074261, Loss val - 0.569333 - Saving
Epoch: 7, Loss train - 0.067592, Loss val - 0.518209 - Saving
Epoch: 8, Loss train - 0.061416, Loss val - 0.470853 - Saving
Epoch: 9, Loss train - 0.056168, Loss val - 0.430618 - Saving
Epoch: 10, Loss train - 0.051308, Loss val - 0.393364 - Saving
Epoch: 11, Loss train - 0.046812, Loss val - 0.358890 - Saving
Epoch: 12, Loss train - 0.042826, Loss val - 0.328331 - Saving
Epoch: 13, Loss train - 0.039091, Loss val - 0.299698 - Saving
Epoch: 14, Loss train - 0.035775, Loss val - 0.274274 - Saving
Epoch: 15, Loss train - 0.032872, Loss val - 0.252015 - Saving
Epoch: 16, Loss train - 0.030314, Loss val - 0.232404 - Saving
E