In [None]:
import os
import torch
from torch import nn
from torchvision.datasets import FashionMNIST
from torchvision.transforms import Compose, Normalize, ToTensor
from torch.utils.data import DataLoader

from datetime import datetime

# Set the device (use GPU if available, otherwise fallback to CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

In [None]:

transform = Compose(
    [ToTensor(),
    Normalize((0.5,), (0.5,))])

# Setup the training data
training_data = FashionMNIST(
    root="data", 
    train=True, 
    download=True, 
    transform=transform  # Convert images to tensor format
)

# Setup the testing data
validation_data = FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=transform  # Convert images to tensor format
)

train_size = int(0.8 * len(training_data))  # 80% for training
test_size = len(training_data) - train_size  # 20% for testing

# Split the dataset
train_dataset, test_dataset = torch.utils.data.random_split(training_data, [train_size, test_size])


# The batch size we will use
batch_size = 16

# Create dataloaders for training, validation and test
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
valid_dataloader = DataLoader(dataset=validation_data, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
# Define the neural network
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()  # Flatten the input image
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),  # Fully connected layer
            nn.ReLU(),  # Activation function
            nn.Linear(512, 512),  # Fully connected layer
            nn.ReLU(),  # Activation function
            nn.Linear(512, 10),  # Output layer (10 classes for FashionMNIST)
        )

    def forward(self, x):
        x = self.flatten(x)  # Flatten the image
        logits = self.linear_relu_stack(x)  # Pass through the linear layers
        return logits

# Hyperparameters
lr = 0.001  # Learning rate
momentum = 0.9  # Momentum for SGD

# Move our model to the selected device
model = NeuralNetwork().to(device)
# Loss function
loss_fn = nn.CrossEntropyLoss()
# Optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)

# Create output directory with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
out_dir = f"output/mnist_{timestamp}"
os.makedirs(os.path.join(out_dir, "models"), exist_ok=True)

In [None]:
# Function to train the model
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()  # Set model to training mode
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)  # Move data to the selected device

        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()  # Zero the gradients
        loss.backward()  # Compute the gradients
        optimizer.step()  # Update the weights

        if batch % 1000 == 999:  # Print loss every 1000 batches
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current + batch_size}/{size}]")

# Function to validate the model
def validate(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()  # Set model to evaluation mode
    val_loss, correct = 0, 0

    with torch.no_grad():  # Disable gradient calculation
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)  # Move data to the selected device
            pred = model(X)
            val_loss += loss_fn(pred, y).item()  # Accumulate validation loss
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()  # Count correct predictions

    val_loss /= num_batches  # Calculate average validation loss
    correct /= size  # Calculate accuracy
    print(f"Validation Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {val_loss:>8f} \n")

In [None]:
# Number of epochs to train for
EPOCHS = 5

# Training loop
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)  # Train the model
    validate(valid_dataloader, model, loss_fn)  # Validate the model
    # Save the model after each epoch
    model_save_path = os.path.join(out_dir, "models", f"model_{timestamp}_{epoch}.pth")
    torch.save(model.state_dict(), model_save_path)
    print(f"Saved model to {model_save_path}")

print("Done!")

In [None]:

import matplotlib.pyplot as plt

# Load the saved model
model = NeuralNetwork().to(device)
model_path = model_save_path
model.load_state_dict(torch.load(model_path))
model.eval()

# Class names for FashionMNIST
class_names = [
    "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
    "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
]

# Function to make predictions and display images
def show_predictions(model, dataloader, class_names):
    model.eval()
    fig, axs = plt.subplots(4, 4, figsize=(12, 12))
    axs = axs.flatten()

    with torch.no_grad():
        for i, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)
            pred = model(X).argmax(1)
            for j in range(len(X)):
                if i * batch_size + j >= 16:
                    break
                img, label, prediction = X[j].cpu().squeeze(), y[j].cpu().item(), pred[j].cpu().item()
                axs[i * batch_size + j].imshow(img, cmap="gray")
                axs[i * batch_size + j].set_title(f"True: {class_names[label]}\nPred: {class_names[prediction]}")
                axs[i * batch_size + j].axis('off')
            if i * batch_size >= 16:
                break
    plt.show()

# Show predictions
show_predictions(model, test_dataloader, class_names)