# üéì VideoMAE - Sign Language Recognition (WLASL)
## Complete Training & Evaluation Pipeline for Thesis

**Autor:** Rafael Ovalle - Tesis UNAB  
**Dataset:** WLASL100/WLASL300 (American Sign Language)  
**Modelo:** VideoMAE (MCG-NJU/videomae-base-finetuned-kinetics)

---

### üìã Features de este Notebook:

- ‚úÖ **Configuraci√≥n Flexible:** Elige entre V1 (baseline) o V2 (experimental)
- ‚úÖ **M√∫ltiples Datasets:** WLASL100 (100 clases) o WLASL300 (300 clases)
- ‚úÖ **Entrenamiento Completo:** Con early stopping, checkpointing y TensorBoard
- ‚úÖ **Evaluaci√≥n Detallada:** Accuracy, Precision, Recall, F1, Top-K, Confusion Matrix
- ‚úÖ **Visualizaciones:** Gr√°ficos de entrenamiento, matrices de confusi√≥n, an√°lisis por clase
- ‚úÖ **Exportaci√≥n Autom√°tica:** Resultados en JSON, TXT, im√°genes y modelo final
- ‚úÖ **Integraci√≥n Drive:** Guarda todo en tu Google Drive autom√°ticamente
- ‚úÖ **Optimizaci√≥n HP:** B√∫squeda de hiperpar√°metros (opcional)

---

### üóÇÔ∏è Configuraciones Disponibles:

| Configuraci√≥n | Dataset | Train Videos | Val Videos | Test Videos | Uso Recomendado |
|---------------|---------|--------------|------------|-------------|------------------|
| **V1 - WLASL100** | 100 clases | 807 | 194 | 117 | Baseline, experimentaci√≥n |
| **V2 - WLASL100** | 100 clases | 1,001 | 117 | 117 | Maximizar datos |
| **V1 - WLASL300** | 300 clases | 1,959 | 557 | 271 | Baseline, experimentaci√≥n |
| **V2 - WLASL300** | 300 clases | 2,516 | 271 | 271 | Maximizar datos |

---

### ‚öôÔ∏è Diferencias V1 vs V2:

**V1 (Baseline):**
- Train/Val/Test separados e independientes
- Regularizaci√≥n activa (weight decay, label smoothing, class weights)
- Batch size: 16, LR: 1e-4
- Ideal para experimentaci√≥n y tuning de hiperpar√°metros

**V2 (Experimental):**
- Train+Val combinados, Test usado como validaci√≥n
- Sin regularizaci√≥n expl√≠cita (conf√≠a en m√°s datos)
- Batch size: 6, LR: 1e-5
- Ideal para modelo final con m√°ximos datos

---

# 1Ô∏è‚É£ Setup Inicial

## 1.1 Verificar GPU y Configuraci√≥n

In [None]:
# Verificar GPU disponible
!nvidia-smi

