In [1]:
from torchdyn.models import *; from torchdyn import *
from torchdyn.nn import DataControl, DepthCat, Augmenter, GalConv2d, Fourier

import torch
import torch.nn as nn

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
# from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_lightning.utilities.model_summary import ModelSummary

from utils import CIFARLearner, get_cifar_dloaders, MetricTracker


In [2]:
# quick run for automated notebook validation
dry_run = True

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
print('GPU State:', device)

GPU State: cpu


In [4]:
epoch = 1

path_to_data='./data/cifar10_data'

trainloader, testloader = get_cifar_dloaders(batch_size=64, size=28, path=path_to_data, download=False, num_workers=8)

In [6]:
func = nn.Sequential(nn.GroupNorm(42, 42),
                     nn.Conv2d(42, 42, 3, padding=1, bias=False),
                     nn.Softplus(),                   
                     nn.Conv2d(42, 42, 3, padding=1, bias=False),
                     nn.Softplus(), 
                     nn.GroupNorm(42, 42),
                     nn.Conv2d(42, 42, 1)
                     ).to(device)

nde = NeuralODE(func, 
               solver='dopri5',
               sensitivity='adjoint',
               atol=1e-4,
               rtol=1e-4
               ).to(device)
s_span = torch.linspace(0, 1, 2)
# NOTE: the first noop `Augmenters` is used only to keep the `nde` at index `2`. Used to extract NFEs in CIFARLearner.
model = nn.Sequential(Augmenter(1, 0),
                      nn.Conv2d(3, 42, 3, padding=1, bias=False),
                      nde,
                      nn.Conv2d(42, 6, 1),
                      nn.AdaptiveAvgPool2d(4),
                      nn.Flatten(),                     
                      nn.Linear(6*16, 10)).to(device)

Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.


In [11]:
learn = CIFARLearner(model, trainloader, testloader)
cb = MetricTracker()
logger = TensorBoardLogger(save_dir='lightning_logs/CIFAR/', name = 'model1')
trainer = pl.Trainer(max_epochs=epoch,
                     accelerator=accelerator,
                     fast_dev_run=dry_run,
                     gpus=torch.cuda.device_count(),
                     logger = logger,
                     callbacks = [cb]
                     )
trainer.fit(learn)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Running in fast_dev_run mode: will run a full train, val, test and prediction loop using 1 batch(es).
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_predict_batches=1)` was configured so 1 batch will be used.
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 36.1 K
-------------------------------------
36.1 K    Trainable params
0         Non-trainable params
36.1 K    Total params
0.144     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

In [12]:
cb.collection

[{'epoch': tensor(0.),
  'train_loss': tensor(2.3473),
  'accuracy': tensor(0.0625),
  'NFE': tensor(26.)}]