In [None]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor

from sklearn.metrics import classification_report # mamba install scikit-learn or pip3 install scikit-learn to install.
from datetime import datetime

import numpy as np

# Set the device (use GPU if available, otherwise fallback to CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
# Setup the training data
training_data = FashionMNIST(
    root="data", 
    train=True, 
    download=True, 
    transform=ToTensor()  # Convert images to tensor format
)

# Setup the testing data
test_dataset = FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()  # Convert images to tensor format
)

train_size = int(0.8 * len(training_data))  # 80% for training
valid_size = len(training_data) - train_size  # 20% for testing

# Split the dataset
train_dataset, valid_dataset = random_split(training_data, [train_size, valid_size], 
                                           generator=torch.Generator().manual_seed(42))

# The batch size we will use
BATCH_SIZE = 32
NUM_WORKERS = 12

# Create dataloaders for training, validation and test
train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_stack1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=20, kernel_size=(5, 5)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
        )

        self.conv_stack2 = nn.Sequential(
            nn.Conv2d(in_channels=20, out_channels=50, kernel_size=(5, 5)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
        )

        self.fc_stack = nn.Sequential(
            nn.Linear(in_features=800, out_features=500),
            nn.ReLU(),
            nn.Linear(in_features=500, out_features=10),
            nn.Softmax(dim=1)
        )
    
    def forward(self, x):
        x = self.conv_stack1(x)
        x = self.conv_stack2(x)
        x = torch.flatten(x, 1)
        x = self.fc_stack(x)
        return x

model = NeuralNetwork().to(device)
print(model)

In [None]:
LR = 0.001
MOMENTUM = 0.9
EPOCHS = 20

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)

# Create output directory with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
out_dir = f"output/mnist_{timestamp}"
os.makedirs(os.path.join(out_dir, "models"), exist_ok=True)

In [None]:
# Function to train the model
def train(dataloader, model, loss_fn, optimizer):
    train_loss, train_acc = 0, 0
    model.train()  # Set model to training mode
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)  # Move data to the selected device

        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()  # Zero the gradients
        loss.backward()  # Compute the gradients
        optimizer.step()  # Update the weights

        train_loss += loss.item()
        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
            
    return train_loss, train_acc

# Function to validate the model
def validate(dataloader, model, loss_fn):
    model.eval()  # Set model to evaluation mode
    val_loss, val_acc = 0, 0

    with torch.no_grad():  # Disable gradient calculation
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)  # Move data to the selected device
            pred = model(X)
            val_loss += loss_fn(pred, y).item()  # Accumulate validation loss
            val_acc += (pred.argmax(1) == y).type(torch.float).sum().item()  # Count correct predictions

    return val_loss, val_acc

In [None]:

results = {
    "train_loss" : [],
    "train_acc" : [],
    "val_loss" : [],
    "val_acc" : []
}


for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}\n-------------------------------")
   
    train_loss, train_acc = train(train_dataloader, model, loss_fn, optimizer)
    val_loss, val_acc = validate(valid_dataloader, model, loss_fn)

    train_loss, train_acc = train_loss / len(train_dataloader), train_acc / len(train_dataloader.dataset)
    val_loss, val_acc = val_loss / len(valid_dataloader), val_acc / len(valid_dataloader.dataset)

    results["train_loss"].append(train_loss)
    results["train_acc"].append(train_acc)
    results["val_loss"].append(val_loss)
    results["val_acc"].append(val_acc)

    print(f"Training Error: \n Accuracy: {(100*train_acc):>0.1f}%, Avg loss: {train_loss:>8f} \n")
    print(f"Validation Error: \n Accuracy: {(100*val_acc):>0.1f}%, Avg loss: {val_loss:>8f} \n")
   
    model_save_path = os.path.join(out_dir, "models", f"model_{timestamp}_{epoch}.pth")
    torch.save(model.state_dict(), model_save_path)

    print(f"Saved model to {model_save_path}")

print("Done!")

In [None]:
import matplotlib.pyplot as plt
import os

# Plotting loss
plt.figure()
plt.plot(results["train_loss"], label="train_loss")
plt.plot(results["val_loss"], label="val_loss")
plt.title("Training and Validation Loss on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss")
plt.legend(loc="upper right")
plt.savefig(os.path.join(out_dir, "loss_plot.png"))

# Plotting accuracy
plt.figure()
plt.plot(results["train_acc"], label="train_acc")
plt.plot(results["val_acc"], label="val_acc")
plt.title("Training and Validation Accuracy on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Accuracy")
plt.legend(loc="lower right")
plt.savefig(os.path.join(out_dir, "accuracy_plot.png"))

plt.show()


**Precision**: This measures the accuracy of the positive predictions. It is the ratio of true positive (TP) predictions to the total number of positive predictions (both true positive and false positive (FP)). The formula is:
$$ \text{Precision} = \frac{TP}{TP + FP} $$

**Recall**: This measures the ability of the model to identify all relevant instances. It is the ratio of true positive predictions to the total number of actual positives (both true positive and false negative (FN)). The formula is:
$$ \text{Recall} = \frac{TP}{TP + FN} $$

**F1 Score**: This is the harmonic mean of precision and recall, providing a single metric that balances both. It is useful when you need a balance between precision and recall. The formula is:
$$ F1 \text{ Score} = 2 \cdot \frac{\text{Precision} \cdot \text{Recall}}{\text{Precision} + \text{Recall}} $$


In [None]:
# load the latest trained weights
weights = torch.load(model_save_path)
model.load_state_dict(weights)


with torch.no_grad():
	model.eval()
	preds = []
	for (x, y) in test_dataloader:
		x = x.to(device)
		pred = model(x)
		preds.extend(pred.argmax(axis=1).cpu().numpy())

# generate a classification report
print(classification_report(test_dataset.targets.cpu().numpy(),
	np.array(preds), target_names=test_dataset.classes))

In [None]:
# Class names for FashionMNIST
class_names = [
    "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
    "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
]

# Function to make predictions and display images
def show_predictions(model, dataloader, class_names):
    model.eval()
    fig, axs = plt.subplots(4, 4, figsize=(12, 12))
    axs = axs.flatten()

    with torch.no_grad():
        for i, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)
            pred = model(X).argmax(1)
            for j in range(len(X)):
                if i * BATCH_SIZE + j >= 16:
                    break
                img, label, prediction = X[j].cpu().squeeze(), y[j].cpu().item(), pred[j].cpu().item()
                axs[i * BATCH_SIZE + j].imshow(img, cmap="gray")
                axs[i * BATCH_SIZE + j].set_title(f"True: {class_names[label]}\nPred: {class_names[prediction]}")
                axs[i * BATCH_SIZE + j].axis('off')
            if i * BATCH_SIZE >= 16:
                break
    plt.show()

# Show predictions
show_predictions(model, test_dataloader, class_names)