In [3]:
import torch
import torch.nn as nn
import torch.optim as opt
from torchvision import transforms, models, datasets
from torch.utils.data import DataLoader
import os
from torchvision.models import ResNet18_Weights
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, classification_report
import torch.nn.functional as F
import numpy as np

In [5]:
# Data Transforms
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert grayscale to 3 channels
    transforms.Resize((224, 224)),  # Resize to 224x224 as ResNet-18 expects
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize the data
])


# Download Pytorch MNIST dataset
train_dataset = datasets.MNIST(root='data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='data', train=False, download=True, transform=transform)

# Create data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

# Import resnet18 module
resnet18 = models.resnet18(weights=ResNet18_Weights.DEFAULT)

# Change the input layer to accept Graysacle
resnet18.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
# Output layer has 10 outputs
resnet18.fc = nn.Linear(resnet18.fc.in_features, 10)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet18.parameters(), lr=0.001)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet18.to(device)

epochs = 5
resnet18.train()
for epoch in range(epochs):
    running_loss = 0.0
    # Wrap the data loader with tqdm for the progress bar
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")
    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = resnet18(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        # Update progress bar with the current loss
        progress_bar.set_postfix(loss=running_loss / len(train_loader))


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:03<00:00, 2.88MB/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 497kB/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 4.50MB/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 9.92MB/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:01<00:00, 23.8MB/s]
Epoch 1/5: 100%|██████████| 938/938 [05:01<00:00,  3.11it/s, loss=0.0694]
Epoch 2/5: 100%|██████████| 938/938 [04:54<00:00,  3.19it/s, loss=0.034]
Epoch 3/5: 100%|██████████| 938/938 [04:47<00:00,  3.26it/s, loss=0.0266]
Epoch 4/5: 100%|██████████| 938/938 [04:48<00:00,  3.25it/s, loss=0.0233]
Epoch 5/5: 100%|██████████| 938/938 [04:48<00:00,  3.25it/s, loss=0.0211]


In [6]:
# Function to evaluate the model on the test set
def evaluate_model(model, test_loader, device):
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():  # No need to track gradients for validation
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Collect predictions and labels for metrics calculation
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Calculate accuracy
    accuracy = correct / total
    print(f'Accuracy: {accuracy * 100:.2f}%')

    # Generate confusion matrix
    conf_matrix = confusion_matrix(all_labels, all_preds)
    print("\nConfusion Matrix:\n", conf_matrix)

    # Generate classification report
    class_report = classification_report(all_labels, all_preds, digits=4)
    print("\nClassification Report:\n", class_report)

# Evaluate the model
evaluate_model(resnet18, test_loader, device)


Accuracy: 99.11%

Confusion Matrix:
 [[ 978    0    0    0    0    0    1    0    0    1]
 [   0 1123    0    1    0    3    7    1    0    0]
 [   0    1 1029    0    0    0    1    0    1    0]
 [   0    0    2 1003    0    4    0    0    1    0]
 [   0    0    0    0  980    0    0    0    1    1]
 [   0    0    0    1    0  890    1    0    0    0]
 [  16    0    0    0    1    4  936    0    1    0]
 [   0    2   10    0    1    0    0 1015    0    0]
 [   1    0    2    0    0    1    0    1  969    0]
 [   0    0    0    0   12    1    0    6    2  988]]

Classification Report:
               precision    recall  f1-score   support

           0     0.9829    0.9980    0.9904       980
           1     0.9973    0.9894    0.9934      1135
           2     0.9866    0.9971    0.9918      1032
           3     0.9980    0.9931    0.9955      1010
           4     0.9859    0.9980    0.9919       982
           5     0.9856    0.9978    0.9916       892
           6     0.9894    0