# Model Training — Fine-Grained Sans-Serif Font Recognition

This notebook trains a deep CNN to classify **20 sans-serif font families** from synthetic text images.

**Architecture:** **EfficientNet-B2** (pretrained on ImageNet) with a custom classification head — chosen for fine-grained recognition because:
- **Compound scaling** captures both coarse structure (letter shape) and fine detail (stroke width, terminal geometry) simultaneously
- Superior parameter efficiency vs ResNet/VGG at similar accuracy
- Proven on fine-grained benchmarks (CUB-200, Stanford Cars, etc.)

**Training strategy:**
1. **Frozen backbone warm-up** (5 epochs) — train only the classifier head on top of frozen ImageNet features
2. **Full fine-tuning** (25 epochs) — unfreeze all layers with differential learning rates (backbone 10× lower) and cosine annealing
3. **Mixed-precision (AMP)** — faster training on Kaggle T4/P100 GPUs

**Outputs:** Best model checkpoint (`.pth`) + full training metadata for local inference.

In [None]:
import os
import copy
import time
import json
import random
from collections import defaultdict

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler, autocast
import torchvision
from torchvision import datasets, transforms, models

# ── Reproducibility ──
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"PyTorch {torch.__version__}  •  Device: {DEVICE}")
if DEVICE.type == "cuda":
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    print(f"  Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

## 1 · Configuration

In [None]:
IS_KAGGLE = os.path.exists("/kaggle")

if IS_KAGGLE:
    DATA_DIR   = "/kaggle/input/domain-randomized-sans-serif-fonts"
    SAVE_DIR   = "/kaggle/working"
else:
    DATA_DIR   = "../data/processed"
    SAVE_DIR   = "../models"

TRAIN_DIR = os.path.join(DATA_DIR, "train")
TEST_DIR  = os.path.join(DATA_DIR, "test")

os.makedirs(SAVE_DIR, exist_ok=True)

# Hyperparameters

NUM_CLASSES      = 20
IMG_SIZE         = 224
BATCH_SIZE       = 64
NUM_WORKERS      = 2

# Phase 1: Frozen backbone warm-up
WARMUP_EPOCHS    = 5
WARMUP_LR        = 1e-3

# Phase 2: Full fine-tuning
FINETUNE_EPOCHS  = 25
BACKBONE_LR      = 1e-4         # Lower LR for pretrained layers
HEAD_LR          = 1e-3         # Higher LR for new classifier head
WEIGHT_DECAY     = 1e-4

# Early stopping
PATIENCE         = 7            # Stop if val loss doesn't improve for 7 epochs

## 2 · Data Pipeline

Images are variable-sized text rectangles, resized to **224×224** and normalized to ImageNet statistics.

Since the dataset was already **heavily augmented offline** (7 TRDG profiles + post-processing with JPEG compression, brightness/contrast jitter, salt-and-pepper noise, and dark-mode inversion), the online training transforms are kept **minimal and non-overlapping**:

- **Random horizontal flip** (p=0.05) — very rare, simulates mirrored screenshots
- **Random erasing** (p=0.1) — simulates partial occlusion (NOT covered offline)
- **Random grayscale** (p=0.05) — desaturated captures

**Test transforms** apply only deterministic resize + normalization (no augmentation).

In [None]:
# Transforms  
# ImageNet normalization stats (required for pretrained EfficientNet)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

train_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.05),    # rare flip — mirrored screenshots
    transforms.RandomGrayscale(p=0.05),          # desaturated captures
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    transforms.RandomErasing(p=0.1, scale=(0.02, 0.08)),  # partial occlusion (not in offline aug)
])

test_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

# Datasets & Loaders  

train_dataset = datasets.ImageFolder(TRAIN_DIR, transform=train_transforms)
test_dataset  = datasets.ImageFolder(TEST_DIR,  transform=test_transforms)

train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=True, drop_last=True,
)
test_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True,
)

# Class names (sorted alphabetically by ImageFolder)
CLASS_NAMES = train_dataset.classes
idx_to_class = {i: c for i, c in enumerate(CLASS_NAMES)}

print(f"Train : {len(train_dataset):,} images")
print(f"Test  : {len(test_dataset):,} images")
print(f"Classes ({len(CLASS_NAMES)}): {CLASS_NAMES}")

In [None]:
# Quick sanity check: visualize a batch

def denormalize(tensor, mean=IMAGENET_MEAN, std=IMAGENET_STD):
    """Undo ImageNet normalization for display."""
    t = tensor.clone()
    for ch, m, s in zip(t, mean, std):
        ch.mul_(s).add_(m)
    return t.clamp_(0, 1)

