# Entrenamiento del Modelo TimeSformer para Detección de Violencia Escolar

## Sección 1: Configuración del Entorno

In [None]:
# Instalación de dependencias necesarias
!pip install -q transformers==4.35.0
!pip install -q torch==2.0.1 torchvision==0.15.2
!pip install -q pytorch-lightning==2.0.9
!pip install -q wandb==0.15.12
!pip install -q timm==0.9.2
!pip install -q scikit-learn==1.3.0
!pip install -q matplotlib==3.7.2 seaborn==0.12.2
!pip install -q opencv-python==4.8.0.76
!pip install -q einops==0.6.1
!pip install -q av==10.0.0
!pip install -q torchmetrics==1.0.3
!pip install -q tensorboardX==2.6.2

In [None]:
# Importaciones básicas
import os
import sys
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import time
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from torchmetrics import Accuracy, Precision, Recall, F1Score, ConfusionMatrix, AUROC, Specificity

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

from sklearn.metrics import roc_curve, auc, classification_report, confusion_matrix
from sklearn.model_selection import train_test_split

import cv2
import av
from einops import rearrange

from transformers import (
    TimesformerForVideoClassification,
    TimesformerConfig,
    AutoImageProcessor,
    get_cosine_schedule_with_warmup
)

# Verificar disponibilidad de GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory available: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
# Configurar semillas para reproducibilidad
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['PYTHONHASHSEED'] = str(SEED)

In [None]:
# Configuración de WandB (opcional pero recomendado para seguimiento de experimentos)
import wandb
try:
    wandb.login()
    wandb_available = True
except:
    print("WandB no disponible. Continuando sin tracking...")
    wandb_available = False

## Sección 2: Montar Google Drive y Configurar Rutas

In [None]:
# Montar Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:

# Configurar rutas al dataset
BASE_PATH = "/content/drive/MyDrive/dataset_violencia"
TRAIN_PATH = os.path.join(BASE_PATH, "train")
VAL_PATH = os.path.join(BASE_PATH, "val")
TEST_PATH = os.path.join(BASE_PATH, "test")

# Rutas para guardar los modelos y checkpoints
MODEL_SAVE_PATH = "/content/drive/MyDrive/violence_detection_models"
CHECKPOINT_PATH = os.path.join(MODEL_SAVE_PATH, "checkpoints")
EXPORT_PATH = os.path.join(MODEL_SAVE_PATH, "export")

# Crear directorios si no existen
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)
os.makedirs(CHECKPOINT_PATH, exist_ok=True)
os.makedirs(EXPORT_PATH, exist_ok=True)

# Verificar la estructura del dataset
if not os.path.exists(TRAIN_PATH):
    raise FileNotFoundError(f"El directorio de training {TRAIN_PATH} no existe")

# Verificar las clases
classes = os.listdir(TRAIN_PATH)
num_classes = len(classes)
print(f"Clases detectadas: {classes} (Total: {num_classes})")

# Contar muestras en cada partición
def count_samples(path):
    return sum([len(os.listdir(os.path.join(path, c))) for c in os.listdir(path) if os.path.isdir(os.path.join(path, c))])

train_samples = count_samples(TRAIN_PATH)
val_samples = count_samples(VAL_PATH)
test_samples = count_samples(TEST_PATH)

print(f"Muestras en Training: {train_samples}")
print(f"Muestras en Validation: {val_samples}")
print(f"Muestras en Test: {test_samples}")

In [None]:
# Verificar estructura del dataset
for path in [TRAIN_PATH, VAL_PATH, TEST_PATH]:
    if not os.path.exists(path):
        raise FileNotFoundError(f"El directorio {path} no existe")
    for class_name in os.listdir(path):
        class_dir = os.path.join(path, class_name)
        if not os.path.isdir(class_dir):
            continue
        num_samples = len([f for f in os.listdir(class_dir) if f.endswith(('.mp4', '.avi'))])
        print(f"Clase '{class_name}' en {os.path.basename(path)}: {num_samples} videos")

## Sección 3: Crear Dataset y DataLoader para Videos

In [None]:
class VideoDataset(Dataset):
    """
    Dataset para videos de clasificación de violencia
    """
    def __init__(self, root_dir, processor, num_frames=8, max_duration=6, transform=None,
                 frame_sample_rate=2, clip_start=0, debug=False):
        self.root_dir = root_dir
        self.processor = processor
        self.num_frames = num_frames
        self.max_duration = max_duration
        self.transform = transform
        self.frame_sample_rate = frame_sample_rate
        self.clip_start = clip_start
        self.debug = debug

        self.classes = sorted(os.listdir(root_dir))
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}

        self.samples = []
        for class_name in self.classes:
            class_dir = os.path.join(root_dir, class_name)
            for video_name in os.listdir(class_dir):
                if video_name.endswith(('.mp4', '.avi')):
                    self.samples.append((os.path.join(class_dir, video_name), self.class_to_idx[class_name]))

        if debug:
            # Limitar cantidad de muestras para debug
            self.samples = self.samples[:100]

    def __len__(self):
        return len(self.samples)

    def load_video(self, video_path):
        """
        Carga un video y extrae los frames necesarios
        """
        try:
            container = av.open(video_path)
            indices = self._sample_frame_indices(container)

            video_frames = []
            container.seek(0)
            for i, frame in enumerate(container.decode(video=0)):
                if i in indices:
                    img = frame.to_ndarray(format="rgb")
                    video_frames.append(img)
                if len(video_frames) == self.num_frames:
                    break

            # Si no se obtuvieron suficientes frames, repetir el último
            if len(video_frames) < self.num_frames:
                last_frame = video_frames[-1] if video_frames else np.zeros((224, 224, 3), dtype=np.uint8)
                while len(video_frames) < self.num_frames:
                    video_frames.append(last_frame)

            return video_frames
        except Exception as e:
            print(f"Error cargando video {video_path}: {e}")
            # Crear frames negros como fallback
            return [np.zeros((224, 224, 3), dtype=np.uint8) for _ in range(self.num_frames)]

    def _sample_frame_indices(self, container):
        """
        Muestrea índices de frames uniformemente distribuidos
        """
        # Obtener información más segura sobre el video
        try:
            video_stream = container.streams.video[0]
            total_frames = video_stream.frames

            if total_frames <= 0 or total_frames is None:
                # Estimar frames basado en la duración y fps
                fps = video_stream.average_rate
                duration = video_stream.duration
                if fps and duration:
                    total_frames = int(duration * float(fps) / 1e6)  # duration en microsegundos
        except Exception as e:
            print(f"Error estimando frames: {e}")
            total_frames = 0

        if total_frames <= 0:
            # Fallback: asumir 30 fps y una duración razonable
            total_frames = min(300, int(self.max_duration * 30))

        # Calcular frames a muestrear
        start_idx = min(self.clip_start, max(0, total_frames - self.num_frames * self.frame_sample_rate))

        # Usar np.linspace para asegurar distribución uniforme
        if total_frames < self.num_frames:
            # Si hay menos frames que los requeridos, repetir índices
            indices = np.array(list(range(total_frames)) * (self.num_frames // total_frames + 1))[:self.num_frames]
        else:
            indices = np.linspace(start_idx, min(total_frames - 1, start_idx + (self.num_frames - 1) * self.frame_sample_rate),
                              self.num_frames, dtype=int)

        return indices

    def __getitem__(self, idx):
        video_path, label = self.samples[idx]

        try:
            frames = self.load_video(video_path)

            # Aplicar transformaciones adicionales
            if self.transform:
                frames = [self.transform(frame) for frame in frames]

            # Procesar frames para TimeSformer
            inputs = self.processor(frames, return_tensors="pt")
            # Quitar la dimensión de batch para que el DataLoader la pueda manejar
            pixel_values = inputs.pixel_values.squeeze(0)

            return {
                "pixel_values": pixel_values,
                "labels": torch.tensor(label, dtype=torch.long),
                "video_path": video_path
            }

        except Exception as e:
            print(f"Error procesando video {video_path}: {e}")
            # Manejar error devolviendo un tensor de ceros
            h, w = 224, 224
            dummy_frames = torch.zeros((self.num_frames, 3, h, w))
            return {
                "pixel_values": dummy_frames,
                "labels": torch.tensor(label, dtype=torch.long),
                "video_path": video_path
            }

# Función de colación personalizada para batch processing
def collate_fn(batch):
    pixel_values = torch.stack([item['pixel_values'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    video_paths = [item['video_path'] for item in batch]

    return {
        'pixel_values': pixel_values,
        'labels': labels,
        'video_paths': video_paths
    }

## Sección 4: Configuración del Modelo y Procesador

In [None]:
# Cargar el procesador de imágenes para TimeSformer
# Cargar el procesador de imágenes para TimeSformer
MODEL_CHECKPOINT = "facebook/timesformer-base-finetuned-k400"  # Modelo pre-entrenado en Kinetics-400

try:
    processor = AutoImageProcessor.from_pretrained(MODEL_CHECKPOINT)
    print(f"Procesador cargado correctamente desde {MODEL_CHECKPOINT}")
except Exception as e:
    print(f"Error cargando procesador desde {MODEL_CHECKPOINT}: {e}")
    # Alternativa: usar otro checkpoint
    BACKUP_CHECKPOINT = "facebook/timesformer-base-finetuned-ssv2"
    print(f"Intentando con checkpoint alternativo: {BACKUP_CHECKPOINT}")
    try:
        processor = AutoImageProcessor.from_pretrained(BACKUP_CHECKPOINT)
        MODEL_CHECKPOINT = BACKUP_CHECKPOINT
        print(f"Procesador cargado correctamente desde {MODEL_CHECKPOINT}")
    except Exception as e2:
        raise RuntimeError(f"No se pudo cargar el procesador: {e2}")

# Definir transforms adicionales (aumentación de datos)
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    # La normalización la hace el procesador
])

valid_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    # La normalización la hace el procesador
])

# Crear datasets para cada partición
def create_datasets(num_frames=8, frame_sample_rate=2):
    train_dataset = VideoDataset(
        root_dir=TRAIN_PATH,
        processor=processor,
        num_frames=num_frames,
        transform=train_transform,
        frame_sample_rate=frame_sample_rate
    )

    val_dataset = VideoDataset(
        root_dir=VAL_PATH,
        processor=processor,
        num_frames=num_frames,
        transform=valid_transform,
        frame_sample_rate=frame_sample_rate
    )

    test_dataset = VideoDataset(
        root_dir=TEST_PATH,
        processor=processor,
        num_frames=num_frames,
        transform=valid_transform,
        frame_sample_rate=frame_sample_rate
    )

    return train_dataset, val_dataset, test_dataset

# Definir función para crear dataloaders
def create_dataloaders(train_dataset, val_dataset, test_dataset, batch_size=4, num_workers=2):
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=collate_fn
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=collate_fn
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=collate_fn
    )

    return train_loader, val_loader, test_loader

