In [None]:
from torch import nn
import torch

# Setting the device to use

# Default device is CPU
device = "cpu"

# If CUDA is available, use it
if torch.cuda.is_available():
    device = "cuda"
# If running on macOS and with Metal, use it
elif torch.backends.mps.is_available():
    device ="mps"

print("Using device: ", device)
torch.device(device)

In [None]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

training_data = MNIST(root="data", train=True, download=True, transform=ToTensor())

training_set_size = len(training_data)

# Splitting the training data into training and validation sets
validation_set_size = int(0.2 * training_set_size)
training_set_size -= validation_set_size

training_set, validation_set = torch.utils.data.random_split(training_data, [training_set_size, validation_set_size])

training_loader = torch.utils.data.DataLoader(training_set, batch_size=128, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=128, shuffle=True)

test_data = MNIST(root="data", train=False, download=True, transform=ToTensor())
test_loader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=True)

In [4]:
from torch.utils.data import DataLoader
from torch.optim import Optimizer, Adam
from torch.utils.tensorboard import SummaryWriter

In [None]:
def train(epoch : int,
          model : nn.Module,
          device : str,
          train_loader : DataLoader,
          optimizer : Optimizer,
          loss_fn : nn.Module,
          tensorboard : SummaryWriter = None) -> float:
    running_loss = 0.
    last_loss = 0.

    model.train(True)

    for idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()

        data, target = data.to(device), target.to(device)

        output = model(data)

        loss = loss_fn(output, target)
        loss.backward()

        optimizer.step()

        running_loss += loss.item()

        if idx % 1000 == 999:
            last_loss = running_loss / 1000
            print(f"Epoch: {epoch}, Batch: {idx + 1}, Loss: {last_loss}")
            if tensorboard is not None:
                tensorboard.add_scalar("Loss/train", last_loss, epoch * len(train_loader) + idx)

            running_loss = 0.
    
    return last_loss