<a href="https://colab.research.google.com/github/GeorgiyAleksanyan/14-332-472-01-ROBOTICS-COMP-VISION-Classify-MNIST-classes-with-ResNet18/blob/main/ga360_RCV_14_332_472_Project_1_Question_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Classify MNIST classes with ResNet18** Fine-tune the ResNet 18 network to classify the MNIST dataset. Report the confusion matrix, the accuracy, the f-score, precision and recall of your classifier. Write a paragraph describing your results and methods.
Huge Thanks to Marcin Zablocki for the tutorial.
1. Github: https://github.com/marrrcin
2. Tutorial: https://zablo.net/blog/post/pytorch-resnet-mnist-jupyter-notebook-2021/

In [1]:
!pip install torch~=2.1.0 torchvision pytorch-lightning
!pip install -U 'torch_xla>=2.1'
!pip install validators matplotlib

Collecting pytorch-lightning
  Downloading pytorch_lightning-2.1.0-py3-none-any.whl (774 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m774.6/774.6 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics>=0.7.0 (from pytorch-lightning)
  Downloading torchmetrics-1.2.0-py3-none-any.whl (805 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m805.2/805.2 kB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from pytorch-lightning)
  Downloading lightning_utilities-0.9.0-py3-none-any.whl (23 kB)
Installing collected packages: lightning-utilities, torchmetrics, pytorch-lightning
Successfully installed lightning-utilities-0.9.0 pytorch-lightning-2.1.0 torchmetrics-1.2.0
Collecting torch_xla>=2.1
  Downloading torch_xla-2.1.0-cp310-cp310-manylinux_2_28_x86_64.whl (81.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.1/81.1 MB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
Collecting cloud-tp

Collecting validators
  Downloading validators-0.22.0-py3-none-any.whl (26 kB)
Installing collected packages: validators
Successfully installed validators-0.22.0


In [2]:
# Import necessary libraries
import torch
from torchvision.models import resnet18
import torch.nn as nn
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, classification_report
import requests

In [3]:
# Specify Model
model = resnet18(num_classes=10)
model
# Define data transformations
transform = ToTensor()

In [4]:
# Load MNIST datasets
train_dataset = MNIST(root="mnist", train=True, download=True, transform=transform)
test_dataset = MNIST(root="mnist", train=False, download=True, transform=transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 167541310.34it/s]

Extracting mnist/MNIST/raw/train-images-idx3-ubyte.gz to mnist/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to mnist/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 15680995.96it/s]


Extracting mnist/MNIST/raw/train-labels-idx1-ubyte.gz to mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 36986145.47it/s]

Extracting mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to mnist/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 17334421.08it/s]


Extracting mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to mnist/MNIST/raw



In [5]:
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=1)
test_loader = DataLoader(test_dataset, batch_size=64, num_workers=1)

In [6]:
# Define the ResNet model for MNIST classification
class ResNetMNIST(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = resnet18(num_classes=10)
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.RMSprop(self.parameters(), lr=0.005)

In [7]:
# Create an instance of the ResNetMNIST model
model = ResNetMNIST()

In [8]:
# Create a trainer for training the model
trainer = pl.Trainer(
    accelerator="cpu",
    max_epochs=1
)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [9]:
# Train the model
trainer.fit(model, train_loader)

INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type             | Params
-------------------------------------------
0 | model | ResNet           | 11.2 M
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.701    Total estimated model params size (MB)
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.


In [10]:
# Save the trained model
trainer.save_checkpoint("resnet18_mnist.pt")

In [11]:
# Define a function for making predictions
def get_prediction(x, model):
    model.freeze()  # Prepare the model for predicting
    probabilities = torch.softmax(model(x), dim=1)
    predicted_class = torch.argmax(probabilities, dim=1)
    return predicted_class, probabilities

In [12]:
# Load the trained model for inference
inference_model = ResNetMNIST.load_from_checkpoint("resnet18_mnist.pt", map_location="cpu")

In [13]:
true_labels, predicted_labels = [], []

# Perform inference on the test data
for batch in tqdm(test_loader, total=len(test_loader)):
    x, y = batch
    true_labels.extend(y)
    preds, _ = get_prediction(x, inference_model)
    predicted_labels.extend(preds.cpu())

100%|██████████| 157/157 [00:26<00:00,  5.92it/s]


In [14]:
#print(classification_report(true_y, pred_y, digits=3))
# Create a confusion matrix
# pred_y = pred_y[:len(pred_y)-16]
#pred_y = np.pad(pred_y,(80,0), mode = 'constant')

# Print the classification report
print(classification_report(true_labels, predicted_labels, digits=3))

# Create a confusion matrix
conf_matrix = confusion_matrix(true_labels, predicted_labels)

# Calculate accuracy, F1-score, precision, and recall
accuracy = accuracy_score(true_labels, predicted_labels)
f1 = f1_score(true_labels, predicted_labels, average='weighted')
precision = precision_score(true_labels, predicted_labels, average='weighted')
recall = recall_score(true_labels, predicted_labels, average='weighted')

print("Confusion Matrix:")
print(conf_matrix)
print(f"Accuracy: {accuracy:.2f}")
print(f"F1-Score: {f1:.2f}")
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")

              precision    recall  f1-score   support

           0      0.985     0.948     0.966       980
           1      0.991     0.991     0.991      1135
           2      0.991     0.914     0.951      1032
           3      0.984     0.985     0.985      1010
           4      0.986     0.988     0.987       982
           5      0.983     0.978     0.980       892
           6      0.994     0.973     0.983       958
           7      0.905     0.989     0.945      1028
           8      0.966     0.984     0.975       974
           9      0.962     0.986     0.974      1009

    accuracy                          0.974     10000
   macro avg      0.975     0.974     0.974     10000
weighted avg      0.975     0.974     0.974     10000

Confusion Matrix:
[[ 929    1    0    0    0    2    2   25    7   14]
 [   0 1125    5    3    0    0    0    1    0    1]
 [   0    2  943    3    3    0    0   66   15    0]
 [   0    0    3  995    0    3    0    5    3    1]
 [   0    0

Judjung by the confusion matrix, the program is successful at calssifying the data from MNIST. The diagonal values are bigger than others which indicates a high rate of accuracy. This matrix also helps us visuallise where ResNet18 gets confused. For instance