In [115]:
import torchdyn
import torch
from torch.autograd import grad
import torch.nn as nn
import matplotlib.pyplot as plt
from torchdyn.core import ODEProblem

import torchdiffeq
import time 
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [116]:
f = nn.Sequential(nn.Linear(1, 32), nn.SELU(), nn.Linear(32, 32), nn.SELU(), nn.Linear(32, 1))
prob = ODEProblem(f, solver='dopri5', sensitivity='adjoint', atol=1e-4, rtol=1e-4)

Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.


### Learning `T` from a target (3)

In [149]:
# torchdyn
x = torch.randn(1, 1, requires_grad=True)
t0 = torch.zeros(1)
T = torch.ones(1).requires_grad_(True)
opt = torch.optim.Adam((T,), lr=1e-2)

for i in range(2000):
    t_span = torch.cat([t0, T])
    t_eval, traj = prob(x, t_span)
    loss = ((t_span[-1:] - torch.tensor([5]))**2).mean()
    print(f'{loss}, {t_span}', end='\r')
    loss.backward(); opt.step(); opt.zero_grad()

2.3283064365386963e-10, tensor([0.0000, 5.0000], grad_fn=<CatBackward>)

In [148]:
# torchdiffeq
# we have to wrap for torchdiffeq
class VectorField(nn.Module):
    def __init__(self, f):
        super().__init__()
        self.f = f
    def forward(self, t, x):
        return self.f(x)
    
sys = VectorField(f)
x = torch.randn(1, 1, requires_grad=True)
t0 = torch.zeros(1)
T = torch.ones(1).requires_grad_(True)
opt = torch.optim.Adam((T,), lr=1e-2)

for i in range(2000):
    t_span = torch.cat([t0, T])
    traj = torchdiffeq.odeint_adjoint(sys, x, t_span, method='dopri5', atol=1e-4, rtol=1e-4)
    loss = ((t_span[-1:] - torch.tensor([5]))**2).mean()
    print(f'{loss}, {t_span}', end='\r')
    loss.backward(); opt.step(); opt.zero_grad()

2.3283064365386963e-10, tensor([0.0000, 5.0000], grad_fn=<CatBackward>)

#### Explicit loss on `T`, gradcheck

In [87]:
t_span = torch.cat([t0, T])
t_eval, traj = prob(x, t_span)
l = ((t_span[-1:] - torch.tensor([5]))**2).mean()
dldt_torchdyn = grad(l, T)[0]

t_span = torch.cat([t0, T])
traj = torchdiffeq.odeint_adjoint(sys, x, t_span, method='dopri5', atol=1e-4, rtol=1e-4)
l = ((t_span[-1:] - torch.tensor([5]))**2).mean()
dldt_torchdiffeq = grad(l, T)[0]

dldt_torchdyn - dldt_torchdiffeq

tensor([0.])

#### Explicit loss on `t0`, gradcheck

In [121]:
t0 = torch.zeros(1).requires_grad_(True)
T = torch.ones(1).requires_grad_(True)

t_span = torch.cat([t0, T])
t_eval, traj = prob(x, t_span)
l = ((t_span[:1] - torch.tensor([5]))**2).mean()
dldt_torchdyn = grad(l, t0)[0]

t_span = torch.cat([t0, T])
traj = torchdiffeq.odeint_adjoint(sys, x, t_span, method='dopri5', atol=1e-4, rtol=1e-4)
l = ((t_span[:1] - torch.tensor([5]))**2).mean()
dldt_torchdiffeq = grad(l, t0)[0]

dldt_torchdyn - dldt_torchdiffeq

tensor([0.])

#### Learning `xT` by stretching `T` (fixed vector field)

Note: vec field is always positive so we are sure to hit the target

In [129]:
f = nn.Sequential(nn.Linear(1, 32), nn.SELU(), nn.Linear(32, 1), nn.Softplus())
prob = ODEProblem(f, solver='dopri5', sensitivity='adjoint', atol=1e-4, rtol=1e-4)

Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.


In [142]:
# torchdyn

x = torch.zeros(1, 1, requires_grad=True) + 0.5
t0 = torch.zeros(1)
T = torch.ones(1).requires_grad_(True)
opt = torch.optim.Adam((T,), lr=1e-2)

for i in range(1000):
    t_span = torch.cat([t0, T])
    t_eval, traj = prob(x, t_span)
    loss = ((traj[-1] - torch.tensor([2]))**2).mean()
    print(f'L: {loss.item():.2f}, T: {t_span[-1].item():.2f}, xT: {traj[-1].item():.2f}', end='\r')
    loss.backward(); opt.step(); opt.zero_grad()

L: 0.00, T: 3.30, xT: 2.00

In [145]:
class VectorField(nn.Module):
    def __init__(self, f):
        super().__init__()
        self.f = f
    def forward(self, t, x):
        return self.f(x)

sys = VectorField(f)
x = torch.zeros(1, 1, requires_grad=True) + 0.5
t0 = torch.zeros(1)
T = torch.ones(1).requires_grad_(True)
opt = torch.optim.Adam((T,), lr=1e-2)

for i in range(1000):
    t_span = torch.cat([t0, T])
    traj = torchdiffeq.odeint_adjoint(sys, x, t_span, method='dopri5', atol=1e-4, rtol=1e-4)
    loss = ((traj[-1] - torch.tensor([2]))**2).mean()
    print(f'L: {loss.item():.2f}, T: {t_span[-1].item():.2f}, xT: {traj[-1].item():.2f}', end='\r')
    loss.backward(); opt.step(); opt.zero_grad()

L: 0.00, T: 3.89, xT: 2.00

In [146]:
x = torch.zeros(1, 1, requires_grad=True) + 0.5

t_span = torch.cat([t0, T])
t_eval, traj = prob(x, t_span)
l = ((traj[-1] - torch.tensor([5]))**2).mean()
dldt_torchdyn = grad(l, T)[0]

t_span = torch.cat([t0, T])
traj = torchdiffeq.odeint_adjoint(sys, x, t_span, method='dopri5', atol=1e-4, rtol=1e-4)
l = ((traj[-1] - torch.tensor([5]))**2).mean()
dldt_torchdiffeq = grad(l, T)[0]

dldt_torchdyn, dldt_torchdiffeq

(tensor([-2.7140]), tensor([-2.1463]))