In [None]:
import os
from collections import OrderedDict
from glob import glob
import pandas as pd
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import yaml
import albumentations as A
from albumentations.core.composition import Compose
from torch.optim import lr_scheduler
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import copy
from torch.utils.data import DataLoader, Sampler, Subset
from torch.profiler import profile, record_function, ProfilerActivity
import random

In [None]:
import archs
import losses
from dataset import Dataset
from metrics import iou_score
from utils import AverageMeter

In [None]:
plt.style.use('default')

print(torch.__file__)
print(torch.__version__)

print("Librerías importadas correctamente")
print(f"CUDA disponible: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
class LabelSmoothingBCELoss(nn.Module):
  def __init__(self, smoothing=0.05):
    super().__init__()
    self.smoothing = smoothing
      
  def forward(self, pred, target):
    target_smooth = target * (1 - self.smoothing) + 0.5 * self.smoothing
    return F.binary_cross_entropy(pred, target_smooth)

In [None]:
class BinaryDiceLoss(nn.Module):
  def __init__(self, smooth=1.0):
    super().__init__()
    self.smooth = smooth
      
  def forward(self, pred, target):
    pred = pred.view(-1)
    target = target.view(-1)
    
    intersection = (pred * target).sum()
    dice = (2. * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth)
    
    return 1 - dice

In [None]:
class TverskyLoss(nn.Module):
  def __init__(self, alpha=0.5, beta=0.5, smooth=1):
    super().__init__()
    self.alpha = alpha
    self.beta = beta
    self.smooth = smooth

  def forward(self, pred, target):
    pred = torch.sigmoid(pred)
    pred_flat = pred.view(-1)
    target_flat = target.view(-1)

    TP = (pred_flat * target_flat).sum()
    FP = ((1 - target_flat) * pred_flat).sum()
    FN = (target_flat * (1 - pred_flat)).sum()

    tversky = (TP + self.smooth) / (TP + self.alpha * FP + self.beta * FN + self.smooth)
    return 1 - tversky

In [None]:
class BCEDiceLoss(nn.Module):
  def __init__(self, bce_weight=0.5, dice_weight=0.5, smooth=1.0):
    super().__init__()
    self.bce_weight = bce_weight
    self.dice_weight = dice_weight
    self.bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([3.0]))
    self.dice = BinaryDiceLoss(smooth=smooth)
      
  def forward(self, pred, target):
    bce_loss = self.bce(pred, target)
    
    pred_probs = torch.sigmoid(pred)
    dice_loss = self.dice(pred_probs, target)
    
    return self.bce_weight * bce_loss + self.dice_weight * dice_loss

print("Clases de pérdida definidas correctamente")

In [None]:
class EarlyStopping:
  def __init__(self, patience=10, min_delta=0.001, restore_best_weights=True):
    self.patience = patience
    self.min_delta = min_delta
    self.restore_best_weights = restore_best_weights
    self.best_score = None
    self.counter = 0
    self.best_weights = None
      
  def __call__(self, val_score, model):
    if self.best_score is None:
      self.best_score = val_score
      self.save_checkpoint(model)
    elif val_score < self.best_score + self.min_delta:
      self.counter += 1
      if self.counter >= self.patience:
        if self.restore_best_weights:
          model.load_state_dict(self.best_weights)
        return True
    else:
      self.best_score = val_score
      self.save_checkpoint(model)
      self.counter = 0
    return False
  
  def save_checkpoint(self, model):
    self.best_weights = copy.deepcopy(model.state_dict())

print("Early Stopping definido correctamente")

