<a href="https://colab.research.google.com/github/GeorgiyAleksanyan/14-332-472-01-ROBOTICS-COMP-VISION-Classify-MNIST-classes-with-ResNet18/blob/main/ga360_RCV_14_332_472_Project_1_Question_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Classify MNIST classes with ResNet18** Fine-tune the ResNet 18 network to classify the MNIST dataset. Report the confusion matrix, the accuracy, the f-score, precision and recall of your classifier. Write a paragraph describing your results and methods.

In [193]:
!pip install torch~=2.1.0 torchvision pytorch-lightning



In [194]:
!pip install validators matplotlib



In [195]:
import torch
torch.__version__

'2.1.0+cu118'

In [196]:
!pip install -U 'torch_xla>=1.13'
from torchvision.models import resnet18
from torch import nn
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm
import pytorch_lightning as pl
from sklearn.metrics import classification_report
import torch
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import json
import requests
import matplotlib.pyplot as plt
import warnings
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
from pytorch_lightning.callbacks import ModelCheckpoint



In [197]:
model = resnet18(num_classes=10)

In [198]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [199]:
model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

In [200]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

In [201]:
train_ds = MNIST("mnist", train=True, download=True, transform=ToTensor())
test_ds = MNIST("mnist", train=False, download=True, transform=ToTensor())

In [202]:
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers = 1)
test_dl = DataLoader(test_ds, batch_size=64, num_workers = 1)

In [203]:
class ResNetMNIST(pl.LightningModule):
  def __init__(self):
    super().__init__()
    self.model = resnet18(num_classes=10)
    self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    self.loss = nn.CrossEntropyLoss()

  def forward(self, x):
    return self.model(x)

  def training_step(self, batch, batch_no):
    x, y = batch
    logits = self(x)
    loss = self.loss(logits, y)
    return loss

  def configure_optimizers(self):
    return torch.optim.RMSprop(self.parameters(), lr=0.005)

In [204]:
model = ResNetMNIST()

In [205]:
trainer = pl.Trainer(
    accelerator ="cpu",
    #tpu_cores=8,
    max_epochs=1,
    #callbacks=[ModelCheckpoint(save_top_k=1, monitor="val_loss", mode="min")]
    #progress_bar_refresh_rate=20
)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [206]:
trainer.fit(model, train_dl)

INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type             | Params
-------------------------------------------
0 | model | ResNet           | 11.2 M
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.701    Total estimated model params size (MB)
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.


In [207]:
trainer.save_checkpoint("resnet18_mnist.pt")

In [208]:
def get_prediction(x, model: pl.LightningModule):
  model.freeze() # prepares model for predicting
  probabilities = torch.softmax(model(x), dim=1)
  predicted_class = torch.argmax(probabilities, dim=1)
  return predicted_class, probabilities

In [209]:
inference_model = ResNetMNIST.load_from_checkpoint("resnet18_mnist.pt", map_location="cpu")

In [215]:
true_y, pred_y = [], []
for batch in tqdm(iter(test_dl), total=len(test_dl)):
  x, y = batch
  true_y.extend(y)
true_y, pred_y = [], []
for batch in tqdm(iter(test_dl), total=len(test_dl)):
  x, y = batch
  true_y.extend(y)
  preds, probs = get_prediction(x, inference_model)
  pred_y.extend(preds.cpu())
[ ]
from sklearn.metrics import classification_report
[ ]
print(classification_report(true_y, pred_y, digits=3))

preds, probs = get_prediction(x, inference_model)
pred_y.extend(preds.cpu())

  0%|          | 0/157 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

              precision    recall  f1-score   support

           0      0.865     0.987     0.922       980
           1      0.998     0.928     0.962      1135
           2      0.874     0.545     0.671      1032
           3      0.962     0.824     0.887      1010
           4      0.996     0.792     0.883       982
           5      0.985     0.904     0.943       892
           6      0.851     0.984     0.913       958
           7      0.986     0.877     0.928      1028
           8      0.563     0.998     0.720       974
           9      0.959     0.922     0.940      1009

    accuracy                          0.875     10000
   macro avg      0.904     0.876     0.877     10000
weighted avg      0.906     0.875     0.877     10000



In [217]:
#print(classification_report(true_y, pred_y, digits=3))
# Create a confusion matrix
pred_y = pred_y[:len(pred_y)-16]
#pred_y = np.pad(pred_y,(80,0), mode = 'constant')

conf_matrix = confusion_matrix(true_y, pred_y)

# Calculate accuracy
accuracy = accuracy_score(true_y, pred_y)

# Calculate F1-score
f1 = f1_score(true_y, pred_y, average='weighted')

# Calculate precision
precision = precision_score(true_y, pred_y, average='weighted')

# Calculate recall
recall = recall_score(true_y, pred_y, average='weighted')

print("Confusion Matrix:")
print(conf_matrix)
print(f"Accuracy: {accuracy:.2f}")
print(f"F1-Score: {f1:.2f}")
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")

Confusion Matrix:
[[ 967    0    0    0    0    0    9    0    4    0]
 [   0 1053   24    0    0    0    9    4   45    0]
 [ 119    0  562   32    0    0   35    0  284    0]
 [   2    0    6  832    0    9    0    0  160    1]
 [   2    0    1    0  778    0   98    2   77   24]
 [   7    0    0    0    0  806   12    1   66    0]
 [   4    1    0    0    0    0  943    0   10    0]
 [  10    1   50    0    3    1    0  902   46   15]
 [   0    0    0    0    0    0    2    0  972    0]
 [   7    0    0    1    0    2    0    6   63  930]]
Accuracy: 0.87
F1-Score: 0.88
Precision: 0.91
Recall: 0.87


Judjung by the confusion matrix, the program is successful at calssifying the data from MNIST. The diagonal values are bigger than others which indicates a high rate of accuracy. This matrix also helps us visuallise where ResNet18 gets confused. For instance