In [1]:
from torch import nn
import torch

# Setting the device to use

# Default device is CPU
device = "cpu"

# If CUDA is available, use it
if torch.cuda.is_available():
    device = "cuda"
# If running on macOS and with Metal, use it
elif torch.backends.mps.is_available():
    device ="mps"

print("Using device: ", device)
torch.device(device)

Using device:  mps


device(type='mps')

In [2]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

training_data = MNIST(root="data", train=True, download=True, transform=ToTensor())

training_set_size = len(training_data)

# Splitting the training data into training and validation sets
validation_set_size = int(0.2 * training_set_size)
training_set_size -= validation_set_size

training_set, validation_set = torch.utils.data.random_split(training_data, [training_set_size, validation_set_size])

training_loader = torch.utils.data.DataLoader(training_set, batch_size=128, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=128, shuffle=True)

test_data = MNIST(root="data", train=False, download=True, transform=ToTensor())
test_loader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=True)

In [3]:
from torch.utils.data import DataLoader
from torch.optim import Optimizer, Adam
from torch.utils.tensorboard import SummaryWriter

In [4]:
def train(epoch : int,
          model : nn.Module,
          device : str,
          train_loader : DataLoader,
          optimizer : Optimizer,
          loss_fn : nn.Module,
          tensorboard : SummaryWriter = None) -> float:
    running_loss = 0.
    last_loss = 0.

    model.train(True)

    for idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()

        data, target = data.to(device), target.to(device)

        output = model(data)

        loss = loss_fn(output, target)
        loss.backward()

        optimizer.step()

        running_loss += loss.item()

        if idx % 100 == 99:
            last_loss = running_loss / 100
            print(f"Epoch: {epoch}, Batch: {idx + 1}, Loss: {last_loss}")
            if tensorboard is not None:
                tensorboard.add_scalar("Loss/train", last_loss, epoch * len(train_loader) + idx)

            running_loss = 0.

    print(f"Epoch: {epoch}, Loss: {last_loss}")
    if tensorboard is not None:
        tensorboard.add_scalar("Loss/train", last_loss, epoch * len(train_loader) + idx)
    
    return last_loss

In [5]:
def validate(epoch : int,
             model : nn.Module,
             device : str,
             validation_loader : DataLoader,
             loss_fn : nn.Module,
             tensorboard : SummaryWriter = None) -> float:
    running_loss = 0.
    correct = 0
    total = 0

    model.eval()

    with torch.no_grad():
        for idx, (data, target) in enumerate(validation_loader):
            data, target = data.to(device), target.to(device)

            output = model(data)

            loss = loss_fn(output, target)
            running_loss += loss.item()

            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

            accuracy = correct / total

            if idx % 50 == 49:
                print(f"Epoch: {epoch}, Batch: {idx + 1}, Loss: {running_loss / (idx + 1)}, Accuracy: {accuracy}")
                if tensorboard is not None:
                    tensorboard.add_scalar("Loss/validation", running_loss / (idx + 1), epoch * len(validation_loader) + idx)
                    tensorboard.add_scalar("Accuracy/validation", accuracy, epoch * len(validation_loader) + idx)

    print(f"Epoch: {epoch}, Validation Loss: {running_loss / len(validation_loader)}, Validation Accuracy: {accuracy}")
    if tensorboard is not None:
        tensorboard.add_scalar("Loss/validation", running_loss / len(validation_loader), epoch * len(validation_loader))
        tensorboard.add_scalar("Accuracy/validation", accuracy, epoch * len(validation_loader))

    return running_loss / len(validation_loader)

In [6]:
def test(model : nn.Module,
         device : str,
         test_loader : DataLoader,
         tensorboard : SummaryWriter = None) -> float:
    correct = 0
    total = 0

    model.eval()

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)

            output = model(data)

            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

            missed_idx = (predicted != target).nonzero()

            for idx in missed_idx:
                idx = idx.item()
                predicted_label = predicted[idx].item()
                actual_label = target[idx].item()
                image = data[idx].squeeze().cpu().numpy()
                
                print(f"Missed: Predicted: {predicted_label}, Actual: {actual_label}")

                if tensorboard is not None:
                    tensorboard.add_image(f"Missed/{idx}, Predicted: {predicted_label}, Actual: {actual_label}", image, dataformats="HW")


    accuracy = correct / total

    print(f"Test accuracy: {accuracy}")

    if tensorboard is not None:
        tensorboard.add_scalar("Accuracy/test", accuracy)

    return accuracy