In [None]:
config = {
  'name': None,
  'epochs': 100,
  'batch_size': 4,
  'epoch_subset_size': 10000,
  
  # Modelo
  'arch': 'UNet',
  'deep_supervision': False,
  'input_channels': 1,
  'input_w': 128,
  'input_h': 128,
  
  # Pérdida
  'loss': 'TverskyLoss',
  
  # Dataset
  'dataset': 'LUNA16',
  'img_ext': '.png',
  'mask_ext': '.png',
  
  # Optimizador
  'optimizer': 'Adam',
  'lr': 5e-4,
  'momentum': 0.9,
  'weight_decay': 5e-4,
  'nesterov': False,
  
  # Scheduler
  'scheduler': 'OneCycleLR',
  'min_lr': 1e-6,
  'factor': 0.5,
  'patience': 5,
  'milestones': '30,60,90',
  'gamma': 0.5,
  'early_stopping': 10,
  
  # Anti-overfitting
  'dropout_rate': 0.5,
  'label_smoothing': 0.0,
  'validation_split': 0.2,
  'accumulation_steps': 2,
  'multiscale_validation': True,
  
  'num_workers': 4,
}

In [None]:
if config['name'] is None:
  if config['deep_supervision']:
    config['name'] = f"{config['dataset']}_{config['arch']}_binary_wDS"
  else:
    config['name'] = f"{config['dataset']}_{config['arch']}_binary_woDS"

# Directorio para guardar el modelo
print(config['name'])
os.makedirs(f"models/{config['name']}", exist_ok=True)

In [None]:
print("CONFIGURACIÓN DEL ENTRENAMIENTO:")
print("-" * 50)
for key, value in config.items():
  print(f"{key:25}: {value}")
print("-" * 50)

# Guardar configuración
with open(f"models/{config['name']}/config.yml", 'w') as f:
  yaml.dump(config, f)

print(f"Configuración guardada en models/{config['name']}/config.yml")

In [None]:
if config['loss'] == 'BCELoss':
  if config['label_smoothing'] > 0:
    criterion = LabelSmoothingBCELoss(config['label_smoothing']).cuda()
  else:
    criterion = nn.BCELoss().cuda()
elif config['loss'] == 'BCEDiceLoss':
  criterion = BCEDiceLoss().cuda()
elif config['loss'] == 'BinaryDiceLoss':
  criterion = BinaryDiceLoss().cuda()
elif config['loss'] == 'BCEWithLogitsLoss':
  criterion = nn.BCEWithLogitsLoss().cuda()
elif config['loss'] == 'TverskyLoss':
  criterion = TverskyLoss(alpha=0.7, beta=0.3).cuda()
else:
  criterion = losses.__dict__[config['loss']]().cuda()

cudnn.benchmark = True

In [None]:
print(f"Creando modelo {config['arch']} para segmentación binaria...")

if config['arch'] == 'NestedUNet':
  model = archs.__dict__[config['arch']](
    input_channels=config['input_channels'],
    deep_supervision=config['deep_supervision'],
    dropout_rate=config['dropout_rate']
  )
else:
  model = archs.__dict__[config['arch']](
    num_classes=1,
    input_channels=config['input_channels'],
    deep_supervision=config['deep_supervision']
  )

model = model.cuda()

In [None]:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Modelo creado exitosamente")
print(f"Parámetros totales: {total_params:,}")
print(f"Parámetros entrenables: {trainable_params:,}")
print(f"Memoria estimada: ~{(total_params * 4) / (1024**2):.1f} MB")

In [None]:
params = filter(lambda p: p.requires_grad, model.parameters())

if config['optimizer'] == 'Adam':
  optimizer = optim.Adam(
    params, lr=config['lr'], weight_decay=config['weight_decay'])
elif config['optimizer'] == 'SGD':
  optimizer = optim.SGD(
    params, lr=config['lr'], momentum=config['momentum'],
    nesterov=config['nesterov'], weight_decay=config['weight_decay'])

