In [10]:
import sys ; sys.path.append('../')
from torchdyn.models import *; from torchdyn import *
import torch; import torch.nn as nn; from torch.utils.data import DataLoader; from torchvision import datasets, transforms

import pytorch_lightning as pl; from pytorch_lightning.loggers import WandbLogger; from pytorch_lightning.metrics.functional import accuracy

##### Loading data

In [11]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 128
size = 28
path_to_data='../data/mnist_data'

all_transforms = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
])

train_data = datasets.MNIST(path_to_data, train=True, download=True,
                            transform=all_transforms)
test_data = datasets.MNIST(path_to_data, train=False,
                           transform=all_transforms)

trainloader = DataLoader(train_data, batch_size=batch_size, pin_memory=True, num_workers=4, shuffle=True)
testloader = DataLoader(test_data, batch_size=batch_size, pin_memory=True, num_workers=4, shuffle=False)

##### Learner and model definition

In [9]:
import pytorch_lightning as pl

class Learner(pl.LightningModule):
    def __init__(self, model:nn.Module):
        super().__init__()
        self.model = model
        self.iters = 0
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        self.iters += 1 
        x = batch[0] 
        xtrJ = self.model(x)  
        logprob = prior.log_prob(xtrJ[:,1:]).to(x) - xtrJ[:,0] # logp(z_S) = logp(z_0) - \int_0^S trJ
        loss = -torch.mean(logprob)
        nde.nfe = 0
        return {'loss': loss}   
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr=2e-3, weight_decay=1e-5)

    def train_dataloader(self):
        return trainloader

In [7]:
f = nn.Sequential(
        nn.Linear(2, 64),
        nn.Softplus(),
        nn.Linear(64, 64),
        nn.Softplus(),
        nn.Linear(64, 64),
        nn.Softplus(),
        nn.Linear(64, 2),
    )

from torch.distributions import MultivariateNormal, Uniform, TransformedDistribution, SigmoidTransform, Categorical
prior = MultivariateNormal(torch.zeros(2).to(device), torch.eye(2).to(device)) 

# stochastic estimators require a definition of a distribution where "noise" vectors are sampled from
noise_dist = MultivariateNormal(torch.zeros(2).to(device), torch.eye(2).to(device))
# cnf wraps the net as with other energy models
cnf = CNF(f, trace_estimator=hutch_trace, noise_dist=noise_dist)
nde = NeuralDE(cnf, solver='dopri5', s_span=torch.linspace(0, 1, 2), sensitivity='adjoint', atol=1e-4, rtol=1e-4)

In [8]:
model = nn.Sequential(Augmenter(augment_idx=1, augment_dims=1),
                      nde).to(device)

##### Train

In [None]:
learn = Learner(model)
trainer = pl.Trainer(max_epochs=600)
trainer.fit(learn);

##### Visualize