# 04 - Training Breed Classifier

Fine-tuning di EfficientNet-B0 per la classificazione delle razze canine.

## Dataset
- **Stanford Dogs Dataset**: 120 razze, ~20,580 immagini
- Alternativa: subset delle razze pi√π comuni nei canili italiani

## Output
- Classificazione razza ‚Üí mapping a `P(stray|breed)` usando i prior statistici

## Strategia
1. Raggruppiamo le 120 razze in macro-categorie (pitbull, shepherd, retriever, etc.)
2. Questo semplifica il problema e allinea con i breed_priors.json

In [None]:
# Installazione dipendenze
%pip install torch torchvision timm albumentations matplotlib seaborn pandas scikit-learn tqdm -q

In [None]:
import os
import sys
from pathlib import Path
import json
import shutil
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from datetime import datetime
from collections import defaultdict
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import timm

import albumentations as A
from albumentations.pytorch import ToTensorV2

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

print(f"Python: {sys.version}")
print(f"PyTorch: {torch.__version__}")
print(f"Device: {'MPS' if torch.backends.mps.is_available() else 'CUDA' if torch.cuda.is_available() else 'CPU'}")

In [None]:
# Configurazione paths - RELATIVI per portabilit√†
import sys
sys.path.insert(0, str(Path.cwd()))
try:
    from notebook_utils import get_paths, get_device, print_paths
    paths = get_paths()
    print_paths(paths)
except ImportError:
    print("notebook_utils.py non trovato, usando fallback...")
    NOTEBOOK_DIR = Path.cwd()
    if NOTEBOOK_DIR.name == "notebooks":
        PROJECT_DIR = NOTEBOOK_DIR.parent.parent
    elif NOTEBOOK_DIR.name == "training":
        PROJECT_DIR = NOTEBOOK_DIR.parent
    else:
        PROJECT_DIR = NOTEBOOK_DIR
        while PROJECT_DIR.name != "ResQPet" and PROJECT_DIR.parent != PROJECT_DIR:
            PROJECT_DIR = PROJECT_DIR.parent
    BASE_DIR = PROJECT_DIR.parent
    paths = {
        'project_dir': PROJECT_DIR,
        'base_dir': BASE_DIR,
        'weights_dir': PROJECT_DIR / "backend" / "weights",
        'data_dir': PROJECT_DIR / "data",
        'stanford_dogs': BASE_DIR / "Stanford Dog",
    }
    paths['weights_dir'].mkdir(parents=True, exist_ok=True)

# Assegna variabili per retrocompatibilit√†
BASE_DIR = paths['base_dir']
STANFORD_DIR = paths['stanford_dogs']
PRIORS_FILE = paths['data_dir'] / "breed_priors.json"
OUTPUT_DIR = paths['weights_dir']

# Carica breed priors
if PRIORS_FILE.exists():
    with open(PRIORS_FILE, 'r') as f:
        breed_priors = json.load(f)
    print("Breed Priors (P(stray|breed)):")
    for breed, prob in breed_priors.items():
        if breed != '_metadata':
            print(f"  {breed}: {prob}")
else:
    print(f"‚ö†Ô∏è breed_priors.json non trovato in {PRIORS_FILE}")
    breed_priors = {'unknown': 0.5}

## 1. Mapping Razze ‚Üí Macro-Categorie

Stanford Dogs ha 120 razze specifiche. Le raggruppiamo nelle categorie usate per i prior.

In [None]:
# Mapping da razze Stanford Dogs a macro-categorie
# Le razze Stanford usano formato: "n02085620-Chihuahua" (synset-nome)

