# 02 - Training Skin Disease Classifier

Fine-tuning di ResNet50 per la classificazione di patologie cutanee nei cani.

## Dataset
- **Dog's Skin Diseases** (Kaggle): 4,315 immagini, 6 classi
  - Healthy
  - Dermatitis
  - Fungal_infections
  - Hypersensitivity
  - Demodicosis
  - Ringworm

## Output
- `P(disease)` ∈ [0, 1] - probabilità di patologia cutanea

In [1]:
# Installazione dipendenze
%pip install torch torchvision timm albumentations matplotlib seaborn pandas scikit-learn tqdm -q


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [None]:
import os
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from tqdm.auto import tqdm
import json

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import timm

# Albumentations per data augmentation (come da documentazione)
import albumentations as A
from albumentations.pytorch import ToTensorV2

from sklearn.metrics import classification_report, confusion_matrix, f1_score, accuracy_score
from PIL import Image
import cv2

print(f"Python: {sys.version}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"MPS available: {torch.backends.mps.is_available()}")

In [None]:
# ============================================================================
# CONFIGURAZIONE HARDWARE - Multi-GPU Support
# ============================================================================
import torch

if torch.cuda.is_available():
    NUM_GPUS = torch.cuda.device_count()
    GPU_NAMES = [torch.cuda.get_device_name(i) for i in range(NUM_GPUS)]
    TOTAL_VRAM = sum(torch.cuda.get_device_properties(i).total_memory for i in range(NUM_GPUS)) / 1e9
    
    print(f"GPU rilevate: {NUM_GPUS}")
    for i, name in enumerate(GPU_NAMES):
        vram = torch.cuda.get_device_properties(i).total_memory / 1e9
        print(f"   [{i}] {name} ({vram:.1f}GB)")
    
    DEVICE = torch.device('cuda')
    USE_MULTI_GPU = NUM_GPUS >= 2
    
    if USE_MULTI_GPU:
        print(f"\n[OK] Multi-GPU attivo: {NUM_GPUS} GPU")
    
elif torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
    NUM_GPUS = 1
    USE_MULTI_GPU = False
    print("Apple Silicon MPS")
else:
    DEVICE = torch.device('cpu')
    NUM_GPUS = 0
    USE_MULTI_GPU = False
    print("[WARN] Nessuna GPU, usando CPU")

print(f"\nUsing device: {DEVICE}")

In [None]:
# Configurazione paths - RELATIVI per portabilità
import sys
sys.path.insert(0, str(Path.cwd()))
try:
    from notebook_utils import get_paths, get_device, print_paths
    paths = get_paths()
    print_paths(paths)
except ImportError:
    print("notebook_utils.py non trovato, usando fallback...")
    NOTEBOOK_DIR = Path.cwd()
    if NOTEBOOK_DIR.name == "notebooks":
        PROJECT_DIR = NOTEBOOK_DIR.parent.parent
    elif NOTEBOOK_DIR.name == "training":
        PROJECT_DIR = NOTEBOOK_DIR.parent
    else:
        PROJECT_DIR = NOTEBOOK_DIR
        while PROJECT_DIR.name != "ResQPet" and PROJECT_DIR.parent != PROJECT_DIR:
            PROJECT_DIR = PROJECT_DIR.parent
    BASE_DIR = PROJECT_DIR.parent
    paths = {
        'project_dir': PROJECT_DIR,
        'base_dir': BASE_DIR,
        'weights_dir': PROJECT_DIR / "backend" / "weights",
        'skin_dataset': BASE_DIR / "Dog's skin diseases",
        'notebooks_dir': PROJECT_DIR / "training" / "notebooks",
    }
    paths['weights_dir'].mkdir(parents=True, exist_ok=True)

# Assegna variabili per retrocompatibilità
BASE_DIR = paths['base_dir']
DATASET_DIR = paths['skin_dataset']
OUTPUT_DIR = paths['weights_dir']

print(f"\nDataset: {DATASET_DIR}")
print(f"Output: {OUTPUT_DIR}")
print(f"Dataset exists: {DATASET_DIR.exists()}")

## 1. Esplorazione Dataset

In [5]:
# Esplora struttura dataset
def explore_dataset(dataset_dir):
    stats = {}
    
    for split in ['train', 'valid', 'test']:
        split_dir = dataset_dir / split
        if not split_dir.exists():
            continue
            
        stats[split] = {}
        
        for class_dir in split_dir.iterdir():
            if class_dir.is_dir():
                images = list(class_dir.glob('*.*'))
                stats[split][class_dir.name] = len(images)
    
    return stats

dataset_stats = explore_dataset(DATASET_DIR)

print("Dataset Statistics:")
for split, classes in dataset_stats.items():
    print(f"\n{split.upper()}:")
    total = 0
    for cls, count in classes.items():
        print(f"  {cls}: {count}")
        total += count
    print(f"  TOTAL: {total}")

Dataset Statistics:

TRAIN:
  Healthy: 492
  Fungal_infections: 375
  ringworm: 791
  demodicosis: 588
  Hypersensitivity: 230
  Dermatitis: 546
  TOTAL: 3022

VALID:
  Healthy: 139
  Fungal_infections: 97
  ringworm: 212
  demodicosis: 174
  Hypersensitivity: 63
  Dermatitis: 175
  TOTAL: 860

TEST:
  Healthy: 69
  Fungal_infections: 54
  ringworm: 115
  demodicosis: 100
  Hypersensitivity: 29
  Dermatitis: 66
  TOTAL: 433


In [None]:
# Definisci classi e mapping - ORDINE ESPLICITO come da documentazione
# IMPORTANTE: Healthy DEVE essere classe 0 per il calcolo P(disease) = 1 - P(Healthy)
CLASS_NAMES = ['Healthy', 'Dermatitis', 'Fungal_infections', 'Hypersensitivity', 'demodicosis', 'ringworm']
NUM_CLASSES = len(CLASS_NAMES)
CLASS_TO_IDX = {cls: idx for idx, cls in enumerate(CLASS_NAMES)}
IDX_TO_CLASS = {idx: cls for cls, idx in CLASS_TO_IDX.items()}

print(f"Classi (ordine fisso): {CLASS_NAMES}")
print(f"Numero classi: {NUM_CLASSES}")
print(f"Mapping: {CLASS_TO_IDX}")

# Disease severity (per calcolo P(disease))
DISEASE_WEIGHTS = {
    'Healthy': 0.0,
    'Dermatitis': 0.6,
    'Fungal_infections': 0.7,
    'Hypersensitivity': 0.4,
    'demodicosis': 0.8,
    'ringworm': 0.75
}

print(f"\nDisease weights: {DISEASE_WEIGHTS}")

In [None]:
# Visualizza distribuzione classi
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for idx, (split, classes) in enumerate(dataset_stats.items()):
    ax = axes[idx]
    ax.bar(classes.keys(), classes.values())
    ax.set_title(f'{split.upper()} Distribution')
    ax.set_xlabel('Class')
    ax.set_ylabel('Count')
    ax.tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.savefig(paths['notebooks_dir'] / 'skin_class_distribution.png', dpi=150)
plt.show()

In [None]:
# Visualizza esempi per classe
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for idx, class_name in enumerate(CLASS_NAMES[:6]):
    class_dir = DATASET_DIR / 'train' / class_name
    if class_dir.exists():
        images = list(class_dir.glob('*.*'))[:1]
        if images:
            img = cv2.imread(str(images[0]))
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            axes[idx].imshow(img)
            axes[idx].set_title(class_name)
    axes[idx].axis('off')

plt.tight_layout()
plt.savefig(paths['notebooks_dir'] / 'skin_samples.png', dpi=150)
plt.show()

## 2. Dataset e DataLoader

In [None]:
class SkinDiseaseDataset(Dataset):
    """Dataset per classificazione patologie cutanee - compatibile con Albumentations"""
    
    def __init__(self, root_dir, split='train', transform=None):
        self.root_dir = Path(root_dir) / split
        self.transform = transform
        self.samples = []
        
        # Carica tutti i campioni
        for class_dir in self.root_dir.iterdir():
            if class_dir.is_dir() and class_dir.name in CLASS_TO_IDX:
                class_idx = CLASS_TO_IDX[class_dir.name]
                for img_path in class_dir.glob('*.*'):
                    if img_path.suffix.lower() in ['.jpg', '.jpeg', '.png']:
                        self.samples.append((str(img_path), class_idx))
        
        print(f"Loaded {len(self.samples)} samples from {split}")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        # Carica immagine come numpy array (per Albumentations)
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            # Albumentations usa dizionario con chiave 'image'
            transformed = self.transform(image=image)
            image = transformed['image']
        
        return image, label

In [None]:
# Data transforms con Albumentations (come da documentazione)
IMG_SIZE = 224

train_transform = A.Compose([
    A.Resize(256, 256),
    A.RandomCrop(IMG_SIZE, IMG_SIZE),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.3),
    A.Rotate(limit=30, p=0.5),
    A.ColorJitter(
        brightness=0.3,
        contrast=0.3,
        saturation=0.3,
        hue=0.1,
        p=0.5
    ),
    A.GaussNoise(var_limit=(10, 50), p=0.3),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
    ToTensorV2()
])

