# üöÅ Drone Gesture Control - Entrenamiento en Google Colab

## Proyecto Final - Inteligencia Artificial

Este notebook entrena los modelos necesarios para el control de dron con gestos:
1. **Red de Segmentaci√≥n (UNet)** - Segmenta la mano del fondo
2. **Red Clasificadora (CNN)** - Clasifica el gesto realizado
3. **Red Temporal (GRU)** - Analiza secuencias para suavizado e intensidad

---

## 1. Configuraci√≥n del Entorno

In [None]:
# Instalar dependencias
!pip install -q torch torchvision
!pip install -q segmentation-models-pytorch
!pip install -q timm
!pip install -q mediapipe
!pip install -q scikit-learn
!pip install -q seaborn
!pip install -q tqdm

print("‚úì Dependencias instaladas")

In [None]:
# Verificar GPU
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Montar Google Drive para guardar checkpoints
from google.colab import drive
drive.mount('/content/drive')

# Crear directorio del proyecto
import os
PROJECT_DIR = '/content/drive/MyDrive/drone_gesture_control'
os.makedirs(PROJECT_DIR, exist_ok=True)
os.makedirs(f'{PROJECT_DIR}/checkpoints', exist_ok=True)
os.makedirs(f'{PROJECT_DIR}/results', exist_ok=True)
print(f"‚úì Proyecto en: {PROJECT_DIR}")

## 2. Subir Dataset

Sube tu dataset grabado con `dataset_recorder.py` a Google Drive en la carpeta:
`/drone_gesture_control/data/dataset/`

Estructura esperada:
```
dataset/
‚îú‚îÄ‚îÄ images/
‚îÇ   ‚îú‚îÄ‚îÄ PITCH_FORWARD/
‚îÇ   ‚îú‚îÄ‚îÄ PITCH_BACKWARD/
‚îÇ   ‚îî‚îÄ‚îÄ ...
‚îú‚îÄ‚îÄ masks/
‚îÇ   ‚îî‚îÄ‚îÄ ...
‚îú‚îÄ‚îÄ landmarks/
‚îÇ   ‚îî‚îÄ‚îÄ ...
‚îî‚îÄ‚îÄ sequences/
    ‚îî‚îÄ‚îÄ ...
```

In [None]:
# Verificar dataset
DATASET_DIR = f'{PROJECT_DIR}/data/dataset'

if os.path.exists(DATASET_DIR):
    print("Estructura del dataset:")
    for root, dirs, files in os.walk(DATASET_DIR):
        level = root.replace(DATASET_DIR, '').count(os.sep)
        indent = ' ' * 2 * level
        print(f"{indent}{os.path.basename(root)}/")
        if level < 2:
            subindent = ' ' * 2 * (level + 1)
            for file in files[:3]:
                print(f"{subindent}{file}")
            if len(files) > 3:
                print(f"{subindent}... ({len(files)} archivos)")
else:
    print(f"‚ö†Ô∏è Dataset no encontrado en {DATASET_DIR}")
    print("Por favor sube tu dataset a Google Drive")

## 3. Configuraci√≥n del Proyecto

In [None]:
# Configuraci√≥n adaptada para Colab
from pathlib import Path

# Rutas
PROJECT_ROOT = Path(PROJECT_DIR)
DATA_DIR = PROJECT_ROOT / "data"
DATASET_DIR = DATA_DIR / "dataset"
CHECKPOINTS_DIR = PROJECT_ROOT / "checkpoints"
RESULTS_DIR = PROJECT_ROOT / "results"

# Clases de gestos
GESTURE_CLASSES = {
    0: "PITCH_FORWARD",
    1: "PITCH_BACKWARD",
    2: "ROLL_RIGHT",
    3: "ROLL_LEFT",
    4: "THROTTLE_UP",
    5: "THROTTLE_DOWN",
    6: "YAW_RIGHT",
    7: "YAW_LEFT",
    8: "HOVER",
    9: "EMERGENCY_STOP",
    10: "NO_GESTURE"
}
NUM_CLASSES = len(GESTURE_CLASSES)