In [None]:
def get_img_ids(stage_subset, config):
  img_path = os.path.normpath(os.path.join('processed', config['dataset'], stage_subset, "images", f"*{config['img_ext']}"))
  img_files = glob(img_path)
  img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_files]
  
  print(f"Buscando imágenes en: {img_path}")
  print(f"Imágenes encontradas en {stage_subset}: {len(img_ids)}")
  
  if len(img_ids) == 0:
    print(f"No se encontraron imágenes en '{stage_subset}'. Verifica la ruta.")
  else:
    print(f"{stage_subset.replace('-', ' ').title()} cargado correctamente")
    print(f"   Ejemplo de archivos: {img_ids[:3]}")
  
  return img_ids

In [None]:
train_img_ids = get_img_ids('stage1-train', config)
val_img_ids = get_img_ids('stage2-val', config)

print(f"División del dataset:")
print(f"   Entrenamiento: {len(train_img_ids)} imágenes")
print(f"   Validación: {len(val_img_ids)} imágenes")

In [None]:
train_transform = Compose([
    A.RandomRotate90(),
    A.HorizontalFlip(),

    A.Resize(config['input_h'], config['input_w']),
    A.Normalize(mean=(0.0,), std=(1.0,))
])

val_transform = Compose([
  A.Resize(config['input_h'], config['input_w']),
  A.Normalize(mean=(0.0,), std=(1.0,))
])

print("Transformaciones definidas:")
print(f"   Entrenamiento: {len(train_transform.transforms)} transformaciones")
print(f"   Validación: {len(val_transform.transforms)} transformaciones")

In [None]:
print(f"Imágenes en entrenamiento: {len(train_img_ids)}")
print(f"Imágenes en validación: {len(val_img_ids)}")

train_patients = sorted([
    os.path.splitext(f)[0] 
    for f in os.listdir(os.path.join("processed", config["dataset"], "stage1-train/images")) 
    if f.endswith(config['img_ext'])
])
val_patients = sorted([
    os.path.splitext(f)[0] 
    for f in os.listdir(os.path.join("processed", config["dataset"], "stage2-val/images")) 
    if f.endswith(config['img_ext'])
])

train_dataset = Dataset(
  img_ids=train_patients,
  img_dir=os.path.join("processed", config["dataset"], "stage1-train/images"),
  mask_dir=os.path.join("processed", config["dataset"], "stage1-train/masks"),
  img_ext=config["img_ext"],
  mask_ext=config["mask_ext"],
  transform=train_transform
)

val_dataset = Dataset(
  img_ids=val_patients,
  img_dir=os.path.join("processed", config["dataset"], "stage2-val/images"),
  mask_dir=os.path.join("processed", config["dataset"], "stage2-val/masks"),
  img_ext=config["img_ext"],
  mask_ext=config["mask_ext"],
  transform=val_transform
)

train_loader = DataLoader(
  train_dataset,
  batch_size=config['batch_size'],
  shuffle=True,
  num_workers=config['num_workers'],
  drop_last=False,
  pin_memory=True,
  persistent_workers=True,
  prefetch_factor=4,
  multiprocessing_context='spawn'
)

val_loader = DataLoader(
  val_dataset,
  batch_size=config['batch_size'],
  shuffle=False,
  num_workers=config['num_workers'],
  drop_last=False,
  pin_memory=True,
  persistent_workers=True,
  prefetch_factor=4,
  multiprocessing_context='spawn'
)

print("DataLoaders creados:")
print(f"   Batches de entrenamiento: {len(train_loader)}")
print(f"   Batches de validación: {len(val_loader)}")
print(f"   Tamaño de batch: {config['batch_size']}")

In [None]:
if config['scheduler'] == 'CosineAnnealingLR':
  scheduler = lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=config['epochs'], eta_min=config['min_lr'])
elif config['scheduler'] == 'ReduceLROnPlateau':
  scheduler = lr_scheduler.ReduceLROnPlateau(
    optimizer, factor=config['factor'], 
    patience=config['patience'], min_lr=config['min_lr'])