print("Transforms Albumentations definiti (come da documentazione)")

In [None]:
# Crea dataset
train_dataset = SkinDiseaseDataset(DATASET_DIR, split='train', transform=train_transform)
val_dataset = SkinDiseaseDataset(DATASET_DIR, split='valid', transform=val_transform)
test_dataset = SkinDiseaseDataset(DATASET_DIR, split='test', transform=val_transform)

# ============================================================================
# DataLoaders - Batch size scalato per multi-GPU
# NOTA: NUM_WORKERS = 0 per evitare problemi di multiprocessing nei container
# ============================================================================
if torch.cuda.is_available() and NUM_GPUS >= 2:
    BATCH_SIZE = 64 * NUM_GPUS  # 64 per GPU
    print(f"[OK] Multi-GPU: batch_size={BATCH_SIZE} ({64} x {NUM_GPUS} GPU)")
elif torch.cuda.is_available():
    BATCH_SIZE = 64
else:
    BATCH_SIZE = 32
# IMPORTANTE: num_workers=0 evita errori multiprocessing in container/server
NUM_WORKERS = 0

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

print(f"\nTrain batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")
print(f"Batch size: {BATCH_SIZE}, Workers: {NUM_WORKERS}")

## 3. Modello ResNet50

In [None]:
# Crea modello con timm
def create_model(num_classes, pretrained=True):
    """Crea ResNet50 con classification head custom"""
    model = timm.create_model(
        'resnet50',
        pretrained=pretrained,
        num_classes=num_classes
    )
    return model

