In [None]:
# ===========================================
# 03_train_dose_unet.ipynb
# Entrenamiento inicial de U-Net 3D para predicción de dosis
# ===========================================

import os
import sys

# Añadimos la carpeta src al path para importar nuestros módulos
sys.path.append("../src")

import numpy as np
import torch
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch.optim as optim

from dataset import DoseDataset
from models import UNet3D

# -------------------------------------------
# CONFIGURACIÓN BÁSICA
# -------------------------------------------

data_dir = "../data_processed"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 1          # empieza con 1 por memoria; luego vemos si se puede subir
num_epochs = 20         # puedes ajustar
learning_rate = 1e-4

print("Data dir:", data_dir)
print("Device:", device)

# -------------------------------------------
# CREAR DATASET Y SPLIT TRAIN/VAL
# -------------------------------------------

full_dataset = DoseDataset(data_dir=data_dir)

dataset_size = len(full_dataset)
print("Total de pacientes en dataset:", dataset_size)

if dataset_size < 3:
    print("⚠️ Tienes menos de 3 pacientes; el split train/val será muy limitado.")

# Definimos proporciones (80% train, 20% val)
val_frac = 0.2
val_size = max(1, int(dataset_size * val_frac))
train_size = dataset_size - val_size

train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

print(f"Train size: {train_size}, Val size: {val_size}")

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

# -------------------------------------------
# CREAR MODELO, PÉRDIDA Y OPTIMIZADOR
# -------------------------------------------

# Obtenemos número de canales de entrada de un sample de ejemplo
X_example, Y_example, pid_example = full_dataset[0]
n_channels_in = X_example.shape[0]   # C

print(f"Ejemplo de entrada '{pid_example}': X shape = {X_example.shape}, Y shape = {Y_example.shape}")
print("Canales de entrada:", n_channels_in)

model = UNet3D(n_channels=n_channels_in, n_classes=1, base_filters=16)
model = model.to(device)

criterion = nn.MSELoss()  # pérdida voxel-wise simple por ahora
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

print("\nModelo creado:")
print(model)

# -------------------------------------------
# LOOP DE ENTRENAMIENTO
# -------------------------------------------

def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0

    for X, Y, pid in loader:
        X = X.to(device)  # [B, C, Z, Y, X]
        Y = Y.to(device)  # [B, 1, Z, Y, X]

        optimizer.zero_grad()
        Y_pred = model(X)
        loss = criterion(Y_pred, Y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * X.size(0)

    epoch_loss = running_loss / len(loader.dataset)
    return epoch_loss


def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0

    with torch.no_grad():
        for X, Y, pid in loader:
            X = X.to(device)
            Y = Y.to(device)

            Y_pred = model(X)
            loss = criterion(Y_pred, Y)
            running_loss += loss.item() * X.size(0)

    epoch_loss = running_loss / len(loader.dataset)
    return epoch_loss


best_val_loss = np.inf
history = {"train_loss": [], "val_loss": []}

for epoch in range(1, num_epochs + 1):
    print(f"\nEpoch {epoch}/{num_epochs}")
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss = evaluate(model, val_loader, criterion, device)

    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)

    print(f"  Train loss: {train_loss:.6f}")
    print(f"  Val   loss: {val_loss:.6f}")

    # Guardar mejor modelo
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        os.makedirs("../models", exist_ok=True)
        best_model_path = "../models/unet3d_dose_best.pth"
        torch.save(model.state_dict(), best_model_path)
        print(f"  ✅ Nuevo mejor modelo guardado en {best_model_path} (val_loss={val_loss:.6f})")

# -------------------------------------------
# GRAFICAR CURVAS DE PÉRDIDA
# -------------------------------------------

import matplotlib.pyplot as plt

plt.figure(figsize=(6,4))
plt.plot(history["train_loss"], label="Train")
plt.plot(history["val_loss"], label="Val")
plt.xlabel("Epoch")
plt.ylabel("Loss (MSE)")
plt.title("Curvas de pérdida")
plt.legend()
plt.grid(True)
plt.show()

print("\nEntrenamiento terminado.")
print("Mejor val_loss:", best_val_loss)