elif config['scheduler'] == 'MultiStepLR':
  scheduler = lr_scheduler.MultiStepLR(
    optimizer, 
    milestones=[int(e) for e in config['milestones'].split(',')], 
    gamma=config['gamma'])
elif config['scheduler'] == 'OneCycleLR':
  steps_per_epoch = len(train_loader) // config['accumulation_steps']
  
  scheduler = lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=config.get('max_lr', 3e-3),
    total_steps=config['epochs'] * steps_per_epoch,
    pct_start=0.3,
    anneal_strategy='cos',
    div_factor=25,
    final_div_factor=1e4,
    three_phase=False,
  )
elif config['scheduler'] == 'ConstantLR':
  scheduler = None

print(f"Optimizador: {config['optimizer']} (LR: {config['lr']})")
print(f"Scheduler: {config['scheduler']}")

In [None]:
# images, masks, infos = next(iter(train_loader))

# num_examples = min(4, images.shape[0])

# plt.figure(figsize=(12, 6))
# for i in range(num_examples):
#     # La imagen está en formato [1, H, W], convertimos a [H, W]
#     img = images[i].squeeze().cpu().numpy()
#     mask = masks[i].squeeze().cpu().numpy()

#     # Mostrar imagen
#     plt.subplot(2, num_examples, i + 1)
#     plt.imshow(img, cmap='gray')
#     plt.title(f"Imagen - {infos['img_id'][i]}")
#     plt.axis('off')

#     # Mostrar máscara
#     plt.subplot(2, num_examples, i + 1 + num_examples)
#     plt.imshow(mask, cmap='gray')
#     plt.title(f"Máscara - {infos['img_id'][i]}")
#     plt.axis('off')

# plt.tight_layout()
# plt.show()

In [None]:
# try:
#   sample_batch = next(iter(train_loader))
#   print(f"   Batch de prueba exitoso - Forma: {sample_batch[0].shape}")
# except Exception as e:
#   print(f"   Error en batch de prueba: {e}")

In [None]:
class RandomSubsetSampler(Sampler):
    """Sampler que selecciona un subconjunto aleatorio diferente en cada época"""
    
    def __init__(self, data_source, subset_size, epoch=None):
        self.data_source = data_source
        self.subset_size = min(subset_size, len(data_source))
        self.epoch = epoch
        
    def set_epoch(self, epoch):
        """Establece la época actual para cambiar la semilla aleatoria"""
        self.epoch = epoch
        
    def __iter__(self):
        if self.epoch is not None:
            random.seed(self.epoch)
            torch.manual_seed(self.epoch)
        
        indices = random.sample(range(len(self.data_source)), self.subset_size)
        return iter(indices)
    
    def __len__(self):
        return self.subset_size

def create_subset_sampler_loader(full_loader, subset_size, config, epoch=None):
    original_dataset = full_loader.dataset
    
    if isinstance(original_dataset, Subset):
        original_dataset = original_dataset.dataset
    
    sampler = RandomSubsetSampler(original_dataset, subset_size, epoch)
    
    subset_loader = DataLoader(
        original_dataset,
        batch_size=config['batch_size'],
        sampler=sampler,
        num_workers=config.get('num_workers', 4),
        pin_memory=config.get('pin_memory', True),
        drop_last=config.get('drop_last', False)
    )
    
    return subset_loader