def freeze_backbone(model):
    """Congela il backbone, lascia trainabile solo il classifier head"""
    # Se DataParallel, accedi al modulo interno
    m = model.module if hasattr(model, 'module') else model
    for name, param in m.named_parameters():
        if 'fc' not in name:  # 'fc' e il classification head in ResNet
            param.requires_grad = False
    print("Backbone congelato - solo classifier head trainabile")

def unfreeze_backbone(model):
    """Sblocca tutti i parametri per fine-tuning"""
    m = model.module if hasattr(model, 'module') else model
    for param in m.parameters():
        param.requires_grad = True
    print("Backbone sbloccato - fine-tuning completo")

def count_trainable_params(model):
    """Conta parametri trainabili"""
    m = model.module if hasattr(model, 'module') else model
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

# Crea modello
model = create_model(NUM_CLASSES, pretrained=True)
model = model.to(DEVICE)

# ============================================================================
# Multi-GPU: DataParallel wrapper
# ============================================================================
if USE_MULTI_GPU:
    model = nn.DataParallel(model)
    print(f"[OK] DataParallel attivo su {NUM_GPUS} GPU")

print(f"\nModel created: ResNet50")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable: {count_trainable_params(model):,}")

## 4. Training

In [None]:
# Training configuration (come da documentazione)
EPOCHS = 50
LEARNING_RATE = 0.001       # LR alto per fase 1 (backbone frozen)
FINE_TUNE_LR = 1e-5         # LR basso per fase 2 (fine-tuning)
WEIGHT_DECAY = 1e-4
PATIENCE = 10

# Freeze strategy (come da documentazione)
FREEZE_BACKBONE = True      # Fase 1: backbone congelato
UNFREEZE_EPOCH = 10         # Sblocca dopo 10 epoche

