# Custom Models with Flash

In [1]:
import os
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as T

from pytorch_lightning import Trainer
import pytorch_lightning.metrics.functional as FM

from pl_flash import Flash

### 1. Load Data

In [2]:
train_dl = DataLoader(MNIST(os.getcwd(),transform=T.ToTensor()), batch_size=64)
test_dl = DataLoader(MNIST(os.getcwd(),train=False, transform=T.ToTensor()), batch_size=64)

### 2. Define Model Architecture

In [3]:
# multilayer perceptron
mlp = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28 * 28, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
    nn.LogSoftmax(dim=1),
)

### 3. Create Flash Model

In [4]:
# create Flash model using cross-entroy loss, and accuracy as a metric
model = Flash(mlp, loss=F.cross_entropy, metrics=[FM.accuracy])

### 4. Train

In [5]:
# empty list is provided as validation loader because of bug in lightning
# https://github.com/PyTorchLightning/pytorch-lightning/issues/3052
trainer = Trainer(max_epochs=1)
trainer.fit(model, train_dl, [])

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 101 K 


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

### 5. Test

In [6]:
trainer.test(model, test_dataloaders=test_dl)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'epoch_test/accuracy': tensor(0.9591),
 'epoch_test/cross_entropy': tensor(0.1352),
 'epoch_test/loss': tensor(0.1352)}
--------------------------------------------------------------------------------



[{'epoch_test/cross_entropy': 0.1352406144142151,
  'epoch_test/loss': 0.1352406144142151,
  'epoch_test/accuracy': 0.9590963125228882}]