Is made for google collab, change directories and eliminate torch.devices in local

### Imports

In [None]:
import os
import cv2
import json
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.io as io

### Settings

In [None]:
# Configuración para resultados reproducibles
SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.backends.cudnn.deterministic = True

# Verificar si hay GPU disponible
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
#print(f"Usando: {device}")

In [None]:
# Montar Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Rutas del proyecto - AJUSTAR SEGÚN TU ESTRUCTURA
BASE_PATH = "/content/drive/MyDrive/slowfast/videos"
VIDEOS_PATH = BASE_PATH  # Si tus videos están en otra ubicación, ajusta esta ruta
METADATA_PATH = "/content/drive/MyDrive/slowfast/metadata.json"

### Data Preparation

In [None]:
# Cargar metadatos
def load_metadata(metadata_path):
    try:
        with open(metadata_path, 'r') as f:
            metadata = json.load(f)
        
        # Convertir a DataFrame para facilitar el procesamiento
        df = pd.DataFrame(metadata)
        
        # Si no hay una columna "output_path", crear una basada en output_filename
        if 'output_path' not in df.columns and 'output_filename' in df.columns:
            # Extraer directorio del primer elemento para ver el patrón
            if len(df) > 0 and 'stroke_type' in df.columns and 'shot_variant' in df.columns:
                # Crear rutas basadas en la estructura de directorios mencionada
                df['output_path'] = df.apply(
                    lambda row: os.path.join('videos', row['stroke_type'], row['shot_variant'], row['output_filename']), 
                    axis=1
                )
            else:
                # Si no hay suficiente información, usar solo el nombre del archivo
                df['output_path'] = df['output_filename'].apply(lambda x: os.path.join('videos', x))
        
        # Añadir columna para la ruta completa
        df['full_path'] = df['output_path']
        
        return df
    except Exception as e:
        print(f"Error al cargar metadatos: {e}")
        # Crear un DataFrame vacío con las columnas necesarias como fallback
        return pd.DataFrame(columns=['input_video', 'stroke_type', 'shot_variant', 'hand_style', 
                                    'output_filename', 'output_path', 'full_path'])

# Crear mapeos para las etiquetas (convertir de texto a índices)
def create_label_mappings(df):
    # Mapeo para stroke_type (tarea 1)
    stroke_types = sorted(df['stroke_type'].unique())
    stroke_type_to_idx = {stroke: idx for idx, stroke in enumerate(stroke_types)}
    
    # Mapeo para shot_variant (tarea 2)
    shot_variants = sorted(df['shot_variant'].unique())
    shot_variant_to_idx = {variant: idx for idx, variant in enumerate(shot_variants)}
    
    # Mapeo para hand_style (tarea 3)
    hand_styles = sorted(df['hand_style'].unique().tolist())
    
    # Asegurarse de que no_hand_style esté al final
    if 'no_hand_style' in hand_styles:
        hand_styles.remove('no_hand_style')
        hand_styles.append('no_hand_style')  # Mover al final
    
    # Si sólo hay 1 clase (todos 'no_hand_style' por ejemplo)
    # Crear una clase ficticia para evitar errores
    if len(hand_styles) < 2:
        print(f"¡Advertencia! Solo hay {len(hand_styles)} estilo(s) de mano: {hand_styles}")
        print("Agregando una clase ficticia para evitar errores.")
        if 'one' not in hand_styles:
            hand_styles.insert(0, 'one')
        elif 'two' not in hand_styles:
            hand_styles.insert(0, 'two')
    
    hand_style_to_idx = {style: idx for idx, style in enumerate(hand_styles)}
    
    return {
        'stroke_type': stroke_type_to_idx,
        'shot_variant': shot_variant_to_idx,
        'hand_style': hand_style_to_idx
    }

### Dataset

