In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

resnet50 = models.resnet50(
    weights=models.ResNet50_Weights.IMAGENET1K_V2
)

for param in resnet50.parameters():
    param.requires_grad = False

for param in resnet50.layer4.parameters():
    param.requires_grad = True

num_classes = len(train_dataset.classes)

resnet50.fc = nn.Sequential(
    nn.Linear(resnet50.fc.in_features, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, num_classes)
)

for param in resnet50.fc.parameters():
    param.requires_grad = True

resnet50 = resnet50.to(DEVICE)


In [None]:
# LOAD SAVED RESNET WEIGHTS (NO TRAINING)
resnet50.load_state_dict(
    torch.load("resnet50_baseline_60.pth", map_location=DEVICE)
)
resnet50.eval()
print("ResNet50 weights loaded")

In [None]:
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, resnet50.parameters()),
    lr=1e-4
)

In [None]:
EPOCHS = 18
best_val_acc = 0.0

for epoch in range(EPOCHS):

    # -------- Training --------
    resnet50.train()
    running_correct = 0
    running_total = 0
    running_loss = 0.0

    for images, labels in train_loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()
        outputs = resnet50(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * labels.size(0)
        _, preds = torch.max(outputs, 1)
        running_correct += (preds == labels).sum().item()
        running_total += labels.size(0)

    train_acc = running_correct / running_total
    train_loss = running_loss / running_total

    # -------- Validation --------
    resnet50.eval()
    val_correct = 0
    val_total = 0
    val_loss = 0.0

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            outputs = resnet50(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * labels.size(0)
            _, preds = torch.max(outputs, 1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    val_acc = val_correct / val_total
    val_loss = val_loss / val_total

    print(
        f"Epoch [{epoch+1}/{EPOCHS}] "
        f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} || "
        f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}"
    )

    # -------- Save Best Model --------
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(resnet50.state_dict(), "resnet50_baseline_60.pth")

In [None]:
import torch
import numpy as np

resnet50.eval()

test_correct = 0
test_total = 0

resnet_all_preds = []
resnet_all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        outputs = resnet50(images)
        preds = torch.argmax(outputs, dim=1)

        test_correct += (preds == labels).sum().item()
        test_total += labels.size(0)

        resnet_all_preds.extend(preds.cpu().numpy())
        resnet_all_labels.extend(labels.cpu().numpy())


test_acc = test_correct / test_total
print(f"ResNet50 Test Accuracy: {test_acc:.4f}")

resnet_test_acc = test_acc

In [None]:
from sklearn.metrics import classification_report

class_names = train_dataset.classes

print(classification_report(
    resnet_all_labels,
    resnet_all_preds,
    target_names=class_names,
    digits=4
))

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

cm = confusion_matrix(resnet_all_labels, resnet_all_preds)

plt.figure(figsize=(12, 10))
sns.heatmap(
    cm,
    annot=False,
    cmap="Blues",
    xticklabels=class_names,
    yticklabels=class_names
)

plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("ResNet50 â€” Confusion Matrix (Test Set)")
plt.tight_layout()
plt.show()