import torch
print(f"\n{'='*60}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print(f"{'='*60}")

## 1.2 Instalar Dependencias

In [None]:
%%capture
# Instalar dependencias necesarias
!pip install transformers==4.36.0
!pip install torch torchvision torchaudio
!pip install opencv-python-headless
!pip install scikit-learn
!pip install matplotlib seaborn
!pip install tensorboard
!pip install tqdm
!pip install pandas

print("‚úÖ Todas las dependencias instaladas correctamente")

## 1.3 Importar Librer√≠as

In [None]:
import os
import json
import shutil
from datetime import datetime
from pathlib import Path
from collections import Counter
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.utils.tensorboard import SummaryWriter

from transformers import VideoMAEForVideoClassification
from torchvision import transforms

from sklearn.metrics import (
    confusion_matrix,
    classification_report,
    accuracy_score,
    precision_recall_fscore_support,
    top_k_accuracy_score
)

from tqdm.auto import tqdm

# Configurar matplotlib para mejor visualizaci√≥n
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

print("‚úÖ Librer√≠as importadas correctamente")

## 1.4 Montar Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Crear carpeta de trabajo en Drive
DRIVE_ROOT = "/content/drive/MyDrive/TESIS_WLASL"
os.makedirs(DRIVE_ROOT, exist_ok=True)

print(f"‚úÖ Google Drive montado en: {DRIVE_ROOT}")

# 2Ô∏è‚É£ Configuraci√≥n del Experimento

## 2.1 Seleccionar Configuraci√≥n

**üéØ Instrucciones:**
1. Elige el dataset (WLASL100 o WLASL300)
2. Elige la versi√≥n (V1 baseline o V2 experimental)
3. Ajusta hiperpar√°metros si es necesario

In [None]:
# ============================================================
#   CONFIGURACI√ìN PRINCIPAL
# ============================================================

# üéØ SELECCIONA TU CONFIGURACI√ìN AQU√ç:
DATASET_TYPE = "wlasl100"  # Opciones: "wlasl100" o "wlasl300"
VERSION = "v1"             # Opciones: "v1" (baseline) o "v2" (experimental)

# Configurar autom√°ticamente basado en selecci√≥n
if DATASET_TYPE == "wlasl100":
    NUM_CLASSES = 100
    DATASET_NAME = "wlasl100_v2" if VERSION == "v2" else "wlasl100"
elif DATASET_TYPE == "wlasl300":
    NUM_CLASSES = 300
    DATASET_NAME = "wlasl300_v2" if VERSION == "v2" else "wlasl300"
else:
    raise ValueError("DATASET_TYPE debe ser 'wlasl100' o 'wlasl300'")

# Configuraci√≥n de hiperpar√°metros basada en versi√≥n
if VERSION == "v1":
    CONFIG = {
        "model_name": "MCG-NJU/videomae-base-finetuned-kinetics",
        "num_classes": NUM_CLASSES,
        "batch_size": 16,
        "max_epochs": 30,
        "lr": 1e-4,
        "weight_decay": 0.05,
        "label_smoothing": 0.1,
        "class_weighted": True,
        "warmup_ratio": 0.1,
        "min_lr": 1e-6,
        "patience": 5,
        "gradient_clip": 1.0,
        "num_workers": 2,
        "save_every": 5,
    }
elif VERSION == "v2":
    CONFIG = {
        "model_name": "MCG-NJU/videomae-base-finetuned-kinetics",
        "num_classes": NUM_CLASSES,
        "batch_size": 6,
        "max_epochs": 30,
        "lr": 1e-5,
        "weight_decay": 0.0,
        "label_smoothing": 0.0,
        "class_weighted": False,
        "warmup_ratio": 0.1,
        "min_lr": 1e-6,
        "patience": 10,
        "gradient_clip": 1.0,
        "num_workers": 2,
        "save_every": 5,
    }
else:
    raise ValueError("VERSION debe ser 'v1' o 'v2'")

# Configuraci√≥n de rutas
CONFIG.update({
    "dataset_name": DATASET_NAME,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "data_root": f"{DRIVE_ROOT}/data/{DATASET_NAME}",
    "checkpoint_dir": f"{DRIVE_ROOT}/models/{VERSION}/{DATASET_NAME}/checkpoints",
    "logs_dir": f"{DRIVE_ROOT}/runs/{VERSION}/{DATASET_NAME}",
    "results_dir": f"{DRIVE_ROOT}/results/{VERSION}/{DATASET_NAME}",
})

# Crear directorios
for key in ["checkpoint_dir", "logs_dir", "results_dir"]:
    os.makedirs(CONFIG[key], exist_ok=True)

# Mostrar configuraci√≥n
print(f"\n{'='*70}")
print(f"{'CONFIGURACI√ìN DEL EXPERIMENTO':^70}")
print(f"{'='*70}")
print(f"Dataset: {DATASET_TYPE.upper()} ({NUM_CLASSES} clases)")
print(f"Versi√≥n: {VERSION.upper()}")
print(f"Dataset Name: {DATASET_NAME}")
print(f"\nHiperpar√°metros:")
print(f"  - Batch Size: {CONFIG['batch_size']}")
print(f"  - Learning Rate: {CONFIG['lr']:.2e}")
print(f"  - Weight Decay: {CONFIG['weight_decay']}")
print(f"  - Label Smoothing: {CONFIG['label_smoothing']}")
print(f"  - Class Weighted: {CONFIG['class_weighted']}")
print(f"  - Patience: {CONFIG['patience']}")
print(f"  - Max Epochs: {CONFIG['max_epochs']}")
print(f"\nRutas:")
print(f"  - Data: {CONFIG['data_root']}")
print(f"  - Checkpoints: {CONFIG['checkpoint_dir']}")
print(f"  - Logs: {CONFIG['logs_dir']}")
print(f"  - Results: {CONFIG['results_dir']}")
print(f"\nDevice: {CONFIG['device']}")
print(f"{'='*70}\n")

# Guardar configuraci√≥n
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
config_path = f"{CONFIG['results_dir']}/config_{timestamp}.json"
with open(config_path, 'w') as f:
    json.dump(CONFIG, f, indent=2)
print(f"‚úÖ Configuraci√≥n guardada en: {config_path}")

# 3Ô∏è‚É£ Preparaci√≥n del Dataset

## 3.1 Dataset Class

In [None]:
NUM_FRAMES = 16

# ============================================================
#   Cargar mapas de labels
# ============================================================
def load_label_maps(meta_json: str, subset_json: str):
    """
    Carga los mapeos de video_id a clase.
    """
    with open(subset_json, "r", encoding="utf-8") as f:
        subset = json.load(f)

    vid2label = {}
    label_set = set()

    for vid, info in subset.items():
        label = info["action"][0]
        vid2label[vid] = label
        label_set.add(label)

    labels_sorted = sorted(label_set)
    label2id = {lab: lab for lab in labels_sorted}
    id2label = {lab: lab for lab in labels_sorted}

    return vid2label, label2id, id2label


def load_split_list(split_txt: str):
    """Carga lista de archivos del split."""
    with open(split_txt, "r", encoding="utf-8") as f:
        return [line.strip() for line in f if line.strip()]


def sample_frames_uniform(video_path: str, num_frames: int = NUM_FRAMES):
    """Extrae frames uniformemente espaciados del video."""
    cap = cv2.VideoCapture(video_path)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    if frame_count <= 0:
        cap.release()
        raise RuntimeError(f"Video vac√≠o o corrupto: {video_path}")

    indices = np.linspace(0, frame_count - 1, num_frames).astype(int)

    frames = []
    for idx in indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
        ret, frame = cap.read()
        if not ret or frame is None:
            continue
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(frame)

    cap.release()

    if len(frames) == 0:
        raise RuntimeError(f"No se pudieron leer frames de {video_path}")

    # Si faltan frames, duplicar el √∫ltimo
    while len(frames) < num_frames:
        frames.append(frames[-1])

    return frames[:num_frames]


# ============================================================
#   Clase principal Dataset
# ============================================================
class WLASLVideoDataset(Dataset):
    def __init__(
        self,
        split: str,
        base_path: str,
        videos_folder: str = "dataset",
        meta_json: str = "WLASL_v0.3.json",
        subset_json: str = "nslt_100.json",
        dataset_size: int = 100,
    ):
        assert split in ["train", "val", "test"]
        self.split = split
        self.dataset_size = dataset_size

        # Auto-detectar dataset_size
        if dataset_size == 100 and "300" in base_path:
            self.dataset_size = 300
        elif dataset_size == 300 and "100" in base_path:
            self.dataset_size = 100

        # Ajustar nombres de archivos JSON
        if self.dataset_size == 300:
            if meta_json == "WLASL_v0.3.json":
                meta_json = "WLASL_v0.3_300.json"
            if subset_json == "nslt_100.json":
                subset_json = "nslt_300.json"

        # Rutas
        self.base = base_path
        self.videos_dir = os.path.join(base_path, videos_folder, split)
        self.splits_dir = os.path.join(base_path, "splits")
        self.meta_json = os.path.join(base_path, meta_json)
        self.subset_json = os.path.join(base_path, subset_json)

        # Cargar mapas de labels
        self.vid2label, self.label2id, self.id2label = load_label_maps(
            self.meta_json, self.subset_json
        )

        # Cargar lista de videos corruptos (opcional)
        corrupt_list_path = os.path.join(self.base, f"corrupt_videos_{split}.txt")
        self.corrupt_set = set()
        if os.path.exists(corrupt_list_path):
            with open(corrupt_list_path, "r", encoding="utf-8") as f:
                self.corrupt_set = {line.strip() for line in f if line.strip()}

        # Cargar lista de videos del split
        split_txt_path = os.path.join(self.splits_dir, f"{split}_split.txt")
        file_list = load_split_list(split_txt_path)

        # Construir lista de muestras
        self.samples = []
        for raw_fname in file_list:
            norm = raw_fname.replace("\\", "/")
            basename = os.path.basename(norm)

            if not basename.endswith(".mp4"):
                continue

            if basename in self.corrupt_set:
                continue

            vid = os.path.splitext(basename)[0]
            video_path = os.path.join(self.videos_dir, basename)

            if os.path.exists(video_path) and vid in self.vid2label:
                label = self.vid2label[vid]
                self.samples.append((video_path, label))

        if len(self.samples) == 0:
            raise RuntimeError(f"No se encontraron muestras para split={split}")

        # Transforms
        if split == "train":
            self.transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(brightness=0.2, contrast=0.2),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                ),
            ])
        else:
            self.transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                ),
            ])

    def __len__(self):
        return len(self.samples)

    def get_labels(self):
        return [label for _, label in self.samples]

    def __getitem__(self, idx):
        for attempt in range(5):
            video_path, label = self.samples[idx]

            try:
                frames = sample_frames_uniform(video_path, NUM_FRAMES)
            except Exception as e:
                print(f"[WARN] Video corrupto: {video_path} ({e})")
                idx = (idx + 1) % len(self.samples)
                continue

            frames_t = [self.transform(f) for f in frames]
            video_tensor = torch.stack(frames_t, dim=0)  # (T, C, H, W)

            return video_tensor, torch.tensor(label, dtype=torch.long)

        raise RuntimeError(f"Demasiados videos corruptos seguidos")