BREED_MAPPING = {
    # Pitbull/Amstaff - razze spesso in canile
    'pitbull_amstaff': [
        'American_Staffordshire_terrier', 'Staffordshire_bullterrier',
        'bull_mastiff', 'boxer', 'Great_Dane'
    ],
    
    # Pastori
    'shepherd': [
        'German_shepherd', 'Belgian_malinois', 'Australian_shepherd',
        'Border_collie', 'collie', 'Shetland_sheepdog', 'Old_English_sheepdog',
        'Bouvier_des_Flandres', 'briard', 'kelpie', 'komondor', 'kuvasz'
    ],
    
    # Retriever e cani da caccia
    'retriever': [
        'golden_retriever', 'Labrador_retriever', 'flat-coated_retriever',
        'curly-coated_retriever', 'Chesapeake_Bay_retriever',
        'Irish_setter', 'English_setter', 'Gordon_setter',
        'cocker_spaniel', 'English_springer', 'Welsh_springer_spaniel',
        'clumber', 'Sussex_spaniel', 'Irish_water_spaniel',
        'vizsla', 'Weimaraner', 'German_short-haired_pointer'
    ],
    
    # Segugi
    'hound': [
        'beagle', 'basset', 'bloodhound', 'bluetick', 'redbone',
        'Walker_hound', 'English_foxhound', 'borzoi', 'Irish_wolfhound',
        'Scottish_deerhound', 'whippet', 'Ibizan_hound', 'Afghan_hound',
        'saluki', 'otterhound', 'black-and-tan_coonhound', 'Rhodesian_ridgeback',
        'dingo', 'basenji', 'Norwegian_elkhound'
    ],
    
    # Terrier
    'terrier': [
        'Airedale', 'Bedlington_terrier', 'Border_terrier', 'Kerry_blue_terrier',
        'Irish_terrier', 'Norfolk_terrier', 'Norwich_terrier', 'Yorkshire_terrier',
        'wire-haired_fox_terrier', 'Lakeland_terrier', 'Sealyham_terrier',
        'Scottish_terrier', 'Tibetan_terrier', 'silky_terrier', 'wheaten_terrier',
        'West_Highland_white_terrier', 'Lhasa', 'cairn', 'Australian_terrier',
        'Dandie_Dinmont', 'Boston_bull', 'miniature_schnauzer', 'giant_schnauzer',
        'standard_schnauzer', 'soft-coated_wheaten_terrier'
    ],
    
    # Toy/piccola taglia
    'toy': [
        'Chihuahua', 'Japanese_spaniel', 'Maltese_dog', 'Pekinese',
        'Shih-Tzu', 'Blenheim_spaniel', 'papillon', 'toy_terrier',
        'miniature_pinscher', 'affenpinscher', 'toy_poodle', 'Pomeranian',
        'pug', 'Italian_greyhound'
    ],
    
    # Working dogs
    'working': [
        'Siberian_husky', 'Alaskan_malamute', 'Eskimo_dog', 'Saint_Bernard',
        'Greater_Swiss_Mountain_dog', 'Bernese_mountain_dog', 'Appenzeller',
        'EntleBucher', 'Rottweiler', 'Doberman', 'miniature_pinscher',
        'Great_Pyrenees', 'Leonberg', 'Newfoundland', 'Tibetan_mastiff'
    ],
    
    # Spitz
    'spitz': [
        'chow', 'keeshond', 'Pomeranian', 'Samoyed', 'schipperke',
        'Shiba_inu', 'Akita'
    ],
    
    # Bulldog
    'bulldog': [
        'French_bulldog', 'English_bulldog'
    ],
    
    # Poodle
    'poodle': [
        'standard_poodle', 'miniature_poodle', 'toy_poodle'
    ]
}

# Inverti il mapping per lookup veloce
BREED_TO_CATEGORY = {}
for category, breeds in BREED_MAPPING.items():
    for breed in breeds:
        BREED_TO_CATEGORY[breed.lower()] = category

# Lista categorie
CATEGORIES = list(BREED_MAPPING.keys()) + ['mixed', 'unknown']
CATEGORY_TO_IDX = {cat: idx for idx, cat in enumerate(CATEGORIES)}
IDX_TO_CATEGORY = {idx: cat for cat, idx in CATEGORY_TO_IDX.items()}

print(f"\nCategorie: {CATEGORIES}")
print(f"Numero categorie: {len(CATEGORIES)}")

## 2. Esplorazione Dataset

In [None]:
# Cerca il dataset Stanford Dogs
# Pu√≤ essere in formato diverso a seconda di come √® stato scaricato

