In [1]:
# https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/frameworks/pytorch/pytorch-lightning.html

In [1]:
import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning import Trainer, LightningModule
from torch.optim import Adam
from torchvision.datasets import MNIST
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
import os

BATCH_SIZE = 64

# workaround for https://github.com/pytorch/vision/issues/1938 - error 403 when
# downloading mnist dataset
import urllib

opener = urllib.request.build_opener()
opener.addheaders = [("User-agent", "Mozilla/5.0")]
urllib.request.install_opener(opener)

In [3]:
class LitMNIST(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer_1 = torch.nn.Linear(28*28, 128)
        self.layer_2 = torch.nn.Linear(128, 256)
        self.layer_3 = torch.nn.Linear(256, 10)
        
    def forward(self, x):
        batch_size, channels, width, height = x.size()

        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = F.relu(x)
        x = self.layer_3(x)

        x = F.log_softmax(x, dim=1)
        return x
    
    
    def process_batch(self, batch):
        return batch
    
    def training_step(self, batch, batch_idx):
        x, y = self.process_bathc(batch)
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss
    
    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)
    
    def configure_optimizers(self):
        return Adam(self.paramertes(), lr=1e-3)
    
    def prepare_data(self):
        self.mnint_train = MNIST(
            os.getcwd(),
            train=True,
            download=True, 
            transform=transforms.ToTensor()
        )
        
    def setup(self, stage=None):
        # transforms for images
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )
        self.mnist_train = MNIST(
            os.getcwd(), train=True, download=False, transform=transform
        )

    def train_dataloader(self):
        return DataLoader(
            self.mnist_train, batch_size=64, num_workers=8, pin_memory=True
        )

In [None]:
model = LitMNIST()
trainer = Trainer(max_epochs=5, devices=1, accelerator="gpu")
# ddp work only in no-interactive mode, to test it unncoment and run as a script
# trainer = Trainer(devices=8, accelerator="gpu", strategy="ddp", max_epochs=5)
## MNIST data set is not always available to download due to network issues
## to run this part of example either uncomment below line
# trainer.fit(model)