print("‚úÖ Dataset class definida correctamente")

## 3.2 Descargar/Verificar Dataset

**‚ö†Ô∏è IMPORTANTE:** 
- Si ya tienes el dataset en tu Drive, ajusta la ruta en `CONFIG['data_root']`
- Si no lo tienes, sube los archivos manualmente a Drive

In [None]:
# Verificar si el dataset existe
data_path = CONFIG['data_root']

required_files = [
    f"{data_path}/splits/train_split.txt",
    f"{data_path}/splits/val_split.txt",
    f"{data_path}/splits/test_split.txt",
]

if NUM_CLASSES == 100:
    required_files.extend([
        f"{data_path}/nslt_100.json",
        f"{data_path}/WLASL_v0.3.json",
    ])
else:
    required_files.extend([
        f"{data_path}/nslt_300.json",
        f"{data_path}/WLASL_v0.3_300.json",
        f"{data_path}/gloss_to_id.json",
    ])

missing_files = [f for f in required_files if not os.path.exists(f)]

if missing_files:
    print(f"‚ùå Archivos faltantes:")
    for f in missing_files:
        print(f"   - {f}")
    print(f"\n‚ö†Ô∏è Por favor, sube el dataset a: {data_path}")
    raise FileNotFoundError("Dataset no encontrado")
