In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchdiffeq import odeint

# ---------- Device Configuration ----------
device = (
    torch.device("cuda") if torch.cuda.is_available() else
    torch.device("mps") if torch.backends.mps.is_available() else
    torch.device("cpu")
)
print(f"Using device: {device}")

# ---------- Vessel Dynamics Parameters ----------
M = torch.tensor([[3980.0, 0, 0],
                  [0, 3980.0, 0],
                  [0, 0, 19703.0]], dtype=torch.float32, device=device)
D_l = torch.tensor([50.0, 200.0, 1281.0], dtype=torch.float32, device=device)
D_q = torch.tensor([135.0, 2000.0, 0.0], dtype=torch.float32, device=device)
D_c = torch.tensor([0.0, 0.0, 3224.0], dtype=torch.float32, device=device)

# ---------- Dynamics Function ----------
def vessel_dynamics(t, state, tau):
    eta = state[:3]
    nu = state[3:]
    u, v, r = nu
    psi = eta[2]

    D = D_l + D_q * nu.abs() + D_c * nu ** 2
    C = torch.tensor([
        [0, 0, -3980.0 * v],
        [0, 0, 3980.0 * u],
        [3980.0 * v, -3980.0 * u, 0]
    ], dtype=torch.float32, device=device)

    nu_dot = torch.linalg.solve(M, tau - D * nu - C @ nu)
    c, s = torch.cos(psi), torch.sin(psi)
    R = torch.tensor([[c, -s, 0], [s, c, 0], [0, 0, 1]], dtype=torch.float32, device=device)
    eta_dot = R @ nu
    return torch.cat([eta_dot, nu_dot])

# ---------- Generate Data ----------
def generate_dataset(N_samples=100, T=10.0, dt=0.1, device='cpu'):
    time = torch.arange(0, T, dt, device=device)
    dataset = []
    for _ in range(N_samples):
        tau_signal = torch.rand((len(time), 3), device=device) * 1000
        state0 = torch.zeros(6, device=device)

        def dyn(t, state):
            idx = int(min(t.item() / dt, len(time) - 1))
            return vessel_dynamics(t, state, tau_signal[idx])

        traj = odeint(dyn, state0, time)
        dataset.append((tau_signal, traj))

    return dataset, time

# ---------- DeepONet Architecture ----------
class BranchNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 128), nn.ReLU(),
            nn.Linear(128, 128), nn.ReLU(),
            nn.Linear(128, output_dim)
        )

    def forward(self, tau_seq):
        return self.net(tau_seq.view(tau_seq.shape[0], -1))

class TrunkNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 128), nn.ReLU(),
            nn.Linear(128, 128), nn.ReLU(),
            nn.Linear(128, output_dim)
        )

    def forward(self, t):
        return self.net(t)

class DeepONet(nn.Module):
    def __init__(self, branch_dim, trunk_dim, output_dim):
        super().__init__()
        self.branch = BranchNet(branch_dim, 128)
        self.trunk = TrunkNet(trunk_dim, 128)
        self.fc = nn.Linear(128, output_dim)

    def forward(self, tau_seq, t):
        B = self.branch(tau_seq)
        T = self.trunk(t)
        return self.fc(B * T)

Using device: cpu


In [None]:
# ---------- Training ----------
if __name__ == '__main__':
    dataset, time = generate_dataset(N_samples=20, T=50.0, dt=0.1, device=device)

    inputs = []
    targets = []
    for tau_seq, traj in dataset:
        for i in range(len(time)):
            inputs.append((tau_seq, time[i].unsqueeze(0)))
            targets.append(traj[i])

    tau_batch = torch.stack([inp[0] for inp in inputs]).to(device)
    t_batch = torch.stack([inp[1] for inp in inputs]).to(device)
    y_batch = torch.stack(targets).to(device)

    model = DeepONet(branch_dim=tau_batch.shape[1] * 3, trunk_dim=1, output_dim=6).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(10000):
        model.train()
        pred = model(tau_batch, t_batch)
        loss = ((pred - y_batch)**2).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if epoch % 1000 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.6f}")

    # ---------- Testing and Comparison ----------
    test_tau, test_time = generate_dataset(N_samples=1, T=50.0, dt=0.1, device=device)
    tau_seq, true_traj = test_tau[0]

    t_query = test_time.unsqueeze(1)
    tau_input = tau_seq.unsqueeze(0).repeat(len(t_query), 1, 1)

    model.eval()
    with torch.no_grad():
        pred_traj = model(tau_input.to(device), t_query.to(device))

    labels = ["x", "y", "psi", "u", "v", "r"]
    for i in range(6):
        plt.plot(test_time.cpu().numpy(), true_traj[:, i].cpu().numpy(), label=f"True {labels[i]}")
        plt.plot(test_time.cpu().numpy(), pred_traj[:, i].cpu().numpy(), '--', label=f"Pred {labels[i]}")
        plt.xlabel("Time [s]")
        plt.ylabel(labels[i])
        plt.legend()
        plt.title(f"DeepONet vs ODE: {labels[i]}")
        plt.grid(True)
        plt.show()