# Configuraci√≥n de entrenamiento
TRAINING_CONFIG = {
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "seed": 42,
    
    # Segmentaci√≥n
    "seg_epochs": 30,
    "seg_batch_size": 16,
    "seg_lr": 1e-4,
    
    # Clasificador
    "cls_epochs": 25,
    "cls_batch_size": 32,
    "cls_lr": 1e-4,
    
    # Temporal
    "temp_epochs": 30,
    "temp_batch_size": 16,
    "temp_lr": 1e-3,
    
    # General
    "train_split": 0.7,
    "val_split": 0.15,
    "test_split": 0.15,
    "patience": 10,
    "augmentation": True,
}

print(f"Device: {TRAINING_CONFIG['device']}")
print(f"Clases: {NUM_CLASSES}")

## 4. Datasets

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import random

class GestureClassificationDataset(Dataset):
    """Dataset para clasificaci√≥n de gestos."""
    
    def __init__(self, root_dir, transform=None, split='train'):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.split = split
        self.images_dir = self.root_dir / "images"
        
        # Recolectar muestras
        self.samples = []
        for class_id, class_name in GESTURE_CLASSES.items():
            class_dir = self.images_dir / class_name
            if not class_dir.exists():
                continue
            for img_file in class_dir.glob("*.jpg"):
                self.samples.append((img_file, class_id))
        
        # Shuffle y split
        random.seed(TRAINING_CONFIG["seed"])
        random.shuffle(self.samples)
        
        n = len(self.samples)
        train_end = int(n * TRAINING_CONFIG["train_split"])
        val_end = train_end + int(n * TRAINING_CONFIG["val_split"])
        
        if split == 'train':
            self.samples = self.samples[:train_end]
        elif split == 'val':
            self.samples = self.samples[train_end:val_end]
        elif split == 'test':
            self.samples = self.samples[val_end:]
        
        # Transform por defecto
        if self.transform is None:
            if split == 'train' and TRAINING_CONFIG["augmentation"]:
                self.transform = transforms.Compose([
                    transforms.Resize((224, 224)),
                    transforms.RandomHorizontalFlip(),
                    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
                    transforms.RandomRotation(20),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                ])
            else:
                self.transform = transforms.Compose([
                    transforms.Resize((224, 224)),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                ])
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, class_id = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        image = self.transform(image)
        return image, class_id


class SegmentationDataset(Dataset):
    """Dataset para segmentaci√≥n."""
    
    def __init__(self, root_dir, transform=None, mask_transform=None, split='train'):
        self.root_dir = Path(root_dir)
        self.split = split
        self.images_dir = self.root_dir / "images"
        self.masks_dir = self.root_dir / "masks"
        
        # Recolectar muestras
        self.samples = []
        for class_name in GESTURE_CLASSES.values():
            class_img_dir = self.images_dir / class_name
            class_mask_dir = self.masks_dir / class_name
            if not class_img_dir.exists():
                continue
            for img_file in class_img_dir.glob("*.jpg"):
                mask_file = class_mask_dir / f"{img_file.stem}.png"
                if mask_file.exists():
                    self.samples.append((img_file, mask_file))
        
        # Shuffle y split
        random.seed(TRAINING_CONFIG["seed"])
        random.shuffle(self.samples)
        
        n = len(self.samples)
        train_end = int(n * TRAINING_CONFIG["train_split"])
        val_end = train_end + int(n * TRAINING_CONFIG["val_split"])
        
        if split == 'train':
            self.samples = self.samples[:train_end]
        elif split == 'val':
            self.samples = self.samples[train_end:val_end]
        elif split == 'test':
            self.samples = self.samples[val_end:]
        
        # Transforms
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        self.mask_transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor()
        ])
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, mask_path = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')
        
        image = self.transform(image)
        mask = self.mask_transform(mask)
        mask = (mask > 0.5).float()
        
        return image, mask


# Crear dataloaders
def get_dataloaders(dataset_type='classification', batch_size=32):
    if dataset_type == 'classification':
        train_ds = GestureClassificationDataset(DATASET_DIR, split='train')
        val_ds = GestureClassificationDataset(DATASET_DIR, split='val')
        test_ds = GestureClassificationDataset(DATASET_DIR, split='test')
    else:
        train_ds = SegmentationDataset(DATASET_DIR, split='train')
        val_ds = SegmentationDataset(DATASET_DIR, split='val')
        test_ds = SegmentationDataset(DATASET_DIR, split='test')
    
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2)
    
    print(f"Dataset {dataset_type}: Train={len(train_ds)}, Val={len(val_ds)}, Test={len(test_ds)}")
    
    return train_loader, val_loader, test_loader