## Sección 5: Definir el Modelo en PyTorch Lightning

In [None]:
class TimeSformerLightningModule(pl.LightningModule):
    def __init__(
        self,
        num_classes=2,
        num_frames=8,
        learning_rate=2e-5,
        weight_decay=0.05,
        warmup_steps=100,
        total_steps=2000,
        freeze_backbone=True,
        unfreeze_layers=3,
        model_checkpoint="facebook/timesformer-base-finetuned-k400"
    ):
        super().__init__()
        self.save_hyperparameters()

        # Cargar configuración del modelo
        self.config = TimesformerConfig.from_pretrained(
            model_checkpoint,
            num_frames=num_frames,
            num_labels=num_classes
        )

        # Cargar modelo pre-entrenado
        self.model = TimesformerForVideoClassification.from_pretrained(
            model_checkpoint,
            config=self.config,
            ignore_mismatched_sizes=True  # Permitir cambiar num_classes
        )

        # Configurar congelamiento de capas para transfer learning
        # Configurar congelamiento de capas para transfer learning
        if freeze_backbone:
            # Congelar todos los parámetros del modelo base
            for param in self.model.parameters():
                param.requires_grad = False

            # Descongelar la cabeza de clasificación
            for param in self.model.classifier.parameters():
                param.requires_grad = True

            # Descongelar las últimas N capas del transformer (para fine-tuning parcial)
            if unfreeze_layers > 0:
                num_encoder_layers = len(list(self.model.timesformer.encoder.layer))
                print(f"Número total de capas encoder: {num_encoder_layers}")

                for i in range(1, min(unfreeze_layers + 1, num_encoder_layers + 1)):
                    layer_idx = num_encoder_layers - i
                    layer = self.model.timesformer.encoder.layer[layer_idx]
                    print(f"Descongelando capa encoder: {layer_idx}")
                    for param in layer.parameters():
                        param.requires_grad = True

        # Inicializar métricas
        self.accuracy = Accuracy(task="multiclass", num_classes=num_classes)
        self.precision_metric = Precision(task="multiclass", num_classes=num_classes, average="macro")
        self.recall_metric = Recall(task="multiclass", num_classes=num_classes, average="macro")
        self.f1_metric = F1Score(task="multiclass", num_classes=num_classes, average="macro")
        self.specificity_metric = Specificity(task="multiclass", num_classes=num_classes, average="macro")
        self.auroc = AUROC(task="multiclass", num_classes=num_classes)
        self.confusion_matrix = ConfusionMatrix(task="multiclass", num_classes=num_classes)

        # Almacenar predicciones para métricas finales
        self.test_preds = []
        self.test_targets = []
        self.test_probs = []

    def forward(self, pixel_values):
        return self.model(pixel_values=pixel_values)

    def training_step(self, batch, batch_idx):
        pixel_values = batch["pixel_values"]
        labels = batch["labels"]

        outputs = self.model(pixel_values=pixel_values, labels=labels)
        loss = outputs.loss
        logits = outputs.logits

        # Calcular métricas
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, labels)
        prec = self.precision_metric(preds, labels)
        rec = self.recall_metric(preds, labels)
        f1 = self.f1_metric(preds, labels)

        # Registrar métricas en el logger
        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
        self.log("train_acc", acc, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
        self.log("train_precision", prec, on_step=False, on_epoch=True, sync_dist=True)
        self.log("train_recall", rec, on_step=False, on_epoch=True, sync_dist=True)
        self.log("train_f1", f1, on_step=False, on_epoch=True, sync_dist=True)

        return loss

    def validation_step(self, batch, batch_idx):
        pixel_values = batch["pixel_values"]
        labels = batch["labels"]

        outputs = self.model(pixel_values=pixel_values, labels=labels)
        loss = outputs.loss
        logits = outputs.logits

        # Calcular métricas
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, labels)
        prec = self.precision_metric(preds, labels)
        rec = self.recall_metric(preds, labels)
        f1 = self.f1_metric(preds, labels)
        spec = self.specificity_metric(preds, labels)

        # Registrar métricas en el logger
        self.log("val_loss", loss, prog_bar=True, sync_dist=True)
        self.log("val_acc", acc, prog_bar=True, sync_dist=True)
        self.log("val_precision", prec, sync_dist=True)
        self.log("val_recall", rec, sync_dist=True)
        self.log("val_f1", f1, sync_dist=True)
        self.log("val_specificity", spec, sync_dist=True)

        return loss

    def test_step(self, batch, batch_idx):
        pixel_values = batch["pixel_values"]
        labels = batch["labels"]

        outputs = self.model(pixel_values=pixel_values)
        logits = outputs.logits

        # Guardar predicciones, targets y probabilidades para análisis posterior
        preds = torch.argmax(logits, dim=1)
        probs = F.softmax(logits, dim=1)

        self.test_preds.append(preds.cpu())
        self.test_targets.append(labels.cpu())
        self.test_probs.append(probs.cpu())

        # Calcular métricas
        acc = self.accuracy(preds, labels)
        self.log("test_acc", acc, sync_dist=True)

        return {"test_loss": 0.0, "preds": preds, "targets": labels}

    def on_test_epoch_end(self):
        # Concatenar todas las predicciones y targets
        preds = torch.cat(self.test_preds, dim=0).numpy()
        targets = torch.cat(self.test_targets, dim=0).numpy()
        probs = torch.cat(self.test_probs, dim=0).numpy()

        # Calcular y mostrar todas las métricas
        report = classification_report(targets, preds, target_names=self.trainer.datamodule.classes)
        conf_matrix = confusion_matrix(targets, preds)

        # Calcular métricas específicas
        tn, fp, fn, tp = conf_matrix.ravel() if conf_matrix.shape == (2, 2) else (0, 0, 0, 0)
        accuracy = (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

        # True Positive Rate (TPR) = Recall
        tpr = recall
        # False Positive Rate (FPR)
        fpr = fp / (fp + tn) if (fp + tn) > 0 else 0

        # Guardar resultados en el logger
        self.logger.experiment.add_text("Classification Report", report)

        # Imprimir resultados
        print(f"\nClassification Report:\n{report}")
        print(f"\nConfusion Matrix:\n{conf_matrix}")
        print(f"\nMétricas clave:")
        print(f"Accuracy: {accuracy:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall (TPR): {recall:.4f}")
        print(f"Specificity: {specificity:.4f}")
        print(f"F1-Score: {f1:.4f}")
        print(f"False Positive Rate: {fpr:.4f}")

        # Limpiar acumuladores para futuras evaluaciones
        self.test_preds.clear()
        self.test_targets.clear()
        self.test_probs.clear()

    def configure_optimizers(self):
        # Crear un optimizador con weight decay que no afecte a bias y normalization
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.model.named_parameters()
                          if not any(nd in n for nd in no_decay) and p.requires_grad],
                "weight_decay": self.hparams.weight_decay,
            },
            {
                "params": [p for n, p in self.model.named_parameters()
                          if any(nd in n for nd in no_decay) and p.requires_grad],
                "weight_decay": 0.0,
            },
        ]

        optimizer = optim.AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate)

        # Configurar scheduler
        scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.hparams.warmup_steps,
            num_training_steps=self.hparams.total_steps
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
            },
        }

