In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

efficientnet = models.efficientnet_b0(
    weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1
)

for param in efficientnet.features.parameters():
    param.requires_grad = False

for param in efficientnet.features[-2:].parameters():
    param.requires_grad = True

num_classes = len(train_dataset.classes)

# Replace classifier
efficientnet.classifier = nn.Sequential(
    nn.Linear(efficientnet.classifier[1].in_features, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(512, num_classes)
)

efficientnet = efficientnet.to(DEVICE)


In [None]:
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

optimizer = optim.AdamW(
    filter(lambda p: p.requires_grad, efficientnet.parameters()),
    lr=3e-4,
    weight_decay=1e-4
)

In [None]:
EPOCHS = 25
best_val_acc = 0.0

for epoch in range(EPOCHS):

    # -------- Training --------
    efficientnet.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0

    for images, labels in train_loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()
        outputs = efficientnet(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        _, preds = torch.max(outputs, 1)
        train_loss += loss.item() * labels.size(0)
        train_correct += (preds == labels).sum().item()
        train_total += labels.size(0)

    train_loss /= train_total
    train_acc = train_correct / train_total

    # -------- Validation --------
    efficientnet.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            outputs = efficientnet(images)
            loss = criterion(outputs, labels)

            _, preds = torch.max(outputs, 1)
            val_loss += loss.item() * labels.size(0)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    val_loss /= val_total
    val_acc = val_correct / val_total

    print(
        f"Epoch [{epoch+1}/{EPOCHS}] "
        f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} || "
        f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}"
    )

    # -------- Save Best Model --------
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(efficientnet.state_dict(), "efficientnet_best.pth")

In [None]:
efficientnet.load_state_dict(
    torch.load("efficientnet_best.pth", map_location=DEVICE)
)

print("Loaded best Phase-1 EfficientNet weights")

In [None]:
for param in efficientnet.features.parameters():
    param.requires_grad = True

for param in efficientnet.classifier.parameters():
    param.requires_grad = True

print("Phase 2: All EfficientNet layers unfrozen")

In [None]:
optimizer = optim.AdamW(
    efficientnet.parameters(),
    lr=3e-5,
    weight_decay=1e-4,
    eps=1e-8
)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

print("Phase 2 optimizer ready")

In [None]:
PHASE2_EPOCHS = 10
best_val_acc_phase2 = 0.0

for epoch in range(PHASE2_EPOCHS):

    # -------- Training --------
    efficientnet.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0

    for images, labels in train_loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()
        outputs = efficientnet(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        _, preds = torch.max(outputs, 1)
        train_loss += loss.item() * labels.size(0)
        train_correct += (preds == labels).sum().item()
        train_total += labels.size(0)

    train_loss /= train_total
    train_acc = train_correct / train_total

    # -------- Validation --------
    efficientnet.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            outputs = efficientnet(images)
            loss = criterion(outputs, labels)

            _, preds = torch.max(outputs, 1)
            val_loss += loss.item() * labels.size(0)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    val_loss /= val_total
    val_acc = val_correct / val_total

    print(
        f"[Phase 2] Epoch [{epoch+1}/{PHASE2_EPOCHS}] | "
        f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} || "
        f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}"
    )

    if val_acc > best_val_acc_phase2:
        best_val_acc_phase2 = val_acc
        torch.save(efficientnet.state_dict(), "efficientnet_phase2_best.pth")

In [None]:
efficientnet.load_state_dict(
    torch.load("efficientnet_phase2_best.pth", map_location=DEVICE)
)

print("Loaded Phase-2 best model")

In [None]:
from torchvision import transforms

tta_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
])

In [None]:
efficientnet.eval()

test_correct = 0
test_total = 0

all_preds_eff = []
all_labels_eff = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        outputs = efficientnet(images)
        preds = outputs.argmax(dim=1)

        test_correct += (preds == labels).sum().item()
        test_total += labels.size(0)

        all_preds_eff.extend(preds.cpu().numpy())
        all_labels_eff.extend(labels.cpu().numpy())

test_acc = test_correct / test_total
efficientnet_test_acc = test_acc

print(f"EfficientNet Test Accuracy: {test_acc:.4f}")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

class_names = train_dataset.classes

cm = confusion_matrix(all_labels_eff, all_preds_eff)

plt.figure(figsize=(12, 10))
sns.heatmap(
    cm,
    cmap="Blues",
    xticklabels=class_names,
    yticklabels=class_names,
    annot=False
)

plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("EfficientNetB0 — Confusion Matrix (Test Set)")
plt.tight_layout()
plt.show()

In [None]:
from sklearn.metrics import classification_report

class_names = train_dataset.classes

report = classification_report(
    all_labels_eff,
    all_preds_eff,
    target_names=class_names,
    digits=4
)

print("EfficientNetB0 — Classification Report (Test Set)")
print(report)