In [None]:
def train_epoch(config, train_loader, model, criterion, optimizer, epoch):
  """Entrena el modelo por una época"""
  avg_meters = {'loss': AverageMeter(), 'iou': AverageMeter()}
  
  model.train()
  pbar = tqdm(total=len(train_loader), desc=f'Epoch {epoch}')
  
  # Gradient accumulation
  accumulation_steps = max(1, config['accumulation_steps'] // config['batch_size'])
  optimizer.zero_grad()
  
  for i, (input, target, *_) in enumerate(train_loader):
    input = input.cuda()
    target = target.cuda()

    # if i == 0:
    #   with profile(
    #     activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    #     record_shapes=True,
    #     profile_memory=True,
    #     with_flops=True
    #   ) as prof:
    #     with record_function("model_training_step"):
    #       # Normalizar target si es necesario
    #       if target.max() > 1.0:
    #         target = target.float() / 255.0

    #       if config['deep_supervision']:
    #         outputs = model(input)
    #         loss = 0
    #         for j, output in enumerate(outputs):
    #           weight = 1.0 / (2 ** j)
    #           loss += weight * criterion(output, target)
    #         loss /= sum([1.0 / (2 ** j) for j in range(len(outputs))])
    #         iou = iou_score(outputs[-1], target)
    #       else:
    #         output = model(input)
    #         loss = criterion(output, target)
    #         iou = iou_score(output, target)

    #       loss = loss / accumulation_steps
    #       loss.backward()
    #   print("\n[torch.profiler] Métricas primer batch:")
    #   print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=8))
    #   print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=8))
    #   try:
    #     print(prof.key_averages().table(sort_by="flops", row_limit=8))
    #   except Exception:
    #     print("FLOPs no disponibles (requiere PyTorch >= 2.0 y soporte experimental).")
    #   continue

    # Normalizar target si es necesario
    if target.max() > 1.0:
      target = target.float() / 255.0
    
    # Forward pass
    if config['deep_supervision']:
      outputs = model(input)
      loss = 0
      for j, output in enumerate(outputs):
        weight = 1.0 / (2 ** j)
        loss += weight * criterion(output, target)
      loss /= sum([1.0 / (2 ** j) for j in range(len(outputs))])
      iou = iou_score(outputs[-1], target)
    else:
      output = model(input)
      loss = criterion(output, target)
      iou = iou_score(output, target)
    
    loss = loss / accumulation_steps
    loss.backward()
    
    if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
      torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
      optimizer.step()
      optimizer.zero_grad()
    
    avg_meters['loss'].update(loss.item() * accumulation_steps, input.size(0))
    avg_meters['iou'].update(iou, input.size(0))
    
    pbar.set_postfix({
      'loss': f"{avg_meters['loss'].avg:.4f}",
      'iou': f"{avg_meters['iou'].avg:.4f}"
    })
    pbar.update(1)
  
  pbar.close()
  return OrderedDict([('loss', avg_meters['loss'].avg), ('iou', avg_meters['iou'].avg)])

def validate_epoch(config, val_loader, model, criterion):
  """Valida el modelo"""
  avg_meters = {'loss': AverageMeter(), 'iou': AverageMeter()}
  
  model.eval()
  
  with torch.no_grad():
    pbar = tqdm(total=len(val_loader), desc='Validation')
    for input, target, *_ in val_loader:
      input = input.cuda()
      target = target.cuda()
      
      if target.max() > 1.0:
        target = target / 255.0
      
      if config['deep_supervision']:
        outputs = model(input)
        loss = 0
        for output in outputs:
          loss += criterion(output, target)
        loss /= len(outputs)
        iou = iou_score(outputs[-1], target)
      else:
        output = model(input)
        loss = criterion(output, target)
        iou = iou_score(output, target)
      
      avg_meters['loss'].update(loss.item(), input.size(0))
      avg_meters['iou'].update(iou, input.size(0))
      
      pbar.set_postfix({
        'val_loss': f"{avg_meters['loss'].avg:.4f}",
        'val_iou': f"{avg_meters['iou'].avg:.4f}"
      })
      pbar.update(1)
    pbar.close()
  
  return OrderedDict([('loss', avg_meters['loss'].avg), ('iou', avg_meters['iou'].avg)])

print("Funciones de entrenamiento y validación definidas")

In [None]:
def train_with_epoch_sampling(config, train_loader, model, criterion, optimizer, epoch):
    epoch_subset_size = config.get('epoch_subset_size')
    if epoch_subset_size is None or epoch_subset_size <= 0:
        return train_epoch(config, train_loader, model, criterion, optimizer, epoch)
    
    subset_loader = create_subset_sampler_loader(
        train_loader, 
        epoch_subset_size, 
        config, 
        epoch=epoch
    )
    
    print(f"Época {epoch}: Entrenando con subconjunto de {len(subset_loader.sampler)} muestras "
          f"({len(subset_loader)} batches)")
    
    return train_epoch(config, subset_loader, model, criterion, optimizer, epoch)

In [None]:
def plot_training_progress(log, save_path=None):
  """Visualiza el progreso del entrenamiento"""
  fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
  
  epochs = log['epoch']
  
  # Loss
  ax1.plot(epochs, log['loss'], 'b-', label='Train Loss', linewidth=2)
  ax1.plot(epochs, log['val_loss'], 'r-', label='Val Loss', linewidth=2)
  ax1.set_title('Training and Validation Loss')
  ax1.set_xlabel('Epoch')
  ax1.set_ylabel('Loss')
  ax1.legend()
  ax1.grid(True, alpha=0.3)
  
  # IoU
  ax2.plot(epochs, log['iou'], 'b-', label='Train IoU', linewidth=2)
  ax2.plot(epochs, log['val_iou'], 'r-', label='Val IoU', linewidth=2)
  ax2.set_title('Training and Validation IoU')
  ax2.set_xlabel('Epoch')
  ax2.set_ylabel('IoU')
  ax2.legend()
  ax2.grid(True, alpha=0.3)
  
  # Learning Rate
  ax3.plot(epochs, log['lr'], 'g-', linewidth=2)
  ax3.set_title('Learning Rate Schedule')
  ax3.set_xlabel('Epoch')
  ax3.set_ylabel('Learning Rate')
  ax3.set_yscale('log')
  ax3.grid(True, alpha=0.3)
  
  # Overfitting Detection
  iou_gap = [train - val for train, val in zip(log['iou'], log['val_iou'])]
  ax4.plot(epochs, iou_gap, 'purple', linewidth=2)
  ax4.axhline(y=0.08, color='red', linestyle='--', alpha=0.7, label='Overfitting Threshold')
  ax4.set_title('Overfitting Detection (Train IoU - Val IoU)')
  ax4.set_xlabel('Epoch')
  ax4.set_ylabel('IoU Gap')
  ax4.legend()
  ax4.grid(True, alpha=0.3)
  
  plt.tight_layout()
  
  if save_path:
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
  
  plt.show()

print("Función de visualización definida")

In [None]:
log = OrderedDict([
  ('epoch', []),
  ('lr', []),
  ('loss', []),
  ('iou', []),
  ('val_loss', []),
  ('val_iou', []),
])

best_iou = 0
early_stopping = EarlyStopping(patience=config['early_stopping'], min_delta=0.001)

In [None]:
EPOCHS_TO_RUN = 100

current_epoch = len(log['epoch'])
end_epoch = min(current_epoch + EPOCHS_TO_RUN, config['epochs'])

for epoch in range(current_epoch, end_epoch):
  print(f'\nÉpoca [{epoch + 1}/{config["epochs"]}]')
  
  # Entrenamiento
  train_log = train_with_epoch_sampling(config, train_loader, model, criterion, optimizer, epoch)
  
  # Validación
  val_log = validate_epoch(config, val_loader, model, criterion)
  
  # Mostrar imágenes
  # model.eval()
  # with torch.no_grad():
  #   sample = next(iter(val_loader))
  #   input, target = sample[0].cuda(), sample[1].cuda()
  #   output = model(input)

  #   if isinstance(output, list):
  #       output = output[-1]

  #   output = torch.sigmoid(output)
  #   pred = (output > 0.5).float()

  #   idx = random.randint(0, input.size(0) - 1)

  # plt.figure(figsize=(12, 4))
  # plt.subplot(1, 3, 1)
  # plt.imshow(input[idx, 0].cpu(), cmap='gray')
  # plt.title(f'Input (Epoch {epoch + 1})')

  # plt.subplot(1, 3, 2)
  # plt.imshow(target[idx, 0].cpu(), cmap='gray')
  # plt.title('Target')

  # plt.subplot(1, 3, 3)
  # plt.imshow(pred[idx, 0].cpu(), cmap='gray')
  # plt.title('Predicción')
  # plt.tight_layout()
  # plt.show()

  # Actualizar scheduler
  if config['scheduler'] == 'CosineAnnealingLR':
    scheduler.step()
  elif config['scheduler'] == 'ReduceLROnPlateau':
    scheduler.step(val_log['loss'])
  
  # Mostrar resultados
  current_lr = optimizer.param_groups[0]['lr']
  print(f'Loss: {train_log["loss"]:.4f} | IoU: {train_log["iou"]:.4f} | '
        f'Val Loss: {val_log["loss"]:.4f} | Val IoU: {val_log["iou"]:.4f} | LR: {current_lr:.2e}')
  
  # Detección temprana de overfitting
  overfitting_gap = train_log['iou'] - val_log['iou']
  loss_gap = val_log['loss'] - train_log['loss']
  
  if overfitting_gap > 0.08 and loss_gap > 0.1:
    print(f'OVERFITTING DETECTADO - IoU gap: {overfitting_gap:.4f}, Loss gap: {loss_gap:.4f}')
    for param_group in optimizer.param_groups:
      param_group['lr'] *= 0.5
    print(f'Learning rate reducido a: {optimizer.param_groups[0]["lr"]:.6f}')
  
  # Actualizar log
  log['epoch'].append(epoch)
  log['lr'].append(current_lr)
  log['loss'].append(train_log['loss'])
  log['iou'].append(train_log['iou'])
  log['val_loss'].append(val_log['loss'])
  log['val_iou'].append(val_log['iou'])
  
  # Guardar log
  pd.DataFrame(log).to_csv(f'models/{config["name"]}/log.csv', index=False)
  
  # Guardar mejor modelo
  if val_log['iou'] > best_iou:
    best_iou = val_log['iou']
    best_model_path = f'models/{config["name"]}/best_model.pth'
    torch.save(model.state_dict(), best_model_path)
    print(f'Mejor modelo guardado en {best_model_path} (IoU: {best_iou:.4f})')

  # Guardar modelo de esta época
  epoch_model_path = f'models/{config["name"]}/checkpoints/model_epoch_{epoch + 1}.pth'
  os.makedirs(os.path.dirname(epoch_model_path), exist_ok=True)
  torch.save(model.state_dict(), epoch_model_path)
  print(f'Modelo de la época {epoch + 1} guardado en {epoch_model_path}')
  
  # Early stopping
  if early_stopping(val_log['iou'], model):
    print("Early stopping activado")
    break
  
  # Limpiar memoria
  torch.cuda.empty_cache()

print(f"\nEntrenamiento completado hasta época {epoch}")
print(f"Mejor IoU alcanzado: {best_iou:.4f}")

In [None]:
if len(log['epoch']) > 0:
  plot_training_progress(log, f'models/{config["name"]}/training_progress.png')
  
  # Mostrar estadísticas actuales
  print("ESTADÍSTICAS DEL ENTRENAMIENTO:")
  print("-" * 40) 
  print(f"Épocas completadas: {len(log['epoch'])}")
  print(f"Mejor IoU: {max(log['val_iou']):.4f}")
  print(f"Menor Loss de validación: {min(log['val_loss']):.4f}")
  print(f"IoU actual: {log['val_iou'][-1]:.4f}")
  print(f"Loss actual: {log['val_loss'][-1]:.4f}")
  
  # Detectar tendencias
  if len(log['val_iou']) >= 3:
    recent_trend = np.mean(log['val_iou'][-3:]) - np.mean(log['val_iou'][-6:-3]) if len(log['val_iou']) >= 6 else 0
    print(f"Tendencia reciente: {'Mejorando' if recent_trend > 0 else 'Empeorando' if recent_trend < 0 else 'Estable'}")
else:
  print("No hay datos de entrenamiento para visualizar. Ejecuta el bloque de entrenamiento primero.")

In [None]:
def save_checkpoint(model, optimizer, scheduler, epoch, best_iou, log, filename):
  """Guardar checkpoint completo"""
  checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
    'best_iou': best_iou,
    'log': log,
    'config': config
  }
  torch.save(checkpoint, filename)
  print(f"Checkpoint guardado: {filename}")

