In [1]:
%load_ext autoreload
%autoreload 2
    
import numpy as np

from scipy.stats import multivariate_normal

import torch
import torch.nn as nn
import torch.optim as optim

# from torchdiffeq import odeint_adjoint as odeint
from torchdiffeq import odeint

In [2]:
def get_batch(batch_size):
    np.random.seed(42)
    p = multivariate_normal(mean=[0., 0.], cov=[[1., 0.], [0., 1.]])
    z_0 = p.rvs(batch_size)
    logpz_0 = p.logpdf(z_0).reshape(-1, 1)

    return (torch.tensor(z_0, dtype=torch.float32), torch.tensor(logpz_0, dtype=torch.float32))

In [3]:
def trace_df_dz(f, z):
    sum_diag = 0.
    for i in range(z.shape[1]):
        sum_diag += torch.autograd.grad(f[:, i].sum(), z, create_graph=True)[0].contiguous()[:, i].contiguous()
        
    return sum_diag.contiguous()

class ContinuousPlanarFlow(nn.Module):

    def __init__(self):
        super(ContinuousPlanarFlow, self).__init__()

        self.diffeq = nn.Linear(2, 2)

    def forward(self, t, states):
        z = states[:,:2]

        batchsize = z.shape[0]

        with torch.set_grad_enabled(True):
            z.requires_grad_(True)

            dz_dt = self.diffeq(z)
            dlogpz_dt = -trace_df_dz(dz_dt, z).view(batchsize, 1)


            res = torch.cat([dz_dt, dlogpz_dt], 1)
            
        return res

In [4]:
# hyper_params
batch_size = 1000

In [5]:
odefunc = ContinuousPlanarFlow()

optimizer = optim.Adam(odefunc.parameters(), lr=1e-3, weight_decay=0.)

for itr in range(1, 20 + 1):
    optimizer.zero_grad()
    z_0, logpz_0 = get_batch(batch_size)
    res = odeint(
            odefunc,
#             (z_0, logpz_0),
            torch.cat([z_0, logpz_0], 1),
            torch.tensor([0., 1.]).to(z_0),
            atol=1e-5,
            rtol=1e-5,
            method='dopri5',
        )

#     dz_dt = dz_dt[1]
#     logpz_t = logpz_t[1]
    dz_dt = res[:,:2]
    logpz_t = res[:,2]
        
    loss = -logpz_t.mean()
    print(loss)
    loss.backward()
    optimizer.step()

tensor(0.6565, grad_fn=<NegBackward>)
tensor(0.6558, grad_fn=<NegBackward>)
tensor(0.6552, grad_fn=<NegBackward>)
tensor(0.6546, grad_fn=<NegBackward>)
tensor(0.6539, grad_fn=<NegBackward>)
tensor(0.6533, grad_fn=<NegBackward>)
tensor(0.6527, grad_fn=<NegBackward>)
tensor(0.6520, grad_fn=<NegBackward>)
tensor(0.6514, grad_fn=<NegBackward>)
tensor(0.6508, grad_fn=<NegBackward>)
tensor(0.6502, grad_fn=<NegBackward>)
tensor(0.6495, grad_fn=<NegBackward>)
tensor(0.6489, grad_fn=<NegBackward>)
tensor(0.6483, grad_fn=<NegBackward>)
tensor(0.6476, grad_fn=<NegBackward>)
tensor(0.6470, grad_fn=<NegBackward>)
tensor(0.6464, grad_fn=<NegBackward>)
tensor(0.6458, grad_fn=<NegBackward>)
tensor(0.6451, grad_fn=<NegBackward>)
tensor(0.6445, grad_fn=<NegBackward>)