In [None]:
class TennisVideoDataset(Dataset):
    def __init__(self, dataframe, label_mappings, transform=None, clip_len=32, 
                 skip_rate=2, slow_pathway_size=8, fast_pathway_size=32, demo_mode=False):
        self.dataframe = dataframe
        self.label_mappings = label_mappings
        self.transform = transform
        self.clip_len = clip_len
        self.skip_rate = skip_rate
        self.slow_pathway_size = slow_pathway_size
        self.fast_pathway_size = fast_pathway_size
        self.demo_mode = demo_mode
        
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, index):
        # Obtener información del video
        video_info = self.dataframe.iloc[index]
        video_path = os.path.join(BASE_PATH, video_info['full_path'])
        
        # Obtener etiquetas
        # Asegurarse de que existan en los mapeos
        try:
            stroke_type_label = self.label_mappings['stroke_type'][video_info['stroke_type']]
        except KeyError:
            print(f"¡Error! Tipo de golpe desconocido: {video_info['stroke_type']}")
            stroke_type_label = 0  # Usar la primera clase como fallback
            
        try:
            shot_variant_label = self.label_mappings['shot_variant'][video_info['shot_variant']]
        except KeyError:
            print(f"¡Error! Variante de golpe desconocida: {video_info['shot_variant']}")
            shot_variant_label = 0  # Usar la primera clase como fallback
            
        hand_style = video_info['hand_style']
        try:
            hand_style_label = self.label_mappings['hand_style'][hand_style]
        except KeyError:
            print(f"¡Error! Estilo de mano desconocido: {hand_style}")
            hand_style_label = 0  # Usar la primera clase como fallback
            
        # Modo de demostración - generar tensores aleatorios en lugar de cargar videos
        if self.demo_mode or not os.path.exists(video_path):
            # Generar datos sintéticos para demostración
            # Fast pathway - más frames
            fast_pathway = torch.rand(3, self.fast_pathway_size, 112, 112)
            
            # Slow pathway - menos frames
            slow_pathway = torch.rand(3, self.slow_pathway_size, 112, 112)
            
            return {
                'slow_pathway': slow_pathway, 
                'fast_pathway': fast_pathway,
                'stroke_type': stroke_type_label,
                'shot_variant': shot_variant_label, 
                'hand_style': hand_style_label
            }
        
        # Cargar y preprocesar el video real (si existe) usando OpenCV
        try:
            # Abrir el video con OpenCV
            cap = cv2.VideoCapture(video_path)
            if not cap.isOpened():
                raise Exception(f"No se pudo abrir el video: {video_path}")
            
            # Obtener información del video
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            
            if total_frames <= 0:
                raise Exception(f"El video no tiene frames o frames inválidos: {video_path}")
            
            # Seleccionar frames uniformemente distribuidos para cada pathway
            fast_indices = np.linspace(0, total_frames - 1, self.fast_pathway_size, dtype=int)
            slow_indices = np.linspace(0, total_frames - 1, self.slow_pathway_size, dtype=int)
            
            # Leer frames para el fast pathway
            fast_frames = []
            for idx in fast_indices:
                cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
                ret, frame = cap.read()
                if not ret:
                    raise Exception(f"Error leyendo frame {idx} del video {video_path}")
                # Convertir de BGR a RGB
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                # Redimensionar a 112x112
                frame = cv2.resize(frame, (112, 112))
                fast_frames.append(frame)
            
            # Reiniciar para leer frames para el slow pathway
            cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
            slow_frames = []
            for idx in slow_indices:
                cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
                ret, frame = cap.read()
                if not ret:
                    raise Exception(f"Error leyendo frame {idx} del video {video_path}")
                # Convertir de BGR a RGB
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                # Redimensionar a 112x112
                frame = cv2.resize(frame, (112, 112))
                slow_frames.append(frame)
            
            # Cerrar el video
            cap.release()
            
            # Convertir listas a tensores
            fast_pathway = torch.from_numpy(np.array(fast_frames)).permute(3, 0, 1, 2).float() / 255.0
            slow_pathway = torch.from_numpy(np.array(slow_frames)).permute(3, 0, 1, 2).float() / 255.0
            
            # Aplicar transformaciones si hay
            if self.transform:
                # Para normalización, aplicar a cada frame
                for t in range(fast_pathway.shape[1]):
                    fast_pathway[:, t] = self.transform(fast_pathway[:, t])
                for t in range(slow_pathway.shape[1]):
                    slow_pathway[:, t] = self.transform(slow_pathway[:, t])
            
            return {
                'slow_pathway': slow_pathway, 
                'fast_pathway': fast_pathway,
                'stroke_type': stroke_type_label,
                'shot_variant': shot_variant_label, 
                'hand_style': hand_style_label
            }
        
        except Exception as e:
            print(f"Error cargando video {video_path}: {e}")
            # Retornar un tensor aleatorio con las dimensiones correctas
            return {
                'slow_pathway': torch.rand(3, self.slow_pathway_size, 112, 112),
                'fast_pathway': torch.rand(3, self.fast_pathway_size, 112, 112),
                'stroke_type': stroke_type_label,
                'shot_variant': shot_variant_label, 
                'hand_style': hand_style_label
            }