imgs, labels = next(iter(train_loader))

fig, axes = plt.subplots(2, 8, figsize=(20, 5))
fig.suptitle("Training Batch Sample (with augmentation)", fontsize=14)
for i, ax in enumerate(axes.flat):
    img = denormalize(imgs[i]).permute(1, 2, 0).numpy()
    ax.imshow(img)
    ax.set_title(CLASS_NAMES[labels[i]], fontsize=9)
    ax.axis("off")
plt.tight_layout()
plt.show()

## 3 · Model Architecture

**EfficientNet-B2** backbone with a custom classification head:

```
EfficientNet-B2 (frozen/unfrozen backbone)
    └─ AdaptiveAvgPool → 1408-d feature vector
        └─ Dropout(0.3)
            └─ Linear(1408 → 512) + ReLU + BatchNorm
                └─ Dropout(0.2)
                    └─ Linear(512 → 20)
```

The two-layer head with BatchNorm helps the model learn a more discriminative font embedding compared to a single linear layer.

In [None]:
def build_model(num_classes: int = NUM_CLASSES, dropout: float = 0.3) -> nn.Module:
    """
    EfficientNet-B2 with a custom 2-layer classification head.
    Returns the model with backbone frozen (ready for Phase 1).
    """
    # Load pretrained EfficientNet-B2
    weights = models.EfficientNet_B2_Weights.IMAGENET1K_V1
    model = models.efficientnet_b2(weights=weights)
    
    # Freeze backbone
    for param in model.features.parameters():
        param.requires_grad = False
    
    # ─Replace classifier head
    # EfficientNet-B2 has 1408 features before the classifier
    in_features = model.classifier[1].in_features  # 1408
    
    model.classifier = nn.Sequential(
        nn.Dropout(p=dropout),
        nn.Linear(in_features, 512),
        nn.ReLU(inplace=True),
        nn.BatchNorm1d(512),
        nn.Dropout(p=0.2),
        nn.Linear(512, num_classes),
    )
    
    return model

model = build_model().to(DEVICE)

# Count parameters
total_params   = sum(p.numel() for p in model.parameters())
frozen_params  = sum(p.numel() for p in model.parameters() if not p.requires_grad)
trainable      = total_params - frozen_params

print(f"Model: EfficientNet-B2")
print(f"  Total params     : {total_params:>10,}")
print(f"  Frozen (backbone): {frozen_params:>10,}")
print(f"  Trainable (head) : {trainable:>10,}")

## 4 · Training Engine

Core training/validation loop with:
- **Mixed-precision (AMP)** for ~2× speedup on T4
- **Gradient clipping** to stabilize fine-tuning
- **Early stopping** on validation loss to prevent overfitting
- Per-epoch metrics tracking for visualization

In [None]:
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)  # slight smoothing helps fine-grained tasks
scaler   = GradScaler()                               # for mixed-precision

# ── History tracking ──
history = {
    "train_loss": [], "train_acc": [],
    "val_loss": [],   "val_acc": [],
    "lr": [],
}

def train_one_epoch(model, loader, optimizer, epoch_num):
    """Train for one epoch with AMP. Returns (avg_loss, accuracy)."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc=f"  Train Ep {epoch_num}", leave=False)
    for images, labels in pbar:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        
        optimizer.zero_grad()
        with autocast():
            outputs = model(images)
            loss = criterion(outputs, labels)
        
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        
        running_loss += loss.item() * images.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)
        
        pbar.set_postfix(loss=f"{loss.item():.4f}", acc=f"{correct/total:.4f}")
    
    return running_loss / total, correct / total


@torch.no_grad()
def evaluate(model, loader, epoch_num):
    """Evaluate on validation/test set. Returns (avg_loss, accuracy)."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc=f"  Val   Ep {epoch_num}", leave=False)
    for images, labels in pbar:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        
        with autocast():
            outputs = model(images)
            loss = criterion(outputs, labels)
        
        running_loss += loss.item() * images.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)
    
    return running_loss / total, correct / total


