In [27]:
from torchdyn.core import NeuralODE
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import torch
import torch.utils.data as data
import torch.nn as nn
import pytorch_lightning as pl

In [28]:
device = torch.device("cuda:0")
torch.set_default_dtype(torch.float64)
torch.set_float32_matmul_precision("high")

In [29]:
def plot_l63(data, n, style="scatter"):
    x, y, z = data[:n, :].T
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection="3d")
    if style == "scatter":
        ax.scatter(x, y, z, s=1)
    elif style == "line":
        ax.plot(x, y, z, lw=0.3)
    else:
        raise ValueError
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")
    ax.set_title(f"L63, {n} points")
    plt.show()

In [30]:
def get_loader(
    train_file: str,
    test_file: str,
    plot: bool = False,
    name: str = "",
):
    train = np.load(train_file)
    test = np.load(test_file)
    print(f"raw data shapes -- train: {train.shape}, test: {test.shape}")
    X = torch.Tensor(train[:-1, :])
    Y = torch.Tensor(train[1:, :])
    print(f"train shapes -- x: {X.shape}, y: {Y.shape}")
    if plot and name == "l63":
        plot_l63(train, n=1000)
        plot_l63(train, n=-1, style="line")
    train = data.TensorDataset(X, Y)
    trainloader = data.DataLoader(
        train, batch_size=len(X), shuffle=True, num_workers=16
    )
    return trainloader

In [31]:
def get_loader_l63():
    return get_loader(
        train_file="lorenz63_on0.05_train.npy",
        test_file="lorenz63_test.npy",
    )

In [32]:
def get_loader_l96():
    return get_loader(
        train_file="lorenz96_on0.05_train.npy",
        test_file="lorenz96_test.npy",
    )

In [51]:
class Learner_l63(pl.LightningModule):
    def __init__(self, t_span: torch.Tensor, model: nn.Module):
        super().__init__()
        self.model, self.t_span = model, t_span
        self.trainloader = get_loader_l63()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        t_eval, y_hat = self.model(x, self.t_span)
        y_hat = y_hat[-1]  # select last point of solution trajectory
        loss = nn.MSELoss()(y_hat, y)
        print(loss)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=3e-4)

    def train_dataloader(self):
        return self.trainloader

In [52]:
def get_model_l63():
    layers = [
        nn.Linear(3, 64),
        nn.ReLU(),
        nn.Linear(64, 64),
        nn.ReLU(),
        nn.Linear(64, 3),
    ]
    f = nn.Sequential(*layers)
    model = NeuralODE(f)
    t_span = torch.linspace(0, 1, 2)  # [0,1]
    return t_span, model

In [53]:
learn = Learner_l63(*get_model_l63())
trainer = pl.Trainer(max_epochs=5, accelerator="gpu", devices="auto")
trainer.fit(learn)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type      | Params | Mode 
--------------------------------------------
0 | model | NeuralODE | 4.6 K  | train
--------------------------------------------
4.6 K     Trainable params
0         Non-trainable params
4.6 K     Total params
0.018     Total estimated model params size (MB)
10        Modules in train mode
0         Modules in eval mode


Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.
raw data shapes -- train: (100000, 3), test: (100000, 3)
train shapes -- x: torch.Size([99999, 3]), y: torch.Size([99999, 3])


Training: |                                                                                                   …

tensor(0.0154, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0146, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0139, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0133, device='cuda:0', grad_fn=<MseLossBackward0>)


`Trainer.fit` stopped: `max_epochs=5` reached.


tensor(0.0127, device='cuda:0', grad_fn=<MseLossBackward0>)


In [57]:
learn(torch.Tensor([1,1,1]))[1][-1,:]

tensor([0.9619, 0.9879, 0.9685], grad_fn=<SliceBackward0>)

In [23]:
def iterate(model)

Neural ODE:
	- order: 1        
	- solver: Tsitouras45()
	- adjoint solver: Tsitouras45()        
	- tolerances: relative 0.001 absolute 0.001        
	- adjoint tolerances: relative 0.0001 absolute 0.0001        
	- num_parameters: 4611        
	- NFE: 70.0

In [47]:
learn.device

device(type='cpu')