else:
    print(f"‚úÖ Dataset verificado en: {data_path}")
    
    # Mostrar estad√≠sticas
    for split_name in ["train", "val", "test"]:
        split_file = f"{data_path}/splits/{split_name}_split.txt"
        with open(split_file) as f:
            count = len([l for l in f if l.strip()])
        print(f"  - {split_name.capitalize()}: {count} videos")

## 3.3 Crear DataLoaders

In [None]:
print("[INFO] Creando datasets y dataloaders...")

# Crear datasets
train_dataset = WLASLVideoDataset(
    split="train",
    base_path=CONFIG['data_root'],
    dataset_size=NUM_CLASSES
)

val_dataset = WLASLVideoDataset(
    split="val",
    base_path=CONFIG['data_root'],
    dataset_size=NUM_CLASSES
)

test_dataset = WLASLVideoDataset(
    split="test",
    base_path=CONFIG['data_root'],
    dataset_size=NUM_CLASSES
)

# Crear dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=CONFIG['num_workers'],
    pin_memory=True if CONFIG['device'] == "cuda" else False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=CONFIG['num_workers'],
    pin_memory=True if CONFIG['device'] == "cuda" else False
)

test_loader = DataLoader(
    test_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=CONFIG['num_workers'],
    pin_memory=True if CONFIG['device'] == "cuda" else False
)

print(f"\n{'='*60}")
print(f"DATASETS CARGADOS")
print(f"{'='*60}")
print(f"Train:      {len(train_dataset):,} videos ({len(train_loader)} batches)")
print(f"Validation: {len(val_dataset):,} videos ({len(val_loader)} batches)")
print(f"Test:       {len(test_dataset):,} videos ({len(test_loader)} batches)")
print(f"{'='*60}\n")