def load_checkpoint(filename, model, optimizer, scheduler=None):
  """Cargar checkpoint completo"""
  checkpoint = torch.load(filename)
  model.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  if scheduler and checkpoint['scheduler_state_dict']:
      scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
  
  return (checkpoint['epoch'], checkpoint['best_iou'], 
          checkpoint['log'], checkpoint['config'])

In [None]:
print(config['name'])
if len(log['epoch']) > 0:
  checkpoint_path = f'models/{config["name"]}/checkpoint_epoch_{len(log["epoch"])}.pth'
  save_checkpoint(model, optimizer, scheduler, len(log['epoch']), best_iou, log, checkpoint_path)

print("Funciones de checkpoint definidas")

In [None]:
def print_final_summary():
  """Imprime un resumen final del entrenamiento"""
  if len(log['epoch']) == 0:
    print("No hay datos de entrenamiento para resumir")
    return
  
  print("\n" + "="*60)
  print("RESUMEN FINAL DEL ENTRENAMIENTO")
  print("="*60)
  
  print(f"Configuración:")
  print(f"   • Modelo: {config['arch']}")
  print(f"   • Función de pérdida: {config['loss']}")
  print(f"   • Optimizador: {config['optimizer']} (LR: {config['lr']})")
  print(f"   • Batch size: {config['batch_size']}")
  print(f"   • Early stopping: {config['early_stopping']} épocas")
  
  print(f"\nResultados:")
  print(f"   • Épocas completadas: {len(log['epoch'])}/{config['epochs']}")
  print(f"   • Mejor IoU: {max(log['val_iou']):.4f}")
  print(f"   • IoU final: {log['val_iou'][-1]:.4f}")
  print(f"   • Mejor Loss validación: {min(log['val_loss']):.4f}")
  print(f"   • Loss final: {log['val_loss'][-1]:.4f}")
  
  # Análisis de overfitting
  final_gap = log['iou'][-1] - log['val_iou'][-1]
  print(f"\nAnálisis de Overfitting:")
  print(f"   • Gap IoU (Train-Val): {final_gap:.4f}")
  if final_gap > 0.08:
    print(f"   • Posible overfitting detectado")
  else:
    print(f"   • Nivel de overfitting aceptable")
  
  # Tendencia de mejora
  if len(log['val_iou']) >= 5:
    recent_iou = np.mean(log['val_iou'][-3:])
    early_iou = np.mean(log['val_iou'][:3])
    improvement = recent_iou - early_iou
    print(f"\nMejora Total:")
    print(f"   • IoU inicial: {early_iou:.4f}")
    print(f"   • IoU reciente: {recent_iou:.4f}")
    print(f"   • Mejora: {improvement:.4f} ({improvement/early_iou*100:.1f}%)")
  
  print(f"\nArchivos guardados:")
  print(f"   • Modelo: models/{config['name']}/model.pth")
  print(f"   • Log: models/{config['name']}/log.csv")
  print(f"   • Config: models/{config['name']}/config.yml")
  print(f"   • Gráficas: models/{config['name']}/training_progress.png")
  
  print("="*60)

# Ejecutar resumen si hay datos
print_final_summary()