## Sección 6: Configurar LightningDataModule

In [None]:
class VideoDataModule(pl.LightningDataModule):
    def __init__(
        self,
        train_path,
        val_path,
        test_path,
        processor,
        num_frames=8,
        frame_sample_rate=2,
        batch_size=4,
        num_workers=2
    ):
        super().__init__()
        self.train_path = train_path
        self.val_path = val_path
        self.test_path = test_path
        self.processor = processor
        self.num_frames = num_frames
        self.frame_sample_rate = frame_sample_rate
        self.batch_size = batch_size
        self.num_workers = num_workers

        # Definir clases
        self.classes = sorted(os.listdir(train_path))

    def setup(self, stage=None):
        # Definir datasets en diferentes etapas
        if stage == 'fit' or stage is None:
            self.train_dataset = VideoDataset(
                root_dir=self.train_path,
                processor=self.processor,
                num_frames=self.num_frames,
                transform=train_transform,
                frame_sample_rate=self.frame_sample_rate
            )

            self.val_dataset = VideoDataset(
                root_dir=self.val_path,
                processor=self.processor,
                num_frames=self.num_frames,
                transform=valid_transform,
                frame_sample_rate=self.frame_sample_rate
            )

        if stage == 'test' or stage is None:
            self.test_dataset = VideoDataset(
                root_dir=self.test_path,
                processor=self.processor,
                num_frames=self.num_frames,
                transform=valid_transform,
                frame_sample_rate=self.frame_sample_rate
            )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=collate_fn
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=collate_fn
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=collate_fn
        )

## Sección 7: Transfer Learning - Primera Fase

In [None]:
# Configurar parámetros del experimento para transfer learning
experiment_name = "violence_detection_transfer_learning"
NUM_FRAMES = 8
BATCH_SIZE = 4
LEARNING_RATE = 2e-5
WEIGHT_DECAY = 0.05
MAX_EPOCHS = 20
NUM_WORKERS = 2  # Ajustar según disponibilidad en Colab

# Calcular steps para scheduler
num_devices = 1  # En Colab normalmente es 1 GPU
data_module = VideoDataModule(
    train_path=TRAIN_PATH,
    val_path=VAL_PATH,
    test_path=TEST_PATH,
    processor=processor,
    num_frames=NUM_FRAMES,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS
)
data_module.setup()

steps_per_epoch = len(data_module.train_dataloader())
total_steps = steps_per_epoch * MAX_EPOCHS
warmup_steps = min(100, int(0.1 * total_steps))

print(f"Steps por epoch: {steps_per_epoch}")
print(f"Total steps: {total_steps}")
print(f"Warmup steps: {warmup_steps}")

# Configurar callbacks
checkpoint_callback = ModelCheckpoint(
    dirpath=CHECKPOINT_PATH,
    filename="timesformer-violence-{epoch:02d}-{val_f1:.4f}",
    monitor="val_f1",
    mode="max",
    save_top_k=3,
    save_last=True,
    verbose=True
)

early_stopping_callback = EarlyStopping(
    monitor="val_f1",
    patience=5,
    mode="max",
    verbose=True
)

lr_monitor = LearningRateMonitor(logging_interval="step")

# Configurar loggers
tb_logger = TensorBoardLogger("lightning_logs", name=experiment_name)

# Configurar WandB logger si está disponible
loggers = [tb_logger]
if wandb_available:
    wandb_logger = WandbLogger(
        project="violence_detection",
        name=experiment_name,
        log_model=True
    )
    loggers.append(wandb_logger)

# Crear modelo con capas congeladas para transfer learning
model = TimeSformerLightningModule(
    num_classes=len(data_module.classes),
    num_frames=NUM_FRAMES,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    warmup_steps=warmup_steps,
    total_steps=total_steps,
    freeze_backbone=True,  # Congelar la mayoría del modelo
    unfreeze_layers=1      # Solo descongelar la última capa transformer
)

# Mostrar capas congeladas y descongeladas
unfrozen_params = [name for name, param in model.named_parameters() if param.requires_grad]
frozen_params = [name for name, param in model.named_parameters() if not param.requires_grad]

print(f"Parámetros descongelados: {len(unfrozen_params)}")
print(f"Algunos parámetros descongelados: {unfrozen_params[:5]}")
print(f"Parámetros congelados: {len(frozen_params)}")
print(f"Algunos parámetros congelados: {frozen_params[:5]}")

# Iniciar entrenamiento
trainer = pl.Trainer(
    max_epochs=MAX_EPOCHS,
    callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor],
    logger=loggers,
    log_every_n_steps=10,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    precision="16-mixed" if torch.cuda.is_available() else "32",  # Usar precisión mixta para ahorrar memoria
)

# Entrenar modelo
trainer.fit(model, datamodule=data_module)

# Evaluar en conjunto de validación
val_results = trainer.validate(model, datamodule=data_module)
print(f"Resultados de validación: {val_results}")

# Guardar el mejor modelo de transfer learning
best_model_path = checkpoint_callback.best_model_path
if best_model_path:
    print(f"Mejor modelo guardado en: {best_model_path}")

    # Cargar el mejor modelo
    best_model = TimeSformerLightningModule.load_from_checkpoint(best_model_path)

    # Guardar el modelo completo para uso posterior
    transfer_model_path = os.path.join(MODEL_SAVE_PATH, "transfer_learning_model")
    os.makedirs(transfer_model_path, exist_ok=True)
    best_model.model.save_pretrained(transfer_model_path)
    processor.save_pretrained(transfer_model_path)

    print(f"Modelo de transfer learning guardado en: {transfer_model_path}")
else:
    print("No se encontró un mejor modelo. Usando el último entrenado.")
    transfer_model_path = os.path.join(MODEL_SAVE_PATH, "transfer_learning_model")
    os.makedirs(transfer_model_path, exist_ok=True)
    model.model.save_pretrained(transfer_model_path)
    processor.save_pretrained(transfer_model_path)

## Sección 8: Evaluación Detallada del Modelo de Transfer Learning

In [None]:
# Evaluar el mejor modelo en el conjunto de prueba
print("Evaluando modelo de transfer learning en el conjunto de prueba...")
test_results = trainer.test(best_model if best_model_path else model, datamodule=data_module)

# Función para generar gráficos de evaluación
def plot_confusion_matrix(cm, classes, title='Matriz de Confusión', cmap=plt.cm.Blues):
    plt.figure(figsize=(8, 8))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    # Añadir valores numéricos
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], 'd'),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('Etiqueta Real')
    plt.xlabel('Etiqueta Predicha')
    return plt

