In [1]:
from torchdyn.models import *; from torchdyn import *
from torchdyn.nn import DataControl, DepthCat, Augmenter, GalConv2d, Fourier

import torch
import torch.nn as nn

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
# from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_lightning.utilities.model_summary import ModelSummary

from utils import MNISTLearner, get_MNIST_dloaders, MetricTracker

In [2]:
# quick run for automated notebook validation
dry_run = True

In [3]:
# GPU
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
print('GPU State:', device)

GPU State: cpu


In [4]:
epoch = 1
path_to_data='./data/mnist_data'

trainloader, testloader = get_MNIST_dloaders(batch_size=64, size=28, path=path_to_data, download= False, num_workers=8)

In [5]:
func = nn.Sequential(nn.Conv2d(11, 11, 3, padding=1),
                     nn.Tanh(),                 
                     ).to(device)

neuralDE = NeuralODE(func, 
                   solver='rk4',
                   sensitivity='autograd').to(device)

model = nn.Sequential(Augmenter(augment_dims=10),
                      neuralDE,
                      nn.Conv2d(11, 1, 3, padding=1),
                      nn.Flatten(),                     
                      nn.Linear(28*28, 10)).to(device)

Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.


In [6]:
learn = MNISTLearner(model, trainloader, testloader)
# print(learn.to(device))
cb = MetricTracker()
logger = TensorBoardLogger(save_dir='lightning_logs/mnist/', name = 'model1')
trainer = pl.Trainer(max_epochs=epoch,
                     accelerator=accelerator,
                     fast_dev_run=dry_run,
                     gpus=torch.cuda.device_count(),
                     logger = logger,
                     callbacks = [cb]
                     )
trainer.fit(learn)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Running in fast_dev_run mode: will run a full train, val, test and prediction loop using 1 batch(es).
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_predict_batches=1)` was configured so 1 batch will be used.
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 9.1 K 
-------------------------------------
9.1 K     Trainable params
0         Non-trainable params
9.1 K     Total params
0.036     Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]



In [7]:
# path = log_path + "mnist-epochepoch=001-stepstep=900.ckpt"
# trainer = pl.Trainer(resume_from_checkpoint=path)
trainer.test(learn)

Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           2.2953412532806396
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 2.2953412532806396}]

In [8]:
cb.collection[-2:]

[{'epoch': tensor(0.),
  'train_loss': tensor(2.2950),
  'accuracy': tensor(0.1719),
  'NFE': tensor(4.)},
 {'test_loss': tensor(2.2953)}]

In [9]:
ModelSummary(learn).param_nums

[9050]