def run_training(model, optimizer, scheduler, num_epochs, phase_name="Training"):
    """
    Run training loop with early stopping.
    Returns the best model state dict and best val accuracy.
    """
    best_val_acc  = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())
    epochs_no_improve = 0
    best_val_loss = float("inf")
    
    print(f"\n{'='*60}")
    print(f" {phase_name} — {num_epochs} epochs")
    print(f"{'='*60}")
    
    for epoch in range(1, num_epochs + 1):
        t0 = time.time()
        
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, epoch)
        val_loss, val_acc     = evaluate(model, test_loader, epoch)
        
        if scheduler is not None:
            scheduler.step()
        
        # Track current LR
        current_lr = optimizer.param_groups[0]["lr"]
        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)
        history["lr"].append(current_lr)
        
        elapsed = time.time() - t0
        print(f"  Epoch {epoch:>2d}/{num_epochs}  "
              f"train_loss={train_loss:.4f}  train_acc={train_acc:.4f}  "
              f"val_loss={val_loss:.4f}  val_acc={val_acc:.4f}  "
              f"lr={current_lr:.2e}  ({elapsed:.0f}s)")
        
        # Early stopping logic
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_val_acc  = val_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            epochs_no_improve = 0
            print(f"    ★ New best model (val_acc={val_acc:.4f})")
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= PATIENCE:
                print(f"    ✗ Early stopping triggered (no improvement for {PATIENCE} epochs)")
                break
    
    # Restore best weights
    model.load_state_dict(best_model_wts)
    print(f"\n  ✓ {phase_name} complete — best val_acc: {best_val_acc:.4f}")
    return best_val_acc

print("✓ Training engine ready")

## 5 · Phase 1 — Frozen Backbone Warm-Up

Train **only the classifier head** (backbone frozen) for 5 epochs. This lets the new layers converge to reasonable weights before we unfreeze the backbone and risk corrupting pretrained features.

In [None]:
# Phase 1: Only classifier head is trainable (backbone frozen from build_model)
optimizer_warmup = optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=WARMUP_LR,
    weight_decay=WEIGHT_DECAY,
)
scheduler_warmup = optim.lr_scheduler.CosineAnnealingLR(
    optimizer_warmup, T_max=WARMUP_EPOCHS, eta_min=1e-5
)

phase1_acc = run_training(
    model, optimizer_warmup, scheduler_warmup,
    num_epochs=WARMUP_EPOCHS,
    phase_name="Phase 1 · Frozen Backbone Warm-Up",
)

## 6 · Phase 2 — Full Fine-Tuning

Unfreeze the **entire backbone** and train with **differential learning rates**:
- Backbone layers: `1e-4` (gentle updates to preserve pretrained features)
- Classifier head: `1e-3` (aggressive updates for task-specific learning)

Cosine annealing smoothly decays both rates to near-zero over 25 epochs.

In [None]:
# Unfreeze backbone
for param in model.features.parameters():
    param.requires_grad = True

# Differential learning rates
optimizer_finetune = optim.AdamW([
    {"params": model.features.parameters(), "lr": BACKBONE_LR},   # backbone: lower LR
    {"params": model.classifier.parameters(), "lr": HEAD_LR},     # head: higher LR
], weight_decay=WEIGHT_DECAY)

scheduler_finetune = optim.lr_scheduler.CosineAnnealingLR(
    optimizer_finetune, T_max=FINETUNE_EPOCHS, eta_min=1e-6
)

# Report new param counts
trainable_now = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Unfroze backbone — trainable params: {trainable_now:,}")

phase2_acc = run_training(
    model, optimizer_finetune, scheduler_finetune,
    num_epochs=FINETUNE_EPOCHS,
    phase_name="Phase 2 · Full Fine-Tuning",
)

## 7 · Training Curves

Plot loss, accuracy, and learning rate across both training phases.

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
epochs_range = range(1, len(history["train_loss"]) + 1)

# Phase boundary line
phase_boundary = WARMUP_EPOCHS + 0.5

# Loss
ax = axes[0]
ax.plot(epochs_range, history["train_loss"], "b-o", markersize=3, label="Train Loss")
ax.plot(epochs_range, history["val_loss"],   "r-o", markersize=3, label="Val Loss")
ax.axvline(phase_boundary, color="grey", ls="--", alpha=0.5, label="Phase 1→2")
ax.set_xlabel("Epoch"); ax.set_ylabel("Loss"); ax.set_title("Loss")
ax.legend(); ax.grid(True, alpha=0.3)

# Accuracy
ax = axes[1]
ax.plot(epochs_range, history["train_acc"], "b-o", markersize=3, label="Train Acc")
ax.plot(epochs_range, history["val_acc"],   "r-o", markersize=3, label="Val Acc")
ax.axvline(phase_boundary, color="grey", ls="--", alpha=0.5, label="Phase 1→2")
ax.set_xlabel("Epoch"); ax.set_ylabel("Accuracy"); ax.set_title("Accuracy")
ax.legend(); ax.grid(True, alpha=0.3)