print(f"Training config:")
print(f"  - Epochs: {EPOCHS}")
print(f"  - Fase 1 (epochs 1-{UNFREEZE_EPOCH}): backbone frozen, LR={LEARNING_RATE}")
print(f"  - Fase 2 (epochs {UNFREEZE_EPOCH+1}-{EPOCHS}): fine-tuning, LR={FINE_TUNE_LR}")

# Class weights per dataset sbilanciato
train_counts = [sum(1 for s in train_dataset.samples if s[1] == i) for i in range(NUM_CLASSES)]
class_weights = torch.FloatTensor([max(train_counts) / c for c in train_counts]).to(DEVICE)
print(f"\nClass weights: {class_weights}")

# Loss
criterion = nn.CrossEntropyLoss(weight=class_weights)

# Congela backbone per fase 1
if FREEZE_BACKBONE:
    freeze_backbone(model)
    print(f"\nParametri trainabili (fase 1): {count_trainable_params(model):,}")

# Optimizer - solo parametri trainabili
optimizer = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()), 
    lr=LEARNING_RATE, 
    weight_decay=WEIGHT_DECAY
)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

In [14]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for images, labels in tqdm(loader, desc='Training'):
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    return total_loss / len(loader), correct / total


def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Evaluating'):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro')
    
    return total_loss / len(loader), accuracy, f1, all_preds, all_labels

In [None]:
# Training loop con two-phase freeze/unfreeze strategy
print("="*50)
print("INIZIO TRAINING SKIN CLASSIFIER")
print("="*50)
print(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"\nStrategia: Two-phase training")
print(f"  Fase 1 (epoch 1-{UNFREEZE_EPOCH}): Backbone frozen, LR={LEARNING_RATE}")
print(f"  Fase 2 (epoch {UNFREEZE_EPOCH+1}+): Fine-tuning completo, LR={FINE_TUNE_LR}")
print()

history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': [], 'val_f1': []
}

