In [None]:
# 🔁 Caricamento moduli
import sys
sys.path.append("../src")

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
import os
import matplotlib.pyplot as plt
from src.dataset_loader import get_dataloader
from src.metrics import compute_metrics
import json


# ⚙️ Impostazioni
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"💻 Using device: {device}")

# 📦 Dataloader
train_path = "/content/dataset_prepared/train" if "COLAB_GPU" in os.environ else "datasets/dataset_prepared/train"
dataset, dataloader = get_dataloader(data_dir=train_path, batch_size=32)

# 🧠 Modello ResNet18 pre-addestrato
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)  # 2 classi: IDC / non-IDC
model = model.to(device)

# 🔧 Loss & optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# 📊 Inizializza storico metriche
history = {
    "loss": [],
    "accuracy": [],
    "precision": [],
    "recall": [],
    "f1": []
}

# 🔁 Training loop
EPOCHS = 10
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    y_true, y_pred = [], []

    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Salva predizioni per metriche
        _, preds = torch.max(outputs, 1)
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())

    # Calcola metriche
    avg_loss = running_loss / len(dataloader)
    metrics = compute_metrics(y_true, y_pred)

    # Salva nello storico
    history["loss"].append(avg_loss)
    for k in metrics:
        history[k].append(metrics[k])

    print(f"📈 Epoch {epoch+1}/{EPOCHS} - Loss: {avg_loss:.4f} | Acc: {metrics['accuracy']:.4f} | F1: {metrics['f1']:.4f}")

# 💾 Salva il modello
os.makedirs("results", exist_ok=True)
torch.save(model.state_dict(), "results/histology_model.pth")
print("✅ Model saved to results/histology_model.pth")

# 📉 Plot finale delle metriche
plt.figure(figsize=(10, 6))
for k in ["accuracy", "precision", "recall", "f1"]:
    plt.plot(history[k], label=k)

plt.xlabel("Epoch")
plt.ylabel("Score")
plt.title("📈 Andamento metriche durante il training")
plt.legend()
plt.grid()
plt.savefig("results/training_metrics_curve.png")
plt.show()


# 💾 Salva lo storico delle metriche
with open("results/train_history.json", "w") as f:
    json.dump(history, f, indent=4)

print("✅ Storico delle metriche salvato in results/train_history.json")