# Learning Rate
ax = axes[2]
ax.plot(epochs_range, history["lr"], "g-o", markersize=3)
ax.axvline(phase_boundary, color="grey", ls="--", alpha=0.5, label="Phase 1→2")
ax.set_xlabel("Epoch"); ax.set_ylabel("LR"); ax.set_title("Learning Rate Schedule")
ax.set_yscale("log"); ax.legend(); ax.grid(True, alpha=0.3)

plt.suptitle("Training History (Phase 1: Warm-Up  |  Phase 2: Fine-Tune)", fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

## 8 · Evaluation — Confusion Matrix & Classification Report

Full evaluation on the **held-out test set** (20% of data, never seen during training).
- **Confusion matrix** reveals which font pairs the model struggles to distinguish
- **Per-class precision / recall / F1** for the paper

In [None]:
from sklearn.metrics import classification_report, confusion_matrix

@torch.no_grad()
def get_all_predictions(model, loader):
    """Collect all predictions and ground truth labels."""
    model.eval()
    all_preds  = []
    all_labels = []
    all_probs  = []
    
    for images, labels in tqdm(loader, desc="Evaluating", leave=False):
        images = images.to(DEVICE)
        with autocast():
            outputs = model(images)
        
        probs = torch.softmax(outputs, dim=1).cpu()
        preds = outputs.argmax(dim=1).cpu()
        
        all_preds.append(preds)
        all_labels.append(labels)
        all_probs.append(probs)
    
    return (
        torch.cat(all_preds).numpy(),
        torch.cat(all_labels).numpy(),
        torch.cat(all_probs).numpy(),
    )

y_pred, y_true, y_probs = get_all_predictions(model, test_loader)

# Classification Report
print("=" * 70)
print(" CLASSIFICATION REPORT")
print("=" * 70)
print(classification_report(y_true, y_pred, target_names=CLASS_NAMES, digits=4))

# Overall accuracy
overall_acc = (y_pred == y_true).mean()
print(f"\n  Overall Test Accuracy: {overall_acc:.4f} ({overall_acc*100:.2f}%)")

In [None]:
# Confusion Matrix Heatmap
cm = confusion_matrix(y_true, y_pred)
cm_pct = cm.astype("float") / cm.sum(axis=1, keepdims=True) * 100  # normalize to %

fig, axes = plt.subplots(1, 2, figsize=(24, 10))

# Raw counts
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES, ax=axes[0])
axes[0].set_title("Confusion Matrix (Counts)", fontsize=14)
axes[0].set_xlabel("Predicted"); axes[0].set_ylabel("True")
axes[0].tick_params(axis="x", rotation=45)

# Percentage (per-row normalized)
sns.heatmap(cm_pct, annot=True, fmt=".1f", cmap="Oranges",
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES, ax=axes[1])
axes[1].set_title("Confusion Matrix (% per True Class)", fontsize=14)
axes[1].set_xlabel("Predicted"); axes[1].set_ylabel("True")
axes[1].tick_params(axis="x", rotation=45)

plt.tight_layout()
plt.show()

# Top confused pairs (for paper discussion)
print("\nTop 10 Most Confused Font Pairs:")
print("-" * 50)
np.fill_diagonal(cm, 0)  # ignore correct predictions
for _ in range(10):
    i, j = np.unravel_index(cm.argmax(), cm.shape)
    if cm[i, j] == 0:
        break
    print(f"  {CLASS_NAMES[i]:>14s} → {CLASS_NAMES[j]:<14s}  ({cm[i,j]} misclassifications)")
    cm[i, j] = 0

## 9 · Per-Class Accuracy Bar Chart

Visual breakdown of which fonts are easiest/hardest to recognize — useful for the paper's analysis section.

In [None]:
# Recompute clean confusion matrix for per-class accuracy
cm_clean = confusion_matrix(y_true, y_pred)
per_class_acc = cm_clean.diagonal() / cm_clean.sum(axis=1)

# Sort by accuracy
sorted_idx = np.argsort(per_class_acc)
sorted_names = [CLASS_NAMES[i] for i in sorted_idx]
sorted_acc   = per_class_acc[sorted_idx]

# Color: green for high accuracy, red for low
colors = plt.cm.RdYlGn(sorted_acc)

