In [24]:
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, roc_auc_score

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms

# 1. Choose device: GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# 2. Define transforms: convert to tensor + normalize
transform = transforms.Compose([
    transforms.ToTensor(),  # convert image to tensor
    transforms.Normalize((0.1307,), (0.3081,))  # mean and std for MNIST
])

# 3. Download datasets
train_dataset = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
test_dataset = torchvision.datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)
train_size = int(0.83 * len(train_dataset))
val_size= len(train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

# 4. DataLoaders for batching
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1000, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)


print(f"Training samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")


Using device: cuda
Training samples: 49800, Test samples: 10000


In [11]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # Convolutional layers
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        
        # Fully connected (linear) layers
        self.fc1 = nn.Linear(64 * 7 * 7, 128)  # after pooling twice
        self.fc2 = nn.Linear(128, 10)          # 10 digits

        # Pooling layer (reduce size by factor of 2)
        self.pool = nn.MaxPool2d(2, 2)
        
    def forward(self, x):
        # First conv block: conv -> BN -> ReLU -> pool
        x = self.pool(torch.relu(self.bn1(self.conv1(x))))
        
        # Second conv block
        x = self.pool(torch.relu(self.bn2(self.conv2(x))))
        
        # Flatten
        x = x.view(-1, 64 * 7 * 7)
        
        # Fully connected layers
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)  # logits, no softmax (CrossEntropyLoss expects raw logits)
        
        return x

# Instantiate model and send to GPU if available
model = CNN().to(device)
print(model)


CNN(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=3136, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)


In [12]:
# Alternative definition style using nn.Sequential

class CNN_Sequential(nn.Module):
    def __init__(self):
        super(CNN_Sequential, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 28 -> 14
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)   # 14 -> 7
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 10)  # raw logits
        )
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 64 * 7 * 7)  # flatten
        x = self.classifier(x)
        return x

In [13]:
# Loss function: CrossEntropy for multi-class classification
criterion = nn.CrossEntropyLoss()

# Optimizer: Adam with learning rate 0.001
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [14]:
# Get one batch of data from the train_loader
images, labels = next(iter(train_loader))

# Move them to GPU if available
images, labels = images.to(device), labels.to(device)

# 1. Zero gradients (clear previous updates)
optimizer.zero_grad()

# 2. Forward pass: get raw logits
outputs = model(images)  # shape [64, 10]

# 3. Compute loss
loss = criterion(outputs, labels)

print("Loss on this batch:", loss.item())

# 4. Backward pass: compute gradients
loss.backward()

# 5. Update weights
optimizer.step()


Loss on this batch: 2.3244516849517822


In [15]:
num_epochs = 5  # you can increase later
patience=3  # for early stopping

for epoch in range(num_epochs):
    # ---- Training ----
    model.train()  # set model to training mode (important for BatchNorm, Dropout)
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # 1. Zero gradients
        optimizer.zero_grad()

        # 2. Forward pass
        outputs = model(images)

        # 3. Compute loss
        loss = criterion(outputs, labels)

        # 4. Backward pass
        loss.backward()

        # 5. Update weights
        optimizer.step()

        # Track loss
        running_loss += loss.item()

        # Track accuracy on training batch
        _, predicted = torch.max(outputs, 1)   # get class with highest logit
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    train_loss = running_loss / len(train_loader)
    train_acc = 100 * correct / total

    # ---- validation loop ----
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss = val_loss / len(val_loader)
    val_acc = 100 * correct / total

    print(f"Epoch {epoch+1}: "
          f"Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}, Val Acc={val_acc:.2f}%")

    # ---- Early stopping condition ----
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        counter = 0
        # you might also save the model here
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping triggered")
            break

Epoch [1/5] Train Loss: 0.1254, Train Acc: 96.13% | Val Loss: 0.0797, Val Acc: 97.48%
Epoch [2/5] Train Loss: 0.0476, Train Acc: 98.54% | Val Loss: 0.0517, Val Acc: 98.52%
Epoch [3/5] Train Loss: 0.0330, Train Acc: 98.92% | Val Loss: 0.0481, Val Acc: 98.60%
Epoch [4/5] Train Loss: 0.0277, Train Acc: 99.11% | Val Loss: 0.0491, Val Acc: 98.34%
Epoch [5/5] Train Loss: 0.0227, Train Acc: 99.29% | Val Loss: 0.0442, Val Acc: 98.79%


In [22]:
all_probs = []
all_labels = []
all_predictions = []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        batch_probs = torch.softmax(outputs, dim=1)
        _, predicted = torch.max(batch_probs, 1)

        # Append to lists
        all_probs.append(batch_probs.cpu())
        all_labels.append(labels.cpu())
        all_predictions.append(predicted.cpu())

# Concatenate all batches into big tensors
all_probs = torch.cat(all_probs, dim=0)     # shape [N, 10]
all_labels = torch.cat(all_labels, dim=0)   # shape [N]
all_predictions = torch.cat(all_predictions, dim=0)  # shape [N]

print(all_probs.shape, all_labels.shape)  # should be [10000, 10], [10000]


torch.Size([10000, 10]) torch.Size([10000])


In [26]:
y_true = all_labels.numpy()
y_pred = all_predictions.numpy()
y_prob = all_probs.numpy()

print("Accuracy:", accuracy_score(y_true, y_pred))
print("Precision:", precision_score(y_true, y_pred, average="macro"))
print("Recall:", recall_score(y_true, y_pred, average="macro"))
print("F1:", f1_score(y_true, y_pred, average="macro"))
print("ROC AUC:", roc_auc_score(y_true, y_prob, multi_class="ovr"))

Accuracy: 0.9881
Precision: 0.9882578045645921
Recall: 0.9879201065250388
F1: 0.9880378500849171
ROC AUC: 0.9999345951652415


In [27]:
torch.save(model.state_dict(), "mnist_cnn.pth")