def plot_roc_curve(fpr, tpr, roc_auc, title='Curva ROC'):
    plt.figure(figsize=(8, 8))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Tasa de Falsos Positivos')
    plt.ylabel('Tasa de Verdaderos Positivos')
    plt.title(title)
    plt.legend(loc="lower right")
    return plt

# Evaluar modelo con predicciones detalladas
def evaluate_detailed(model, data_loader, classes):
    model.eval()
    model.to(device)

    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluando"):
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)

            outputs = model.model(pixel_values=pixel_values)
            logits = outputs.logits

            preds = torch.argmax(logits, dim=1)
            probs = F.softmax(logits, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)

    # Calcular métricas
    report = classification_report(all_labels, all_preds, target_names=classes, output_dict=True)
    conf_mat = confusion_matrix(all_labels, all_preds)

    # Para caso binario (violencia / no_violencia)
    if len(classes) == 2:
        # Calcular TPR/FPR para curva ROC
        fpr, tpr, _ = roc_curve(all_labels, all_probs[:, 1])
        roc_auc = auc(fpr, tpr)

        # Extraer métricas específicas
        accuracy = (all_preds == all_labels).mean()

        # Para clase positiva (violencia = 1)
        precision = report['1']['precision'] if '1' in report else report[classes[1]]['precision']
        recall = report['1']['recall'] if '1' in report else report[classes[1]]['recall']
        f1 = report['1']['f1-score'] if '1' in report else report[classes[1]]['f1-score']

        # Calcular especificidad (TNR)
        tn, fp, fn, tp = conf_mat.ravel()
        specificity = tn / (tn + fp)

        # Tasa de falsos positivos
        false_positive_rate = fp / (fp + tn)

        metrics = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'specificity': specificity,
            'f1': f1,
            'tpr': recall,  # TPR = recall
            'fpr': false_positive_rate,
            'roc_auc': roc_auc
        }

        # Generar gráficos
        cm_plot = plot_confusion_matrix(conf_mat, classes, title='Matriz de Confusión - Transfer Learning')
        roc_plot = plot_roc_curve(fpr, tpr, roc_auc, title='Curva ROC - Transfer Learning')

        # Guardar gráficos
        cm_plot.savefig(os.path.join(MODEL_SAVE_PATH, 'transfer_learning_confusion_matrix.png'))
        roc_plot.savefig(os.path.join(MODEL_SAVE_PATH, 'transfer_learning_roc_curve.png'))

        plt.close('all')

        # Mostrar métricas
        print("\n=== Métricas Detalladas - Transfer Learning ===")
        print(f"Accuracy: {accuracy:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall (TPR): {recall:.4f}")
        print(f"Specificity (TNR): {specificity:.4f}")
        print(f"F1-Score: {f1:.4f}")
        print(f"False Positive Rate: {false_positive_rate:.4f}")
        print(f"ROC AUC: {roc_auc:.4f}")

        # Guardar métricas
        with open(os.path.join(MODEL_SAVE_PATH, 'transfer_learning_metrics.json'), 'w') as f:
            json.dump(metrics, f, indent=4)

        return metrics, conf_mat, (fpr, tpr, roc_auc)

    # Para caso multiclase (si se amplía a más categorías)
    else:
        # Calcular accuracy general
        accuracy = (all_preds == all_labels).mean()

        # Calcular macro promedio de métricas
        precision = np.mean([report[cls]['precision'] for cls in classes])
        recall = np.mean([report[cls]['recall'] for cls in classes])
        f1 = np.mean([report[cls]['f1-score'] for cls in classes])

        # Para multiclase, calcular One-vs-Rest ROC AUC
        roc_auc = {}
        for i, cls in enumerate(classes):
            fpr, tpr, _ = roc_curve((all_labels == i).astype(int), all_probs[:, i])
            roc_auc[cls] = auc(fpr, tpr)

        metrics = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'roc_auc': roc_auc
        }

        # Generar gráfico de confusión
        cm_plot = plot_confusion_matrix(conf_mat, classes, title='Matriz de Confusión - Multiclase')
        cm_plot.savefig(os.path.join(MODEL_SAVE_PATH, 'transfer_learning_confusion_matrix.png'))
        plt.close('all')

        # Mostrar métricas
        print("\n=== Métricas Detalladas - Transfer Learning (Multiclase) ===")
        print(f"Accuracy: {accuracy:.4f}")
        print(f"Precision (macro): {precision:.4f}")
        print(f"Recall (macro): {recall:.4f}")
        print(f"F1-Score (macro): {f1:.4f}")
        for cls in classes:
            print(f"ROC AUC ({cls}): {roc_auc[cls]:.4f}")

        # Guardar métricas
        with open(os.path.join(MODEL_SAVE_PATH, 'transfer_learning_metrics.json'), 'w') as f:
            json.dump(metrics, f, indent=4)

        return metrics, conf_mat, roc_auc

# Ejecutar evaluación detallada en el conjunto de prueba
loaded_model = best_model if best_model_path else model
test_metrics, conf_mat, roc_data = evaluate_detailed(
    loaded_model,
    data_module.test_dataloader(),
    data_module.classes
)

## Sección 9: Fine-Tuning - Segunda Fase

In [None]:
# Cargar el modelo pre-entrenado con transfer learning
print("Iniciando fase de fine-tuning...")
pretrained_model_path = transfer_model_path

# Configurar parámetros para fine-tuning
experiment_name = "violence_detection_fine_tuning"
NUM_FRAMES = 8  # Mantener el mismo número de frames
BATCH_SIZE = 2  # Reducir batch size para fine-tuning
LEARNING_RATE = 5e-6  # Reducir el learning rate para fine-tuning
WEIGHT_DECAY = 0.01
MAX_EPOCHS = 15
NUM_WORKERS = 2

# Crear modelo para fine-tuning, descongelando más capas
model_ft = TimeSformerLightningModule(
    num_classes=len(data_module.classes),
    num_frames=NUM_FRAMES,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    warmup_steps=warmup_steps,
    total_steps=total_steps,
    freeze_backbone=False,  # Descongelar todo el modelo
    model_checkpoint=pretrained_model_path
)

# Configurar callbacks para fine-tuning
checkpoint_callback_ft = ModelCheckpoint(
    dirpath=CHECKPOINT_PATH,
    filename="timesformer-violence-ft-{epoch:02d}-{val_f1:.4f}",
    monitor="val_f1",
    mode="max",
    save_top_k=3,
    save_last=True,
    verbose=True
)

early_stopping_callback_ft = EarlyStopping(
    monitor="val_f1",
    patience=7,
    mode="max",
    verbose=True
)

lr_monitor_ft = LearningRateMonitor(logging_interval="step")

# Configurar loggers para fine-tuning
tb_logger_ft = TensorBoardLogger("lightning_logs", name=experiment_name)

loggers_ft = [tb_logger_ft]
if wandb_available:
    wandb_logger_ft = WandbLogger(
        project="violence_detection",
        name=experiment_name,
        log_model=True
    )
    loggers_ft.append(wandb_logger_ft)

# Iniciar entrenamiento de fine-tuning
trainer_ft = pl.Trainer(
    max_epochs=MAX_EPOCHS,
    callbacks=[checkpoint_callback_ft, early_stopping_callback_ft, lr_monitor_ft],
    logger=loggers_ft,
    log_every_n_steps=10,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    precision="16-mixed" if torch.cuda.is_available() else "32",
)

# Entrenar modelo con fine-tuning
trainer_ft.fit(model_ft, datamodule=data_module)

# Evaluar en conjunto de validación
val_results_ft = trainer_ft.validate(model_ft, datamodule=data_module)
print(f"Resultados de validación con fine-tuning: {val_results_ft}")

# Guardar el mejor modelo de fine-tuning
best_model_path_ft = checkpoint_callback_ft.best_model_path
if best_model_path_ft:
    print(f"Mejor modelo de fine-tuning guardado en: {best_model_path_ft}")

    # Cargar el mejor modelo
    best_model_ft = TimeSformerLightningModule.load_from_checkpoint(best_model_path_ft)

    # Guardar el modelo completo para uso posterior
    fine_tuning_model_path = os.path.join(MODEL_SAVE_PATH, "fine_tuning_model")
    os.makedirs(fine_tuning_model_path, exist_ok=True)
    best_model_ft.model.save_pretrained(fine_tuning_model_path)
    processor.save_pretrained(fine_tuning_model_path)

    print(f"Modelo de fine-tuning guardado en: {fine_tuning_model_path}")