In [7]:
from datetime import datetime

In [8]:
def train_model(model : nn.Module,
                device : str,
                training_loader : DataLoader,
                validation_loader : DataLoader,
                test_loader : DataLoader,
                optimizer : Optimizer,
                loss_fn : nn.Module,
                epochs : int,
                best_loss : float = float("inf"),
                model_name : str = "model") -> float:
    model.to(device)

    timestamp = datetime.now().strftime("%Y%m%d%H%M%S")

    tensorboard = SummaryWriter(f"runs/{model_name}_{timestamp}")

    for epoch in range(epochs):
        avg_train_loss = train(epoch, model, device, training_loader, optimizer, loss_fn, tensorboard)
        avg_validation_loss = validate(epoch, model, device, validation_loader, loss_fn, tensorboard)

        print(f"Epoch: {epoch}, Average training loss: {avg_train_loss}, Average validation loss: {avg_validation_loss}")

        if tensorboard is not None:
            tensorboard.add_scalar("Loss/train/epoch", avg_train_loss, epoch)
            tensorboard.add_scalar("Loss/validation/epoch", avg_validation_loss, epoch)

            tensorboard.flush()

        if avg_validation_loss < best_loss:
            best_loss = avg_validation_loss
            model_path = f"models/{model_name}_{timestamp}_{epoch}.pth"
            torch.save(model.state_dict(), model_path)

    test(model, device, test_loader, tensorboard)

    tensorboard.close()

    return best_loss


In [9]:
class MNISTNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.stack = nn.Sequential(
            nn.Conv2d(1, 16, 3), # 28 x 28 x 1 -> 26 x 26 x 16
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2), # 26 x 26 x 16 -> 13 x 13 x 16
            nn.Dropout2d(0.25),
            nn.Conv2d(16, 32, 3), # 13 x 13 x 16 -> 11 x 11 x 32
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2), # 11 x 11 x 32 -> 5 x 5 x 32
            nn.Dropout2d(0.25),
            nn.Flatten(), # 5 x 5 x 32 -> 800
            nn.Linear(800, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(32, 10),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, x):
        return self.stack(x)

In [10]:
model = MNISTNet()
optimizer = Adam(model.parameters())

loss_fn = nn.CrossEntropyLoss()

best_loss = float("inf")

In [11]:
best_loss = train_model(model = model, 
                        device = device,
                        training_loader = training_loader,
                        validation_loader = validation_loader,
                        test_loader = validation_loader,
                        optimizer = optimizer, 
                        loss_fn = loss_fn, 
                        epochs = 200, 
                        best_loss = best_loss,
                        model_name="mnist")

Epoch: 0, Batch: 100, Loss: 1.9354435765743256
Epoch: 0, Batch: 200, Loss: 1.2537269204854966
Epoch: 0, Batch: 300, Loss: 0.904978449344635
Epoch: 0, Loss: 0.904978449344635
Epoch: 0, Batch: 50, Loss: 0.2446446317434311, Accuracy: 0.95703125
Epoch: 0, Validation Loss: 0.2455397272046576, Validation Accuracy: 0.9564166666666667
Epoch: 0, Average training loss: 0.904978449344635, Average validation loss: 0.2455397272046576
Epoch: 1, Batch: 100, Loss: 0.6084040519595146
Epoch: 1, Batch: 200, Loss: 0.5332401889562607
Epoch: 1, Batch: 300, Loss: 0.46603034853935243
Epoch: 1, Loss: 0.46603034853935243
Epoch: 1, Batch: 50, Loss: 0.09801070533692836, Accuracy: 0.97484375
Epoch: 1, Validation Loss: 0.09719429422724754, Validation Accuracy: 0.97525
Epoch: 1, Average training loss: 0.46603034853935243, Average validation loss: 0.09719429422724754
Epoch: 2, Batch: 100, Loss: 0.4123382794857025
Epoch: 2, Batch: 200, Loss: 0.39408037036657334
Epoch: 2, Batch: 300, Loss: 0.3374658676981926
Epoch: 2, 