In [69]:
import torchode as to
import torch
import torch.nn as nn

In [70]:
class Model(nn.Module):
    def __init__(self, n_features, n_hidden):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(n_features, n_hidden),
            nn.Softplus(),
            nn.Linear(n_hidden, n_hidden),
            nn.Softplus(),
            nn.Linear(n_hidden, n_features)
        )
    
    def forward(self, t, y):
        return self.layers(y) + torch.sin(t).unsqueeze(1) * 0.1 * y
    
n_features = 5
model = Model(n_features, 16)

In [71]:
batch_size = 16
n_steps = 2
y0 = torch.randn((batch_size, n_features))
t_eval = torch.linspace(0.0, 1.0, n_steps)

In [90]:
def score_fn_single(t, x):
    x = x.unsqueeze(0)  # Add batch dimension
    t = t.unsqueeze(0)  # Add batch dimension
    score = model(t, x)[0]  # Remove batch dimension
    return score, score

jac = torch.func.jacrev(score_fn_single, argnums=1, has_aux=True)
def divergence_and_score_fn(t, x):
    jacobian, score = jac(t, x)
    divergence = jacobian.trace()
    return divergence, score
divergence_and_score_fn_batched = torch.func.vmap(divergence_and_score_fn, in_dims=(0, 0), out_dims=(0, 0))

def batched_ode_fn(t, y):
    x, ll = y[..., :-1], y[..., -1]
    dll_dt, dx_dt = divergence_and_score_fn_batched(t, x) 
    return torch.cat((dx_dt, dll_dt.unsqueeze(-1)), dim=-1)

In [91]:
term = to.ODETerm(batched_ode_fn)
step_method = to.Dopri5(term=term)
step_size_controller = to.IntegralController(atol=1e-9, rtol=1e-7, term=term)
adjoint = to.AutoDiffAdjoint(step_method, step_size_controller)


In [92]:
torch.zeros(y0.shape[:-1])

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [93]:
y0.shape

torch.Size([16, 5])

In [94]:
problem = to.InitialValueProblem(y0=torch.cat((y0, torch.zeros(y0.shape[:-1] +  (1,))), dim=-1), t_eval=t_eval.repeat((batch_size, 1)))
sol = adjoint.solve(problem)

In [95]:
sol.ys[:,-1]

tensor([[ 0.7569, -0.1904, -0.7325, -1.0537, -1.1965,  0.2043],
        [ 0.7500, -1.2704,  0.5399,  0.2933,  0.3101,  0.1927],
        [-0.0919, -0.1947, -0.6431,  0.4204, -0.7464,  0.2044],
        [ 1.3058, -1.2387,  1.8141, -0.6707,  0.8087,  0.1848],
        [-1.3181,  1.6815,  1.7555,  0.0428, -0.3510,  0.1979],
        [-1.0591,  0.9934,  0.9641, -0.4529, -0.6001,  0.1984],
        [ 0.9482, -0.0202, -0.4817,  0.5545,  1.3114,  0.1924],
        [-1.5128,  2.1030, -0.1650, -1.5705,  0.2192,  0.1952],
        [ 2.0713,  0.8621, -0.1511, -0.5245, -1.6747,  0.2104],
        [ 0.0797,  0.4518,  0.2003, -0.1748, -0.8849,  0.2027],
        [ 1.6497,  2.2991,  0.0074, -0.0812,  0.4267,  0.2011],
        [ 1.3782,  0.7549, -0.5366, -1.5030, -0.0769,  0.1974],
        [-0.3231, -1.1820, -1.1619,  0.8262, -1.2616,  0.2088],
        [ 1.5096,  1.7350, -0.4104,  1.0505,  1.3044,  0.1996],
        [-0.2271,  0.7034, -3.1854, -1.9670, -0.4570,  0.2025],
        [ 0.2292, -0.5983, -1.1174,  0.0

In [97]:
sol.stats

{'n_f_evals': tensor([44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44]),
 'n_steps': tensor([6, 7, 7, 6, 6, 6, 6, 6, 6, 7, 6, 6, 6, 6, 6, 6]),
 'n_accepted': tensor([6, 7, 7, 6, 6, 6, 6, 6, 6, 7, 6, 6, 6, 6, 6, 6]),
 'n_initialized': tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])}