best_f1 = 0
patience_counter = 0
best_model_path = OUTPUT_DIR / 'skin_classifier_best.pt'

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    
    # ===== FASE 2: Unfreeze backbone dopo UNFREEZE_EPOCH =====
    if epoch == UNFREEZE_EPOCH:
        print("\n" + "="*40)
        print("FASE 2: Fine-tuning completo")
        print("="*40)
        
        # Sblocca backbone
        unfreeze_backbone(model)
        print(f"Parametri trainabili: {count_trainable_params(model):,}")
        
        # Nuovo optimizer con LR ridotto per fine-tuning
        optimizer = optim.AdamW(
            model.parameters(), 
            lr=FINE_TUNE_LR, 
            weight_decay=WEIGHT_DECAY
        )
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
        print(f"Nuovo LR: {FINE_TUNE_LR}")
        print()
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
    
    # Validate
    val_loss, val_acc, val_f1, _, _ = evaluate(model, val_loader, criterion, DEVICE)
    
    # Update scheduler
    scheduler.step(val_loss)
    
    # Log
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_f1'].append(val_f1)
    
    phase = "Fase 1 (frozen)" if epoch < UNFREEZE_EPOCH else "Fase 2 (fine-tune)"
    print(f"  [{phase}]")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}")
    print(f"  LR: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Save best model - GESTISCE DATAPARALLEL
    if val_f1 > best_f1:
        best_f1 = val_f1
        patience_counter = 0
        # Estrai state_dict senza prefisso 'module.' per compatibilità
        model_state = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
        torch.save({
            'epoch': epoch,
            'model_state_dict': model_state,
            'optimizer_state_dict': optimizer.state_dict(),
            'val_f1': val_f1,
            'val_acc': val_acc,
            'class_names': CLASS_NAMES,
            'num_classes': NUM_CLASSES
        }, best_model_path)
        print(f"  ✓ Best model saved (F1: {val_f1:.4f})")
    else:
        patience_counter += 1
        # Early stopping solo dopo fase 2
        if patience_counter >= PATIENCE and epoch >= UNFREEZE_EPOCH:
            print(f"\nEarly stopping at epoch {epoch+1}")
            break

print("\n" + "="*50)
print("TRAINING COMPLETATO!")
print("="*50)

## 5. Valutazione Finale

In [None]:
# Carica best model
# NOTA: weights_only=False necessario per PyTorch 2.6+
checkpoint = torch.load(best_model_path, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Best model loaded from epoch {checkpoint['epoch']+1}")
print(f"Val F1: {checkpoint['val_f1']:.4f}")

In [None]:
# Test finale
test_loss, test_acc, test_f1, test_preds, test_labels = evaluate(
    model, test_loader, criterion, DEVICE
)

print(f"\nTest Results:")
print(f"  Loss: {test_loss:.4f}")
print(f"  Accuracy: {test_acc:.4f}")
print(f"  F1 Score: {test_f1:.4f}")

In [None]:
# Classification report
print("\nClassification Report:")
print(classification_report(test_labels, test_preds, target_names=CLASS_NAMES))

In [None]:
# Confusion matrix
cm = confusion_matrix(test_labels, test_preds)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.tight_layout()
plt.savefig(paths['notebooks_dir'] / 'skin_confusion_matrix.png', dpi=150)
plt.show()

In [None]:
# Training curves
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss
axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Val')
axes[0].set_title('Loss')
axes[0].legend()

# Accuracy
axes[1].plot(history['train_acc'], label='Train')
axes[1].plot(history['val_acc'], label='Val')
axes[1].set_title('Accuracy')
axes[1].legend()

# F1
axes[2].plot(history['val_f1'], label='Val F1')
axes[2].set_title('Validation F1 Score')
axes[2].legend()

plt.tight_layout()
plt.savefig(paths['notebooks_dir'] / 'skin_training_curves.png', dpi=150)
plt.show()

## 6. Export Modello

In [None]:
# Salva modello finale - GESTISCE DATAPARALLEL
final_model_path = OUTPUT_DIR / 'skin_classifier.pt'

# Estrai state_dict senza prefisso 'module.' per compatibilità
model_state = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()

torch.save({
    'model_state_dict': model_state,
    'class_names': CLASS_NAMES,
    'num_classes': NUM_CLASSES,
    'disease_weights': DISEASE_WEIGHTS,
    'test_accuracy': test_acc,
    'test_f1': test_f1
}, final_model_path)

print(f"Modello salvato in: {final_model_path}")
print(f"Dimensione: {final_model_path.stat().st_size / 1024 / 1024:.2f} MB")

In [None]:
# Test del modello esportato
print("\nTest modello esportato...")

# Carica modello (weights_only=False per PyTorch 2.6+)
loaded_checkpoint = torch.load(final_model_path, weights_only=False)
test_model = create_model(loaded_checkpoint['num_classes'], pretrained=False)
test_model.load_state_dict(loaded_checkpoint['model_state_dict'])
test_model.eval()

# Test su un'immagine
test_img_path = list((DATASET_DIR / 'test' / CLASS_NAMES[0]).glob('*.*'))[0]
test_img = Image.open(test_img_path).convert('RGB')
test_tensor = val_transform(test_img).unsqueeze(0)

with torch.no_grad():
    output = test_model(test_tensor)
    probs = torch.softmax(output, dim=1)[0]
    pred_idx = probs.argmax().item()
    pred_class = CLASS_NAMES[pred_idx]
    pred_conf = probs[pred_idx].item()

print(f"\nTest image: {test_img_path.name}")
print(f"Prediction: {pred_class} ({pred_conf:.2%})")
print(f"\nProbabilità per classe:")
for i, cls in enumerate(CLASS_NAMES):
    print(f"  {cls}: {probs[i].item():.2%}")

In [None]:
# Riepilogo finale
print("\n" + "="*50)
print("RIEPILOGO TRAINING SKIN CLASSIFIER")
print("="*50)
print(f"\nDataset: {DATASET_DIR}")
print(f"  - Classi: {NUM_CLASSES}")
print(f"  - Train samples: {len(train_dataset)}")
print(f"  - Val samples: {len(val_dataset)}")
print(f"  - Test samples: {len(test_dataset)}")
print(f"\nTraining:")
print(f"  - Epochs: {len(history['train_loss'])}")
print(f"  - Best epoch: {checkpoint['epoch']+1}")
print(f"\nRisultati Test:")
print(f"  - Accuracy: {test_acc:.4f}")
print(f"  - F1 Score: {test_f1:.4f}")
print(f"\nModello salvato: {final_model_path}")
print("\nProssimo step: 03_pose_training.ipynb")