print("‚úÖ DataLoaders creados correctamente")

# 4Ô∏è‚É£ Entrenamiento

## 4.1 Funciones de Entrenamiento

In [None]:
# ============================================================
#   Funciones auxiliares
# ============================================================
def compute_class_weights(labels: list, num_classes: int, device: str):
    """Calcula pesos por clase inversamente proporcional a la frecuencia."""
    class_counts = Counter(labels)
    weights = torch.zeros(num_classes, dtype=torch.float32)

    for class_id in range(num_classes):
        count = class_counts.get(class_id, 0)
        if count > 0:
            weights[class_id] = 1.0 / count
        else:
            weights[class_id] = 0.0

    if weights.sum() > 0:
        weights = weights / weights.mean()

    return weights.to(device)


def calculate_accuracy(outputs: torch.Tensor, labels: torch.Tensor) -> float:
    """Calcula accuracy dado logits y labels."""
    predictions = torch.argmax(outputs, dim=1)
    correct = (predictions == labels).sum().item()
    accuracy = 100.0 * correct / labels.size(0)
    return accuracy


# ============================================================
#   Warmup + Cosine Scheduler
# ============================================================
class WarmupCosineScheduler:
    def __init__(self, optimizer, warmup_steps, total_steps, min_lr=1e-6, last_epoch=-1):
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.min_lr = min_lr
        self.base_lrs = [group['lr'] for group in optimizer.param_groups]
        self.current_step = last_epoch + 1

    def step(self):
        self.current_step += 1

        for param_group, base_lr in zip(self.optimizer.param_groups, self.base_lrs):
            if self.current_step < self.warmup_steps:
                lr = base_lr * (self.current_step / self.warmup_steps)
            else:
                progress = (self.current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
                lr = self.min_lr + (base_lr - self.min_lr) * 0.5 * (1 + np.cos(np.pi * progress))

            param_group['lr'] = lr

    def get_last_lr(self):
        return [group['lr'] for group in self.optimizer.param_groups]

    def state_dict(self):
        return {
            'current_step': self.current_step,
            'base_lrs': self.base_lrs,
        }

    def load_state_dict(self, state_dict):
        self.current_step = state_dict['current_step']
        self.base_lrs = state_dict['base_lrs']


# ============================================================
#   Entrenamiento de una √©poca
# ============================================================
def train_one_epoch(
    model, dataloader, criterion, optimizer, scheduler, device, epoch, writer, gradient_clip=1.0
):
    model.train()
    running_loss = 0.0
    running_acc = 0.0
    total_batches = len(dataloader)

    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch:02d} [TRAIN]", leave=False)

    for batch_idx, (videos, labels) in enumerate(progress_bar):
        videos = videos.to(device)
        labels = labels.to(device)

        outputs = model(pixel_values=videos)
        logits = outputs.logits

        loss = criterion(logits, labels)
        batch_acc = calculate_accuracy(logits, labels)

        optimizer.zero_grad()
        loss.backward()

        if gradient_clip > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip)

        optimizer.step()
        scheduler.step()

        running_loss += loss.item()
        running_acc += batch_acc

        current_lr = scheduler.get_last_lr()[0]
        progress_bar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'acc': f"{batch_acc:.1f}%",
            'lr': f"{current_lr:.2e}"
        })

        global_step = (epoch - 1) * total_batches + batch_idx
        writer.add_scalar('Train/Loss_batch', loss.item(), global_step)
        writer.add_scalar('Train/Accuracy_batch', batch_acc, global_step)
        writer.add_scalar('Train/Learning_rate', current_lr, global_step)

    avg_loss = running_loss / total_batches
    avg_acc = running_acc / total_batches

    return avg_loss, avg_acc


# ============================================================
#   Evaluaci√≥n
# ============================================================
@torch.no_grad()
def evaluate(model, dataloader, criterion, device, epoch, split="VAL"):
    model.eval()
    running_loss = 0.0
    running_acc = 0.0
    total_batches = len(dataloader)

    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch:02d} [{split:^5}]", leave=False)

    for videos, labels in progress_bar:
        videos = videos.to(device)
        labels = labels.to(device)

        outputs = model(pixel_values=videos)
        logits = outputs.logits

        loss = criterion(logits, labels)
        batch_acc = calculate_accuracy(logits, labels)

        running_loss += loss.item()
        running_acc += batch_acc

        progress_bar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'acc': f"{batch_acc:.1f}%"
        })

    avg_loss = running_loss / total_batches
    avg_acc = running_acc / total_batches

    return avg_loss, avg_acc