def find_stanford_images(base_dir):
    """Trova le immagini del dataset Stanford Dogs"""
    images = []
    
    # Pattern 1: Kaggle format (Images/breed_name/image.jpg)
    images_dir = base_dir / "Images"
    if images_dir.exists():
        for breed_dir in images_dir.iterdir():
            if breed_dir.is_dir():
                breed_name = breed_dir.name.split('-')[-1] if '-' in breed_dir.name else breed_dir.name
                for img_path in breed_dir.glob('*.jpg'):
                    images.append((img_path, breed_name))
        return images
    
    # Pattern 2: Flat structure
    for img_path in base_dir.rglob('*.jpg'):
        # Estrai nome razza dal path
        breed_name = img_path.parent.name.split('-')[-1]
        images.append((img_path, breed_name))
    
    return images

# Prova a trovare le immagini
if STANFORD_DIR.exists():
    all_images = find_stanford_images(STANFORD_DIR)
    print(f"Trovate {len(all_images)} immagini")
    
    # Conta per razza
    breed_counts = defaultdict(int)
    for _, breed in all_images:
        breed_counts[breed] += 1
    
    print(f"\nRazze trovate: {len(breed_counts)}")
    print("\nTop 10 razze per numero immagini:")
    for breed, count in sorted(breed_counts.items(), key=lambda x: -x[1])[:10]:
        print(f"  {breed}: {count}")
else:
    print(f"Dataset non trovato in {STANFORD_DIR}")
    print("\nPer scaricare:")
    print("1. Kaggle: https://www.kaggle.com/datasets/jessicali9530/stanford-dogs-dataset")
    print("2. Originale: http://vision.stanford.edu/aditya86/ImageNetDogs/")

In [None]:
# Mappa le razze alle categorie
def map_breed_to_category(breed_name):
    """Mappa una razza alla sua macro-categoria"""
    breed_lower = breed_name.lower().replace('-', '_').replace(' ', '_')
    
    # Cerca match esatto
    if breed_lower in BREED_TO_CATEGORY:
        return BREED_TO_CATEGORY[breed_lower]
    
    # Cerca match parziale
    for key, category in BREED_TO_CATEGORY.items():
        if key in breed_lower or breed_lower in key:
            return category
    
    # Default a 'mixed' per razze non mappate
    return 'mixed'

# Test mapping
test_breeds = ['golden_retriever', 'German_shepherd', 'Chihuahua', 'pug', 'beagle']
print("Test mapping razze:")
for breed in test_breeds:
    category = map_breed_to_category(breed)
    print(f"  {breed} ‚Üí {category}")

In [None]:
# Crea dataset con categorie
if STANFORD_DIR.exists() and all_images:
    # Mappa immagini a categorie
    categorized_images = []
    category_counts = defaultdict(int)
    
    for img_path, breed in all_images:
        category = map_breed_to_category(breed)
        categorized_images.append((img_path, category))
        category_counts[category] += 1
    
    print("Distribuzione per categoria:")
    for cat, count in sorted(category_counts.items(), key=lambda x: -x[1]):
        prior = breed_priors.get(cat, 0.5)
        print(f"  {cat}: {count} immagini (P(stray)={prior})")
    
    # Visualizza distribuzione
    plt.figure(figsize=(12, 5))
    cats = list(category_counts.keys())
    counts = [category_counts[c] for c in cats]
    priors = [breed_priors.get(c, 0.5) for c in cats]
    
    plt.subplot(1, 2, 1)
    plt.bar(cats, counts, color='steelblue')
    plt.xticks(rotation=45, ha='right')
    plt.title('Immagini per Categoria')
    plt.ylabel('Numero immagini')
    
    plt.subplot(1, 2, 2)
    colors = plt.cm.RdYlGn_r(np.array(priors))
    plt.bar(cats, priors, color=colors)
    plt.xticks(rotation=45, ha='right')
    plt.title('P(stray|breed) per Categoria')
    plt.ylabel('Probabilit√† randagio')
    plt.axhline(y=0.5, color='gray', linestyle='--', alpha=0.7)
    
    plt.tight_layout()
    plt.savefig(OUTPUT_DIR.parent.parent / 'training' / 'notebooks' / 'breed_distribution.png', dpi=150)
    plt.show()

## 3. Preparazione Dataset

In [None]:
class BreedDataset(Dataset):
    """Dataset per classificazione razze canine (Albumentations)"""
    
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        
        # Carica immagine come numpy array per Albumentations
        try:
            image = Image.open(img_path).convert('RGB')
            image = np.array(image)
        except Exception as e:
            print(f"Errore caricamento {img_path}: {e}")
            # Ritorna immagine placeholder
            image = np.zeros((224, 224, 3), dtype=np.uint8)
        
        if self.transform:
            transformed = self.transform(image=image)
            image = transformed['image']
        
        return image, label

