In [2]:
import torch
import torch.nn as nn
import numpy as np
dtype = torch.float64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from damped_pendulum import DampedPendulum
from gpode           import GPODE
from helpers         import fit_model, plot_fit, integrate

In [3]:
# data generation
n   = 2   # observation dimensionality
N   = 10  # number of training sequences
T   = 25  # the length of each sequence
dt  = 0.1 # time difference between each observation pair
sig = 0.1 # standard deviation of the observation noise

In [4]:
unknown_parametric_ode = DampedPendulum(alpha=0.1, omega_0=1.0)

In [9]:
with torch.no_grad():
    ts = torch.arange(T, dtype=dtype, device=device) * dt
    x0 = torch.randn([N,n], dtype=dtype, device=device)
    X  = integrate(unknown_parametric_ode.odef, x0, ts)  # clean data sequences, [T,N,n]
    Y  = X + torch.randn_like(X)*sig
del unknown_parametric_ode

In [10]:
ode_model = GPODE(n, n).to(device).to(dtype)

In [11]:
plot_fit(ode_model, Y[:,:3], ts, fname='before-fit.png') 

In [None]:
fit_model(ode_model, Y, ts, Niter=2000, lr=5e-3)

In [13]:
Ntrain = Y.shape[1]

In [15]:
idx  = torch.randperm(Ntrain)[:5] 

In [17]:
Y_minibatch = Y[:,idx]

In [19]:
Y_minibatch.shape

torch.Size([25, 5, 2])

In [18]:
Y_minibatch[0].shape

torch.Size([5, 2])

In [22]:
gp_draw = ode_model.sgp.draw_posterior_function()

In [24]:
odef = lambda t,x: gp_draw(x)

In [26]:
gp_draw(Y_minibatch[0])

tensor([[ 0.7894,  0.8204],
        [ 0.4324,  0.7965],
        [-0.3719, -0.3529],
        [ 0.4023,  0.2751],
        [ 0.0231,  0.7735]], device='cuda:0', dtype=torch.float64,
       grad_fn=<SqueezeBackward1>)

In [28]:
omega,tau,w,nu = ode_model.sgp.cache(100,1)

In [31]:
prior = ode_model.sgp.rff(Y_minibatch[0], omega, tau, w) 

In [34]:
ode_model.sgp.Z.shape

torch.Size([50, 2])

In [35]:
Kxz = ode_model.sgp.covar_module(Y_minibatch[0],ode_model.sgp.Z).evaluate()

In [39]:
Kxz = Kxz.unsqueeze(0).repeat([1,1,1,1]).permute(0,3,2,1)

In [40]:
Kxz.shape

torch.Size([1, 50, 5, 2])

In [20]:
Yhat = ode_model.forward_trajectory(Y_minibatch[0], ts)

In [21]:
Yhat.shape

torch.Size([25, 5, 2])

In [43]:
fit_model(ode_model, Y, ts, Niter=2000, lr=5e-3)

998.5423088948949
845.0763302528649


KeyboardInterrupt: 