In [None]:
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p torch

In [None]:
!pip install pytorch-lightning

- Runs on CPU or GPU (if available)

# Convolutional Neural Network on MNIST

## Imports

In [None]:
import os
import time

import numpy as np
import torch
import torch.nn.functional as F

if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True

## Settings

In [None]:
##########################
### SETTINGS
##########################

# Hyperparameters
random_seed = 1
learning_rate = 0.001
num_epochs = 1
batch_size = 128

# Architecture
num_classes = 10

## Dataset

In [None]:
!pip install gitpython

In [None]:
import os

from git import Repo

if not os.path.exists("mnist-pngs"):
    Repo.clone_from("https://github.com/rasbt/mnist-pngs", "mnist-pngs")

In [None]:
from torchvision import transforms

data_transforms = {
    "train": transforms.Compose(
        [
            transforms.Resize(32),
            transforms.RandomCrop((28, 28)),
            transforms.ToTensor(),
            # normalize images to [-1, 1] range
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    ),
    "test": transforms.Compose(
        [
            transforms.Resize(32),
            transforms.CenterCrop((28, 28)),
            transforms.ToTensor(),
            # normalize images to [-1, 1] range
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    ),
}

In [None]:
from torch.utils.data.dataset import random_split
from torchvision.datasets import ImageFolder

train_dset = ImageFolder(root="mnist-pngs/train", transform=data_transforms["train"])

train_dset, valid_dset = random_split(train_dset, lengths=[55000, 5000])

test_dset = ImageFolder(root="mnist-pngs/test", transform=data_transforms["test"])

In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    dataset=train_dset,
    batch_size=batch_size,
    drop_last=True,
    num_workers=4,
    shuffle=True,
)

valid_loader = DataLoader(
    dataset=valid_dset,
    batch_size=batch_size,
    drop_last=False,
    num_workers=4,
    shuffle=False,
)

test_loader = DataLoader(
    dataset=test_dset,
    batch_size=batch_size,
    drop_last=False,
    num_workers=4,
    shuffle=False,
)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import torchvision.utils as vutils

real_batch = next(iter(train_loader))
plt.figure(figsize=(8, 8))
plt.axis("off")
plt.title("Training images")
plt.imshow(
    np.transpose(
        vutils.make_grid(real_batch[0][:64], padding=2, normalize=True), (1, 2, 0)
    )
)
plt.show()

## Model

In [None]:
import torch


class PyTorchCNN(torch.nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.num_classes = num_classes
        self.features = torch.nn.Sequential(
            torch.nn.Conv2d(
                in_channels=3,
                out_channels=8,
                kernel_size=(3, 3),
                stride=(1, 1),
                padding=1,
            ),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0),
            torch.nn.ReLU(),
            torch.nn.Conv2d(
                in_channels=8,
                out_channels=16,
                kernel_size=(3, 3),
                stride=(1, 1),
                padding=1,
            ),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0),
        )

        self.classifier = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.Linear(784, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


## Training with LightningLite

In [None]:
def compute_accuracy(model, data_loader):
    with torch.no_grad():
        correct_pred, num_examples = 0, 0
        for features, targets in data_loader:
            logits = model(features)
            _, predicted_labels = torch.max(logits, 1)
            num_examples += targets.size(0)
            correct_pred += (predicted_labels == targets).sum()
        return correct_pred.float() / num_examples

In [None]:
from pytorch_lightning.lite import LightningLite
from torch.utils.data import DataLoader

from torch.utils.data.dataset import random_split
from torchvision.datasets import ImageFolder


torch.manual_seed(random_seed)


class Lite(LightningLite):

  def run(self):
    model = self.train()
    self.test(model)

  def train(self):

    train_dset = ImageFolder(root="mnist-pngs/train", transform=data_transforms["train"])
    train_dset, valid_dset = random_split(train_dset, lengths=[55000, 5000])
    train_loader = DataLoader(
        dataset=train_dset,
        batch_size=batch_size,
        drop_last=True,
        num_workers=4,
        shuffle=True,
    )
    valid_loader = DataLoader(
        dataset=valid_dset,
        batch_size=batch_size,
        drop_last=False,
        num_workers=4,
        shuffle=False,
    )

    model = PyTorchCNN(num_classes=num_classes)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    model, optimizer = self.setup(model, optimizer)
    train_loader = self.setup_dataloaders(train_loader)
    valid_loader = self.setup_dataloaders(valid_loader)

    start_time = time.time()
    for epoch in range(num_epochs):
        model = model.train()
        for batch_idx, (features, targets) in enumerate(train_loader):

            # features, targets = features.to(device), targets.to(device)

            ### Forward pass
            logits = model(features)
            loss = F.cross_entropy(logits, targets)

            ### Backward pass (backpropagation)
            optimizer.zero_grad()
            self.backward(loss)

            ### Update model parameters
            optimizer.step()

            ### Batch-level logging
            if not (batch_idx + 1) % 100:
                print(
                    f"Epoch: {epoch + 1:03d}/{num_epochs:03d} | "
                    f"Batch: {batch_idx + 1:03d}/{len(train_loader):03d} | "
                    f"Loss: {loss:.4f}"
                )

        ### Epoch-level logging
        model = model.eval()
        train_acc = compute_accuracy(model, train_loader)
        valid_acc = compute_accuracy(model, valid_loader)
        print(
            f"Training accuracy: {train_acc * 100:.2f}% | "
            f"Validation accuracy: {valid_acc * 100:.2f}%"
        )
        print(f"Time elapsed: {(time.time() - start_time) / 60:.2f} min")

    print(f"Total training time: {(time.time() - start_time) / 60:.2f} min")
    return model


  def test(self, model):
    test_dset = ImageFolder(root="mnist-pngs/test", transform=data_transforms["test"])
    test_loader = DataLoader(
      dataset=test_dset,
      batch_size=batch_size,
      drop_last=False,
      num_workers=4,
      shuffle=False,
    )
    test_loader = self.setup_dataloaders(test_loader)
    model = model.eval()
    test_acc = compute_accuracy(model, test_loader)
    print(f"Test accuracy: {test_acc*100:.2f}%")


In [None]:
Lite(accelerator="auto", devices="auto").run()

In [None]:
%watermark -iv