Cell 1 — Imports

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

from physinet.marine.gym import MarineWaveGym
from physinet.marine.fno_marine import MarineFNO
from physinet.marine.decision_support import compute_operability_score


Cell 2 — Generate Synthetic Dataset

In [None]:
gym = MarineWaveGym(nx=64, ny=64, nt=20, dt=0.1)
dataset = gym.rollout_dataset(n_samples=20)

print("Generated samples:", len(dataset))


Cell 3 — Prepare Training Tensors


input: eta[:,:,0] (initial wave elevation)

target: eta[:,:,10] (future prediction)

In [None]:
inputs = []
targets = []

for sample in dataset:
    eta = sample["eta"]  # shape [nx,ny,nt]
    x0 = eta[:, :, 0]
    xt = eta[:, :, 10]

    inputs.append(x0[..., None])   # tambah channel
    targets.append(xt[..., None])

inputs = torch.tensor(np.stack(inputs), dtype=torch.float32)
targets = torch.tensor(np.stack(targets), dtype=torch.float32)

print("Input shape:", inputs.shape)   # [B,64,64,1]
print("Target shape:", targets.shape) # [B,64,64,1]


Cell 4 — Instantiate Model

In [None]:
model = MarineFNO(modes=12, width=32)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()


Cell 5 — Training Loop 
Demo 3-5

In [None]:
for epoch in range(5):
    optimizer.zero_grad()
    pred = model(inputs)
    loss = loss_fn(pred, targets)
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch+1}, Loss = {loss.item():.5f}")


Cell 6 — Inference on One Sample

In [None]:
idx = 0
x0 = inputs[idx:idx+1]
yt = targets[idx].squeeze().detach().numpy()

with torch.no_grad():
    y_pred = model(x0).squeeze().numpy()


Cell 7 — Visualization

In [None]:
plt.figure(figsize=(12,4))

plt.subplot(1,3,1)
plt.title("Initial Wave η(x,y,t0)")
plt.imshow(x0.squeeze())
plt.colorbar()

plt.subplot(1,3,2)
plt.title("True Wave η at t=Δt")
plt.imshow(yt)
plt.colorbar()

plt.subplot(1,3,3)
plt.title("Predicted Wave")
plt.imshow(y_pred)
plt.colorbar()

plt.tight_layout()
plt.show()


Cell 8 — Operability Score

In [None]:
score = compute_operability_score(y_pred, threshold=0.5)
print("Operability Score:", score)