else:
    print("No se encontró un mejor modelo de fine-tuning. Usando el último entrenado.")
    fine_tuning_model_path = os.path.join(MODEL_SAVE_PATH, "fine_tuning_model")
    os.makedirs(fine_tuning_model_path, exist_ok=True)
    model_ft.model.save_pretrained(fine_tuning_model_path)
    processor.save_pretrained(fine_tuning_model_path)

## Sección 10: Evaluación Detallada del Modelo Fine-Tuned

In [None]:
# Evaluar el mejor modelo de fine-tuning en el conjunto de prueba
print("Evaluando modelo de fine-tuning en el conjunto de prueba...")
test_results_ft = trainer_ft.test(best_model_ft if best_model_path_ft else model_ft, datamodule=data_module)

# Ejecutar evaluación detallada en el conjunto de prueba
loaded_model_ft = best_model_ft if best_model_path_ft else model_ft
test_metrics_ft, conf_mat_ft, roc_data_ft = evaluate_detailed(
    loaded_model_ft,
    data_module.test_dataloader(),
    data_module.classes
)

# Comparar métricas entre transfer learning y fine-tuning
print("\n=== Comparación de Modelos ===")
metrics_comparison = {
    "Métrica": ["Accuracy", "Precision", "Recall (TPR)", "Specificity", "F1-Score", "ROC AUC"],
    "Transfer Learning": [
        f"{test_metrics['accuracy']:.4f}",
        f"{test_metrics['precision']:.4f}",
        f"{test_metrics['recall']:.4f}",
        f"{test_metrics['specificity']:.4f}",
        f"{test_metrics['f1']:.4f}",
        f"{test_metrics['roc_auc']:.4f}"
    ],
    "Fine-Tuning": [
        f"{test_metrics_ft['accuracy']:.4f}",
        f"{test_metrics_ft['precision']:.4f}",
        f"{test_metrics_ft['recall']:.4f}",
        f"{test_metrics_ft['specificity']:.4f}",
        f"{test_metrics_ft['f1']:.4f}",
        f"{test_metrics_ft['roc_auc']:.4f}"
    ]
}

# Crear DataFrame y mostrar tabla de comparación
comparison_df = pd.DataFrame(metrics_comparison)
print(comparison_df.to_string(index=False))

# Guardar tabla de comparación
comparison_df.to_csv(os.path.join(MODEL_SAVE_PATH, 'model_comparison.csv'), index=False)

# Generar gráfico de comparación
plt.figure(figsize=(12, 8))
bar_width = 0.35
x = np.arange(len(metrics_comparison["Métrica"]))

# Convertir valores a float para gráfico
tl_values = [float(val) for val in metrics_comparison["Transfer Learning"]]
ft_values = [float(val) for val in metrics_comparison["Fine-Tuning"]]

plt.bar(x - bar_width/2, tl_values, bar_width, label='Transfer Learning')
plt.bar(x + bar_width/2, ft_values, bar_width, label='Fine-Tuning')

plt.xlabel('Métricas')
plt.ylabel('Valores')
plt.title('Comparación de Modelos: Transfer Learning vs Fine-Tuning')
plt.xticks(x, metrics_comparison["Métrica"])
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(MODEL_SAVE_PATH, 'models_comparison.png'))
plt.close()

# Comparar curvas ROC si es un problema binario
if len(data_module.classes) == 2:
    plt.figure(figsize=(10, 8))

    # Extraer datos ROC
    fpr_tl, tpr_tl, roc_auc_tl = roc_data
    fpr_ft, tpr_ft, roc_auc_ft = roc_data_ft

    plt.plot(fpr_tl, tpr_tl, color='blue', lw=2,
             label=f'Transfer Learning (AUC = {roc_auc_tl:.2f})')
    plt.plot(fpr_ft, tpr_ft, color='red', lw=2,
             label=f'Fine-Tuning (AUC = {roc_auc_ft:.2f})')

    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Tasa de Falsos Positivos')
    plt.ylabel('Tasa de Verdaderos Positivos')
    plt.title('Comparación de Curvas ROC')
    plt.legend(loc="lower right")
    plt.savefig(os.path.join(MODEL_SAVE_PATH, 'roc_curves_comparison.png'))
    plt.close()

## Sección 11: Exportación del Modelo Final para Despliegue

In [None]:
# Seleccionar el mejor modelo para exportar (generalmente será el fine-tuned)
if test_metrics_ft['f1'] >= test_metrics['f1']:
    export_model_path = fine_tuning_model_path
    best_final_model = loaded_model_ft
    print("El modelo de fine-tuning tiene mejor rendimiento y será exportado como modelo final.")
else:
    export_model_path = transfer_model_path
    best_final_model = loaded_model
    print("El modelo de transfer learning tiene mejor rendimiento y será exportado como modelo final.")

# Exportar el modelo para despliegue
final_export_path = os.path.join(EXPORT_PATH, "final_model")
os.makedirs(final_export_path, exist_ok=True)

# Guardar el modelo en formato de Hugging Face Transformers
best_final_model.model.save_pretrained(final_export_path)
processor.save_pretrained(final_export_path)

# Guardar configuración adicional para facilitar el despliegue
config = {
    "model_type": "TimesformerForVideoClassification",
    "num_frames": NUM_FRAMES,
    "frame_sample_rate": data_module.frame_sample_rate,
    "classes": data_module.classes,
    "input_size": [224, 224],
    "created_date": time.strftime("%Y-%m-%d %H:%M:%S"),
    "metrics": test_metrics_ft if test_metrics_ft['f1'] >= test_metrics['f1'] else test_metrics
}

with open(os.path.join(final_export_path, "config.json"), "w") as f:
    json.dump(config, f, indent=4)

print(f"Modelo final exportado a: {final_export_path}")

# Exportar también en formato ONNX para inferencia rápida (opcional)
try:
    import onnx
    import onnxruntime
    from transformers.onnx import export

    # Definir ruta para modelo ONNX
    onnx_path = os.path.join(EXPORT_PATH, "onnx_model")
    os.makedirs(onnx_path, exist_ok=True)

    # Exportar a ONNX
    onnx_model_path = os.path.join(onnx_path, "model.onnx")

    # Crear un input dummy
    dummy_inputs = {
        "pixel_values": torch.randn(1, NUM_FRAMES, 3, 224, 224).to(device)
    }

    # Exportar modelo
    torch.onnx.export(
        best_final_model.model,
        (dummy_inputs["pixel_values"],),
        onnx_model_path,
        export_params=True,
        opset_version=12,
        input_names=["pixel_values"],
        output_names=["logits"],
        dynamic_axes={
            "pixel_values": {0: "batch_size"},
            "logits": {0: "batch_size"}
        }
    )

    # Verificar modelo ONNX
    onnx_model = onnx.load(onnx_model_path)
    onnx.checker.check_model(onnx_model)

    # Guardar config con el modelo ONNX
    with open(os.path.join(onnx_path, "config.json"), "w") as f:
        config["onnx"] = True
        json.dump(config, f, indent=4)

    # Guardar el procesador
    processor.save_pretrained(onnx_path)

    print(f"Modelo ONNX exportado a: {onnx_path}")
except Exception as e:
    print(f"No se pudo exportar a ONNX: {e}")
    print("Continuando sin exportación ONNX")

## Sección 12: Prueba del Modelo con Inferencia en Tiempo Real

