In [1]:
import torch

In [2]:
from torchdyn.utils import plot_3D_dataset
from torchdyn.datasets import ToyDataset
from torch.utils.data import TensorDataset, DataLoader

In [3]:
device = torch.device("cpu")
dry_run = False

In [4]:
d = ToyDataset()
X, yn = d.generate(n_samples=512, dataset_type='moons', noise=.4)
X_train = torch.Tensor(X).to(device)
y_train = torch.LongTensor(yn.long()).to(device)
train = TensorDataset(X_train, y_train)
trainloader = DataLoader(train, batch_size=len(X), shuffle=False)

In [5]:
import torch.nn as nn
import pytorch_lightning as pl

class Learner(pl.LightningModule):
    def __init__(self, t_span:torch.Tensor, model:nn.Module):
        super().__init__()
        self.model, self.t_span = model, t_span
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch      
        t_eval, y_hat = self.model(x, t_span)
        y_hat = y_hat[-1] # select last point of solution trajectory
        loss = nn.CrossEntropyLoss()(y_hat, y)
        return {'loss': loss}   
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.01)

    def train_dataloader(self):
        return trainloader

In [6]:
from torchdyn.core import NeuralSDE

In [7]:
import torchsde

t_span = torch.linspace(0, 0.1, 100).to(device)
size = X.shape

bm = torchsde.BrownianInterval(
    t0=t_span[0],
    t1=t_span[-1],
    size=size,
    device=device,
    levy_area_approximation='space-time'
)

In [11]:
f = nn.Sequential(nn.Linear(2, 64), nn.Tanh(), nn.Linear(64, 2))
g = nn.Sequential(nn.Linear(2, 64), nn.Tanh(), nn.Linear(64, 2))

t_span = torch.linspace(0, 1, 2)

model = NeuralSDE(f, 
                  g,
                  solver='euler',
                  noise_type='diagonal',
                  sde_type='ito',
                  sensitivity='autograd',
                  s_span=t_span,
                 ).to(device)

TypeError: __init__() got an unexpected keyword argument 'bm'