# ============================================================
#   Guardar checkpoint
# ============================================================
def save_checkpoint(
    epoch, model, optimizer, scheduler, train_loss, train_acc,
    val_loss, val_acc, checkpoint_dir, is_best=False
):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_loss': train_loss,
        'train_acc': train_acc,
        'val_loss': val_loss,
        'val_acc': val_acc,
    }

    checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt")
    torch.save(checkpoint, checkpoint_path)
    print(f"[CHECKPOINT] Guardado: {checkpoint_path}")

    if is_best:
        best_path = os.path.join(checkpoint_dir, "best_model.pt")
        torch.save(checkpoint, best_path)
        print(f"[BEST MODEL] Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")


print("‚úÖ Funciones de entrenamiento definidas")

## 4.2 Inicializar Modelo y Optimizer

In [None]:
print("[INFO] Inicializando modelo y componentes de entrenamiento...\n")

# Cargar modelo
device = CONFIG['device']
model = VideoMAEForVideoClassification.from_pretrained(
    CONFIG['model_name'],
    num_labels=CONFIG['num_classes'],
    ignore_mismatched_sizes=True
)
model.to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Modelo: {CONFIG['model_name']}")
print(f"Par√°metros totales: {total_params:,}")
print(f"Par√°metros entrenables: {trainable_params:,}\n")

# Class weights (si est√° activado)
class_weights = None
if CONFIG['class_weighted']:
    print("[INFO] Calculando class weights...")
    train_labels = train_dataset.get_labels()
    class_weights = compute_class_weights(train_labels, CONFIG['num_classes'], device)
    print(f"Class weights (min={class_weights.min():.3f}, max={class_weights.max():.3f})\n")

# Loss function
criterion = nn.CrossEntropyLoss(
    weight=class_weights,
    label_smoothing=CONFIG['label_smoothing']
)

# Optimizer
optimizer = AdamW(
    model.parameters(),
    lr=CONFIG['lr'],
    weight_decay=CONFIG['weight_decay'],
    betas=(0.9, 0.999)
)

# Scheduler
total_steps = len(train_loader) * CONFIG['max_epochs']
warmup_steps = int(total_steps * CONFIG['warmup_ratio'])

scheduler = WarmupCosineScheduler(
    optimizer,
    warmup_steps=warmup_steps,
    total_steps=total_steps,
    min_lr=CONFIG['min_lr']
)

print(f"Optimizer: AdamW (lr={CONFIG['lr']:.2e}, wd={CONFIG['weight_decay']})")
print(f"Scheduler: Warmup + Cosine Decay")
print(f"Total steps: {total_steps}")
print(f"Warmup steps: {warmup_steps}\n")

# TensorBoard writer
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_dir = f"{CONFIG['logs_dir']}/run_{timestamp}"
writer = SummaryWriter(log_dir=log_dir)

print(f"‚úÖ TensorBoard logs: {log_dir}")
print(f"‚úÖ Modelo y componentes inicializados")

## 4.3 Loop de Entrenamiento Principal

In [None]:
print(f"\n{'='*70}")
print(f"{'INICIO DEL ENTRENAMIENTO':^70}")
print(f"{'='*70}\n")

# Variables de control
best_val_loss = float('inf')
best_val_acc = 0.0
epochs_without_improve = 0
training_history = []

# Directorio de checkpoints para este run
run_checkpoint_dir = f"{CONFIG['checkpoint_dir']}/run_{timestamp}"
os.makedirs(run_checkpoint_dir, exist_ok=True)

# Guardar configuraci√≥n del run
run_config_path = f"{run_checkpoint_dir}/config.json"
with open(run_config_path, 'w') as f:
    json.dump(CONFIG, f, indent=2)