In [None]:
# Trasformazioni con Albumentations
IMG_SIZE = 224

train_transform = A.Compose([
    A.Resize(256, 256),
    A.RandomCrop(IMG_SIZE, IMG_SIZE),
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=15, p=0.5),
    A.ColorJitter(
        brightness=0.2,
        contrast=0.2,
        saturation=0.2,
        hue=0.05,
        p=0.5
    ),
    A.CoarseDropout(
        max_holes=8,
        max_height=16,
        max_width=16,
        p=0.3
    ),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
    ToTensorV2()
])

print("Trasformazioni Albumentations definite")

In [None]:
# Split dataset
if STANFORD_DIR.exists() and categorized_images:
    # Prepara liste
    all_paths = [str(img[0]) for img in categorized_images]
    all_labels = [CATEGORY_TO_IDX[img[1]] for img in categorized_images]
    
    # Split stratificato
    train_paths, temp_paths, train_labels, temp_labels = train_test_split(
        all_paths, all_labels, test_size=0.3, stratify=all_labels, random_state=42
    )
    
    val_paths, test_paths, val_labels, test_labels = train_test_split(
        temp_paths, temp_labels, test_size=0.5, stratify=temp_labels, random_state=42
    )
    
    print(f"Train: {len(train_paths)} immagini")
    print(f"Val: {len(val_paths)} immagini")
    print(f"Test: {len(test_paths)} immagini")
    
    # Crea dataset
    train_dataset = BreedDataset(train_paths, train_labels, train_transform)
    val_dataset = BreedDataset(val_paths, val_labels, val_transform)
    test_dataset = BreedDataset(test_paths, test_labels, val_transform)
    
    # DataLoaders
    # NOTA: num_workers=0 su macOS per evitare problemi di multiprocessing
    BATCH_SIZE = 32
    NUM_WORKERS = 0
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    
    print(f"\nBatch size: {BATCH_SIZE}")
    print(f"Train batches: {len(train_loader)}")
    print(f"Val batches: {len(val_loader)}")
    print(f"Num workers: {NUM_WORKERS}")
else:
    print("Dataset non disponibile - usando dati sintetici per demo")
    # Crea dati sintetici per test del codice
    train_loader = None
    val_loader = None
    test_loader = None

## 4. Definizione Modello

In [None]:
class BreedClassifier(nn.Module):
    """EfficientNet-B0 per classificazione razze"""
    
    def __init__(self, num_classes: int, pretrained: bool = True, dropout: float = 0.3):
        super().__init__()
        
        # Backbone EfficientNet-B0
        self.backbone = timm.create_model(
            'efficientnet_b0',
            pretrained=pretrained,
            num_classes=0  # Rimuove classifier
        )
        
        # Feature dimension
        self.feature_dim = self.backbone.num_features
        
        # Custom classifier
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(self.feature_dim, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(dropout),
            nn.Linear(256, num_classes)
        )
        
        self.num_classes = num_classes
    
    def forward(self, x):
        # Extract features
        features = self.backbone(x)
        # Classify
        logits = self.classifier(features)
        return logits
    
    def predict_proba(self, x):
        """Ritorna probabilit√† per ogni classe"""
        logits = self.forward(x)
        return torch.softmax(logits, dim=-1)

In [None]:
# Inizializza modello
NUM_CLASSES = len(CATEGORIES)

model = BreedClassifier(num_classes=NUM_CLASSES, pretrained=True, dropout=0.3)

# Device
if torch.backends.mps.is_available():
    device = torch.device('mps')
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

model = model.to(device)