fig, ax = plt.subplots(figsize=(12, 8))
bars = ax.barh(sorted_names, sorted_acc * 100, color=colors, edgecolor="grey", linewidth=0.5)

# Add value labels
for bar, acc in zip(bars, sorted_acc):
    ax.text(bar.get_width() + 0.3, bar.get_y() + bar.get_height()/2,
            f"{acc*100:.1f}%", va="center", fontsize=10)

ax.set_xlabel("Accuracy (%)", fontsize=12)
ax.set_title("Per-Class Test Accuracy", fontsize=14)
ax.set_xlim(0, 105)
ax.axvline(overall_acc * 100, color="blue", ls="--", alpha=0.5, label=f"Mean: {overall_acc*100:.1f}%")
ax.legend(fontsize=11)
ax.grid(axis="x", alpha=0.3)
plt.tight_layout()
plt.show()

## 10 · Sample Predictions

Visual spot-check: show 20 random test images with true vs predicted labels. Green = correct, Red = wrong.

In [None]:
# Random sample predictions
rng = np.random.RandomState(42)
sample_indices = rng.choice(len(test_dataset), size=20, replace=False)

fig, axes = plt.subplots(4, 5, figsize=(20, 12))
fig.suptitle("Sample Test Predictions  (Green=Correct, Red=Wrong)", fontsize=15, y=1.01)

for i, ax in enumerate(axes.flat):
    idx = sample_indices[i]
    img_tensor, true_label = test_dataset[idx]
    
    # Display image (denormalized)
    img_display = denormalize(img_tensor).permute(1, 2, 0).numpy()
    ax.imshow(img_display)
    
    # Prediction
    pred_label = y_pred[idx]
    true_name = CLASS_NAMES[true_label]
    pred_name = CLASS_NAMES[pred_label]
    confidence = y_probs[idx][pred_label] * 100
    
    correct = pred_label == true_label
    color = "green" if correct else "red"
    
    ax.set_title(
        f"True: {true_name}\nPred: {pred_name} ({confidence:.0f}%)",
        fontsize=9, color=color, fontweight="bold"
    )
    ax.axis("off")

plt.tight_layout()
plt.show()

print(f"\nFinal Test Accuracy: {overall_acc*100:.2f}%")
print("Training complete! Download best_model.pth and model_metadata.json from Kaggle Output.")

## 11 · Save Model

In [None]:
# Save Checkpoint

checkpoint_path = os.path.join(SAVE_DIR, "best_model.pth")
metadata_path   = os.path.join(SAVE_DIR, "model_metadata.json")

# Checkpoint: state dict + everything needed to reconstruct
checkpoint = {
    "model_state_dict": model.state_dict(),
    "architecture":     "efficientnet_b2",
    "num_classes":      NUM_CLASSES,
    "img_size":         IMG_SIZE,
    "class_names":      CLASS_NAMES,
    "idx_to_class":     idx_to_class,
    "imagenet_mean":    IMAGENET_MEAN,
    "imagenet_std":     IMAGENET_STD,
    "best_val_acc":     phase2_acc,
    "total_epochs":     len(history["train_loss"]),
    "history":          history,
}

torch.save(checkpoint, checkpoint_path)
ckpt_size = os.path.getsize(checkpoint_path) / (1024 * 1024)
print(f"✓ Checkpoint saved: {checkpoint_path} ({ckpt_size:.1f} MB)")

# Human-readable metadata JSON (for the inference notebook)
metadata = {
    "architecture":  "efficientnet_b2",
    "num_classes":   NUM_CLASSES,
    "img_size":      IMG_SIZE,
    "class_names":   CLASS_NAMES,
    "imagenet_mean": IMAGENET_MEAN,
    "imagenet_std":  IMAGENET_STD,
    "best_val_acc":  round(phase2_acc, 4),
    "training_info": {
        "warmup_epochs":   WARMUP_EPOCHS,
        "finetune_epochs": FINETUNE_EPOCHS,
        "total_epochs":    len(history["train_loss"]),
        "batch_size":      BATCH_SIZE,
        "backbone_lr":     BACKBONE_LR,
        "head_lr":         HEAD_LR,
        "label_smoothing": 0.1,
    },
}

with open(metadata_path, "w") as f:
    json.dump(metadata, f, indent=2)
print(f"✓ Metadata saved: {metadata_path}")

print(f"\n  Download these two files from Kaggle Output:")
print(f"    1. {os.path.basename(checkpoint_path)}")
print(f"    2. {os.path.basename(metadata_path)}")
print(f"  Place them in your local  models/  directory for inference.")