try:
    for epoch in range(1, CONFIG['max_epochs'] + 1):
        print(f"\n{'='*70}")
        print(f"EPOCH {epoch}/{CONFIG['max_epochs']}")
        print(f"{'='*70}")

        # Entrenamiento
        train_loss, train_acc = train_one_epoch(
            model=model,
            dataloader=train_loader,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            device=device,
            epoch=epoch,
            writer=writer,
            gradient_clip=CONFIG['gradient_clip']
        )

        # Validaci√≥n
        val_loss, val_acc = evaluate(
            model=model,
            dataloader=val_loader,
            criterion=criterion,
            device=device,
            epoch=epoch,
            split="VAL"
        )

        # Logging
        current_lr = scheduler.get_last_lr()[0]
        print(f"\n{'='*70}")
        print(f"RESULTADOS EPOCH {epoch}")
        print(f"{'='*70}")
        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.2f}%")
        print(f"LR actual:  {current_lr:.2e}")
        print(f"{'='*70}\n")

        # TensorBoard logging por √©poca
        writer.add_scalar('Train/Loss_epoch', train_loss, epoch)
        writer.add_scalar('Train/Accuracy_epoch', train_acc, epoch)
        writer.add_scalar('Val/Loss_epoch', val_loss, epoch)
        writer.add_scalar('Val/Accuracy_epoch', val_acc, epoch)

        # Guardar historial
        training_history.append({
            'epoch': epoch,
            'train_loss': train_loss,
            'train_acc': train_acc,
            'val_loss': val_loss,
            'val_acc': val_acc,
            'lr': current_lr
        })

        # Early stopping y checkpoints
        is_best = val_loss < best_val_loss

        if is_best:
            best_val_loss = val_loss
            best_val_acc = val_acc
            epochs_without_improve = 0
        else:
            epochs_without_improve += 1

        # Guardar checkpoint
        if epoch % CONFIG['save_every'] == 0 or is_best or epoch == CONFIG['max_epochs']:
            save_checkpoint(
                epoch=epoch,
                model=model,
                optimizer=optimizer,
                scheduler=scheduler,
                train_loss=train_loss,
                train_acc=train_acc,
                val_loss=val_loss,
                val_acc=val_acc,
                checkpoint_dir=run_checkpoint_dir,
                is_best=is_best
            )

        # Early stopping
        if epochs_without_improve >= CONFIG['patience']:
            print(f"\n[EARLY STOP] No mejora durante {CONFIG['patience']} epochs")
            print(f"[EARLY STOP] Deteniendo en epoch {epoch}")
            break

except KeyboardInterrupt:
    print("\n[INTERRUPTED] Entrenamiento interrumpido por el usuario")
    # Guardar checkpoint actual
    save_checkpoint(
        epoch=epoch,
        model=model,
        optimizer=optimizer,
        scheduler=scheduler,
        train_loss=train_loss,
        train_acc=train_acc,
        val_loss=val_loss,
        val_acc=val_acc,
        checkpoint_dir=run_checkpoint_dir,
        is_best=False
    )

# Cerrar writer
writer.close()

# Guardar historial de entrenamiento
history_df = pd.DataFrame(training_history)
history_path = f"{run_checkpoint_dir}/training_history.csv"
history_df.to_csv(history_path, index=False)

print(f"\n{'='*70}")
print(f"{'ENTRENAMIENTO COMPLETADO':^70}")
print(f"{'='*70}")
print(f"Mejor Val Loss: {best_val_loss:.4f}")
print(f"Mejor Val Accuracy: {best_val_acc:.2f}%")
print(f"Checkpoints: {run_checkpoint_dir}")
print(f"Logs: {log_dir}")
print(f"Historial: {history_path}")
print(f"{'='*70}\n")

## 4.4 Visualizar Curvas de Entrenamiento

In [None]:
# Leer historial
history_df = pd.read_csv(history_path)

# Crear figura con subplots
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Loss
axes[0].plot(history_df['epoch'], history_df['train_loss'], label='Train Loss', marker='o')
axes[0].plot(history_df['epoch'], history_df['val_loss'], label='Val Loss', marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training & Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(history_df['epoch'], history_df['train_acc'], label='Train Acc', marker='o')
axes[1].plot(history_df['epoch'], history_df['val_acc'], label='Val Acc', marker='s')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Training & Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Learning Rate
axes[2].plot(history_df['epoch'], history_df['lr'], label='Learning Rate', marker='o', color='green')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Learning Rate')
axes[2].set_title('Learning Rate Schedule')
axes[2].set_yscale('log')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()

# Guardar
curves_path = f"{CONFIG['results_dir']}/training_curves_{timestamp}.png"
plt.savefig(curves_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úÖ Curvas guardadas en: {curves_path}")