In [1]:
import torchode as to
import torch
import torch.nn as nn

In [2]:
class Model(nn.Module):
    def __init__(self, n_features, n_hidden):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(n_features, n_hidden),
            nn.Softplus(),
            nn.Linear(n_hidden, n_hidden),
            nn.Softplus(),
            nn.Linear(n_hidden, n_features)
        )
    
    def forward(self, t, y):
        return self.layers(y) + torch.sin(t).unsqueeze(1) * 0.1 * y
    
n_features = 5
model = Model(n_features, 16)

In [3]:
batch_size = 16
n_steps = 2
y0 = torch.randn((batch_size, n_features))
t_eval = torch.linspace(0.0, 1.0, n_steps)

In [4]:
def score_fn_single(t, x):
    x = x.unsqueeze(0)  # Add batch dimension
    t = t.unsqueeze(0)  # Add batch dimension
    score = model(t, x)[0]  # Remove batch dimension
    return score, score

jac = torch.func.jacrev(score_fn_single, argnums=1, has_aux=True)
def divergence_and_score_fn(t, x):
    jacobian, score = jac(t, x)
    divergence = jacobian.trace()
    return divergence, score
divergence_and_score_fn_batched = torch.func.vmap(divergence_and_score_fn, in_dims=(0, 0), out_dims=(0, 0))

def batched_ode_fn(t, y):
    x, ll = y[..., :-1], y[..., -1]
    dll_dt, dx_dt = divergence_and_score_fn_batched(t, x) 
    return torch.cat((dx_dt, dll_dt.unsqueeze(-1)), dim=-1)

In [5]:
batch = 16
x = torch.randn((batch, n_features))
ll = torch.zeros((batch, 1))
y = torch.cat((x, ll), dim=-1)

In [6]:
term = to.ODETerm(batched_ode_fn)
step_method = to.Dopri5(term=term)
step_size_controller = to.IntegralController(atol=1e-9, rtol=1e-7, term=term)
adjoint = to.AutoDiffAdjoint(step_method, step_size_controller)


In [7]:
problem = to.InitialValueProblem(y0=torch.cat((y0, torch.zeros(y0.shape[:-1] +  (1,))), dim=-1), t_eval=t_eval.repeat((batch_size, 1)))
sol = adjoint.solve(problem)

In [8]:
sol.ys.shape

torch.Size([16, 2, 6])

In [9]:
sol.stats

{'n_f_evals': tensor([44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44]),
 'n_steps': tensor([7, 6, 7, 6, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 7, 7]),
 'n_accepted': tensor([7, 6, 6, 6, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 7, 7]),
 'n_initialized': tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])}