# 🧠 4. Entrenament del model – DWTFormer

Aquest notebook entrena el model `DWTFormer` sobre el dataset `PathMNIST` i avalua el seu rendiment amb mètriques com l’accuracy, F1-score, matriu de confusió i corba ROC.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import os

from src.data.dataset_loader import load_medmnist_dataset
from src.model.dwtformer import DWTFormer
from src.train.train_model import train
from src.train.evaluate import (
    evaluate_model,
    print_f1_per_class,
    plot_confusion,
    plot_multiclass_roc
)

# 🔧 Paràmetres globals
BATCH_SIZE = 128
NUM_EPOCHS = 10
LEARNING_RATE = 1e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NUM_CLASSES = {
    "pathmnist": 9,
    "chestmnist": 14,
    "organamnist": 11
}
DATASETS = ["pathmnist", "chestmnist", "organamnist"]


In [None]:
os.makedirs("model", exist_ok=True)
for ds in DATASETS:
    os.makedirs(f"annexos/metrics/{ds}", exist_ok=True)


In [None]:
# 🔁 Entrenament complet per a cada dataset
for dataset in DATASETS:
    print(f"\n📁 Processant {dataset.upper()}")

    if dataset == "chestmnist":
        criterion = nn.BCEWithLogitsLoss()
        multilabel = True
    else:
        criterion = nn.CrossEntropyLoss()
        multilabel = False

    train_loader, val_loader, test_loader = load_medmnist_dataset(dataset, batch_size=BATCH_SIZE)
    model = DWTFormer(num_classes=NUM_CLASSES[dataset]).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    print(f"🔍 Dataset: {dataset} | Multilabel: {multilabel} | Loss: {type(criterion).__name__}")

    history = train(
        model, train_loader, val_loader, criterion, optimizer,
        num_epochs=NUM_EPOCHS,
        device=DEVICE,
        save_path=f"model/dwtformer_{dataset}.pt",
        multilabel=multilabel
    )

    # 🧪 Avaluació
    test_loss, y_true, y_pred, y_scores = evaluate_model(
        model, test_loader, criterion, DEVICE, multilabel=multilabel
    )

    print(f"\n📊 Resultats per {dataset.upper()}")
    print(f"Test Loss: {test_loss:.4f}")
    print_f1_per_class(y_true, y_pred, class_names=[str(i) for i in range(NUM_CLASSES[dataset])])

    # 📊 Visualització només per multiclasse
    if not multilabel:
        fig1 = plt.figure()
        plot_confusion(y_true, y_pred, class_names=[str(i) for i in range(NUM_CLASSES[dataset])])
        fig1.savefig(f"annexos/metrics/{dataset}/confusion_matrix.png")
        plt.close(fig1)

        fig2 = plt.figure()
        plot_multiclass_roc(
            y_true, y_scores,
            n_classes=NUM_CLASSES[dataset],
            class_names=[str(i) for i in range(NUM_CLASSES[dataset])]
        )
        fig2.savefig(f"annexos/metrics/{dataset}/roc_curve_multiclass.png")
        plt.close(fig2)

    print(f"✅ Resultats guardats per {dataset}")