### Model

In [None]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=(1, stride, stride), padding=1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm3d(out_channels)
        
        # Shortcut connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=(1, stride, stride)),
                nn.BatchNorm3d(out_channels)
            )
            
    def forward(self, x):
        residual = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        out += self.shortcut(residual)
        out = self.relu(out)
        
        return out

class SlowFastNetwork(nn.Module):
    def __init__(self, slow_channels=3, fast_channels=3, 
                 num_stroke_classes=4, num_variant_classes=7, num_hand_classes=3):
        super(SlowFastNetwork, self).__init__()
        
        # Configuración del pathway lento (características más espaciales)
        self.slow_pathway = nn.Sequential(
            nn.Conv3d(slow_channels, 64, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3)),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
            
            ResBlock(64, 64),
            ResBlock(64, 128, stride=2),
            ResBlock(128, 256, stride=2),
            ResBlock(256, 512, stride=2),
            
            nn.AdaptiveAvgPool3d((1, 1, 1))
        )
        
        # Configuración del pathway rápido (características más temporales)
        self.fast_pathway = nn.Sequential(
            nn.Conv3d(fast_channels, 32, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3)),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
            
            ResBlock(32, 32),
            ResBlock(32, 64, stride=2),
            ResBlock(64, 128, stride=2),
            ResBlock(128, 256, stride=2),
            
            nn.AdaptiveAvgPool3d((1, 1, 1))
        )
        
        # Fusion y clasificación
        self.fusion = nn.Sequential(
            nn.Linear(512 + 256, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )
        
        # Cabezas multi-tarea
        self.stroke_classifier = nn.Linear(512, num_stroke_classes)
        self.variant_classifier = nn.Linear(512, num_variant_classes)
        self.hand_classifier = nn.Linear(512, num_hand_classes)
        
    def forward(self, slow_input, fast_input):
        # Procesar cada pathway
        slow_features = self.slow_pathway(slow_input)
        fast_features = self.fast_pathway(fast_input)
        
        # Aplanar las características
        slow_features = slow_features.view(slow_features.size(0), -1)
        fast_features = fast_features.view(fast_features.size(0), -1)
        
        # Fusionar características
        fused_features = torch.cat([slow_features, fast_features], dim=1)
        fused_features = self.fusion(fused_features)
        
        # Clasificación multi-tarea
        stroke_preds = self.stroke_classifier(fused_features)
        variant_preds = self.variant_classifier(fused_features)
        hand_preds = self.hand_classifier(fused_features)
        
        return stroke_preds, variant_preds, hand_preds

### Training & Validation

In [None]:
# Función de pérdida multi-tarea
def compute_loss(stroke_preds, variant_preds, hand_preds, 
                stroke_targets, variant_targets, hand_targets):
    # Usar CrossEntropy para cada tarea
    stroke_loss = F.cross_entropy(stroke_preds, stroke_targets)
    variant_loss = F.cross_entropy(variant_preds, variant_targets)
    hand_loss = F.cross_entropy(hand_preds, hand_targets)
    
    # Combinar pérdidas (se pueden ajustar los pesos)
    total_loss = stroke_loss + variant_loss + hand_loss
    
    return {
        'total': total_loss,
        'stroke': stroke_loss.item(),
        'variant': variant_loss.item(),
        'hand': hand_loss.item()
    }

# Evaluar el modelo
def evaluate(model, dataloader, device):
    model.eval()
    
    # Métricas
    total_loss = 0
    stroke_correct = 0
    variant_correct = 0
    hand_correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for batch in dataloader:
            # Mover los datos al dispositivo
            slow_pathway = batch['slow_pathway'].to(device)
            fast_pathway = batch['fast_pathway'].to(device)
            stroke_targets = batch['stroke_type'].to(device)
            variant_targets = batch['shot_variant'].to(device)
            hand_targets = batch['hand_style'].to(device)
            
            # Forward pass
            stroke_preds, variant_preds, hand_preds = model(slow_pathway, fast_pathway)
            
            # Calcular pérdida
            loss_dict = compute_loss(stroke_preds, variant_preds, hand_preds,
                                    stroke_targets, variant_targets, hand_targets)
            total_loss += loss_dict['total'].item()
            
            # Calcular precisión
            _, stroke_pred_idx = torch.max(stroke_preds, dim=1)
            _, variant_pred_idx = torch.max(variant_preds, dim=1)
            _, hand_pred_idx = torch.max(hand_preds, dim=1)
            
            stroke_correct += (stroke_pred_idx == stroke_targets).sum().item()
            variant_correct += (variant_pred_idx == variant_targets).sum().item()
            hand_correct += (hand_pred_idx == hand_targets).sum().item()
            
            total_samples += stroke_targets.size(0)
    
    # Calcular métricas promedio
    avg_loss = total_loss / len(dataloader)
    stroke_acc = stroke_correct / total_samples
    variant_acc = variant_correct / total_samples
    hand_acc = hand_correct / total_samples
    
    return {
        'loss': avg_loss,
        'stroke_acc': stroke_acc,
        'variant_acc': variant_acc,
        'hand_acc': hand_acc,
        'avg_acc': (stroke_acc + variant_acc + hand_acc) / 3
    }

In [None]:
# Función principal de entrenamiento
def train_model(model, train_loader, val_loader, optimizer, device, 
                num_epochs=10, save_path='model_checkpoints'):
    
    # Crear directorio para guardar checkpoints
    os.makedirs(save_path, exist_ok=True)
    
    # Listas para almacenar métricas
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []
    
    # Mejor modelo
    best_val_acc = 0.0
    
    for epoch in range(num_epochs):
        # Entrenamiento
        model.train()
        train_loss = 0.0
        
        # Metrics for training
        train_stroke_correct = 0
        train_variant_correct = 0
        train_hand_correct = 0
        train_total = 0
        
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        
        for batch in progress_bar:
            # Mover los datos al dispositivo
            slow_pathway = batch['slow_pathway'].to(device)
            fast_pathway = batch['fast_pathway'].to(device)
            stroke_targets = batch['stroke_type'].to(device)
            variant_targets = batch['shot_variant'].to(device)
            hand_targets = batch['hand_style'].to(device)
            
            # Reiniciar gradientes
            optimizer.zero_grad()
            
            # Forward pass
            stroke_preds, variant_preds, hand_preds = model(slow_pathway, fast_pathway)
            
            # Calcular pérdida
            loss_dict = compute_loss(stroke_preds, variant_preds, hand_preds,
                                   stroke_targets, variant_targets, hand_targets)
            
            # Backward pass y optimización
            loss_dict['total'].backward()
            optimizer.step()
            
            # Actualizar pérdida total
            train_loss += loss_dict['total'].item()
            
            # Calcular métricas de precisión
            _, stroke_pred_idx = torch.max(stroke_preds, dim=1)
            _, variant_pred_idx = torch.max(variant_preds, dim=1)
            _, hand_pred_idx = torch.max(hand_preds, dim=1)
            
            train_stroke_correct += (stroke_pred_idx == stroke_targets).sum().item()
            train_variant_correct += (variant_pred_idx == variant_targets).sum().item()
            train_hand_correct += (hand_pred_idx == hand_targets).sum().item()
            train_total += stroke_targets.size(0)
            
            # Actualizar la barra de progreso
            progress_bar.set_postfix({
                'loss': loss_dict['total'].item(),
                'stroke_loss': loss_dict['stroke'],
                'variant_loss': loss_dict['variant'],
                'hand_loss': loss_dict['hand']
            })
        
        # Calcular métricas de entrenamiento
        avg_train_loss = train_loss / len(train_loader)
        train_stroke_acc = train_stroke_correct / train_total
        train_variant_acc = train_variant_correct / train_total
        train_hand_acc = train_hand_correct / train_total
        train_avg_acc = (train_stroke_acc + train_variant_acc + train_hand_acc) / 3
        
        train_losses.append(avg_train_loss)
        train_accs.append(train_avg_acc)
        
        # Validación
        val_metrics = evaluate(model, val_loader, device)
        val_losses.append(val_metrics['loss'])
        val_accs.append(val_metrics['avg_acc'])
        
        # Imprimir métricas del epoch
        print(f"Epoch {epoch+1}/{num_epochs} - "
              f"Train Loss: {avg_train_loss:.4f}, Train Acc: {train_avg_acc:.4f}, "
              f"Val Loss: {val_metrics['loss']:.4f}, Val Acc: {val_metrics['avg_acc']:.4f}")
        print(f"  Train - Stroke: {train_stroke_acc:.4f}, Variant: {train_variant_acc:.4f}, Hand: {train_hand_acc:.4f}")
        print(f"  Val - Stroke: {val_metrics['stroke_acc']:.4f}, Variant: {val_metrics['variant_acc']:.4f}, Hand: {val_metrics['hand_acc']:.4f}")
        
        # Guardar el mejor modelo
        if val_metrics['avg_acc'] > best_val_acc:
            best_val_acc = val_metrics['avg_acc']
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_accuracy': val_metrics['avg_acc'],
                'val_loss': val_metrics['loss']
            }, os.path.join(save_path, 'best_model.pth'))
            print(f"  Nuevo mejor modelo guardado con precisión: {best_val_acc:.4f}")
    
    # Guardar el modelo final
    torch.save({
        'epoch': num_epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_accuracy': val_metrics['avg_acc'],
        'val_loss': val_metrics['loss']
    }, os.path.join(save_path, 'final_model.pth'))
    
    # Visualizar curvas de entrenamiento
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss vs. Epoch')
    
    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Train Accuracy')
    plt.plot(val_accs, label='Val Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Accuracy vs. Epoch')
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, 'training_curves.png'))
    plt.show()
    
    return {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accs': train_accs,
        'val_accs': val_accs,
        'best_val_acc': best_val_acc
    }

