Arvid Lundervold, 2025-01-14

[![Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/MMIV-ML/BMED365-2025/blob/main/Lab2-DL/notebooks/01-MNIST-Classification-with-CNN.ipynb)

### NB 3: MNIST Classification using a Convolutional Neural Network (CNN)

In [1]:
# This is a quick check of whether the notebook is currently running on Google Colaboratory
# as that makes some difference for the code below.
# We'll do this in every notebook of the course.

try:
    import google.colab
    # If this statement executes without error, you're in a Colab environment.
    is_colab = True
    print("Running in Google Colab.")
except ImportError:
    # An ImportError means you're not in a Colab environment.
    is_colab = False
    print("Not running in Google Colab.")

Not running in Google Colab.


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

# 1. Create Custom Dataset for PNG files
class MNISTDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.transform = transform
        self.images = []
        self.labels = []
        
        # Load images and labels
        for filename in os.listdir(folder_path):
            if filename.endswith('.png'):
                # Extract label from filename (assuming format: digit_index.png)
                label = int(filename.split('_')[0])
                
                # Load image
                img_path = os.path.join(folder_path, filename)
                image = Image.open(img_path).convert('L')  # Convert to grayscale
                
                self.images.append(image)
                self.labels.append(label)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

# 2. Define CNN Model
class MNISTClassifier(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # CNN layers
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3)
        self.pool = nn.MaxPool2d(2)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(576, 128)
        self.fc2 = nn.Linear(128, 10)
        
        # Store metrics
        self.train_acc = []
        self.val_acc = []
        self.train_loss = []
        self.val_loss = []
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.dropout1(x)
        x = F.relu(self.conv3(x))
        x = self.dropout2(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy_score(y.cpu(), preds.cpu())
        self.train_acc.append(acc)
        self.train_loss.append(loss.item())
        self.log('train_loss', loss)
        self.log('train_acc', acc)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy_score(y.cpu(), preds.cpu())
        self.val_acc.append(acc)
        self.val_loss.append(loss.item())
        self.log('val_loss', loss)
        self.log('val_acc', acc)
        return {'val_loss': loss, 'val_acc': acc}
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

# 3. Data Loading and Training
# Transform for the images
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Create datasets
train_dataset = MNISTDataset('mnist_images/train', transform=transform)
test_dataset = MNISTDataset('mnist_images/test', transform=transform)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

# Initialize model and trainer
model = MNISTClassifier()
trainer = pl.Trainer(max_epochs=10, accelerator='gpu' if torch.cuda.is_available() else 'cpu')

# Train the model
trainer.fit(model, train_loader, test_loader)

# 4. Evaluation and Visualization
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in test_loader:
        x, y = batch
        logits = model(x)
        preds = torch.argmax(logits, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(y.cpu().numpy())

# Plot confusion matrix
plt.figure(figsize=(10, 8))
cm = confusion_matrix(all_labels, all_preds)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('CNN Confusion Matrix', fontsize=16)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()

# Print classification report
print("\nClassification Report:")
print(classification_report(all_labels, all_preds))

# Plot training curves
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(model.train_loss, label='Train')
plt.plot(model.val_loss, label='Validation')
plt.title('Loss Curves')
plt.xlabel('Batch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(model.train_acc, label='Train')
plt.plot(model.val_acc, label='Validation')
plt.title('Accuracy Curves')
plt.xlabel('Batch')
plt.ylabel('Accuracy')
plt.legend()
plt.tight_layout()
plt.show()

# Save the model
torch.save(model.state_dict(), 'mnist_images/mnist_cnn.pt')
print("\nModel saved as 'mnist_images/mnist_cnn.pt'")

FileNotFoundError: [Errno 2] No such file or directory: 'mnist_images/train'