In [None]:
# Función para procesar un video y hacer inferencia
def predict_violence(video_path, model, processor, num_frames=8, frame_sample_rate=2):
    # Cargar video
    container = av.open(video_path)

    # Estimar total de frames
    total_frames = container.streams.video[0].frames
    if total_frames <= 0:
        # Estimar frames basado en la duración y fps
        fps = container.streams.video[0].average_rate
        total_duration = container.streams.video[0].duration
        if fps:
            total_frames = int(total_duration * float(fps) / 1e6)  # duration en microsegundos

    if total_frames <= 0:
        # Fallback: asumir 30 fps y 5 segundos
        total_frames = 150

    # Muestrear frames
    indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)

    # Extraer frames
    frames = []
    container.seek(0)
    for i, frame in enumerate(container.decode(video=0)):
        if i in indices:
            img = frame.to_ndarray(format="rgb")
            frames.append(img)
        if len(frames) == num_frames:
            break

    # Si no hay suficientes frames, repetir el último
    if len(frames) < num_frames:
        last_frame = frames[-1] if frames else np.zeros((224, 224, 3), dtype=np.uint8)
        while len(frames) < num_frames:
            frames.append(last_frame)

    # Preprocesar frames
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ])

    frames_processed = [transform(frame) for frame in frames]

    # Procesar con el procesador de TimeSformer
    inputs = processor(frames_processed, return_tensors="pt")
    pixel_values = inputs.pixel_values.to(device)

    # Realizar inferencia
    model.eval()
    with torch.no_grad():
        outputs = model.model(pixel_values=pixel_values)
        logits = outputs.logits
        probs = F.softmax(logits, dim=1)
        predicted_class = torch.argmax(logits, dim=1).item()

    return {
        "predicted_class": predicted_class,
        "class_name": data_module.classes[predicted_class],
        "confidence": probs[0][predicted_class].item(),
        "probabilities": {data_module.classes[i]: prob.item() for i, prob in enumerate(probs[0])}
    }

# Probar el modelo con algunos videos del conjunto de prueba
test_videos = []
for class_name in data_module.classes:
    class_dir = os.path.join(TEST_PATH, class_name)
    videos = os.listdir(class_dir)[:3]  # Tomar 3 videos de cada clase
    test_videos.extend([(os.path.join(class_dir, video), class_name) for video in videos])

# Mostrar resultados de las predicciones
print("\n=== Prueba de Inferencia con Videos ===")
print(f"Probando {len(test_videos)} videos...")

results = []
for video_path, true_class in test_videos:
    prediction = predict_violence(video_path, best_final_model, processor, NUM_FRAMES)

    # Mostrar resultado
    print(f"\nVideo: {os.path.basename(video_path)}")
    print(f"Clase real: {true_class}")
    print(f"Predicción: {prediction['class_name']} (confianza: {prediction['confidence']:.4f})")
    print(f"Probabilidades: {prediction['probabilities']}")

    results.append({
        "video": os.path.basename(video_path),
        "true_class": true_class,
        "predicted_class": prediction['class_name'],
        "confidence": prediction['confidence'],
        "correct": true_class == prediction['class_name']
    })

# Calcular precisión en este conjunto de prueba pequeño
accuracy = sum(r["correct"] for r in results) / len(results)
print(f"\nPrecisión en la prueba de inferencia: {accuracy:.4f}")

# Guardar resultados
with open(os.path.join(MODEL_SAVE_PATH, "inference_test_results.json"), "w") as f:
    json.dump(results, f, indent=4)

## Sección 13: Visualización de la Activación del Modelo (Interpretabilidad)

In [None]:
# Implementar visualización de atención (Grad-CAM) para TimeSformer
def visualize_attention(video_path, model, processor, num_frames=8, class_idx=None):
    # Cargar y procesar video
    prediction = predict_violence(video_path, model, processor, num_frames)
    pred_class = prediction["predicted_class"] if class_idx is None else class_idx

    # Cargar frames originales para visualización
    container = av.open(video_path)
    indices = np.linspace(0, container.streams.video[0].frames - 1, num_frames, dtype=int)

    frames = []
    container.seek(0)
    for i, frame in enumerate(container.decode(video=0)):
        if i in indices:
            img = frame.to_ndarray(format="rgb")
            frames.append(img)
        if len(frames) == num_frames:
            break

    # Extraer atención de la última capa
    # Nota: Esto es una simplificación, la visualización real de atención es más compleja
    # y específica para cada arquitectura

    # Como alternativa, mostramos frames con probabilidades
    plt.figure(figsize=(20, 10))
    for i, frame in enumerate(frames):
        plt.subplot(2, 4, i+1)
        plt.imshow(frame)
        plt.title(f"Frame {i}")
        plt.axis('off')

    plt.suptitle(f"Predicción: {prediction['class_name']} (Confianza: {prediction['confidence']:.4f})",
                 fontsize=16)
    plt.tight_layout()

    # Guardar visualización
    viz_dir = os.path.join(MODEL_SAVE_PATH, "visualizations")
    os.makedirs(viz_dir, exist_ok=True)
    plt.savefig(os.path.join(viz_dir, f"frames_{os.path.basename(video_path)}.png"))
    plt.close()

    return frames, prediction

# Visualizar algunos ejemplos
for video_path, true_class in test_videos[:2]:  # Solo 2 videos para ejemplo
    visualize_attention(video_path, best_final_model, processor, NUM_FRAMES)
    print(f"Visualización guardada para {os.path.basename(video_path)}")

## Sección 14: Resumen y Conclusiones

In [None]:
# Generar resumen del experimento
summary = {
    "experiment_date": time.strftime("%Y-%m-%d %H:%M:%S"),
    "dataset": {
        "train_samples": train_samples,
        "val_samples": val_samples,
        "test_samples": test_samples,
        "classes": data_module.classes
    },
    "transfer_learning": {
        "base_model": MODEL_CHECKPOINT,
        "learning_rate": LEARNING_RATE,
        "batch_size": BATCH_SIZE,
        "num_frames": NUM_FRAMES,
        "epochs": MAX_EPOCHS,
        "metrics": test_metrics
    },
    "fine_tuning": {
        "base_model": "transfer_learning_model",
        "learning_rate": LEARNING_RATE,
        "batch_size": BATCH_SIZE,
        "num_frames": NUM_FRAMES,
        "epochs": MAX_EPOCHS,
        "metrics": test_metrics_ft
    },
    "final_model": {
        "path": final_export_path,
        "type": "fine_tuning" if test_metrics_ft['f1'] >= test_metrics['f1'] else "transfer_learning",
        "metrics": test_metrics_ft if test_metrics_ft['f1'] >= test_metrics['f1'] else test_metrics
    }
}

# Guardar resumen
with open(os.path.join(MODEL_SAVE_PATH, "experiment_summary.json"), "w") as f:
    json.dump(summary, f, indent=4)

# Impresión de resumen
print("\n===== RESUMEN DEL EXPERIMENTO =====")
print(f"Fecha: {summary['experiment_date']}")
print("\nDataset:")
print(f"- Muestras entrenamiento: {summary['dataset']['train_samples']}")
print(f"- Muestras validación: {summary['dataset']['val_samples']}")
print(f"- Muestras test: {summary['dataset']['test_samples']}")
print(f"- Clases: {summary['dataset']['classes']}")

print("\nTransfer Learning:")
print(f"- Modelo base: {summary['transfer_learning']['base_model']}")
print(f"- Frames por clip: {summary['transfer_learning']['num_frames']}")
print(f"- Métricas principales:")
print(f"  - Accuracy: {summary['transfer_learning']['metrics']['accuracy']:.4f}")
print(f"  - F1-Score: {summary['transfer_learning']['metrics']['f1']:.4f}")
print(f"  - ROC AUC: {summary['transfer_learning']['metrics']['roc_auc']:.4f}")

print("\nFine-Tuning:")
print(f"- Modelo base: {summary['fine_tuning']['base_model']}")
print(f"- Métricas principales:")
print(f"  - Accuracy: {summary['fine_tuning']['metrics']['accuracy']:.4f}")
print(f"  - F1-Score: {summary['fine_tuning']['metrics']['f1']:.4f}")
print(f"  - ROC AUC: {summary['fine_tuning']['metrics']['roc_auc']:.4f}")

print("\nModelo Final:")
print(f"- Tipo: {summary['final_model']['type']}")
print(f"- Ruta: {summary['final_model']['path']}")
print(f"- Accuracy: {summary['final_model']['metrics']['accuracy']:.4f}")
print(f"- F1-Score: {summary['final_model']['metrics']['f1']:.4f}")
print(f"- ROC AUC: {summary['final_model']['metrics']['roc_auc']:.4f}")