# Verificar datasets
if DATASET_DIR.exists():
    train_loader, val_loader, test_loader = get_dataloaders('classification')
    for images, labels in train_loader:
        print(f"Batch shape: {images.shape}, Labels: {labels[:5]}")
        break

## 5. Modelos

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

# Clasificador CNN
class GestureClassifier(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES, pretrained=True):
        super().__init__()
        self.backbone = models.resnet18(weights='IMAGENET1K_V1' if pretrained else None)
        self.feature_dim = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(self.feature_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)


# UNet para segmentaci√≥n
try:
    import segmentation_models_pytorch as smp
    
    class SegmentationModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.model = smp.Unet(
                encoder_name="mobilenet_v2",
                encoder_weights="imagenet",
                in_channels=3,
                classes=2
            )
        
        def forward(self, x):
            return self.model(x)
except ImportError:
    print("SMP no disponible, usando UNet b√°sico")


# Red Temporal GRU
class TemporalGRU(nn.Module):
    def __init__(self, input_size=512+63, hidden_size=256, num_layers=2, num_classes=NUM_CLASSES):
        super().__init__()
        self.hidden_size = hidden_size
        
        self.input_proj = nn.Linear(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, num_layers, 
                          batch_first=True, dropout=0.3)
        
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size, num_classes)
        )
        
        self.intensity = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        x = self.input_proj(x)
        output, hidden = self.gru(x)
        context = hidden[-1]
        
        logits = self.classifier(context)
        intensity = self.intensity(context).squeeze(-1)
        
        return logits, intensity


# Modelo completo CNN + GRU
class GestureSequenceModel(nn.Module):
    def __init__(self, pretrained=True, freeze_cnn=True):
        super().__init__()
        
        # CNN backbone
        self.cnn = models.mobilenet_v3_small(weights='IMAGENET1K_V1' if pretrained else None)
        self.cnn_dim = self.cnn.classifier[0].in_features
        self.cnn.classifier = nn.Identity()
        
        if freeze_cnn:
            for p in self.cnn.parameters():
                p.requires_grad = False
        
        # Landmark projection
        self.landmark_proj = nn.Linear(63, 64)
        
        # Temporal
        self.temporal = TemporalGRU(input_size=self.cnn_dim + 64)
    
    def forward(self, frames, landmarks=None):
        B, T, C, H, W = frames.shape
        
        frames_flat = frames.view(B * T, C, H, W)
        cnn_features = self.cnn(frames_flat)
        cnn_features = cnn_features.view(B, T, -1)
        
        if landmarks is not None:
            lm_features = self.landmark_proj(landmarks)
            combined = torch.cat([cnn_features, lm_features], dim=2)
        else:
            zeros = torch.zeros(B, T, 64, device=frames.device)
            combined = torch.cat([cnn_features, zeros], dim=2)
        
        return self.temporal(combined)


# Test modelos
print("\nTest de modelos:")

model = GestureClassifier()
x = torch.randn(2, 3, 224, 224)
y = model(x)
print(f"Clasificador: {x.shape} -> {y.shape}")

model = TemporalGRU()
x = torch.randn(2, 15, 575)
logits, intensity = model(x)
print(f"GRU: {x.shape} -> logits={logits.shape}, intensity={intensity.shape}")

## 6. Funciones de Entrenamiento

In [None]:
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    running_correct = 0
    running_total = 0
    
    for images, labels in tqdm(dataloader, desc="Training"):
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        _, predicted = torch.max(outputs, 1)
        running_loss += loss.item() * images.size(0)
        running_correct += (predicted == labels).sum().item()
        running_total += labels.size(0)
    
    return running_loss / running_total, running_correct / running_total


def validate(model, dataloader, criterion, device, return_preds=False):
    model.eval()
    running_loss = 0.0
    running_correct = 0
    running_total = 0
    all_preds, all_labels = [], []
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Validating"):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            _, predicted = torch.max(outputs, 1)
            running_loss += loss.item() * images.size(0)
            running_correct += (predicted == labels).sum().item()
            running_total += labels.size(0)
            
            if return_preds:
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
    
    if return_preds:
        return running_loss / running_total, running_correct / running_total, all_preds, all_labels
    return running_loss / running_total, running_correct / running_total


