In [2]:
import torch
from torchdiffeq import odeint
import torch.nn as nn

# Define control_func (dummy example for testing)
control_func = lambda t: torch.tensor([0.1])

# Define f and g as required by StabNODE
class Felu(nn.Module):
    def __init__(self, dim_in, dim_out, hidden_dim=2):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(dim_in, hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim, dim_out),
        )
    def forward(self, x):
        return -torch.exp(self.network(x))

class Gelu(nn.Module):
    def __init__(self, dim_in, dim_out, hidden_dim=2):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(dim_in, hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim, dim_out),
            nn.Tanh(),
        )
    def forward(self, x, u):
        xu = torch.cat([x, u], dim=-1)
        return self.network(xu)

class StabNODE(nn.Module):
    def __init__(self, f, g):
        super().__init__()
        self.f = f
        self.g = g
    def forward(self, t, state, u_func):
        x = state
        u = u_func(t).unsqueeze(0)  # Ensure u is reshaped for batching consistency
        fx = self.f(x)
        gx = self.g(x, u)
        return fx * (x - gx)

# Instantiate f and g modules
f = Felu(dim_in=1, dim_out=1)
g = Gelu(dim_in=2, dim_out=1)

# Instantiate the StabNODE model
model = StabNODE(f=f, g=g)

# Sample data for testing
x0_batch = torch.tensor([0.5, 1.0, 1.5]).reshape(-1, 1)  # Batch of initial conditions (reshaped for consistency)
t_batch = torch.tensor([
    [0., 1., 2., 3.],
    [0., 0.5, 1., 1.5],
    [0., 0.75, 1.5, 2.25]
])  # Batch of t_spans (2D tensor).


In [4]:

# Define batched ODE integration
def batched_odeint(x0_batch, t_batch):
    batch_size = x0_batch.size(0)
    
    # Time spans need to be uniform for all batches in odeint
    t_span_uniform = torch.linspace(0., t_batch[:, -1].max(), steps=t_batch.size(1))
    print(t_span_uniform)
    
    def batch_func(t, states):
        u = control_func(t).repeat(states.size(0), 1)  # Repeat control function across batch
        fx = model.f(states)
        gx = model.g(states, u)
        return fx * (states - gx)

    # Batched integration
    y0 = x0_batch.reshape(-1, 1)
    trajectories = odeint(batch_func, y0, t_span_uniform, method='rk4') 
    
    return trajectories.squeeze()

# Solve batched ODEs
x_pred_all = batched_odeint(x0_batch, t_batch)

print("Batch of predicted trajectories:")
print(x_pred_all)

tensor([0., 1., 2., 3.])
Batch of predicted trajectories:
tensor([[0.5000, 1.0000, 1.5000],
        [0.3475, 0.5263, 0.7340],
        [0.2977, 0.3563, 0.4282],
        [0.2818, 0.3005, 0.3237]], grad_fn=<SqueezeBackward0>)
