In [None]:
import pytorch_lightning as pl
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split

class LitModel(pl.LightningModule):
    def __init__(self):
        super(LitModel, self).__init__()
        self.layer_1 = nn.Linear(28 * 28, 128)
        self.layer_2 = nn.Linear(128, 256)
        self.layer_3 = nn.Linear(256, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.layer_1(x))
        x = F.relu(self.layer_2(x))
        x = F.log_softmax(self.layer_3(x), dim=1)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.nll_loss(y_hat, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

# Data
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
train, val = random_split(dataset, [55000, 5000])
train_loader = DataLoader(train, batch_size=32)

# Model
model = LitModel()

# Train
trainer = pl.Trainer(max_epochs=1, gpus=1)
trainer.fit(model, train_loader)
