# Focused Error Analysis on Validation Set (Artportalen, AgeModel)

This notebook loads your AgeModel and Artportalen validation set, runs inference, builds a confusion matrix, computes per-class metrics, and helps you review high-confidence mistakes.

In [1]:
# ---- Set your paths here ----
config_path = '/Users/amee/Documents/code/master-thesis/EagleID/configs/config-hpc-artportalen.yml'
ckpt_path = '/Users/amee/Documents/code/master-thesis/EagleID/checkpoints/agemodel2.ckpt'
val_img_dir = '/Users/amee/Documents/code/master-thesis/datasets/artportalen_goeag'
val_csv = '/Users/amee/Documents/code/master-thesis/AgeClassifier/annot/final_val_sep_sightings.csv'
out_path = 'confusion_matrix-yA.png'


In [2]:
import sys
sys.path.append('..')

import torch
import yaml
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from torch.utils.data import DataLoader
from models.age_model import AgeModel
from data.artportalen_goleag import ArtportalenDataModule
import os


In [3]:
# Load config and override paths
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

config['dataset'] = val_img_dir
config['cache_path'] = val_csv
train_csv = val_csv

preprocess_lvl = config.get('preprocess_lvl', 2)
batch_size = 16
img_size = config.get('img_size', 224)
mean = config['transforms']['mean'][0] if isinstance(config['transforms']['mean'], list) else config['transforms']['mean']
std = config['transforms']['std'][0] if isinstance(config['transforms']['std'], list) else config['transforms']['std']
num_classes = config.get('num_classes', 5)
class_names = [str(i) for i in range(num_classes)]


In [4]:
# Load model
model = AgeModel(config=config, num_classes=num_classes, pretrained=False)
checkpoint = torch.load(ckpt_path, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'])
model.eval()


AgeModel(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (drop_block): Identity()
        (act2): ReLU(inplace=True)
        (aa): Identity()
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, mo

In [5]:
# Load validation data
data_module = ArtportalenDataModule(
    data_dir=val_img_dir,
    preprocess_lvl=preprocess_lvl,
    batch_size=batch_size,
    size=img_size,
    mean=mean,
    std=std,
    test=True
)
data_module.setup_from_csv(train_csv=train_csv, val_csv=val_csv)
val_dataset = data_module.val_dataset
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)


Train: 840 Val: 840
Unique classes in dataset: [4 5 1 2 3]
Number of classes: 5


In [None]:
# Run inference and collect results
device = torch.device('cuda' if torch.cuda.is_available() else 'mps')
model.to(device)

results = []
i=0

for batch_idx, batch in enumerate(val_loader):
    print(f'Batch {i} of {len(val_loader)}')
    x, y = batch[:2]
    x = x.to(device)
    y = y.to(device)
    logits = model(x)
    if hasattr(model, '_decode'):
        preds = model._decode(logits)
    else:
        preds = torch.argmax(logits, dim=1)
    probs = torch.sigmoid(logits) if logits.shape[-1] == num_classes - 1 else torch.softmax(logits, dim=1)
    for j in range(x.size(0)):
        true_label = y[j].item()
        pred_label = preds[j].item()
        logit_row = logits[j].cpu().detach().numpy()
        confidence = float(probs[j].max().cpu().item())
        try:
            img_path = val_dataset.dataframe.iloc[batch_idx * batch_size + j]['file_name']
        except Exception:
            img_path = None
        results.append({
            'img_path': img_path,
            'true_label': true_label,
            'pred_label': pred_label,
            'confidence': confidence,
            'logits': logit_row
        })
    i+=1

df = pd.DataFrame(results)
df.to_csv('val_error_analysis.csv', index=False)
df.head()


Batch 0 of 53
Batch 16 of 53
Batch 32 of 53
Batch 48 of 53
Batch 64 of 53
Batch 80 of 53
Batch 96 of 53
Batch 112 of 53
Batch 128 of 53
Batch 144 of 53
Batch 160 of 53
Batch 176 of 53
Batch 192 of 53
Batch 208 of 53
Batch 224 of 53
Batch 240 of 53
Batch 256 of 53
Batch 272 of 53
Batch 288 of 53
Batch 304 of 53
Batch 320 of 53
Batch 336 of 53
Batch 352 of 53
Batch 368 of 53
Batch 384 of 53
Batch 400 of 53
Batch 416 of 53
Batch 432 of 53
Batch 448 of 53
Batch 464 of 53
Batch 480 of 53
Batch 496 of 53
Batch 512 of 53
Batch 528 of 53
Batch 544 of 53
Batch 560 of 53
Batch 576 of 53
Batch 592 of 53
Batch 608 of 53
Batch 624 of 53
Batch 640 of 53
Batch 656 of 53
Batch 672 of 53
Batch 688 of 53
Batch 704 of 53
Batch 720 of 53
Batch 736 of 53
Batch 752 of 53
Batch 768 of 53
Batch 784 of 53
Batch 800 of 53
Batch 816 of 53
Batch 832 of 53


In [None]:
# Confusion matrix and per-class metrics
cm = confusion_matrix(df['true_label'], df['pred_label'])
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

print(classification_report(df['true_label'], df['pred_label'], target_names=class_names))


In [None]:
# Find and review high-confidence mistakes
mistakes = df[df['true_label'] != df['pred_label']]
top_conf_mistakes = mistakes.sort_values('confidence', ascending=False).head(20)
top_conf_mistakes.to_csv('top_confidence_mistakes.csv', index=False)
top_conf_mistakes[['img_path', 'true_label', 'pred_label', 'confidence']]


In [None]:
# (Optional) Display images inline for manual review
from PIL import Image
from IPython.display import display

for idx, row in top_conf_mistakes.iterrows():
    if row['img_path'] is not None:
        img = Image.open(os.path.join(val_img_dir, row['img_path']))
        print(f'True: {row['true_label']} | Pred: {row['pred_label']} | Conf: {row['confidence']:.2f}')
        display(img)
