# Custom Models with Flash

In [None]:
import os
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as T

from pytorch_lightning import Trainer
import pytorch_lightning.metrics.functional as FM

from pl_flash import Flash

### 1. Load Data

In [None]:
train_dl = DataLoader(MNIST("data", download=True, transform=T.ToTensor()), batch_size=64)
test_dl = DataLoader(MNIST("data", download=True, train=False, transform=T.ToTensor()), batch_size=64)

### 2. Define Model Architecture

In [None]:
# multilayer perceptron
mlp = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28 * 28, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
    nn.LogSoftmax(dim=1),
)

### 3. Create Flash Model

In [None]:
# create Flash model using cross-entroy loss, and accuracy as a metric
model = Flash(mlp, loss=F.cross_entropy, metrics=[FM.accuracy])

### 4. Train

In [None]:
# empty list is provided as validation loader because of bug in lightning
# https://github.com/PyTorchLightning/pytorch-lightning/issues/3052
trainer = Trainer(max_epochs=1)
trainer.fit(model, train_dl, [])

### 5. Test

In [None]:
trainer.test(model, test_dataloaders=test_dl)