def plot_training_curves(train_losses, train_accs, val_losses, val_accs, title="Training Curves"):
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    epochs = range(1, len(train_losses) + 1)
    
    axes[0].plot(epochs, train_losses, 'b-', label='Train')
    axes[0].plot(epochs, val_losses, 'r-', label='Val')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].legend()
    axes[0].set_title('Loss')
    axes[0].grid(True)
    
    axes[1].plot(epochs, train_accs, 'b-', label='Train')
    axes[1].plot(epochs, val_accs, 'r-', label='Val')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].legend()
    axes[1].set_title('Accuracy')
    axes[1].grid(True)
    
    fig.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(RESULTS_DIR / f"{title.replace(' ', '_').lower()}.png", dpi=150)
    plt.show()


def plot_confusion_matrix(preds, labels, class_names):
    cm = confusion_matrix(labels, preds)
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(RESULTS_DIR / "confusion_matrix.png", dpi=150)
    plt.show()

print("‚úì Funciones de entrenamiento definidas")

## 7. Entrenar Clasificador CNN

In [None]:
# Configuraci√≥n
EPOCHS = TRAINING_CONFIG["cls_epochs"]
BATCH_SIZE = TRAINING_CONFIG["cls_batch_size"]
LR = TRAINING_CONFIG["cls_lr"]
DEVICE = TRAINING_CONFIG["device"]

print(f"\n{'='*50}")
print("ENTRENAMIENTO DE CLASIFICADOR CNN")
print(f"{'='*50}")
print(f"Epochs: {EPOCHS}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Learning rate: {LR}")
print(f"Device: {DEVICE}")

# Datasets
train_loader, val_loader, test_loader = get_dataloaders('classification', BATCH_SIZE)

# Modelo
model = GestureClassifier(pretrained=True).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10)

# Tracking
train_losses, train_accs = [], []
val_losses, val_accs = [], []
best_val_acc = 0.0

# Entrenamiento
for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
    val_loss, val_acc = validate(model, val_loader, criterion, DEVICE)
    
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    val_losses.append(val_loss)
    val_accs.append(val_acc)
    
    print(f"  Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
    print(f"  Val   - Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")
    
    scheduler.step()
    
    # Guardar mejor modelo
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
        }, CHECKPOINTS_DIR / 'classifier_resnet18_best.pt')
        print(f"  ‚úì Mejor modelo guardado (acc={val_acc:.4f})")

print(f"\n¬°Entrenamiento completado! Mejor accuracy: {best_val_acc:.4f}")

In [None]:
# Visualizar curvas de entrenamiento
plot_training_curves(train_losses, train_accs, val_losses, val_accs, 
                     title="Clasificador CNN")

In [None]:
# Evaluaci√≥n en test
checkpoint = torch.load(CHECKPOINTS_DIR / 'classifier_resnet18_best.pt')
model.load_state_dict(checkpoint['model_state_dict'])

test_loss, test_acc, preds, labels = validate(model, test_loader, criterion, DEVICE, return_preds=True)
print(f"\nTest - Loss: {test_loss:.4f}, Accuracy: {test_acc:.4f}")

# Matriz de confusi√≥n
class_names = [GESTURE_CLASSES[i][:10] for i in range(NUM_CLASSES)]
plot_confusion_matrix(preds, labels, class_names)

# Reporte
print("\nReporte de clasificaci√≥n:")
print(classification_report(labels, preds, target_names=class_names, zero_division=0))

## 8. Entrenar Red de Segmentaci√≥n

In [None]:
# Dice + BCE Loss
class DiceBCELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()
    
    def forward(self, logits, targets):
        bce_loss = self.bce(logits[:, 1, :, :], targets.squeeze(1))
        probs = torch.sigmoid(logits[:, 1, :, :])
        intersection = (probs * targets.squeeze(1)).sum(dim=(1, 2))
        union = probs.sum(dim=(1, 2)) + targets.squeeze(1).sum(dim=(1, 2))
        dice_loss = 1 - (2. * intersection + 1e-6) / (union + 1e-6)
        return 0.5 * bce_loss + 0.5 * dice_loss.mean()


def calculate_iou(pred, target):
    pred = pred.view(-1)
    target = target.view(-1)
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    return ((intersection + 1e-6) / (union + 1e-6)).item()