print("\n===== CONCLUSIONES =====")
# Determinar modelo con mejor rendimiento
if test_metrics_ft['f1'] > test_metrics['f1']:
    improvement = (test_metrics_ft['f1'] - test_metrics['f1']) / test_metrics['f1'] * 100
    print(f"El fine-tuning mejoró el rendimiento del modelo en un {improvement:.2f}% en términos de F1-Score.")
    best_approach = "fine-tuning"
elif test_metrics_ft['f1'] < test_metrics['f1']:
    decline = (test_metrics['f1'] - test_metrics_ft['f1']) / test_metrics['f1'] * 100
    print(f"El fine-tuning no mejoró el rendimiento y disminuyó un {decline:.2f}% en términos de F1-Score.")
    print("Esto sugiere posible sobreajuste durante el fine-tuning.")
    best_approach = "transfer learning"
else:
    print("No hubo diferencia significativa entre transfer learning y fine-tuning.")
    best_approach = "ambos enfoques"

# Análisis de métricas críticas para detección de violencia
if best_approach == "fine-tuning":
    recall = test_metrics_ft['recall']
    fpr = test_metrics_ft['fpr']
elif best_approach == "transfer learning":
    recall = test_metrics['recall']
    fpr = test_metrics['fpr']
else:
    recall = max(test_metrics['recall'], test_metrics_ft['recall'])
    fpr = min(test_metrics['fpr'], test_metrics_ft['fpr'])

print(f"\nPara un sistema de detección de violencia escolar, es crucial:")
print(f"- Alta sensibilidad/recall: {recall:.4f} (porcentaje de casos de violencia detectados)")
print(f"- Baja tasa de falsos positivos: {fpr:.4f} (porcentaje de falsos alertas)")

# Recomendaciones finales
print("\nRecomendaciones para despliegue:")
print("1. Implementar el modelo exportado con sistema de umbral ajustable para balancear")
print("   sensibilidad vs. falsos positivos según las necesidades específicas del entorno escolar.")
print("2. Considerar procesamiento de video en ventanas deslizantes con solapamiento")
print("   para garantizar detección continua en transmisiones de video en tiempo real.")
print("3. Implementar mecanismo de retroalimentación que permita ajustar el modelo")
print("   con casos específicos del entorno particular donde se despliega.")
print("4. Complementar la detección visual con análisis de audio para mejorar precisión.")

## Sección 15: Código para Integración del Modelo en Sistema de Tiempo Real

In [None]:
# Crear una clase de inferencia que pueda ser utilizada en un sistema en tiempo real
class ViolenceDetector:
    def __init__(self, model_path, device="cuda" if torch.cuda.is_available() else "cpu"):
        """
        Inicializa el detector de violencia con un modelo entrenado

        Args:
            model_path: Ruta al modelo exportado
            device: Dispositivo para inferencia ('cuda' o 'cpu')
        """
        self.device = device
        self.config = None

        # Cargar configuración
        config_path = os.path.join(model_path, "config.json")
        if os.path.exists(config_path):
            with open(config_path, "r") as f:
                self.config = json.load(f)

        # Establecer parámetros
        self.num_frames = self.config.get("num_frames", 8) if self.config else 8
        self.frame_sample_rate = self.config.get("frame_sample_rate", 2) if self.config else 2
        self.classes = self.config.get("classes", ["no_violence", "violence"]) if self.config else ["no_violence", "violence"]
        self.threshold = 0.6  # Umbral de confianza para detectar violencia

        # Cargar modelo y procesador
        from transformers import TimesformerForVideoClassification, AutoImageProcessor

        print(f"Cargando modelo desde {model_path}...")
        self.model = TimesformerForVideoClassification.from_pretrained(model_path)
        self.model.to(device)
        self.model.eval()

        self.processor = AutoImageProcessor.from_pretrained(model_path)

        # Preparar transformaciones
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ])

        print("Detector de violencia inicializado correctamente.")

    def preprocess_frames(self, frames):
        """
        Preprocesa una lista de frames para la inferencia

        Args:
            frames: Lista de arrays numpy RGB

        Returns:
            Tensor procesado listo para el modelo
        """
        # Asegurar que tenemos el número correcto de frames
        if len(frames) < self.num_frames:
            # Repetir último frame si no hay suficientes
            last_frame = frames[-1] if frames else np.zeros((224, 224, 3), dtype=np.uint8)
            frames = frames + [last_frame] * (self.num_frames - len(frames))
        elif len(frames) > self.num_frames:
            # Muestrear frames uniformemente
            indices = np.linspace(0, len(frames) - 1, self.num_frames, dtype=int)
            frames = [frames[i] for i in indices]

        # Aplicar transformaciones
        frames_processed = [self.transform(frame) for frame in frames]

        # Procesar con el procesador de TimeSformer
        inputs = self.processor(frames_processed, return_tensors="pt")
        return inputs.pixel_values.to(self.device)

    def predict(self, frames):
        """
        Realiza la predicción sobre un conjunto de frames

        Args:
            frames: Lista de arrays numpy RGB

        Returns:
            dict: Resultado de la predicción con etiqueta, confianza y flag de violencia
        """
        # Preprocesar frames
        pixel_values = self.preprocess_frames(frames)

        # Realizar inferencia
        with torch.no_grad():
            outputs = self.model(pixel_values=pixel_values)
            logits = outputs.logits
            probs = F.softmax(logits, dim=1)
            predicted_class = torch.argmax(logits, dim=1).item()

        # Obtener índice de la clase 'violence' (normalmente 1 en un sistema binario)
        violence_idx = self.classes.index("violence") if "violence" in self.classes else 1
        violence_prob = probs[0][violence_idx].item()

        # Determinar si se detecta violencia según umbral
        is_violence = predicted_class == violence_idx and violence_prob >= self.threshold

        return {
            "is_violence": is_violence,
            "class_name": self.classes[predicted_class],
            "confidence": probs[0][predicted_class].item(),
            "violence_confidence": violence_prob,
            "probabilities": {self.classes[i]: prob.item() for i, prob in enumerate(probs[0])}
        }

    def process_video_file(self, video_path):
        """
        Procesa un archivo de video completo

        Args:
            video_path: Ruta al archivo de video

        Returns:
            dict: Resultado de la predicción
        """
        # Cargar video
        container = av.open(video_path)

        # Estimar total de frames
        total_frames = container.streams.video[0].frames
        if total_frames <= 0:
            # Estimar frames basado en duración y fps
            fps = container.streams.video[0].average_rate
            total_duration = container.streams.video[0].duration
            if fps:
                total_frames = int(total_duration * float(fps) / 1e6)  # duration en microsegundos

        if total_frames <= 0:
            # Fallback: asumir 30 fps y 5 segundos
            total_frames = 150

        # Muestrear frames
        indices = np.linspace(0, total_frames - 1, self.num_frames, dtype=int)

        # Extraer frames
        frames = []
        container.seek(0)
        for i, frame in enumerate(container.decode(video=0)):
            if i in indices:
                img = frame.to_ndarray(format="rgb")
                frames.append(img)
            if len(frames) == self.num_frames:
                break

        # Realizar predicción
        return self.predict(frames)

    def process_video_stream(self, video_capture, callback=None, window_size=30, stride=15):
        """
        Procesa un stream de video en tiempo real con ventanas deslizantes

        Args:
            video_capture: Objeto de captura de OpenCV
            callback: Función a llamar cuando se detecta violencia
            window_size: Tamaño de la ventana en frames
            stride: Avance de la ventana en frames
        """
        frame_buffer = []
        frame_count = 0

        print("Iniciando procesamiento de stream de video...")
        print(f"Presiona 'q' para salir")

        try:
            while video_capture.isOpened():
                ret, frame = video_capture.read()
                if not ret:
                    break

                # Convertir BGR a RGB
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame_buffer.append(frame_rgb)

                # Mantener buffer con tamaño máximo de window_size
                if len(frame_buffer) > window_size:
                    frame_buffer.pop(0)

                # Procesar cuando tenemos suficientes frames y estamos en stride
                if len(frame_buffer) == window_size and frame_count % stride == 0:
                    result = self.predict(frame_buffer)

                    # Mostrar resultado en la imagen
                    label = f"{result['class_name']}: {result['confidence']:.2f}"
                    color = (0, 0, 255) if result['is_violence'] else (0, 255, 0)
                    cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)

                    # Llamar callback si hay violencia
                    if result['is_violence'] and callback:
                        callback(result, frame_rgb)

                # Mostrar frame
                cv2.imshow('Video', frame)
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break

                frame_count += 1

        finally:
            video_capture.release()
            cv2.destroyAllWindows()