print(f"Device: {device}")
print(f"Classi: {NUM_CLASSES}")
print(f"Feature dim: {model.feature_dim}")
print(f"\nParametri totali: {sum(p.numel() for p in model.parameters()):,}")
print(f"Parametri trainabili: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

## 5. Training

In [None]:
# Configurazione training
EPOCHS = 30
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4

# Loss con class weights per dataset sbilanciato
if train_loader:
    # Calcola pesi per classi
    class_counts = np.bincount(train_labels, minlength=NUM_CLASSES)
    class_weights = 1.0 / (class_counts + 1)  # +1 per evitare divisione per zero
    class_weights = class_weights / class_weights.sum() * NUM_CLASSES
    class_weights = torch.FloatTensor(class_weights).to(device)
    print("Class weights:", class_weights)
else:
    class_weights = None

criterion = nn.CrossEntropyLoss(weight=class_weights)

# Optimizer con learning rate differenziato
optimizer = optim.AdamW([
    {'params': model.backbone.parameters(), 'lr': LEARNING_RATE * 0.1},  # Backbone: LR basso
    {'params': model.classifier.parameters(), 'lr': LEARNING_RATE}       # Classifier: LR alto
], weight_decay=WEIGHT_DECAY)

# Scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

print(f"\nConfigurazione training:")
print(f"  Epochs: {EPOCHS}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Weight decay: {WEIGHT_DECAY}")

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    """Training per una epoch"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc='Training')
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100.*correct/total:.2f}%'})
    
    return total_loss / len(loader), 100. * correct / total


def validate(model, loader, criterion, device):
    """Validazione"""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Validation'):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return total_loss / len(loader), 100. * correct / total, all_preds, all_labels

In [None]:
# TRAINING LOOP
if train_loader:
    print("="*50)
    print("INIZIO TRAINING BREED CLASSIFIER")
    print("="*50)
    print(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print()
    
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': []
    }
    
    best_val_acc = 0
    patience = 10
    patience_counter = 0
    
    for epoch in range(EPOCHS):
        print(f"\n--- Epoch {epoch+1}/{EPOCHS} ---")
        
        # Training
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        
        # Validation
        val_loss, val_acc, _, _ = validate(model, val_loader, criterion, device)
        
        # Scheduler step
        scheduler.step()
        
        # Log
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'categories': CATEGORIES,
                'category_to_idx': CATEGORY_TO_IDX,
                'breed_priors': breed_priors
            }, OUTPUT_DIR / 'breed_classifier_best.pt')
            print(f"üíæ Nuovo best model salvato! (acc={val_acc:.2f}%)")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"\nEarly stopping after {epoch+1} epochs")
                break
    
    print("\n" + "="*50)
    print("TRAINING COMPLETATO!")
    print(f"Best validation accuracy: {best_val_acc:.2f}%")
    print("="*50)
else:
    print("Nessun dato disponibile per training")
    print("Saltando alla sezione di export modello placeholder...")

In [None]:
# Visualizza curve training
if train_loader and history['train_loss']:
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # Loss
    axes[0].plot(history['train_loss'], label='Train')
    axes[0].plot(history['val_loss'], label='Validation')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training Loss')
    axes[0].legend()
    
    # Accuracy
    axes[1].plot(history['train_acc'], label='Train')
    axes[1].plot(history['val_acc'], label='Validation')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].set_title('Training Accuracy')
    axes[1].legend()
    
    plt.tight_layout()
    plt.savefig(OUTPUT_DIR.parent.parent / 'training' / 'notebooks' / 'breed_training_curves.png', dpi=150)
    plt.show()

## 6. Valutazione su Test Set

In [None]:
if test_loader:
    # Carica best model (weights_only=False per PyTorch 2.6+)
    best_path = OUTPUT_DIR / 'breed_classifier_best.pt'
    if best_path.exists():
        checkpoint = torch.load(best_path, map_location=device, weights_only=False)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Best model caricato (epoch {checkpoint['epoch']}, acc={checkpoint['val_acc']:.2f}%)")
    
    # Valutazione
    test_loss, test_acc, test_preds, test_labels = validate(model, test_loader, criterion, device)
    
    print(f"\nTest Results:")
    print(f"  Loss: {test_loss:.4f}")
    print(f"  Accuracy: {test_acc:.2f}%")
    
    # Classification report
    print("\nClassification Report:")
    print(classification_report(test_labels, test_preds, target_names=CATEGORIES, zero_division=0))

In [None]:
# Confusion matrix
if test_loader:
    cm = confusion_matrix(test_labels, test_preds)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=CATEGORIES, yticklabels=CATEGORIES)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix - Breed Categories')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(OUTPUT_DIR.parent.parent / 'training' / 'notebooks' / 'breed_confusion_matrix.png', dpi=150)
    plt.show()

## 7. Export Modello per Produzione

In [None]:
# Salva modello finale per il backend
final_model_path = OUTPUT_DIR / 'breed_classifier.pt'

# Se abbiamo trainato, usa il best model
best_path = OUTPUT_DIR / 'breed_classifier_best.pt'
if best_path.exists():
    shutil.copy(best_path, final_model_path)
    print(f"Modello trainato copiato in: {final_model_path}")
else:
    # Salva modello pre-trained (non fine-tuned)
    torch.save({
        'model_state_dict': model.state_dict(),
        'categories': CATEGORIES,
        'category_to_idx': CATEGORY_TO_IDX,
        'idx_to_category': IDX_TO_CATEGORY,
        'breed_priors': breed_priors,
        'note': 'Modello non fine-tuned - usare euristiche'
    }, final_model_path)
    print(f"Modello placeholder salvato in: {final_model_path}")

# Salva anche il mapping
mapping_path = OUTPUT_DIR.parent.parent / 'data' / 'breed_mapping.json'
with open(mapping_path, 'w') as f:
    json.dump({
        'categories': CATEGORIES,
        'category_to_idx': CATEGORY_TO_IDX,
        'breed_mapping': BREED_MAPPING
    }, f, indent=2)
print(f"Mapping salvato in: {mapping_path}")

In [None]:
# Test modello esportato
print("\nTest modello esportato...")

# Carica (weights_only=False per PyTorch 2.6+)
checkpoint = torch.load(final_model_path, map_location=device, weights_only=False)

test_model = BreedClassifier(num_classes=len(checkpoint['categories']))
test_model.load_state_dict(checkpoint['model_state_dict'])
test_model.to(device)
test_model.eval()

# Test con immagine random
test_input = torch.randn(1, 3, 224, 224).to(device)
with torch.no_grad():
    output = test_model.predict_proba(test_input)

print(f"\nOutput shape: {output.shape}")
print(f"\nProbabilit√† per categoria (test random):")
for idx, prob in enumerate(output[0].cpu().numpy()):
    cat = checkpoint['categories'][idx]
    stray_prior = checkpoint['breed_priors'].get(cat, 0.5)
    print(f"  {cat}: {prob:.4f} (P(stray|breed)={stray_prior})")

In [None]:
# Funzione helper per ottenere P(stray|breed)
def get_stray_probability_from_breed(model, image_tensor, breed_priors, device):
    """
    Dato un'immagine, predice la categoria di razza e ritorna P(stray|breed)
    
    Returns:
        predicted_category: str - categoria predetta
        category_prob: float - confidenza nella predizione
        stray_prob: float - P(stray|breed) dalla tabella prior
    """
    model.eval()
    with torch.no_grad():
        image_tensor = image_tensor.to(device)
        probs = model.predict_proba(image_tensor)
        
        # Categoria pi√π probabile
        top_prob, top_idx = probs.max(dim=1)
        predicted_category = CATEGORIES[top_idx.item()]
        category_prob = top_prob.item()
        
        # Prior per la categoria
        stray_prob = breed_priors.get(predicted_category, 0.5)
        
        return predicted_category, category_prob, stray_prob

# Test
cat, conf, stray = get_stray_probability_from_breed(
    test_model, test_input, breed_priors, device
)
print(f"\nPredizione: {cat} (conf={conf:.2%})")
print(f"P(stray|{cat}) = {stray}")

In [None]:
# Riepilogo finale
print("\n" + "="*50)
print("RIEPILOGO BREED CLASSIFIER")
print("="*50)
print(f"\nArchitettura: EfficientNet-B0")
print(f"Categorie: {len(CATEGORIES)}")
for cat in CATEGORIES:
    prior = breed_priors.get(cat, 0.5)
    print(f"  - {cat}: P(stray)={prior}")

if train_loader and 'best_val_acc' in dir():
    print(f"\nTraining completato:")
    print(f"  Best validation accuracy: {best_val_acc:.2f}%")
    if 'test_acc' in dir():
        print(f"  Test accuracy: {test_acc:.2f}%")
else:
    print(f"\nModello non fine-tuned")
    print("  Il backend user√† euristiche fino al training")

print(f"\nFile salvati:")
print(f"  - {final_model_path}")
print(f"  - {mapping_path}")
