# 04 - Model Evaluation

Test set evaluation

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from pathlib import Path
from sklearn.metrics import confusion_matrix, classification_report

# Use GPU if available
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print(f"Device: {device}")

In [None]:
DATA_PATH = Path("../data/raw/soil-classification/Orignal-Dataset")
OUTPUTS_PATH = Path("../outputs")
CHECKPOINT_PATH = OUTPUTS_PATH / "checkpoints"
CHECKPOINT_PATH.mkdir(parents=True, exist_ok=True)

# variables (ImageNet standard)
IMG_DEFAULT_SIZE = 256
IMG_CROP_SIZE = 224
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# hyperparameters
BATCH_SIZE = 32
NUM_CLASSES = 7
SEED = 24

In [None]:
test_transform = transforms.Compose([
    transforms.Resize(IMG_DEFAULT_SIZE),
    transforms.CenterCrop(IMG_CROP_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
])

In [None]:
full_dataset = ImageFolder(root=DATA_PATH, transform=test_transform)

train_size = int(0.7 * len(full_dataset))
val_size = int(0.15 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

_, _, test_dataset = random_split(
    full_dataset,
    [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(SEED)
)

test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Test set: {len(test_dataset)} images")
print(f"Classes: {full_dataset.classes}")

In [None]:
model = models.efficientnet_b0()
in_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(in_features, NUM_CLASSES)

checkpoint = torch.load(CHECKPOINT_PATH/'best_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()

print(f"Loaded model from epoch {checkpoint['epoch']}")
print(f"Val acc was: {checkpoint['val_acc']:.2f}%")

In [None]:
all_preds = []
all_labels = []
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        _, preds = torch.max(outputs, 1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

        total += labels.size(0)
        correct += (preds == labels).sum().item()

test_acc = 100 * correct / total
print(f"\nTest Accuracy: {test_acc:.2f}%")
print(f"Correct: {correct}/{total}")

In [None]:
cm = confusion_matrix(all_labels, all_preds)
class_names = full_dataset.classes

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix')
plt.ylabel('True')
plt.xlabel('Predicted')
plt.tight_layout()
plt.savefig(OUTPUTS_PATH / 'confusion_matrix.png', dpi=150)
plt.show()

In [None]:
report = classification_report(all_labels, all_preds, target_names=class_names)
print("Classification Report:")
print(report)

In [None]:
class_correct = [0] * NUM_CLASSES
class_total = [0] * NUM_CLASSES

for label, pred in zip(all_labels, all_preds):
    class_total[label] += 1
    if label == pred:
        class_correct[label] += 1

print("\nPer-class accuracy:")
for i, name in enumerate(class_names):
    acc = 100 * class_correct[i] / class_total[i] if class_total[i] > 0 else 0
    print(f"{name}: {acc:.2f}% ({class_correct[i]}/{class_total[i]})")

In [None]:
accuracies = [100 * class_correct[i] / class_total[i] if class_total[i] > 0 else 0
              for i in range(NUM_CLASSES)]

plt.figure(figsize=(10, 6))
plt.bar(class_names, accuracies)
plt.axhline(y=test_acc, color='red', linestyle='--', label=f'Overall: {test_acc:.2f}%')
plt.ylabel('Accuracy (%)')
plt.title('Per-class Accuracy')
plt.xticks(rotation=45, ha='right')
plt.legend()
plt.tight_layout()
plt.savefig(OUTPUTS_PATH / 'per_class_accuracy.png', dpi=150)
plt.show()

In [None]:
from collections import Counter

errors = []
for true_label, pred_label in zip(all_labels, all_preds):
    if true_label != pred_label:
        errors.append((class_names[true_label], class_names[pred_label]))

error_counts = Counter(errors)

print(f"\nTotal errors: {len(errors)}")
print("Most common mistakes:")
for (true_class, pred_class), count in error_counts.most_common(5):
    print(f"  {true_class} â†’ {pred_class}: {count}x")

In [None]:
print("\nSummary:")
print(f"Test acc: {test_acc:.2f}%")
print(f"Val acc: {checkpoint['val_acc']:.2f}%")
print(f"Best class: {class_names[np.argmax(accuracies)]} ({max(accuracies):.2f}%)")
print(f"Worst: {class_names[np.argmin(accuracies)]} ({min(accuracies):.2f}%)")  # only 3 samples though
print(f"Errors: {len(errors)}/{len(test_dataset)}")

In [None]:
import json

results = {
    'test_accuracy': test_acc,
    'val_accuracy': checkpoint['val_acc'],
    'num_test_samples': len(test_dataset),
    'per_class_accuracy': dict(zip(class_names, accuracies)),
    'confusion_matrix': cm.tolist()
}

with open(OUTPUTS_PATH / 'test_results.json', 'w') as f:
    json.dump(results, f, indent=2)

print("Saved")