# Ejemplo de cómo usar el detector para integración en un sistema real
print("\n===== CÓDIGO DE INTEGRACIÓN PARA SISTEMA REAL =====")
print("Este código puede ser integrado en una aplicación para detección en tiempo real:")
print("""
# Ejemplo de uso:
detector = ViolenceDetector("ruta/al/modelo/exportado")

# Para procesar un video grabado:
result = detector.process_video_file("ruta/al/video.mp4")
print(f"Violencia detectada: {result['is_violence']}")
print(f"Confianza: {result['confidence']:.4f}")

# Para procesar un stream en tiempo real:
def alert_callback(result, frame):
    print(f"¡ALERTA! Violencia detectada con confianza {result['confidence']:.4f}")
    # Aquí se pueden agregar acciones como:
    # - Enviar notificación
    # - Guardar el clip de video
    # - Activar alarma

# Conectar a cámara (0 para webcam, o URL para cámara IP)
cap = cv2.VideoCapture(0)
detector.process_video_stream(cap, callback=alert_callback)
""")

## Sección 16: Evaluación con Número de Frames Alternativo

In [None]:
# Probar con un número diferente de frames para comparar rendimiento
ALTERNATIVE_NUM_FRAMES = 16  # Probar con más frames

print("\n===== EVALUANDO CON NÚMERO ALTERNATIVO DE FRAMES =====")
print(f"Configurando entrenamiento con {ALTERNATIVE_NUM_FRAMES} frames...")

# Crear datasets con número alternativo de frames
alt_data_module = VideoDataModule(
    train_path=TRAIN_PATH,
    val_path=VAL_PATH,
    test_path=TEST_PATH,
    processor=processor,
    num_frames=ALTERNATIVE_NUM_FRAMES,
    batch_size=2,  # Reducir por mayor consumo de memoria
    num_workers=NUM_WORKERS
)
alt_data_module.setup()

# Configurar entrenamiento más corto para prueba de concepto
alt_experiment_name = f"violence_detection_frames_{ALTERNATIVE_NUM_FRAMES}"
ALT_MAX_EPOCHS = 5

# Crear modelo con número alternativo de frames
alt_model = TimeSformerLightningModule(
    num_classes=len(alt_data_module.classes),
    num_frames=ALTERNATIVE_NUM_FRAMES,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    warmup_steps=warmup_steps,
    total_steps=total_steps,
    freeze_backbone=True,
    unfreeze_layers=1
)

# Configurar callbacks
alt_checkpoint_callback = ModelCheckpoint(
    dirpath=CHECKPOINT_PATH,
    filename=f"timesformer-violence-frames{ALTERNATIVE_NUM_FRAMES}-"+"{epoch:02d}-{val_f1:.4f}",
    monitor="val_f1",
    mode="max",
    save_top_k=1,
    save_last=True,
    verbose=True
)

alt_early_stopping = EarlyStopping(
    monitor="val_f1",
    patience=3,
    mode="max",
    verbose=True
)

# Configurar logger
alt_tb_logger = TensorBoardLogger("lightning_logs", name=alt_experiment_name)

# Iniciar entrenamiento
alt_trainer = pl.Trainer(
    max_epochs=ALT_MAX_EPOCHS,
    callbacks=[alt_checkpoint_callback, alt_early_stopping, lr_monitor],
    logger=alt_tb_logger,
    log_every_n_steps=10,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    precision="16-mixed" if torch.cuda.is_available() else "32",
)

# Entrenar modelo con frames alternativos
print(f"Iniciando entrenamiento con {ALTERNATIVE_NUM_FRAMES} frames...")
alt_trainer.fit(alt_model, datamodule=alt_data_module)

# Evaluar en conjunto de prueba
alt_test_results = alt_trainer.test(alt_model, datamodule=alt_data_module)

# Comparar con resultados anteriores
print("\n=== Comparación por Número de Frames ===")
print(f"Frames: {NUM_FRAMES} vs {ALTERNATIVE_NUM_FRAMES}")

# Código simplificado para evaluación rápida
alt_loaded_model = alt_model
with torch.no_grad():
    all_preds = []
    all_labels = []

    for batch in alt_data_module.test_dataloader():
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        outputs = alt_loaded_model.model(pixel_values=pixel_values)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    alt_accuracy = (all_preds == all_labels).mean()
    alt_f1 = f1_score(all_labels, all_preds, average='macro')

# Comparar métricas
frames_comparison = {
    "Métrica": ["Accuracy", "F1-Score"],
    f"{NUM_FRAMES} frames": [
        f"{test_metrics_ft['accuracy']:.4f}",
        f"{test_metrics_ft['f1']:.4f}"
    ],
    f"{ALTERNATIVE_NUM_FRAMES} frames": [
        f"{alt_accuracy:.4f}",
        f"{alt_f1:.4f}"
    ]
}

# Mostrar comparación
print("\nComparación de rendimiento por número de frames:")
frames_comparison_df = pd.DataFrame(frames_comparison)
print(frames_comparison_df.to_string(index=False))

# Añadir conclusión
print("\nConclusión sobre número de frames:")
if alt_f1 > test_metrics_ft['f1']:
    print(f"Utilizar {ALTERNATIVE_NUM_FRAMES} frames mejora el rendimiento del modelo.")
    print("Considerar este parámetro para entrenamientos futuros completos.")
else:
    print(f"Utilizar {NUM_FRAMES} frames proporciona mejor balance entre rendimiento y eficiencia.")

## Sección 17: Guardado de Documentación Final y Guía de Uso

In [None]:
# Generar documentación final del modelo
documentation = f"""
# Modelo de Detección de Violencia Escolar

Este documento describe el modelo de detección de violencia escolar entrenado con TimeSformer.

## Información General

- **Fecha de entrenamiento**: {time.strftime("%Y-%m-%d")}
- **Modelo base**: TimeSformer
- **Arquitectura**: {MODEL_CHECKPOINT}
- **Número de frames**: {NUM_FRAMES}
- **Resolución**: 224x224
- **Clases**: {data_module.classes}
- **Autor**: Franz Reinaldo Gonzales Suyo

## Métricas de Rendimiento

- **Accuracy**: {test_metrics_ft['accuracy']:.4f}
- **Precision**: {test_metrics_ft['precision']:.4f}
- **Recall (Sensibilidad)**: {test_metrics_ft['recall']:.4f}
- **Specificity**: {test_metrics_ft['specificity']:.4f}
- **F1-Score**: {test_metrics_ft['f1']:.4f}
- **ROC AUC**: {test_metrics_ft['roc_auc']:.4f}

## Uso del Modelo

1. **Carga del modelo**:
```python
from transformers import TimesformerForVideoClassification, AutoImageProcessor

model_path = "ruta/al/modelo/exportado"
model = TimesformerForVideoClassification.from_pretrained(model_path)
processor = AutoImageProcessor.from_pretrained(model_path)

## Procesamiento de video:

In [None]:
import av
import torch
import numpy as np
from torchvision import transforms

# Cargar video
container = av.open("video.mp4")

# Extraer frames
num_frames = {NUM_FRAMES}
indices = np.linspace(0, container.streams.video[0].frames - 1, num_frames, dtype=int)
frames = []
for i, frame in enumerate(container.decode(video=0)):
    if i in indices:
        img = frame.to_ndarray(format="rgb")
        frames.append(img)
    if len(frames) == num_frames:
        break

# Preprocesar frames
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])
frames_processed = [transform(frame) for frame in frames]

# Procesar con el procesador de TimeSformer
inputs = processor(frames_processed, return_tensors="pt")

## Inferencia:

In [None]:
import torch.nn.functional as F

with torch.no_grad():
    outputs = model(pixel_values=inputs.pixel_values)
    logits = outputs.logits
    probs = F.softmax(logits, dim=1)
    predicted_class = torch.argmax(logits, dim=1).item()

classes = {data_module.classes}
print(f"Clase predicha: {classes[predicted_class]}")
print(f"Confianza: {probs[0][predicted_class].item():.4f}")