In [None]:
import os
import sys
import json
import torch
import random
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torch.utils.data import DataLoader, Subset

# Percorsi locali ai moduli
sys.path.append("src")
from src.dataset_loader import get_dataloader
from src.metrics import compute_metrics

# ✅ (Opzionale) attiva logging su wandb
try:
    import wandb
    USE_WANDB = True
    wandb.init(project="IDC-binary-classification", name="resnet18_run")
except ImportError:
    USE_WANDB = False

# ⚙️ Configurazioni generali
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"💻 Using device: {device}")

# 📦 Caricamento dataset
train_path = "datasets/dataset_prepared/train"
dataset, _ = get_dataloader(data_dir=train_path, batch_size=32)

# 🔁 Usa subset per debug (opzionale)
USE_SUBSET = False
if USE_SUBSET:
    indices = random.sample(range(len(dataset)), int(len(dataset) * 0.1))
    dataset = Subset(dataset, indices)

# 🔄 Dataloader ottimizzato
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True)

# 🧠 Modello ResNet18
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)
model = model.to(device)

# 🔧 Loss e ottimizzatore
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# 📊 Tracciamento metriche
history = {k: [] for k in ["loss", "accuracy", "precision", "recall", "f1"]}
best_f1 = 0.0

# 🔁 Training loop
EPOCHS = 10
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    y_true, y_pred = [], []

    loop = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False)

    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)

        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())

        loop.set_postfix(loss=loss.item())

    # 📊 Metriche epoca
    avg_loss = running_loss / len(dataloader)
    metrics = compute_metrics(y_true, y_pred)
    history["loss"].append(avg_loss)
    for k in metrics:
        history[k].append(metrics[k])

    print(f"📈 Epoch {epoch+1}/{EPOCHS} - Loss: {avg_loss:.4f} | Acc: {metrics['accuracy']:.4f} | F1: {metrics['f1']:.4f}")

    # 💾 Salva best model
    if metrics["f1"] > best_f1:
        best_f1 = metrics["f1"]
        torch.save(model.state_dict(), "results/best_model.pth")
        print(f"💾 Best model salvato (F1: {best_f1:.4f})")

    # 🚀 wandb logging
    if USE_WANDB:
        wandb.log({"epoch": epoch + 1, "loss": avg_loss, **metrics})

# 📁 Crea cartella risultati se non esiste
os.makedirs("results", exist_ok=True)

# 💾 Salva modello finale
torch.save(model.state_dict(), "results/histology_model.pth")

# 💾 Salva storico metriche
with open("results/train_history.json", "w") as f:
    json.dump(history, f, indent=4)

# 📈 Plot delle metriche
plt.figure(figsize=(10, 6))
for k in ["accuracy", "precision", "recall", "f1"]:
    plt.plot(history[k], label=k)

plt.xlabel("Epoch")
plt.ylabel("Score")
plt.title("📈 Andamento metriche durante il training")
plt.legend()
plt.grid()
plt.savefig("results/training_metrics_curve.png")
plt.show()

print("✅ Training completato e file salvati in 'results/'")
