In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchdiffeq import odeint

# ---------- Vessel Dynamics Parameters ----------
M = torch.tensor([[3980.0, 0, 0],
                  [0, 3980.0, 0],
                  [0, 0, 19703.0]], dtype=torch.float32)
D_l = torch.tensor([50.0, 200.0, 1281.0], dtype=torch.float32)
D_q = torch.tensor([135.0, 2000.0, 0.0], dtype=torch.float32)
D_c = torch.tensor([0.0, 0.0, 3224.0], dtype=torch.float32)

# ---------- Dynamics Function ----------
def vessel_dynamics(t, state, tau):
    eta = state[:3]  # [x, y, psi]
    nu = state[3:]   # [u, v, r]

    u, v, r = nu
    psi = eta[2]

    D = D_l + D_q * nu.abs() + D_c * nu ** 2
    C = torch.tensor([
        [0, 0, -3980.0 * v],
        [0, 0, 3980.0 * u],
        [3980.0 * v, -3980.0 * u, 0]
    ], dtype=torch.float32)

    nu_dot = torch.linalg.solve(M, tau - D * nu - C @ nu)

    c, s = torch.cos(psi), torch.sin(psi)
    R = torch.tensor([[c, -s, 0], [s, c, 0], [0, 0, 1]], dtype=torch.float32)
    eta_dot = R @ nu

    return torch.cat([eta_dot, nu_dot])

# ---------- Generate Data ----------
def generate_dataset(N_samples=100, T=10.0, dt=0.1):
    time = torch.arange(0, T, dt)
    dataset = []
    for _ in range(N_samples):
        tau_signal = torch.rand((len(time), 3)) * 1000  # random force inputs
        state0 = torch.zeros(6)  # [x, y, psi, u, v, r]

        def dyn(t, state):
            idx = int(min(t.item() / dt, len(time) - 1))
            return vessel_dynamics(t, state, tau_signal[idx])

        traj = odeint(dyn, state0, time)
        dataset.append((tau_signal, traj))

    return dataset, time

# ---------- DeepONet Architecture ----------
class BranchNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 128), nn.ReLU(),
            nn.Linear(128, 128), nn.ReLU(),
            nn.Linear(128, output_dim)
        )

    def forward(self, tau_seq):
        return self.net(tau_seq.view(tau_seq.shape[0], -1))

class TrunkNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 128), nn.ReLU(),
            nn.Linear(128, 128), nn.ReLU(),
            nn.Linear(128, output_dim)
        )

    def forward(self, t):
        return self.net(t)

class DeepONet(nn.Module):
    def __init__(self, branch_dim, trunk_dim, output_dim):
        super().__init__()
        self.branch = BranchNet(branch_dim, 128)
        self.trunk = TrunkNet(trunk_dim, 128)
        self.fc = nn.Linear(128, output_dim)

    def forward(self, tau_seq, t):
        B = self.branch(tau_seq)
        T = self.trunk(t)
        return self.fc(B * T)

# ---------- Training ----------
if __name__ == '__main__':
    # Generate dataset
    dataset, time = generate_dataset(N_samples=20, T=5.0, dt=0.1)

    # Prepare data
    inputs = []
    targets = []
    for tau_seq, traj in dataset:
        for i in range(len(time)):
            inputs.append((tau_seq, time[i].unsqueeze(0)))
            targets.append(traj[i])

    tau_batch = torch.stack([inp[0] for inp in inputs])
    t_batch = torch.stack([inp[1] for inp in inputs])
    y_batch = torch.stack(targets)

    model = DeepONet(branch_dim=tau_batch.shape[1] * 3, trunk_dim=1, output_dim=6)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(1000):
        model.train()
        pred = model(tau_batch, t_batch)
        loss = ((pred - y_batch)**2).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if epoch % 100 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.6f}")

    # ---------- Testing and Comparison ----------
    test_tau, test_time = generate_dataset(N_samples=1, T=5.0, dt=0.1)
    tau_seq, true_traj = test_tau[0]

    # DeepONet prediction
    t_query = test_time.unsqueeze(1)
    tau_input = tau_seq.unsqueeze(0).repeat(len(t_query), 1, 1)
    model.eval()
    with torch.no_grad():
        pred_traj = model(tau_input, t_query)

    # Plot comparison
    labels = ["x", "y", "psi", "u", "v", "r"]
    for i in range(6):
        plt.plot(test_time.numpy(), true_traj[:, i].numpy(), label=f"True {labels[i]}")
        plt.plot(test_time.numpy(), pred_traj[:, i].numpy(), '--', label=f"Pred {labels[i]}")
        plt.xlabel("Time [s]")
        plt.ylabel(labels[i])
        plt.legend()
        plt.title(f"DeepONet vs ODE: {labels[i]}")
        plt.grid(True)
        plt.show()




ModuleNotFoundError: No module named 'torchdiffeq'