In [1]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision import datasets
from sklearn.metrics import classification_report, confusion_matrix

In [2]:
model = models.resnet18(pretrained=True)


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 398MB/s]


In [3]:
num_classes = 10  # MNIST has 10 classes
model.fc = torch.nn.Linear(in_features=model.fc.in_features, out_features=num_classes)


In [9]:
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert to 3-channel grayscale
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])



In [10]:
mnist_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)


In [11]:
data_loader = torch.utils.data.DataLoader(mnist_dataset, batch_size=64, shuffle=False)


In [12]:
model.eval()
all_predictions = []
all_targets = []

with torch.no_grad():
    for data, target in data_loader:
        output = model(data)
        _, predicted = torch.max(output, 1)
        all_predictions.extend(predicted.numpy())
        all_targets.extend(target.numpy())



In [13]:
# Calculate the confusion matrix and classification report
confusion = confusion_matrix(all_targets, all_predictions)
classification_rep = classification_report(all_targets, all_predictions)

# Print the results
print("Confusion Matrix:")
print(confusion)
print("\nClassification Report:")
print(classification_rep)

Confusion Matrix:
[[   0  963    0    0    0    6    0    2    4    5]
 [   0 1015    0    0    0   25    0    1    1   93]
 [   0 1017    0    0    0    2    0    2    0   11]
 [   0  914    0    0    0   15    1    1    1   78]
 [   0  945    0    0    0    1    0    3    0   33]
 [   0  856    0    0    0    3    2    0    0   31]
 [   0  940    0    0    0    8    0    3    4    3]
 [   0  977    0    0    0    0    0    0    0   51]
 [   0  849    0    0    0    1    1    1    2  120]
 [   0  905    0    0    0    1    0    0    0  103]]

Classification Report:
              precision    recall  f1-score   support

           0       0.00      0.00      0.00       980
           1       0.11      0.89      0.19      1135
           2       0.00      0.00      0.00      1032
           3       0.00      0.00      0.00      1010
           4       0.00      0.00      0.00       982
           5       0.05      0.00      0.01       892
           6       0.00      0.00      0.00     

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