### Main

In [None]:
def main():
    print("Cargando metadatos...")
    df = load_metadata(METADATA_PATH)
    original_len = len(df)
    print(f"Total de videos en metadatos: {original_len}")

    # SIEMPRE FORZAR MODO DEMO = True para usar menos recursos
    demo_mode = True  # CAMBIAR AQUÍ: Forzar modo demo
    
    if original_len == 0 or demo_mode:  # CAMBIO: Usar datos sintéticos siempre en modo demo
        print("\nCreando datos de ejemplo para demostración...")

        # Crear datos sintéticos para demostración
        example_data = []
        for i in range(10):  # CAMBIO: Reducir a 10 ejemplos (original era 30)
            stroke_type = random.choice(['Forehand', 'Backhand'])  # CAMBIO: Solo usar 2 tipos para minimizar
            
            if stroke_type == 'Forehand':
                variant = 'Topspin'  # CAMBIO: Simplificar a solo una variante
                hand = 'no_hand_style'
            else:  # Backhand
                variant = 'Topspin (1H)'  # CAMBIO: Simplificar a solo una variante
                hand = 'one'
                
            example_data.append({
                'input_video': 'example.mp4',
                'stroke_type': stroke_type,
                'shot_variant': variant,
                'hand_style': hand,
                'output_filename': f'example_{i}.mp4',
            })
        
        df = pd.DataFrame(example_data)
        df['output_path'] = df.apply(
            lambda row: os.path.join('videos', row['stroke_type'], row['shot_variant'], row['output_filename']), 
            axis=1
        )
        df['full_path'] = df['output_path']
        print(f"Datos de ejemplo creados: {len(df)} entradas")
        valid_videos = list(range(len(df)))
    else:
        # Código para verificar videos existentes (no se usará en modo demo)
        valid_videos = []
        print("\nVerificando que los videos existan...")
        # ... [CÓDIGO ORIGINAL DE VERIFICACIÓN] ...
    
    # Filtrar solo los videos que existen o usar todos si son datos de ejemplo
    if len(valid_videos) > 0:
        df = df.loc[valid_videos].reset_index(drop=True)
        print(f"Videos encontrados: {len(df)}/{original_len}")
    else:
        print("\n¡ADVERTENCIA! No se encontró ningún video. Usando modo demo con datos sintéticos.")

    # Creación de mapeos de etiquetas
    print("\nDistribución de Stroke Types:")
    print(df['stroke_type'].value_counts())
    
    print("\nDistribución de Shot Variants:")
    print(df['shot_variant'].value_counts())
    
    print("\nDistribución de Hand Styles:")
    print(df['hand_style'].value_counts())
    
    # Crear mapeos
    label_mappings = create_label_mappings(df)
    print("\nClases:")
    for task, mapping in label_mappings.items():
        print(f"  {task}: {len(mapping)} clases")
        if len(mapping) > 0:
            print(f"  - Ejemplos: {list(mapping.keys())[:min(3, len(mapping))]}")
    
    # Verificar si tenemos suficientes datos para entrenar
    if len(df) == 0 or len(label_mappings['stroke_type']) == 0:
        print("\n¡ERROR! No hay suficientes datos o clases para entrenar.")
        return
    
    # CAMBIO: Simplificar división train/val en modo demo
    # Usar 80% de los datos sintéticos para train y 20% para val
    train_size = int(0.8 * len(df))
    train_df = df.iloc[:train_size].reset_index(drop=True)
    val_df = df.iloc[train_size:].reset_index(drop=True)
    
    print(f"\nConjunto de entrenamiento: {len(train_df)} videos")
    print(f"Conjunto de validación: {len(val_df)} videos")
    
    # CAMBIO: Simplificar transformaciones
    transform = transforms.Compose([
        transforms.Resize((56, 56)),  # CAMBIO: Reducir resolución a la mitad (original 112,112)
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Crear datasets - ELIMINADO el parámetro num_frames que causaba el error
    train_dataset = TennisVideoDataset(
        train_df, 
        label_mappings, 
        transform=transform, 
        demo_mode=True  # CAMBIO: Forzar demo_mode=True
    )
    val_dataset = TennisVideoDataset(
        val_df, 
        label_mappings, 
        transform=transform, 
        demo_mode=True  # CAMBIO: Forzar demo_mode=True
    )
    
    # CAMBIO: Reducir batch_size y workers
    batch_size = 1  # Batch mínimo
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    
    # CAMBIO: Forzar CPU
    device = torch.device("cpu")
    
    # CAMBIO: Reducir complejidad del modelo
    num_stroke_classes = len(label_mappings['stroke_type'])
    num_variant_classes = len(label_mappings['shot_variant'])
    num_hand_classes = len(label_mappings['hand_style'])
    
    print(f"\nConfiguración del modelo:")
    print(f"  Clases de golpe (stroke): {num_stroke_classes}")
    print(f"  Clases de variante (variant): {num_variant_classes}")
    print(f"  Clases de estilo de mano (hand): {num_hand_classes}")
    
    # CAMBIO: Si tu SlowFastNetwork no tiene parámetros para reducir complejidad,
    # usa la versión original
    model = SlowFastNetwork(
        num_stroke_classes=max(2, num_stroke_classes),
        num_variant_classes=max(2, num_variant_classes),
        num_hand_classes=max(2, num_hand_classes)
    ).to(device)
    
    # CAMBIO: Usar learning rate más alto para converger más rápido en demo
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    
    # CAMBIO: Forzar 1 epoch
    num_epochs = 1
    
    print("\nIniciando entrenamiento en modo demo...")
    
    # Versión simplificada para demo
    with torch.no_grad():
        model.eval()
        for batch_idx, batch in enumerate(train_loader):
            try:
                # Usar las claves correctas del batch
                slow_pathway = batch['slow_pathway'].to(device)
                fast_pathway = batch['fast_pathway'].to(device)
                
                # CAMBIO: Verificar tamaños de tensores
                print(f"Batch {batch_idx}:")
                print(f"  slow_pathway shape: {slow_pathway.shape}")
                print(f"  fast_pathway shape: {fast_pathway.shape}")
                
                # Pasar los datos por el modelo
                stroke_preds, variant_preds, hand_preds = model(slow_pathway, fast_pathway)
                
                # Mostrar formas de las salidas
                print("✅ Salida del modelo (modo demo):")
                print(f"  stroke_preds shape: {stroke_preds.shape}")
                print(f"  variant_preds shape: {variant_preds.shape}")
                print(f"  hand_preds shape: {hand_preds.shape}")
                
                # Solo procesar 2 batches y salir
                if batch_idx >= 1:
                    break
            except Exception as e:
                print(f"Error en batch {batch_idx}: {str(e)}")
                print(f"Claves en el batch: {list(batch.keys())}")
                # Si hay error, mostrar las claves disponibles en el batch
                # e intenta usar diferentes nombres si los anteriores no funcionan
                for key in batch:
                    print(f"  {key} shape: {batch[key].shape}")
                break
    
    print("\n✅ Modo demo completado. El modelo se ejecutó correctamente con datos sintéticos.")
    
    # CAMBIO: No guardar ningún modelo o checkpoint
    print("No se guardaron checkpoints en modo demo para ahorrar recursos.")

In [None]:
if __name__ == "__main__":
    main()