In [71]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision import models

from sklearn.metrics import confusion_matrix, classification_report, accuracy_score

In [65]:
batch_size = 4
num_epochs = 2

In [66]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [67]:
# We are using the CIFAR-10 dataset

# Use the same transforms and data augmentation as the ResNet18 model during training
transform = models.ResNet18_Weights.IMAGENET1K_V1.transforms()

# Load built-in CIFAR-10 dataset (torchvision)
train_dataset = torchvision.datasets.CIFAR10(root="./CIFAR10_data", train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root="./CIFAR10_data", train=False, download=True, transform=transform)

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# Classes from CIFAR-10 dataset
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [68]:
# Load pre-trained ResNet18 model
model_ft = models.resnet18(weights='IMAGENET1K_V1')

# Change the last fully connected layer to output 10 classes
num_ftrs = model_ft.fc.in_features # Number of input features for the last layer
model_ft.fc = nn.Linear(num_ftrs, len(classes))

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

# All parameters are being optimized (fine-tuning!)
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

In [69]:
# Training loop
n_total_steps = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # origin shape: [4, 3, 32, 32] = 4, 3, 1024 meaning 4 images, 3 channels (RGB), 32x32 pixels
        # input_layer: 3 input channels (RGB), 6 output channels, 5 kernel size
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model_ft(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer_ft.zero_grad()
        loss.backward()
        optimizer_ft.step()

        if (i+1) % 2000 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item()}")

print("Finished training")

Epoch [1/2], Step [2000/12500], Loss: 1.3118975162506104
Epoch [1/2], Step [4000/12500], Loss: 1.064820408821106
Epoch [1/2], Step [6000/12500], Loss: 1.7125447988510132
Epoch [1/2], Step [8000/12500], Loss: 2.1971778869628906
Epoch [1/2], Step [10000/12500], Loss: 0.5122009515762329
Epoch [1/2], Step [12000/12500], Loss: 0.32774579524993896
Epoch [2/2], Step [2000/12500], Loss: 0.5738235712051392
Epoch [2/2], Step [4000/12500], Loss: 0.2700616121292114
Epoch [2/2], Step [6000/12500], Loss: 0.01647564023733139
Epoch [2/2], Step [8000/12500], Loss: 0.21853423118591309
Epoch [2/2], Step [10000/12500], Loss: 0.2566913068294525
Epoch [2/2], Step [12000/12500], Loss: 0.11422920227050781
Finished training


In [72]:
# Test the model
y_true = []
y_pred = []
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model_ft(images)
        _, predicted = torch.max(outputs, 1)
        y_true += labels.tolist()
        y_pred += predicted.tolist()

print(f"Accuracy: {accuracy_score(y_true, y_pred)}")
print(f"Classification Report: {classification_report(y_true, y_pred)}")
print(f"Confusion Matrix: {confusion_matrix(y_true, y_pred)}")
for i in range(10):
    print(f"Class {i} ({classes[i]}) accuracy: {accuracy_score(np.array(y_true)[np.array(y_true) == i], np.array(y_pred)[np.array(y_true) == i])}")

# Save the model checkpoint
torch.save(model_ft.state_dict(), "model.ckpt")

Accuracy: 0.8498
Classification Report:               precision    recall  f1-score   support

           0       0.85      0.89      0.87      1000
           1       0.96      0.90      0.93      1000
           2       0.81      0.82      0.82      1000
           3       0.73      0.77      0.75      1000
           4       0.96      0.76      0.85      1000
           5       0.87      0.71      0.78      1000
           6       0.72      0.94      0.81      1000
           7       0.90      0.86      0.88      1000
           8       0.88      0.92      0.90      1000
           9       0.90      0.92      0.91      1000

    accuracy                           0.85     10000
   macro avg       0.86      0.85      0.85     10000
weighted avg       0.86      0.85      0.85     10000

Confusion Matrix: [[886   6  22  15   1   1   7   4  43  15]
 [  7 903   4   2   0   0   7   1  18  58]
 [ 35   2 822  47   9   8  63   7   6   1]
 [ 17   1  38 767   4  54  89  16  10   4]
 [ 20   0  