# DeiT-Tiny Inference & Evaluation

Inference and evaluation notebook for a fine-tuned `deit_tiny_patch16_224` model.

This notebook:
- Loads the multi-task ViT model
- Runs inference on a chosen split (`train`, `val`, or `test` if available)
- Computes the following metrics for the main *class* prediction task:
  - Accuracy
  - Macro Precision / Recall / F1
  - Micro Precision / Recall / F1
  - Confusion Matrix


In [None]:
import os

# Paths (adjust if your project layout is different)
DATA_ROOT = "data"  # folder containing images/, labels.csv, attributes.yaml
OUTPUT_DIR = "outputs"
CHECKPOINT_PATH = os.path.join(OUTPUT_DIR, "best_model.pth")  # fine-tuned weights

# Data / evaluation settings
SPLIT = "val"              # typically "val" or "test"
BATCH_SIZE = 32
NUM_WORKERS = 2

import torch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)


In [None]:
import yaml
import numpy as np

from torch.utils.data import DataLoader
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    confusion_matrix,
)

from dataset import EverydayObjectsDataset
from model import MultiTaskViT

import matplotlib.pyplot as plt


In [None]:
# Load train split once to infer number of classes and class mapping
IMG_SIZE = 224  # must match training

train_ds = EverydayObjectsDataset(root=DATA_ROOT, split="train", img_size=IMG_SIZE)
num_classes = len(train_ds.class_to_idx)
print("Num classes:", num_classes)

# Build index -> class name mapping
idx_to_class = {idx: name for name, idx in train_ds.class_to_idx.items()}
class_names = [idx_to_class[i] for i in range(len(idx_to_class))]
print("Classes:", class_names)

# Attribute schema (used to construct the model)
ATTR_YAML_NAME = "attributes.yaml"
with open(os.path.join(DATA_ROOT, ATTR_YAML_NAME)) as f:
    attr_schema = yaml.safe_load(f)
print("Attributes:", list(attr_schema.keys()))

# Evaluation dataset / loader
eval_ds = EverydayObjectsDataset(root=DATA_ROOT, split=SPLIT, img_size=IMG_SIZE)
eval_loader = DataLoader(
    eval_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
)
len(eval_loader), len(eval_ds)


In [None]:
# Build model and load checkpoint
BACKBONE_NAME = "deit_tiny_patch16_224"

model = MultiTaskViT(
    backbone_name=BACKBONE_NAME,
    num_classes=num_classes,
    attr_schema=attr_schema,
    pretrained=False,  # weights will come from the checkpoint
)

assert os.path.exists(CHECKPOINT_PATH), f"Checkpoint not found: {CHECKPOINT_PATH}"

checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu")
# Support either plain state_dict or dict with 'model' key
state_dict = checkpoint.get("model", checkpoint)
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()

print("Loaded model from:", CHECKPOINT_PATH)


In [None]:
# Run inference and collect predictions for the main class task
all_true = []
all_pred = []

with torch.no_grad():
    for batch in eval_loader:
        (
            imgs,
            y_class,
            y_color,
            y_material,
            y_condition,
            y_size,
            meta,
        ) = batch
        
        imgs = imgs.to(DEVICE)
        y_class = y_class.to(DEVICE)
        
        feats, class_logits, attr_logits = model(imgs)
        preds = class_logits.argmax(dim=1)
        
        all_true.extend(y_class.cpu().numpy())
        all_pred.extend(preds.cpu().numpy())

all_true = np.array(all_true)
all_pred = np.array(all_pred)

print("Number of samples:", len(all_true))


In [None]:
# Compute metrics
acc = accuracy_score(all_true, all_pred)

prec_macro, rec_macro, f1_macro, _ = precision_recall_fscore_support(
    all_true, all_pred, average="macro", zero_division=0
)
prec_micro, rec_micro, f1_micro, _ = precision_recall_fscore_support(
    all_true, all_pred, average="micro", zero_division=0
)

cm = confusion_matrix(all_true, all_pred)

print(f"Accuracy:       {acc:.4f}")
print(f"Macro  P/R/F1:  {prec_macro:.4f} / {rec_macro:.4f} / {f1_macro:.4f}")
print(f"Micro  P/R/F1:  {prec_micro:.4f} / {rec_micro:.4f} / {f1_micro:.4f}")


In [None]:
# Plot confusion matrix
fig, ax = plt.subplots(figsize=(8, 8))
im = ax.imshow(cm)

ax.set_xticks(range(len(class_names)))
ax.set_yticks(range(len(class_names)))
ax.set_xticklabels(class_names, rotation=45, ha="right")
ax.set_yticklabels(class_names)

ax.set_xlabel("Predicted label")
ax.set_ylabel("True label")
ax.set_title("Confusion Matrix")

# Show values in cells
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        ax.text(j, i, cm[i, j], ha="center", va="center")

plt.tight_layout()
plt.show()