# Configuraci√≥n
EPOCHS = TRAINING_CONFIG["seg_epochs"]
BATCH_SIZE = TRAINING_CONFIG["seg_batch_size"]
LR = TRAINING_CONFIG["seg_lr"]

print(f"\n{'='*50}")
print("ENTRENAMIENTO DE RED DE SEGMENTACI√ìN")
print(f"{'='*50}")

# Datasets
train_loader, val_loader, test_loader = get_dataloaders('segmentation', BATCH_SIZE)

# Modelo
seg_model = SegmentationModel().to(DEVICE)
criterion = DiceBCELoss()
optimizer = torch.optim.AdamW(seg_model.parameters(), lr=LR)

# Tracking
train_losses, train_ious = [], []
val_losses, val_ious = [], []
best_iou = 0.0

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    
    # Train
    seg_model.train()
    epoch_loss, epoch_iou = 0.0, 0.0
    for images, masks in tqdm(train_loader, desc="Training"):
        images = images.to(DEVICE)
        masks = masks.to(DEVICE)
        
        optimizer.zero_grad()
        outputs = seg_model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            pred = (torch.sigmoid(outputs[:, 1]) > 0.5).float()
            iou = calculate_iou(pred, masks.squeeze(1))
        
        epoch_loss += loss.item()
        epoch_iou += iou
    
    train_losses.append(epoch_loss / len(train_loader))
    train_ious.append(epoch_iou / len(train_loader))
    
    # Validate
    seg_model.eval()
    val_loss, val_iou = 0.0, 0.0
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(DEVICE)
            masks = masks.to(DEVICE)
            outputs = seg_model(images)
            loss = criterion(outputs, masks)
            pred = (torch.sigmoid(outputs[:, 1]) > 0.5).float()
            val_loss += loss.item()
            val_iou += calculate_iou(pred, masks.squeeze(1))
    
    val_losses.append(val_loss / len(val_loader))
    val_ious.append(val_iou / len(val_loader))
    
    print(f"  Train - Loss: {train_losses[-1]:.4f}, IoU: {train_ious[-1]:.4f}")
    print(f"  Val   - Loss: {val_losses[-1]:.4f}, IoU: {val_ious[-1]:.4f}")
    
    if val_ious[-1] > best_iou:
        best_iou = val_ious[-1]
        torch.save({
            'epoch': epoch,
            'model_state_dict': seg_model.state_dict(),
        }, CHECKPOINTS_DIR / 'segmentation_unet_best.pt')
        print(f"  ‚úì Mejor modelo guardado (IoU={best_iou:.4f})")

print(f"\n¬°Entrenamiento completado! Mejor IoU: {best_iou:.4f}")

In [None]:
# Visualizar resultados de segmentaci√≥n
plot_training_curves(train_losses, train_ious, val_losses, val_ious,
                     title="Segmentaci√≥n UNet")

## 9. Descargar Modelos

In [None]:
# Crear zip con checkpoints
import shutil

print("Archivos en checkpoints:")
for f in CHECKPOINTS_DIR.glob("*"):
    size = f.stat().st_size / 1e6
    print(f"  {f.name}: {size:.1f} MB")

# Crear zip
shutil.make_archive('/content/checkpoints', 'zip', CHECKPOINTS_DIR)
print(f"\n‚úì Archivo zip creado: /content/checkpoints.zip")

# Descargar
from google.colab import files
files.download('/content/checkpoints.zip')

## 10. Resumen Final

In [None]:
print("\n" + "="*60)
print("RESUMEN DE ENTRENAMIENTO")
print("="*60)

# Verificar modelos guardados
models_info = [
    ("Clasificador CNN", "classifier_resnet18_best.pt"),
    ("Segmentaci√≥n UNet", "segmentation_unet_best.pt"),
    ("Red Temporal GRU", "temporal_gru_best.pt"),
]

for name, filename in models_info:
    path = CHECKPOINTS_DIR / filename
    if path.exists():
        size = path.stat().st_size / 1e6
        print(f"‚úì {name}: {size:.1f} MB")
    else:
        print(f"‚úó {name}: No entrenado")

print("\n" + "="*60)
print("PR√ìXIMOS PASOS")
print("="*60)
print("1. Descarga los checkpoints")
print("2. Copia a la carpeta 'checkpoints/' del proyecto")
print("3. Ejecuta: python main.py --mode integrated")
print("="*60)