In [None]:
# ---------- Training ----------
if __name__ == '__main__':
#    dataset, time = generate_dataset(N_samples=20, T=50.0, dt=0.1, device=device)

    inputs = []
    targets = []
    for tau_seq, traj in dataset:
        for i in range(len(time)):
            inputs.append((tau_seq, time[i].unsqueeze(0)))
            targets.append(traj[i])

    tau_batch = torch.stack([inp[0] for inp in inputs]).to(device)
    t_batch = torch.stack([inp[1] for inp in inputs]).to(device)
    y_batch = torch.stack(targets).to(device)

    model = DeepONet(branch_dim=tau_batch.shape[1] * 3, trunk_dim=1, output_dim=6).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(50000):
        model.train()
        pred = model(tau_batch, t_batch)
        loss = ((pred - y_batch)**2).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if epoch % 1000 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.6f}")

In [None]:
# ---------- Test DeepONet on Constant Input Case ----------

# Define constant input
T_test = 50.0
dt = 0.1
test_time = torch.arange(0, T_test, dt, device=device)
n_steps = len(test_time)
tau_const = torch.tensor([[500.0, 0.0, 200.0]], device=device).repeat(n_steps, 1)

# Prepare DeepONet input
# match training input size
L_train = model.branch.net[0].in_features // 3  # infer sequence length used in training

# downsample or slice tau to length L_train
tau_seq_test = tau_const[:L_train]  # shape: [L_train, 3]
tau_input = tau_seq_test.unsqueeze(0).repeat(n_steps, 1, 1)  # [n_steps, L_train, 3]

t_query = test_time.unsqueeze(1)

# Predict trajectory using DeepONet
model.eval()
with torch.no_grad():
    pred_traj = model(tau_input, t_query).cpu().numpy()

# Simulate ground truth using numerical integration
x = np.zeros((6, n_steps))
x[:, 0] = [0, 0, 0, 0, 0, 0]

M = np.array([[3980.0, 0, 0],
              [0, 3980.0, 0],
              [0, 0, 19703.0]])
D_l = np.array([[50.0, 0, 0],
                [0, 200.0, 0],
                [0, 0, 1281.0]])
D_q = np.array([[135.0, 0, 0],
                [0, 2000.0, 0],
                [0, 0, 0.0]])
D_c = np.array([[0.0, 0, 0],
                [0, 0.0, 0],
                [0, 0, 3224.0]])

def C_matrix(u, v, r):
    return np.array([[0, 0, -3980.0 * v],
                     [0, 0, 3980.0 * u],
                     [3980.0 * v, -3980.0 * u, 0]])

def R_rot(psi):
    return np.array([[np.cos(psi), -np.sin(psi), 0],
                     [np.sin(psi), np.cos(psi), 0],
                     [0, 0, 1]])

tau = np.array([500, 0, 200])
for i in range(n_steps - 1):
    u, v, r = x[3, i], x[4, i], x[5, i]
    D = D_l + D_q * np.abs([u, v, r]) + D_c * np.array([u**2, v**2, r**2])
    nu_dot = np.linalg.inv(M) @ (tau - D @ np.array([u, v, r]) - C_matrix(u, v, r) @ np.array([u, v, r]))
    x[3:, i+1] = x[3:, i] + dt * nu_dot
    eta_dot = R_rot(x[2, i]) @ np.array([u, v, r])
    x[:3, i+1] = x[:3, i] + dt * eta_dot

# ---------- Plot Comparison ----------
labels = ["x", "y", "psi", "u", "v", "r"]
for i in range(6):
    plt.plot(test_time.cpu(), x[i, :], label=f'Ground Truth {labels[i]}')
    plt.plot(test_time.cpu(), pred_traj[:, i], '--', label=f'DeepONet {labels[i]}')
    plt.xlabel("Time (s)")
    plt.ylabel(labels[i])
    plt.title(f"Constant Input Test: {labels[i]}")
    plt.grid(True)
    plt.legend()
    plt.show()


In [None]:
plt.plot(x[0, :], x[1, :], label=f'Ground Truth {labels[0]}')
plt.plot(pred_traj[:, 0], pred_traj[:, 1], '--', label=f'DeepONet {labels[0]}')