# Proyecto: Detección de Violencia Física Escolar en Tiempo Real

In [1]:
!pip install transformers datasets tqdm scikit-learn matplotlib seaborn opencv-python onnx onnxruntime-gpu wandb

Collecting onnx
  Downloading onnx-1.18.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.9 kB)
Collecting onnxruntime-gpu
  Downloading onnxruntime_gpu-1.22.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.9 kB)
Collecting coloredlogs (from onnxruntime-gpu)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime-gpu)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnx-1.18.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.6/17.6 MB[0m [31m111.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading onnxruntime_gpu-1.22.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (283.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m283.2/283.2 MB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading coloredlogs-15.0.1-py

In [4]:
import os
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from google.colab import drive
from torch.utils.data import Dataset, DataLoader
from transformers import AutoImageProcessor, TimesformerForVideoClassification
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc, classification_report
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import logging
from pathlib import Path
import cv2
import onnx
import onnxruntime as ort
import wandb
from datetime import datetime
import warnings
import requests
from retrying import retry
import subprocess
import transformers
warnings.filterwarnings("ignore")

In [3]:
!pip install retrying

Collecting retrying
  Downloading retrying-1.3.4-py3-none-any.whl.metadata (6.9 kB)
Downloading retrying-1.3.4-py3-none-any.whl (11 kB)
Installing collected packages: retrying
Successfully installed retrying-1.3.4


In [5]:
# Configurar logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Configurar semillas para reproducibilidad
torch.manual_seed(42)
np.random.seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [6]:
# --- Sección 1: Configuración Inicial ---
# Montar Google Drive
try:
    drive.mount('/content/drive')
    logger.info("Google Drive montado correctamente")
except Exception as e:
    logger.error(f"Error al montar Google Drive: {e}")
    raise

Mounted at /content/drive


In [7]:
# Configurar dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Usando dispositivo: {device}")

# Verificar versión de transformers
print(f"Versión de transformers: {transformers.__version__}")
if transformers.__version__ < "4.30.0":
    logger.warning("Versión de transformers antigua. Se recomienda actualizar: !pip install --upgrade transformers")

# Definir rutas
DATASET_PATH = "/content/drive/MyDrive/dataset_violencia"
CHECKPOINT_PATH = "/content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints"
Path(CHECKPOINT_PATH).mkdir(parents=True, exist_ok=True)

Usando dispositivo: cuda
Versión de transformers: 4.51.3


In [8]:
# Configuraciones globales
NUM_CLASSES = 2
BATCH_SIZE = 2  # Reducido para estabilidad
ACCUMULATION_STEPS = 2  # Gradient accumulation para simular batch size 4
NUM_EPOCHS_TL = 20
NUM_EPOCHS_FT = 15
LEARNING_RATE_TL = 5e-5
LEARNING_RATE_FT = 2e-5
WEIGHT_DECAY = 1e-4
DROPOUT = 0.3
PATIENCE = 5
NUM_FRAMES = 8
IMG_SIZE = 224
FPS = 15
LABELS = ["no_violence", "violence"]
TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")
MODEL_ID = "facebook/timesformer-base-finetuned-k400"

# gonzalesfranz2019@gmail.com
APY_KEY = '096223375376bd39ddf1dd5e53e6cf1df1f3fa65'

# Inicializar Weights & Biases
wandb.init(project="timesformer_violence_detection", config={
    "batch_size": BATCH_SIZE,
    "learning_rate_tl": LEARNING_RATE_TL,
    "learning_rate_ft": LEARNING_RATE_FT,
    "num_epochs_tl": NUM_EPOCHS_TL,
    "num_epochs_ft": NUM_EPOCHS_FT,
    "weight_decay": WEIGHT_DECAY,
    "dropout": DROPOUT,
    "num_frames": NUM_FRAMES,
    "img_size": IMG_SIZE,
    "model_id": MODEL_ID
})

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mgonzalesfranz2019[0m ([33mgonzalesfranz2019-universidad-mayor-real-y-pontificia-de[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
# --- Sección 2: Validación del Modelo y Procesador ---
@retry(stop_max_attempt_number=3, wait_fixed=2000)
def validate_model_access(model_id):
    """Valida la accesibilidad del modelo y procesador en Hugging Face."""
    try:
        processor = AutoImageProcessor.from_pretrained(model_id)
        model = TimesformerForVideoClassification.from_pretrained(model_id)
        print(f"Modelo {model_id} accesible correctamente")
        return processor, model
    except Exception as e:
        logger.error(f"Error al acceder al modelo {model_id}: {e}")
        print("Si el repositorio es privado, ejecuta: !huggingface-cli login")
        raise

try:
    processor, _ = validate_model_access(MODEL_ID)
except Exception as e:
    logger.error("Fallo al cargar el modelo. Verifica el identificador o tu conexión.")
    raise

preprocessor_config.json:   0%|          | 0.00/412 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


config.json:   0%|          | 0.00/22.7k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/486M [00:00<?, ?B/s]

Modelo facebook/timesformer-base-finetuned-k400 accesible correctamente


In [10]:
# --- Sección 2: Validación del Modelo y Procesador ---
@retry(stop_max_attempt_number=3, wait_fixed=2000)
def validate_model_access(model_id):
    """Valida la accesibilidad del modelo y procesador en Hugging Face."""
    try:
        processor = AutoImageProcessor.from_pretrained(model_id)
        model = TimesformerForVideoClassification.from_pretrained(model_id)
        print(f"Modelo {model_id} accesible correctamente")
        return processor, model
    except Exception as e:
        logger.error(f"Error al acceder al modelo {model_id}: {e}")
        print("Si el repositorio es privado, ejecuta: !huggingface-cli login")
        raise

try:
    processor, _ = validate_model_access(MODEL_ID)
except Exception as e:
    logger.error("Fallo al cargar el modelo. Verifica el identificador o tu conexión.")
    raise

# --- Sección 3: Monitoreo de GPU ---
def log_gpu_memory():
    """Registra el uso de memoria de la GPU."""
    try:
        result = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used,memory.total', '--format=csv'])
        print(f"Uso de GPU:\n{result.decode('utf-8')}")
    except Exception as e:
        logger.warning(f"No se pudo obtener información de GPU: {e}")

log_gpu_memory()

# --- Sección 4: Validación del Dataset ---
def validate_dataset(root_dir):
    """Valida la integridad del dataset y registra estadísticas detalladas."""
    splits = ["train", "val", "test"]
    total_videos = 0
    class_counts = {label: 0 for label in LABELS}
    problematic_videos = []

    for split in splits:
        for label in LABELS:
            label_dir = os.path.join(root_dir, split, label)
            if not os.path.exists(label_dir):
                logger.error(f"Directorio no encontrado: {label_dir}")
                raise FileNotFoundError(f"Directorio {label_dir} no existe")
            videos = [f for f in os.listdir(label_dir) if f.endswith(".mp4")]
            print(f"{split}/{label}: {len(videos)} videos")
            total_videos += len(videos)
            class_counts[label] += len(videos)

            for video in videos:
                video_path = os.path.join(label_dir, video)
                cap = cv2.VideoCapture(video_path)
                if not cap.isOpened():
                    logger.error(f"Video corrupto: {video_path}")
                    problematic_videos.append(video_path)
                    continue
                frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
                fps = cap.get(cv2.CAP_PROP_FPS) or FPS
                width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                if frame_count < 1:
                    logger.error(f"Video sin frames: {video_path}")
                    problematic_videos.append(video_path)
                if fps <= 0:
                    logger.warning(f"FPS inválido en {video_path}, usando FPS por defecto: {FPS}")
                if width == 0 or height == 0:
                    logger.error(f"Resolución inválida en {video_path}")
                    problematic_videos.append(video_path)
                logger.debug(f"Video: {video_path}, Frames: {frame_count}, FPS: {fps}, Resolución: {width}x{height}")
                cap.release()

    print(f"Total de videos validados: {total_videos}")
    print(f"Distribución de clases: {class_counts}")
    print(f"Videos problemáticos detectados: {len(problematic_videos)}")
    if problematic_videos:
        logger.warning(f"Videos problemáticos: {problematic_videos}")
    if abs(class_counts["violence"] - class_counts["no_violence"]) / total_videos > 0.2:
        logger.warning("Desbalance de clases detectado. Considerar técnicas de balanceo.")
    return problematic_videos

problematic_videos = validate_dataset(DATASET_PATH)

Modelo facebook/timesformer-base-finetuned-k400 accesible correctamente
Uso de GPU:
memory.used [MiB], memory.total [MiB]
2 MiB, 15360 MiB

train/no_violence: 4000 videos


ERROR:__main__:Video corrupto: /content/drive/MyDrive/dataset_violencia/train/no_violence/no_violencia_1822_horizontal_flip.mp4


train/violence: 4000 videos


ERROR:__main__:Video corrupto: /content/drive/MyDrive/dataset_violencia/train/violence/violencia_directa_1240_horizontal_flip.mp4


val/no_violence: 750 videos
val/violence: 750 videos
test/no_violence: 400 videos
test/violence: 400 videos




Total de videos validados: 10300
Distribución de clases: {'no_violence': 5150, 'violence': 5150}
Videos problemáticos detectados: 2


In [None]:
# --- Sección 2: Validación del Dataset ---
def validate_dataset(root_dir):
    """Valida la integridad del dataset."""
    splits = ["train", "val", "test"]
    total_videos = 0
    for split in splits:
        for label in LABELS:
            label_dir = os.path.join(root_dir, split, label)
            if not os.path.exists(label_dir):
                logger.error(f"Directorio no encontrado: {label_dir}")
                raise FileNotFoundError(f"Directorio {label_dir} no existe")
            videos = [f for f in os.listdir(label_dir) if f.endswith(".mp4")]
            logger.info(f"{split}/{label}: {len(videos)} videos")
            print(f"{split}/{label}: {len(videos)} videos")
            total_videos += len(videos)

            # Add this part to list the videos
            print(f"Videos in {split}/{label}:")
            for video in videos:
                print(f"- {video}")
            print("-" * 20) # Separator for clarity

            # Validar videos
            for video in videos[:5]:  # Verificar primeros 5
                cap = cv2.VideoCapture(os.path.join(label_dir, video))
                if not cap.isOpened():
                    logger.error(f"Video corrupto: {video}")
                cap.release()

    logger.info(f"Total de videos validados: {total_videos}")
    print(f"Total de videos validados: {total_videos}")

validate_dataset(DATASET_PATH)

[1;30;43mSe han truncado las últimas 5000 líneas del flujo de salida.[0m
- violencia_directa_2047_horizontal_flip.mp4
- violencia_directa_2048_horizontal_flip.mp4
- violencia_directa_2049_horizontal_flip.mp4
- violencia_directa_2045_horizontal_flip.mp4
- violencia_directa_2046_horizontal_flip.mp4
- violencia_directa_2044_horizontal_flip.mp4
- violencia_directa_2052_horizontal_flip.mp4
- violencia_directa_2053_horizontal_flip.mp4
- violencia_directa_2051_horizontal_flip.mp4
- violencia_directa_2050_horizontal_flip.mp4
- violencia_directa_2054_horizontal_flip.mp4
- violencia_directa_2055_horizontal_flip.mp4
- violencia_directa_2059_horizontal_flip.mp4
- violencia_directa_2058_horizontal_flip.mp4
- violencia_directa_2056_horizontal_flip.mp4
- violencia_directa_2057_horizontal_flip.mp4
- violencia_directa_2060_horizontal_flip.mp4
- violencia_directa_2061_horizontal_flip.mp4
- violencia_directa_2062_horizontal_flip.mp4
- violencia_directa_2063_horizontal_flip.mp4
- violencia_directa_2066_

In [11]:
# --- Sección 5: Definir Dataset Personalizado ---
class ViolenceDataset(Dataset):
    def __init__(self, root_dir, split, processor, num_frames=NUM_FRAMES):
        self.root_dir = os.path.join(root_dir, split)
        self.processor = processor
        self.num_frames = num_frames
        self.videos = []
        self.labels = []
        self.durations = []
        self.resolutions = []

        for label in LABELS:
            label_dir = os.path.join(self.root_dir, label)
            for video_file in os.listdir(label_dir):
                if video_file.endswith(".mp4"):
                    video_path = os.path.join(label_dir, video_file)
                    cap = cv2.VideoCapture(video_path)
                    if not cap.isOpened():
                        logger.warning(f"Omitiendo video corrupto: {video_path}")
                        continue
                    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
                    fps = cap.get(cv2.CAP_PROP_FPS) or FPS
                    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                    if frame_count < 1 or width == 0 or height == 0:
                        logger.warning(f"Omitiendo video inválido: {video_path}")
                        cap.release()
                        continue
                    duration = frame_count / fps
                    cap.release()

                    self.videos.append(video_path)
                    self.labels.append(LABELS.index(label))
                    self.durations.append(duration)
                    self.resolutions.append((width, height))

        logger.info(f"{split} dataset cargado con {len(self.videos)} videos")
        if len(self.videos) == 0:
            logger.error(f"No se cargaron videos en {split}. Verifica el directorio.")
            raise ValueError(f"No se encontraron videos válidos en {split}")

    def __len__(self):
        return len(self.videos)

    def __getitem__(self, idx):
        try:
            video_path = self.videos[idx]
            label = self.labels[idx]

            cap = cv2.VideoCapture(video_path)
            frames = []
            while cap.isOpened():
                ret, frame = cap.read()
                if not ret:
                    break
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(frame)
            cap.release()

            if len(frames) == 0:
                logger.error(f"Video sin frames: {video_path}")
                raise ValueError(f"No se pudieron leer frames de {video_path}")

            if len(frames) < self.num_frames:
                frames = frames + [frames[-1]] * (self.num_frames - len(frames))
            else:
                indices = np.linspace(0, len(frames) - 1, self.num_frames).astype(int)
                frames = [frames[i] for i in indices]

            inputs = self.processor(frames, return_tensors="pt")
            inputs = {k: v.squeeze(0) for k, v in inputs.items()}

            return inputs, torch.tensor(label, dtype=torch.long), self.durations[idx]
        except Exception as e:
            logger.error(f"Error al procesar video {video_path}: {e}")
            raise

# --- Sección 6: Cargar Dataset y DataLoaders ---
train_dataset = ViolenceDataset(DATASET_PATH, "train", processor)
val_dataset = ViolenceDataset(DATASET_PATH, "val", processor)
test_dataset = ViolenceDataset(DATASET_PATH, "test", processor)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,  # Reducido a 0 para estabilidad en Colab
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

In [12]:
# --- Sección 7: Funciones de Evaluación y Visualización ---
def compute_metrics(labels, preds, probs):
    if len(labels) == 0 or len(preds) == 0:
        logger.error("No se recolectaron etiquetas o predicciones. Verifica los datos.")
        return {
            "confusion_matrix": np.zeros((NUM_CLASSES, NUM_CLASSES)),
            "accuracy": 0.0,
            "precision": 0.0,
            "recall": 0.0,
            "specificity": 0.0,
            "tpr": 0.0,
            "fpr": 0.0,
            "f1_score": 0.0,
            "fpr_roc": [0],
            "tpr_roc": [0],
            "auc": 0.0
        }

    cm = confusion_matrix(labels, preds)
    acc = accuracy_score(labels, preds)
    prec = precision_score(labels, preds, average='binary', zero_division=0)
    rec = recall_score(labels, preds, average='binary', zero_division=0)
    f1 = f1_score(labels, preds, average='binary', zero_division=0)
    spec = cm[0, 0] / (cm[0, 0] + cm[0, 1]) if (cm[0, 0] + cm[0, 1]) > 0 else 0
    tpr = rec
    fpr = cm[0, 1] / (cm[0, 1] + cm[0, 0]) if (cm[0, 1] + cm[0, 0]) > 0 else 0

    fpr_roc, tpr_roc, _ = roc_curve(labels, probs, drop_intermediate=False)
    auc_score = auc(fpr_roc, tpr_roc)

    return {
        "confusion_matrix": cm,
        "accuracy": acc,
        "precision": prec,
        "recall": rec,
        "specificity": spec,
        "tpr": tpr,
        "fpr": fpr,
        "f1_score": f1,
        "fpr_roc": fpr_roc,
        "tpr_roc": tpr_roc,
        "auc": auc_score
    }

def plot_metrics(metrics, phase, epoch, save_path):
    plt.figure(figsize=(20, 5))

    plt.subplot(1, 4, 1)
    sns.heatmap(metrics["confusion_matrix"], annot=True, fmt="d", cmap="Blues")
    plt.title(f"Matriz de Confusión - {phase} (Época {epoch})")
    plt.xlabel("Predicho")
    plt.ylabel("Real")

    plt.subplot(1, 4, 2)
    plt.plot(metrics["fpr_roc"], metrics["tpr_roc"], label=f"AUC = {metrics['auc']:.2f}")
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel("FPR")
    plt.ylabel("TPR")
    plt.title(f"Curva ROC - {phase} (Época {epoch})")
    plt.legend()

    plt.subplot(1, 4, 3)
    metrics_text = (f"Exactitud: {metrics['accuracy']:.3f}\n"
                    f"Precisión: {metrics['precision']:.3f}\n"
                    f"Sensibilidad: {metrics['recall']:.3f}\n"
                    f"Especificidad: {metrics['specificity']:.3f}\n"
                    f"F1-Score: {metrics['f1_score']:.3f}\n"
                    f"AUC: {metrics['auc']:.3f}")
    plt.text(0.1, 0.5, metrics_text, fontsize=12)
    plt.axis('off')
    plt.title(f"Métricas - {phase} (Época {epoch})")

    plt.subplot(1, 4, 4)
    plt.bar(["F1-Score"], [metrics["f1_score"]], color='skyblue')
    plt.ylim(0, 1)
    plt.title(f"F1-Score - {phase} (Época {epoch})")

    plt.tight_layout()
    plt.savefig(os.path.join(save_path, f"metrics_{phase}_epoch_{epoch}.png"))
    plt.close()
    wandb.log({f"{phase}_metrics_epoch_{epoch}": wandb.Image(os.path.join(save_path, f"metrics_{phase}_epoch_{epoch}.png"))})

def plot_learning_curves(history, phase, save_path):
    plt.figure(figsize=(15, 4))

    plt.subplot(1, 3, 1)
    plt.plot(history["train_loss"], label="Train Loss")
    plt.plot(history["val_loss"], label="Val Loss")
    plt.xlabel("Época")
    plt.ylabel("Pérdida")
    plt.title(f"Curva de Pérdida - {phase}")
    plt.legend()

    plt.subplot(1, 3, 2)
    plt.plot(history["train_acc"], label="Train Acc")
    plt.plot(history["val_acc"], label="Val Acc")
    plt.xlabel("Época")
    plt.ylabel("Exactitud")
    plt.title(f"Curva de Exactitud - {phase}")
    plt.legend()

    plt.subplot(1, 3, 3)
    plt.plot(history["val_f1"], label="Val F1")
    plt.xlabel("Época")
    plt.ylabel("F1-Score")
    plt.title(f"Curva de F1-Score - {phase}")
    plt.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(save_path, f"learning_curves_{phase}.png"))
    plt.close()
    wandb.log({f"learning_curves_{phase}": wandb.Image(os.path.join(save_path, f"learning_curves_{phase}.png"))})

def plot_probability_histogram(labels, probs, phase, save_path):
    plt.figure(figsize=(8, 6))
    for label in range(NUM_CLASSES):
        mask = np.array(labels) == label
        plt.hist(np.array(probs)[mask], bins=20, alpha=0.5, label=LABELS[label])
    plt.xlabel("Probabilidad de Violencia")
    plt.ylabel("Frecuencia")
    plt.title(f"Histograma de Probabilidades - {phase}")
    plt.legend()
    plt.savefig(os.path.join(save_path, f"prob_histogram_{phase}.png"))
    plt.close()
    wandb.log({f"prob_histogram_{phase}": wandb.Image(os.path.join(save_path, f"prob_histogram_{phase}.png"))})

def plot_prob_vs_duration(durations, probs, labels, phase, save_path):
    plt.figure(figsize=(8, 6))
    for label in range(NUM_CLASSES):
        mask = np.array(labels) == label
        plt.scatter(np.array(durations)[mask], np.array(probs)[mask], alpha=0.5, label=LABELS[label])
    plt.xlabel("Duración del Video (segundos)")
    plt.ylabel("Probabilidad de Violencia")
    plt.title(f"Probabilidad vs. Duración - {phase}")
    plt.legend()
    plt.savefig(os.path.join(save_path, f"prob_vs_duration_{phase}.png"))
    plt.close()
    wandb.log({f"prob_vs_duration_{phase}": wandb.Image(os.path.join(save_path, f"prob_vs_duration_{phase}.png"))})

def plot_f1_vs_duration(durations, f1_scores, phase, save_path):
    plt.figure(figsize=(8, 6))
    plt.scatter(durations, f1_scores, alpha=0.5)
    plt.xlabel("Duración del Video (segundos)")
    plt.ylabel("F1-Score")
    plt.title(f"F1-Score vs. Duración - {phase}")
    plt.savefig(os.path.join(save_path, f"f1_vs_duration_{phase}.png"))
    plt.close()
    wandb.log({f"f1_vs_duration_{phase}": wandb.Image(os.path.join(save_path, f"f1_vs_duration_{phase}.png"))})

def plot_error_rate_vs_duration(durations, correct, phase, save_path):
    plt.figure(figsize=(8, 6))
    plt.scatter(durations, [1 - c for c in correct], alpha=0.5)
    plt.xlabel("Duración del Video (segundos)")
    plt.ylabel("Tasa de Error (0=Correcto, 1=Incorrecto)")
    plt.title(f"Tasa de Error vs. Duración - {phase}")
    plt.savefig(os.path.join(save_path, f"error_rate_vs_duration_{phase}.png"))
    plt.close()
    wandb.log({f"error_rate_vs_duration_{phase}": wandb.Image(os.path.join(save_path, f"error_rate_vs_duration_{phase}.png"))})


In [13]:
# --- Sección 8: Funciones de Entrenamiento ---
def train_epoch(model, loader, optimizer, criterion, scaler, device, accumulation_steps):
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []
    all_probs = []
    failed_batches = 0
    optimizer.zero_grad()

    for i, (inputs, labels, _) in enumerate(tqdm(loader, desc="Entrenando")):
        try:
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device)

            logger.debug(f"Batch {i}: Input shape: {inputs['pixel_values'].shape}, Labels shape: {labels.shape}")

            with autocast():
                outputs = model(**inputs)
                loss = criterion(outputs.logits, labels) / accumulation_steps

            scaler.scale(loss).backward()

            if (i + 1) % accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            total_loss += loss.item() * accumulation_steps
            preds = torch.argmax(outputs.logits, dim=1)
            probs = torch.softmax(outputs.logits, dim=1)[:, 1]

            all_preds.extend(preds.detach().cpu().numpy())
            all_labels.extend(labels.detach().cpu().numpy())
            all_probs.extend(probs.detach().cpu().numpy())

            torch.cuda.empty_cache()
        except Exception as e:
            logger.error(f"Error en batch {i}: {e}")
            failed_batches += 1
            continue

    if failed_batches == len(loader):
        logger.error("Todos los lotes fallaron. Verifica los datos o la configuración del modelo.")
        raise RuntimeError("No se procesó ningún lote correctamente")

    print(f"Lotes procesados: {len(loader) - failed_batches}/{len(loader)}, Lotes fallidos: {failed_batches}")

    avg_loss = total_loss / (len(loader) - failed_batches) if len(loader) > failed_batches else 0
    metrics = compute_metrics(all_labels, all_preds, all_probs)
    return avg_loss, metrics, all_labels, all_probs

def evaluate_epoch(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    all_probs = []
    inference_times = []
    failed_batches = 0
    all_durations = []

    with torch.no_grad():
        for i, (inputs, labels, durations) in enumerate(tqdm(loader, desc="Evaluando")):
            try:
                start_time = time.time()
                inputs = {k: v.to(device) for k, v in inputs.items()}
                labels = labels.to(device)

                logger.debug(f"Batch {i}: Input shape: {inputs['pixel_values'].shape}, Labels shape: {labels.shape}")

                with autocast():
                    outputs = model(**inputs)
                    loss = criterion(outputs.logits, labels)

                inference_times.append(time.time() - start_time)
                total_loss += loss.item()
                preds = torch.argmax(outputs.logits, dim=1)
                probs = torch.softmax(outputs.logits, dim=1)[:, 1]

                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())
                all_durations.extend(durations)

                torch.cuda.empty_cache()
            except Exception as e:
                logger.error(f"Error en batch {i}: {e}")
                failed_batches += 1
                continue

    if failed_batches == len(loader):
        logger.error("Todos los lotes fallaron. Verifica los datos o la configuración del modelo.")
        raise RuntimeError("No se procesó ningún lote correctamente")

    print(f"Lotes procesados: {len(loader) - failed_batches}/{len(loader)}, Lotes fallidos: {failed_batches}")

    avg_loss = total_loss / (len(loader) - failed_batches) if len(loader) > failed_batches else 0
    metrics = compute_metrics(all_labels, all_preds, all_probs)
    avg_inference_time = np.mean(inference_times) if inference_times else 0
    return avg_loss, metrics, all_labels, all_probs, avg_inference_time, all_durations


# ENTRENAMIENTO CON TRANSFER LEARNING

In [None]:
# --- Sección 9: Transfer Learning ---
logger.info("Iniciando Transfer Learning")
print("Iniciando Transfer Learning")
model = TimesformerForVideoClassification.from_pretrained(
    MODEL_ID,
    num_labels=NUM_CLASSES,
    ignore_mismatched_sizes=True
)

for name, param in model.named_parameters():
    if "timesformer" in name:
        param.requires_grad = False
    else:
        param.requires_grad = True

model = model.to(device)
try:
    model = torch.compile(model)
    logger.info("Modelo compilado con torch.compile")
    print("Modelo compilado con torch.compile")
except Exception as e:
    logger.warning(f"torch.compile no soportado: {e}")
    print(f"torch.compile no soportado: {e}")

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE_TL, weight_decay=WEIGHT_DECAY)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
scaler = GradScaler()

history_tl = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": [], "val_f1": []}
best_val_loss = float("inf")
early_stop_counter = 0

for epoch in range(NUM_EPOCHS_TL):
    start_time = time.time()

    train_loss, train_metrics, train_labels, train_probs = train_epoch(
        model, train_loader, optimizer, criterion, scaler, device, ACCUMULATION_STEPS
    )
    val_loss, val_metrics, val_labels, val_probs, val_inference_time, val_durations = evaluate_epoch(
        model, val_loader, criterion, device
    )

    history_tl["train_loss"].append(train_loss)
    history_tl["val_loss"].append(val_loss)
    history_tl["train_acc"].append(train_metrics["accuracy"])
    history_tl["val_acc"].append(val_metrics["accuracy"])
    history_tl["val_f1"].append(val_metrics["f1_score"])

    scheduler.step(val_loss)

    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "scaler_state_dict": scaler.state_dict(),
        "val_loss": val_loss,
        "val_metrics": val_metrics
    }
    checkpoint_path = os.path.join(CHECKPOINT_PATH, f"checkpoint_tl_epoch_{epoch}_{TIMESTAMP}.pt")
    torch.save(checkpoint, checkpoint_path)
    logger.info(f"Checkpoint guardado: {checkpoint_path}")
    print(f"Checkpoint guardado: {checkpoint_path}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_path = os.path.join(CHECKPOINT_PATH, f"best_model_tl_{TIMESTAMP}.pt")
        torch.save(model.state_dict(), best_model_path)
        logger.info(f"Mejor modelo guardado: {best_model_path}")
        print(f"Mejor modelo guardado: {best_model_path}")
        early_stop_counter = 0
    else:
        early_stop_counter += 1

    logger.info(f"Época {epoch+1}/{NUM_EPOCHS_TL}")
    print(f"Época {epoch+1}/{NUM_EPOCHS_TL}")
    logger.info(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    logger.info(f"Train Acc: {train_metrics['accuracy']:.3f}, Val Acc: {val_metrics['accuracy']:.3f}")
    print(f"Train Acc: {train_metrics['accuracy']:.3f}, Val Acc: {val_metrics['accuracy']:.3f}")
    logger.info(f"Val F1-Score: {val_metrics['f1_score']:.3f}")
    print(f"Val F1-Score: {val_metrics['f1_score']:.3f}")
    logger.info(f"Tiempo: {time.time() - start_time:.2f} segundos")
    print(f"Tiempo: {time.time() - start_time:.2f} segundos")
    wandb.log({
        "epoch": epoch,
        "train_loss": train_loss,
        "val_loss": val_loss,
        "train_acc": train_metrics["accuracy"],
        "val_acc": val_metrics["accuracy"],
        "val_f1": val_metrics["f1_score"],
        "val_inference_time": val_inference_time
    })

    plot_metrics(val_metrics, "Validación TL", epoch, CHECKPOINT_PATH)
    plot_probability_histogram(val_labels, val_probs, f"Validación TL Epoch {epoch}", CHECKPOINT_PATH)
    plot_prob_vs_duration(val_durations, val_probs, val_labels, f"Validación TL Epoch {epoch}", CHECKPOINT_PATH)

    if early_stop_counter >= PATIENCE:
        logger.info("Early Stopping activado")
        print("Early Stopping activado")
        break

plot_learning_curves(history_tl, "Transfer Learning", CHECKPOINT_PATH)

Iniciando Transfer Learning


Some weights of TimesformerForVideoClassification were not initialized from the model checkpoint at facebook/timesformer-base-finetuned-k400 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([400, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([400]) in the checkpoint and torch.Size([2]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Modelo compilado con torch.compile


Entrenando:   0%|          | 0/4000 [00:00<?, ?it/s]

Lotes procesados: 4000/4000, Lotes fallidos: 0


Evaluando:   0%|          | 0/750 [00:00<?, ?it/s]

Lotes procesados: 750/750, Lotes fallidos: 0
Checkpoint guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/checkpoint_tl_epoch_0_20250518_044626.pt
Mejor modelo guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/best_model_tl_20250518_044626.pt
Época 1/20
Train Loss: 0.3305, Val Loss: 0.2104
Train Acc: 0.870, Val Acc: 0.931
Val F1-Score: 0.930
Tiempo: 943.01 segundos


Entrenando:   0%|          | 0/4000 [00:00<?, ?it/s]

Lotes procesados: 4000/4000, Lotes fallidos: 0


Evaluando:   0%|          | 0/750 [00:00<?, ?it/s]

Lotes procesados: 750/750, Lotes fallidos: 0
Checkpoint guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/checkpoint_tl_epoch_1_20250518_044626.pt
Mejor modelo guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/best_model_tl_20250518_044626.pt
Época 2/20
Train Loss: 0.1934, Val Loss: 0.1649
Train Acc: 0.931, Val Acc: 0.947
Val F1-Score: 0.947
Tiempo: 870.31 segundos


Entrenando:   0%|          | 0/4000 [00:00<?, ?it/s]

Lotes procesados: 4000/4000, Lotes fallidos: 0


Evaluando:   0%|          | 0/750 [00:00<?, ?it/s]

Lotes procesados: 750/750, Lotes fallidos: 0
Checkpoint guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/checkpoint_tl_epoch_2_20250518_044626.pt
Mejor modelo guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/best_model_tl_20250518_044626.pt
Época 3/20
Train Loss: 0.1612, Val Loss: 0.1457
Train Acc: 0.941, Val Acc: 0.953
Val F1-Score: 0.953
Tiempo: 869.37 segundos


Entrenando:   0%|          | 0/4000 [00:00<?, ?it/s]

Lotes procesados: 4000/4000, Lotes fallidos: 0


Evaluando:   0%|          | 0/750 [00:00<?, ?it/s]

Lotes procesados: 750/750, Lotes fallidos: 0
Checkpoint guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/checkpoint_tl_epoch_3_20250518_044626.pt
Mejor modelo guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/best_model_tl_20250518_044626.pt
Época 4/20
Train Loss: 0.1430, Val Loss: 0.1338
Train Acc: 0.948, Val Acc: 0.959
Val F1-Score: 0.959
Tiempo: 895.66 segundos


Entrenando:   0%|          | 0/4000 [00:00<?, ?it/s]

Lotes procesados: 4000/4000, Lotes fallidos: 0


Evaluando:   0%|          | 0/750 [00:00<?, ?it/s]

Lotes procesados: 750/750, Lotes fallidos: 0
Checkpoint guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/checkpoint_tl_epoch_4_20250518_044626.pt
Mejor modelo guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/best_model_tl_20250518_044626.pt
Época 5/20
Train Loss: 0.1305, Val Loss: 0.1268
Train Acc: 0.953, Val Acc: 0.959
Val F1-Score: 0.959
Tiempo: 874.99 segundos


Entrenando:   0%|          | 0/4000 [00:00<?, ?it/s]

Lotes procesados: 4000/4000, Lotes fallidos: 0


Evaluando:   0%|          | 0/750 [00:00<?, ?it/s]

Lotes procesados: 750/750, Lotes fallidos: 0
Checkpoint guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/checkpoint_tl_epoch_5_20250518_044626.pt
Mejor modelo guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/best_model_tl_20250518_044626.pt
Época 6/20
Train Loss: 0.1214, Val Loss: 0.1194
Train Acc: 0.957, Val Acc: 0.962
Val F1-Score: 0.962
Tiempo: 871.27 segundos


Entrenando:   0%|          | 0/4000 [00:00<?, ?it/s]

Lotes procesados: 4000/4000, Lotes fallidos: 0


Evaluando:   0%|          | 0/750 [00:00<?, ?it/s]

Lotes procesados: 750/750, Lotes fallidos: 0
Checkpoint guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/checkpoint_tl_epoch_6_20250518_044626.pt
Mejor modelo guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/best_model_tl_20250518_044626.pt
Época 7/20
Train Loss: 0.1138, Val Loss: 0.1175
Train Acc: 0.962, Val Acc: 0.962
Val F1-Score: 0.961
Tiempo: 874.44 segundos


Entrenando:   0%|          | 0/4000 [00:00<?, ?it/s]

Lotes procesados: 4000/4000, Lotes fallidos: 0


Evaluando:   0%|          | 0/750 [00:00<?, ?it/s]

Lotes procesados: 750/750, Lotes fallidos: 0
Checkpoint guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/checkpoint_tl_epoch_7_20250518_044626.pt
Mejor modelo guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/best_model_tl_20250518_044626.pt
Época 8/20
Train Loss: 0.1084, Val Loss: 0.1117
Train Acc: 0.963, Val Acc: 0.966
Val F1-Score: 0.966
Tiempo: 857.88 segundos


Entrenando:   0%|          | 0/4000 [00:00<?, ?it/s]

Lotes procesados: 4000/4000, Lotes fallidos: 0


Evaluando:   0%|          | 0/750 [00:00<?, ?it/s]

Lotes procesados: 750/750, Lotes fallidos: 0
Checkpoint guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/checkpoint_tl_epoch_8_20250518_044626.pt
Mejor modelo guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/best_model_tl_20250518_044626.pt
Época 9/20
Train Loss: 0.1033, Val Loss: 0.1088
Train Acc: 0.966, Val Acc: 0.965
Val F1-Score: 0.965
Tiempo: 858.21 segundos


Entrenando:   0%|          | 0/4000 [00:00<?, ?it/s]

Lotes procesados: 4000/4000, Lotes fallidos: 0


Evaluando:   0%|          | 0/750 [00:00<?, ?it/s]

Lotes procesados: 750/750, Lotes fallidos: 0
Checkpoint guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/checkpoint_tl_epoch_9_20250518_044626.pt
Mejor modelo guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/best_model_tl_20250518_044626.pt
Época 10/20
Train Loss: 0.0989, Val Loss: 0.1054
Train Acc: 0.968, Val Acc: 0.963
Val F1-Score: 0.963
Tiempo: 858.55 segundos


Entrenando:   0%|          | 0/4000 [00:00<?, ?it/s]

Lotes procesados: 4000/4000, Lotes fallidos: 0


Evaluando:   0%|          | 0/750 [00:00<?, ?it/s]

Lotes procesados: 750/750, Lotes fallidos: 0
Checkpoint guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/checkpoint_tl_epoch_10_20250518_044626.pt
Mejor modelo guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/best_model_tl_20250518_044626.pt
Época 11/20
Train Loss: 0.0954, Val Loss: 0.1035
Train Acc: 0.969, Val Acc: 0.966
Val F1-Score: 0.966
Tiempo: 873.17 segundos


Entrenando:   0%|          | 0/4000 [00:00<?, ?it/s]

Lotes procesados: 4000/4000, Lotes fallidos: 0


Evaluando:   0%|          | 0/750 [00:00<?, ?it/s]

Lotes procesados: 750/750, Lotes fallidos: 0
Checkpoint guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/checkpoint_tl_epoch_11_20250518_044626.pt
Mejor modelo guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/best_model_tl_20250518_044626.pt
Época 12/20
Train Loss: 0.0920, Val Loss: 0.1024
Train Acc: 0.969, Val Acc: 0.965
Val F1-Score: 0.965
Tiempo: 870.40 segundos


Entrenando:   0%|          | 0/4000 [00:00<?, ?it/s]

Lotes procesados: 4000/4000, Lotes fallidos: 0


Evaluando:   0%|          | 0/750 [00:00<?, ?it/s]

Lotes procesados: 750/750, Lotes fallidos: 0
Checkpoint guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/checkpoint_tl_epoch_12_20250518_044626.pt
Mejor modelo guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/best_model_tl_20250518_044626.pt
Época 13/20
Train Loss: 0.0893, Val Loss: 0.0989
Train Acc: 0.970, Val Acc: 0.964
Val F1-Score: 0.964
Tiempo: 869.14 segundos


Entrenando:   0%|          | 0/4000 [00:00<?, ?it/s]

Lotes procesados: 4000/4000, Lotes fallidos: 0


Evaluando:   0%|          | 0/750 [00:00<?, ?it/s]

Lotes procesados: 750/750, Lotes fallidos: 0
Checkpoint guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/checkpoint_tl_epoch_13_20250518_044626.pt
Mejor modelo guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/best_model_tl_20250518_044626.pt
Época 14/20
Train Loss: 0.0868, Val Loss: 0.0971
Train Acc: 0.970, Val Acc: 0.965
Val F1-Score: 0.965
Tiempo: 868.59 segundos


Entrenando:   0%|          | 0/4000 [00:00<?, ?it/s]

Lotes procesados: 4000/4000, Lotes fallidos: 0


Evaluando:   0%|          | 0/750 [00:00<?, ?it/s]

Lotes procesados: 750/750, Lotes fallidos: 0
Checkpoint guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/checkpoint_tl_epoch_14_20250518_044626.pt
Mejor modelo guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/best_model_tl_20250518_044626.pt
Época 15/20
Train Loss: 0.0842, Val Loss: 0.0957
Train Acc: 0.973, Val Acc: 0.966
Val F1-Score: 0.966
Tiempo: 862.20 segundos


Entrenando:   0%|          | 0/4000 [00:00<?, ?it/s]

Lotes procesados: 4000/4000, Lotes fallidos: 0


Evaluando:   0%|          | 0/750 [00:00<?, ?it/s]

Lotes procesados: 750/750, Lotes fallidos: 0
Checkpoint guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/checkpoint_tl_epoch_15_20250518_044626.pt
Mejor modelo guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/best_model_tl_20250518_044626.pt
Época 16/20
Train Loss: 0.0820, Val Loss: 0.0950
Train Acc: 0.973, Val Acc: 0.965
Val F1-Score: 0.965
Tiempo: 856.71 segundos


Entrenando:   0%|          | 0/4000 [00:00<?, ?it/s]

Lotes procesados: 4000/4000, Lotes fallidos: 0


Evaluando:   0%|          | 0/750 [00:00<?, ?it/s]

Lotes procesados: 750/750, Lotes fallidos: 0
Checkpoint guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/checkpoint_tl_epoch_16_20250518_044626.pt
Mejor modelo guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/best_model_tl_20250518_044626.pt
Época 17/20
Train Loss: 0.0800, Val Loss: 0.0929
Train Acc: 0.974, Val Acc: 0.967
Val F1-Score: 0.967
Tiempo: 860.10 segundos


Entrenando:   0%|          | 0/4000 [00:00<?, ?it/s]

Lotes procesados: 4000/4000, Lotes fallidos: 0


Evaluando:   0%|          | 0/750 [00:00<?, ?it/s]

Lotes procesados: 750/750, Lotes fallidos: 0
Checkpoint guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/checkpoint_tl_epoch_17_20250518_044626.pt
Mejor modelo guardado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/best_model_tl_20250518_044626.pt
Época 18/20
Train Loss: 0.0782, Val Loss: 0.0923
Train Acc: 0.975, Val Acc: 0.967
Val F1-Score: 0.967
Tiempo: 863.04 segundos


Entrenando:   0%|          | 0/4000 [00:00<?, ?it/s]

In [None]:
!pip install wandb albumentations retrying

In [None]:
# --- Sección 10: Fine-Tuning ---
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import torch
from torch.cuda.amp import autocast
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm

logger.info("Iniciando Fine-Tuning")
print("Iniciando Fine-Tuning")

# Validar dataset antes de Fine-Tuning
problematic_videos = validate_dataset(DATASET_PATH)
if problematic_videos:
    logger.warning(f"Se detectaron {len(problematic_videos)} videos problemáticos. Revisa el log y corrige antes de continuar.")
    print(f"Se detectaron {len(problematic_videos)} videos problemáticos. Revisa el log y corrige antes de continuar.")
    raise ValueError("Corrige los videos problemáticos antes de continuar con Fine-Tuning.")

# Función para corregir el prefijo '_orig_mod.' en el state_dict
def fix_state_dict(state_dict):
    new_state_dict = {}
    for key, value in state_dict.items():
        new_key = key.replace("_orig_mod.", "")  # Eliminar el prefijo
        new_state_dict[new_key] = value
    return new_state_dict

# Implementación corregida de evaluate_epoch
def evaluate_epoch(model, data_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_labels = []
    all_probs = []
    inference_times = []
    durations = []
    with torch.no_grad():
        for inputs, labels, duration in tqdm(data_loader, desc="Evaluando"):
            start_time = time.time()
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device)
            with autocast():
                outputs = model(**inputs).logits
                loss = criterion(outputs, labels)
            running_loss += loss.item()
            # Aplicar softmax para obtener probabilidades de ambas clases
            probs = torch.softmax(outputs, dim=1).cpu().numpy()  # Forma: (batch_size, 2)
            all_probs.append(probs)
            all_labels.append(labels.cpu().numpy())
            inference_times.append(time.time() - start_time)
            durations.append(duration.numpy() if isinstance(duration, torch.Tensor) else duration)
    all_probs = np.vstack(all_probs)  # Concatenar: (n_samples, 2)
    all_labels = np.concatenate(all_labels)
    durations = np.concatenate(durations)
    avg_loss = running_loss / len(data_loader)
    metrics = {
        "accuracy": accuracy_score(all_labels, np.argmax(all_probs, axis=1)),
        "f1_score": f1_score(all_labels, np.argmax(all_probs, axis=1), average='weighted')
    }
    logger.info(f"Forma de all_probs en evaluate_epoch: {all_probs.shape}")
    print(f"Forma de all_probs en evaluate_epoch: {all_probs.shape}")
    return avg_loss, metrics, all_labels, all_probs, np.mean(inference_times), durations

# Inicializar el modelo
model = TimesformerForVideoClassification.from_pretrained(
    MODEL_ID,
    num_labels=NUM_CLASSES,
    ignore_mismatched_sizes=True
)

# Cargar mejor modelo de Transfer Learning
TIMESTAMP = "20250518_044626"
best_model_path = os.path.join(CHECKPOINT_PATH, f"best_model_tl_{TIMESTAMP}.pt")
if not os.path.exists(best_model_path):
    logger.error(f"No se encontró el mejor modelo TL: {best_model_path}")
    raise FileNotFoundError(f"No se encontró {best_model_path}")
try:
    state_dict = torch.load(best_model_path, map_location=device)
    state_dict = fix_state_dict(state_dict)
    model.load_state_dict(state_dict, strict=True)
    logger.info(f"Mejor modelo TL cargado: {best_model_path}")
    print(f"Mejor modelo TL cargado: {best_model_path}")
except Exception as e:
    logger.error(f"Error al cargar el state_dict: {e}")
    raise RuntimeError(f"Error al cargar el state_dict: {e}")

# Descongelar capas superiores
trainable_params = 0
total_params = 0
for name, param in model.named_parameters():
    if any(layer in name for layer in ["timesformer.encoder.layer.9", "timesformer.encoder.layer.10", "timesformer.encoder.layer.11", "classifier"]):
        param.requires_grad = True
        trainable_params += param.numel()
    else:
        param.requires_grad = False
    total_params += param.numel()
logger.info(f"Parámetros entrenables: {trainable_params} ({trainable_params/total_params*100:.2f}% del total)")
print(f"Parámetros entrenables: {trainable_params} ({trainable_params/total_params*100:.2f}% del total)")

# Aumentar dropout en el clasificador
model.classifier.dropout = torch.nn.Dropout(p=0.5)
logger.info("Dropout del clasificador aumentado a 0.5")
print("Dropout del clasificador aumentado a 0.5")

# Configuración
LEARNING_RATE_FT = 1e-5
WEIGHT_DECAY = 1e-2
NUM_EPOCHS_FT = 12

model = model.to(device)
try:
    model = torch.compile(model)
    logger.info("Modelo compilado con torch.compile")
    print("Modelo compilado con torch.compile")
except Exception as e:
    logger.warning(f"torch.compile no soportado: {e}")
    print(f"torch.compile no soportado: {e}")

# Definir optimizador, criterio y escalador
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=LEARNING_RATE_FT,
    weight_decay=WEIGHT_DECAY
)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
scaler = GradScaler()

# Aumento de datos
augmentation = A.Compose([
    A.Rotate(limit=10, p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.RandomCrop(height=IMG_SIZE, width=IMG_SIZE, p=0.3),
    A.Resize(height=IMG_SIZE, width=IMG_SIZE),
    ToTensorV2()
])

# Actualizar ViolenceDataset con aumento de datos
class ViolenceDatasetAugmented(ViolenceDataset):
    def __getitem__(self, idx):
        try:
            video_path = self.videos[idx]
            label = self.labels[idx]
            cap = cv2.VideoCapture(video_path)
            frames = []
            while cap.isOpened():
                ret, frame = cap.read()
                if not ret:
                    break
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(frame)
            cap.release()
            if len(frames) == 0:
                logger.error(f"Video sin frames: {video_path}")
                raise ValueError(f"No se pudieron leer frames de {video_path}")
            if len(frames) < self.num_frames:
                frames = frames + [frames[-1]] * (self.num_frames - len(frames))
            else:
                indices = np.linspace(0, len(frames) - 1, self.num_frames).astype(int)
                frames = [frames[i] for i in indices]
            augmented_frames = []
            for frame in frames:
                augmented = augmentation(image=frame)
                augmented_frames.append(augmented["image"].permute(1, 2, 0).numpy())
            inputs = self.processor(augmented_frames, return_tensors="pt")
            inputs = {k: v.squeeze(0) for k, v in inputs.items()}
            return inputs, torch.tensor(label, dtype=torch.long), self.durations[idx]
        except Exception as e:
            logger.error(f"Error al procesar video {video_path}: {e}")
            raise

# Actualizar datasets
train_dataset = ViolenceDatasetAugmented(DATASET_PATH, "train", processor)
val_dataset = ViolenceDataset(DATASET_PATH, "val", processor)
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

# Función para analizar errores
def analyze_errors(labels, preds, probs, videos, phase, epoch, save_path):
    errors = [(i, labels[i], preds[i], probs[i], videos[i]) for i in range(len(labels)) if labels[i] != preds[i]]
    logger.info(f"Número de errores en {phase} (Época {epoch}): {len(errors)}")
    print(f"Número de errores en {phase} (Época {epoch}): {len(errors)}")
    false_positives = [(i, prob, video) for i, true, pred, prob, video in errors if true == 0 and pred == 1]
    false_negatives = [(i, prob, video) for i, true, pred, prob, video in errors if true == 1 and pred == 0]
    logger.info(f"Falsos Positivos: {len(false_positives)}")
    print(f"Falsos Positivos: {len(false_positives)}")
    logger.info(f"Falsos Negativos: {len(false_negatives)}")
    print(f"Falsos Negativos: {len(false_negatives)}")
    for idx, prob, video_path in false_positives[:3]:
        logger.info(f"Falso Positivo - Índice: {idx}, Video: {video_path}, Probabilidad: {prob:.3f}")
        print(f"Falso Positivo - Índice: {idx}, Video: {video_path}, Probabilidad: {prob:.3f}")
    for idx, prob, video_path in false_negatives[:3]:
        logger.info(f"Falso Negativo - Índice: {idx}, Video: {video_path}, Probabilidad: {prob:.3f}")
        print(f"Falso Negativo - Índice: {idx}, Video: {video_path}, Probabilidad: {prob:.3f}")
    for idx, true, pred, prob, video_path in errors[:3]:
        cap = cv2.VideoCapture(video_path)
        frames = []
        while cap.isOpened() and len(frames) < 4:
            ret, frame = cap.read()
            if not ret:
                break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame)
        cap.release()
        plt.figure(figsize=(12, 3))
        for i, frame in enumerate(frames):
            plt.subplot(1, 4, i+1)
            plt.imshow(frame)
            plt.axis('off')
            plt.title(f"Frame {i+1}")
        plt.suptitle(f"Error: Real={LABELS[true]}, Pred={LABELS[pred]}, Prob={prob:.3f}")
        plt.savefig(os.path.join(save_path, f"error_{phase}_epoch_{epoch}_{idx}.png"))
        plt.close()
        wandb.log({f"error_{phase}_epoch_{epoch}_{idx}": wandb.Image(os.path.join(save_path, f"error_{phase}_epoch_{epoch}_{idx}.png"))})
    return len(false_positives), len(false_negatives)

# Entrenamiento
history_ft = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": [], "val_f1": [], "fp": [], "fn": []}
best_val_loss = float("inf")
early_stop_counter = 0

for epoch in range(NUM_EPOCHS_FT):
    start_time = time.time()
    train_loss, train_metrics, train_labels, train_probs = train_epoch(
        model, train_loader, optimizer, criterion, scaler, device, ACCUMULATION_STEPS
    )
    val_loss, val_metrics, val_labels, val_probs, val_inference_time, val_durations = evaluate_epoch(
        model, val_loader, criterion, device
    )
    # Depurar val_probs
    logger.info(f"Tipo de val_probs: {type(val_probs)}")
    print(f"Tipo de val_probs: {type(val_probs)}")
    if isinstance(val_probs, list):
        val_probs = np.vstack([p.cpu().numpy() if isinstance(p, torch.Tensor) else p for p in val_probs])
    elif isinstance(val_probs, torch.Tensor):
        val_probs = val_probs.cpu().numpy()
    else:
        logger.error(f"Tipo inesperado de val_probs: {type(val_probs)}")
        raise ValueError(f"Tipo inesperado de val_probs: {type(val_probs)}")
    logger.info(f"Forma de val_probs: {val_probs.shape}")
    print(f"Forma de val_probs: {val_probs.shape}")
    logger.info(f"Primeras filas de val_probs: {val_probs[:5]}")
    print(f"Primeras filas de val_probs: {val_probs[:5]}")

    # Manejar val_probs con una sola columna
    if val_probs.ndim == 2 and val_probs.shape[1] == 2:
        val_probs_positive = val_probs[:, 1]  # Probabilidad de la clase "violence"
    elif val_probs.ndim == 2 and val_probs.shape[1] == 1:
        logger.warning("val_probs tiene una sola columna. Asumiendo probabilidades de la clase positiva.")
        val_probs_positive = val_probs[:, 0]
        # Construir array con ambas clases
        val_probs = np.hstack([1 - val_probs, val_probs])  # [P(no_violence), P(violence)]
    else:
        logger.error(f"Forma inesperada de val_probs: {val_probs.shape}")
        raise ValueError(f"Forma inesperada de val_probs: {val_probs.shape}")

    val_preds = (val_probs_positive > 0.5).astype(int)
    fp, fn = analyze_errors(
        val_labels, val_preds, val_probs_positive,
        [val_dataset.videos[i] for i in range(len(val_labels))], "Validación FT", epoch, CHECKPOINT_PATH
    )
    history_ft["train_loss"].append(train_loss)
    history_ft["val_loss"].append(val_loss)
    history_ft["train_acc"].append(train_metrics["accuracy"])
    history_ft["val_acc"].append(val_metrics["accuracy"])
    history_ft["val_f1"].append(val_metrics["f1_score"])
    history_ft["fp"].append(fp)
    history_ft["fn"].append(fn)
    scheduler.step(val_loss)
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "scaler_state_dict": scaler.state_dict(),
        "val_loss": val_loss,
        "val_metrics": val_metrics
    }
    checkpoint_path = os.path.join(CHECKPOINT_PATH, f"checkpoint_ft_epoch_{epoch}_{TIMESTAMP}.pt")
    torch.save(checkpoint, checkpoint_path)
    logger.info(f"Checkpoint guardado: {checkpoint_path}")
    print(f"Checkpoint guardado: {checkpoint_path}")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_path = os.path.join(CHECKPOINT_PATH, f"best_model_ft_{TIMESTAMP}.pt")
        torch.save(model.state_dict(), best_model_path)
        logger.info(f"Mejor modelo guardado: {best_model_path}")
        print(f"Mejor modelo guardado: {best_model_path}")
        early_stop_counter = 0
    else:
        early_stop_counter += 1
    logger.info(f"Época {epoch+1}/{NUM_EPOCHS_FT}")
    print(f"Época {epoch+1}/{NUM_EPOCHS_FT}")
    logger.info(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    logger.info(f"Train Acc: {train_metrics['accuracy']:.3f}, Val Acc: {val_metrics['accuracy']:.3f}")
    print(f"Train Acc: {train_metrics['accuracy']:.3f}, Val Acc: {val_metrics['accuracy']:.3f}")
    logger.info(f"Val F1-Score: {val_metrics['f1_score']:.3f}")
    print(f"Val F1-Score: {val_metrics['f1_score']:.3f}")
    logger.info(f"Tiempo: {time.time() - start_time:.2f} segundos")
    print(f"Tiempo: {time.time() - start_time:.2f} segundos")
    logger.info(f"Tiempo de inferencia promedio (validación): {val_inference_time:.4f} segundos")
    print(f"Tiempo de inferencia promedio (validación): {val_inference_time:.4f} segundos")
    wandb.log({
        "epoch": epoch,
        "train_loss": train_loss,
        "val_loss": val_loss,
        "train_acc": train_metrics["accuracy"],
        "val_acc": val_metrics["accuracy"],
        "val_f1": val_metrics["f1_score"],
        "val_inference_time": val_inference_time,
        "false_positives": fp,
        "false_negatives": fn
    })
    plot_metrics(val_metrics, "Validación FT", epoch, CHECKPOINT_PATH)
    plot_probability_histogram(val_labels, val_probs_positive, f"Validación FT Epoch {epoch}", CHECKPOINT_PATH)
    plot_prob_vs_duration(val_durations, val_probs_positive, val_labels, f"Validación FT Epoch {epoch}", CHECKPOINT_PATH)
    if early_stop_counter >= PATIENCE:
        logger.info("Early Stopping activado")
        print("Early Stopping activado")
        break

plot_learning_curves(history_ft, "Fine-Tuning", CHECKPOINT_PATH)


Iniciando Fine-Tuning
train/no_violence: 4000 videos
train/violence: 4000 videos
val/no_violence: 750 videos
val/violence: 750 videos
test/no_violence: 400 videos
test/violence: 400 videos
Total de videos validados: 10300
Distribución de clases: {'no_violence': 5150, 'violence': 5150}
Videos problemáticos detectados: 0


Some weights of TimesformerForVideoClassification were not initialized from the model checkpoint at facebook/timesformer-base-finetuned-k400 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([400, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([400]) in the checkpoint and torch.Size([2]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Mejor modelo TL cargado: /content/drive/MyDrive/Modelo-Violencia/timesformer_checkpoints/best_model_tl_20250518_044626.pt
Parámetros entrenables: 30128642 (24.85% del total)
Dropout del clasificador aumentado a 0.5
Modelo compilado con torch.compile


Entrenando:  16%|█▋        | 655/4000 [02:31<10:55,  5.10it/s]

In [None]:
# --- Sección 11: Evaluación en Conjunto de Prueba ---
logger.info("Evaluando modelo final en conjunto de prueba")
print("Evaluando modelo final en conjunto de prueba")
best_model_path = os.path.join(CHECKPOINT_PATH, f"best_model_ft_{TIMESTAMP}.pt")
if not os.path.exists(best_model_path):
    logger.error(f"No se encontró el mejor modelo FT: {best_model_path}")
    raise FileNotFoundError(f"No se encontró {best_model_path}")
try:
    state_dict = torch.load(best_model_path, map_location=device)
    state_dict = fix_state_dict(state_dict)
    model.load_state_dict(state_dict, strict=True)
    logger.info(f"Mejor modelo FT cargado: {best_model_path}")
    print(f"Mejor modelo FT cargado: {best_model_path}")
except Exception as e:
    logger.error(f"Error al cargar el state_dict: {e}")
    raise RuntimeError(f"Error al cargar el state_dict: {e}")

test_dataset = ViolenceDataset(DATASET_PATH, "test", processor)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)
test_loss, test_metrics, test_labels, test_probs, test_inference_time, test_durations = evaluate_epoch(
    model, test_loader, criterion, device
)
# Procesar test_probs
if isinstance(test_probs, list):
    test_probs = np.vstack([p.cpu().numpy() if isinstance(p, torch.Tensor) else p for p in test_probs])
elif isinstance(test_probs, torch.Tensor):
    test_probs = test_probs.cpu().numpy()
logger.info(f"Forma de test_probs: {test_probs.shape}")
print(f"Forma de test_probs: {test_probs.shape}")
if test_probs.ndim == 2 and test_probs.shape[1] == 2:
    test_probs_positive = test_probs[:, 1]
elif test_probs.ndim == 2 and test_probs.shape[1] == 1:
    logger.warning("test_probs tiene una sola columna. Asumiendo probabilidades de la clase positiva.")
    test_probs_positive = test_probs[:, 0]
    test_probs = np.hstack([1 - test_probs, test_probs])
else:
    logger.error(f"Forma inesperada de test_probs: {test_probs.shape}")
    raise ValueError(f"Forma inesperada de test_probs: {test_probs.shape}")

test_preds = (test_probs_positive > 0.5).astype(int)
fp, fn = analyze_errors(
    test_labels, test_preds, test_probs_positive,
    [test_dataset.videos[i] for i in range(len(test_labels))], "Prueba", "Final", CHECKPOINT_PATH
)

logger.info("Resultados en conjunto de prueba:")
print("Resultados en conjunto de prueba:")
logger.info(f"Pérdida: {test_loss:.4f}")
print(f"Pérdida: {test_loss:.4f}")
logger.info(f"Exactitud: {test_metrics['accuracy']:.3f}")
print(f"Exactitud: {test_metrics['accuracy']:.3f}")
logger.info(f"F1-Score: {test_metrics['f1_score']:.3f}")
print(f"F1-Score: {test_metrics['f1_score']:.3f}")
logger.info(f"Tiempo de inferencia promedio: {test_inference_time:.4f} segundos")
print(f"Tiempo de inferencia promedio: {test_inference_time:.4f} segundos")
logger.info(f"Falsos Positivos: {fp}")
print(f"Falsos Positivos: {fp}")
logger.info(f"Falsos Negativos: {fn}")
print(f"Falsos Negativos: {fn}")

plot_metrics(test_metrics, "Prueba", "Final", CHECKPOINT_PATH)
plot_probability_histogram(test_labels, test_probs_positive, "Prueba Final", CHECKPOINT_PATH)
plot_prob_vs_duration(test_durations, test_probs_positive, test_labels, "Prueba Final", CHECKPOINT_PATH)

from sklearn.metrics import classification_report
report = classification_report(test_labels, test_preds, target_names=LABELS)
print("\nInforme de Clasificación:")
print(report)
wandb.log({"test_classification_report": wandb.Html(report)})

In [None]:
# --- Sección 11: Evaluación en Conjunto de Prueba ---
logger.info("Evaluando modelo final en conjunto de prueba")
print("Evaluando modelo final en conjunto de prueba")
best_model_path = os.path.join(CHECKPOINT_PATH, f"best_model_ft_{TIMESTAMP}.pt")
model.load_state_dict(torch.load(best_model_path))
logger.info(f"Mejor modelo FT cargado: {best_model_path}")
print(f"Mejor modelo FT cargado: {best_model_path}")

test_loss, test_metrics, test_labels, test_probs, test_inference_time, test_durations = evaluate_epoch(
    model, test_loader, criterion, device
)
test_preds = np.argmax(test_probs, axis=1) if test_probs.ndim > 1 else (test_probs > 0.5).astype(int)
fp, fn = analyze_errors(test_labels, test_preds, test_probs[:, 1] if test_probs.ndim > 1 else test_probs,
                       [test_dataset.videos[i] for i in range(len(test_labels))], "Prueba", "Final", CHECKPOINT_PATH)

logger.info("Resultados en conjunto de prueba:")
print("Resultados en conjunto de prueba:")
logger.info(f"Pérdida: {test_loss:.4f}")
print(f"Pérdida: {test_loss:.4f}")
logger.info(f"Exactitud: {test_metrics['accuracy']:.3f}")
print(f"Exactitud: {test_metrics['accuracy']:.3f}")
logger.info(f"Precisión: {test_metrics['precision']:.3f}")
print(f"Precisión: {test_metrics['precision']:.3f}")
logger.info(f"Sensibilidad: {test_metrics['recall']:.3f}")
print(f"Sensibilidad: {test_metrics['recall']:.3f}")
logger.info(f"Especificidad: {test_metrics['specificity']:.3f}")
print(f"Especificidad: {test_metrics['specificity']:.3f}")
logger.info(f"F1-Score: {test_metrics['f1_score']:.3f}")
print(f"F1-Score: {test_metrics['f1_score']:.3f}")
logger.info(f"AUC: {test_metrics['auc']:.3f}")
print(f"AUC: {test_metrics['auc']:.3f}")
logger.info(f"Tiempo de inferencia promedio: {test_inference_time:.4f} segundos")
print(f"Tiempo de inferencia promedio: {test_inference_time:.4f} segundos")
logger.info(f"Falsos Positivos: {fp}")
print(f"Falsos Positivos: {fp}")
logger.info(f"Falsos Negativos: {fn}")
print(f"Falsos Negativos: {fn}")

plot_metrics(test_metrics, "Prueba", "Final", CHECKPOINT_PATH)
plot_probability_histogram(test_labels, test_probs[:, 1] if test_probs.ndim > 1 else test_probs,
                         "Prueba Final", CHECKPOINT_PATH)
plot_prob_vs_duration(test_durations, test_probs[:, 1] if test_probs.ndim > 1 else test_probs,
                     test_labels, "Prueba Final", CHECKPOINT_PATH)

print("\nInforme de Clasificación:")
report = classification_report(test_labels, test_preds, target_names=LABELS)
print(report)
wandb.log({"test_classification_report": wandb.Html(report)})

# --- Sección 12: Cuantización Dinámica ---
logger.info("Aplicando cuantización dinámica para optimizar latencia")
print("Aplicando cuantización dinámica para optimizar latencia")
quantized_model = torch.quantization.quantize_dynamic(
    model.cpu(), {torch.nn.Linear}, dtype=torch.qint8
)
quantized_model_path = os.path.join(CHECKPOINT_PATH, f"quantized_model_ft_{TIMESTAMP}.pt")
torch.save(quantized_model.state_dict(), quantized_model_path)
logger.info(f"Modelo cuantizado guardado: {quantized_model_path}")
print(f"Modelo cuantizado guardado: {quantized_model_path}")

# Verificar latencia del modelo cuantizado
quantized_model = quantized_model.to(device)
quantized_model.eval()
test_inference_times = []
with torch.no_grad():
    for inputs, _, _ in tqdm(test_loader, desc="Evaluando modelo cuantizado"):
        start_time = time.time()
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with autocast():
            _ = quantized_model(**inputs)
        test_inference_times.append(time.time() - start_time)
avg_quantized_inference_time = np.mean(test_inference_times)
logger.info(f"Tiempo de inferencia promedio (cuantizado): {avg_quantized_inference_time:.4f} segundos")
print(f"Tiempo de inferencia promedio (cuantizado): {avg_quantized_inference_time:.4f} segundos")
wandb.log({"quantized_inference_time": avg_quantized_inference_time})

if avg_quantized_inference_time > 0.2:
    logger.warning("Latencia cuantizada aún alta para tiempo real. Considerar optimización adicional.")
    print("Latencia cuantizada aún alta para tiempo real. Considerar optimización adicional.")

In [None]:
# --- Sección 9: Evaluación Final en Conjunto de Prueba ---
logger.info("Evaluando modelo final en conjunto de prueba")
best_model_path = os.path.join(CHECKPOINT_PATH, f"best_model_ft_{TIMESTAMP}.pt")
model.load_state_dict(torch.load(best_model_path))
logger.info(f"Mejor modelo FT cargado: {best_model_path}")

test_loss, test_metrics, test_labels, test_probs, test_inference_time = evaluate_epoch(
    model, test_loader, criterion, device
)

logger.info("Resultados en conjunto de prueba:")
logger.info(f"Pérdida: {test_loss:.4f}")
logger.info(f"Exactitud: {test_metrics['accuracy']:.3f}")
logger.info(f"Precisión: {test_metrics['precision']:.3f}")
logger.info(f"Sensibilidad: {test_metrics['recall']:.3f}")
logger.info(f"Especificidad: {test_metrics['specificity']:.3f}")
logger.info(f"F1-Score: {test_metrics['f1_score']:.3f}")
logger.info(f"AUC: {test_metrics['auc']:.3f}")
logger.info(f"Tiempo de inferencia promedio: {test_inference_time:.4f} segundos")

plot_metrics(test_metrics, "Prueba", "Final", CHECKPOINT_PATH)
plot_probability_histogram(test_labels, test_probs, "Prueba Final", CHECKPOINT_PATH)

# Informe detallado
print("\nInforme de Clasificación:")
print(classification_report(test_labels, test_metrics["confusion_matrix"].argmax(axis=1), target_names=LABELS))
wandb.log({"test_classification_report": wandb.Html(classification_report(test_labels, test_metrics["confusion_matrix"].argmax(axis=1), target_names=LABELS, output_dict=True))})


In [None]:
# --- Sección 10: Análisis Detallado ---
# Análisis de Errores
errors = [(i, test_labels[i], test_metrics["confusion_matrix"].argmax(axis=1)[i], test_probs[i], test_dataset.videos[i]) for i in range(len(test_labels)) if test_labels[i] != test_metrics["confusion_matrix"].argmax(axis=1)[i]]
logger.info(f"Número de errores: {len(errors)}")
for idx, true, pred, prob, video_path in errors[:5]:
    logger.info(f"Índice: {idx}, Video: {video_path}, Real: {LABELS[true]}, Predicho: {LABELS[pred]}, Probabilidad: {prob:.3f}")

    # Visualizar frames de error
    cap = cv2.VideoCapture(video_path)
    frames = []
    while cap.isOpened() and len(frames) < 4:
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(frame)
    cap.release()

    plt.figure(figsize=(12, 3))
    for i, frame in enumerate(frames):
        plt.subplot(1, 4, i+1)
        plt.imshow(frame)
        plt.axis('off')
        plt.title(f"Frame {i+1}")
    plt.suptitle(f"Error: Real={LABELS[true]}, Pred={LABELS[pred]}")
    plt.savefig(os.path.join(CHECKPOINT_PATH, f"error_{idx}.png"))
    plt.show()
    wandb.log({f"error_{idx}": wandb.Image(os.path.join(CHECKPOINT_PATH, f"error_{idx}.png"))})

# Análisis de Duración de Videos vs. Rendimiento
correct = [1 if test_labels[i] == test_metrics["confusion_matrix"].argmax(axis=1)[i] else 0 for i in range(len(test_labels))]
plt.figure(figsize=(8, 6))
plt.scatter(test_dataset.durations, correct, alpha=0.5)
plt.xlabel("Duración del Video (segundos)")
plt.ylabel("Correcto (1) / Incorrecto (0)")
plt.title("Duración de Videos vs. Rendimiento")
plt.savefig(os.path.join(CHECKPOINT_PATH, "duration_vs_performance.png"))
plt.show()
wandb.log({"duration_vs_performance": wandb.Image(os.path.join(CHECKPOINT_PATH, "duration_vs_performance.png"))})

# Análisis de Sesgo de Clases
class_counts = np.bincount(test_labels)
logger.info(f"Distribución de clases en prueba: {dict(zip(LABELS, class_counts))}")
if abs(class_counts[0] - class_counts[1]) / sum(class_counts) > 0.2:
    logger.warning("Posible desbalance de clases detectado")

# Comparación de Modelos
logger.info("Comparación de Modelos:")
logger.info(f"Mejor Exactitud TL (Validación): {max(history_tl['val_acc']):.3f}")
logger.info(f"Mejor F1 TL (Validación): {max(history_tl['val_f1']):.3f}")
logger.info(f"Mejor Exactitud FT (Validación): {max(history_ft['val_acc']):.3f}")
logger.info(f"Mejor F1 FT (Validación): {max(history_ft['val_f1']):.3f}")
logger.info(f"Exactitud Final (Prueba): {test_metrics['accuracy']:.3f}")
logger.info(f"F1-Score Final (Prueba): {test_metrics['f1_score']:.3f}")

# Análisis de Latencia
logger.info(f"Latencia promedio de inferencia (prueba): {test_inference_time:.4f} segundos")
if test_inference_time > 0.2:
    logger.warning("Latencia alta para tiempo real. Considerar optimización adicional.")

In [None]:
# --- Sección 11: Exportación del Modelo ---
logger.info("Exportando modelo final")
# Guardar en formato PyTorch
final_model_path = os.path.join(CHECKPOINT_PATH, f"final_model_{TIMESTAMP}.pt")
torch.save(model.state_dict(), final_model_path)
logger.info(f"Modelo PyTorch guardado: {final_model_path}")

# Validar carga del modelo
try:
    model.load_state_dict(torch.load(final_model_path))
    logger.info("Modelo PyTorch cargado correctamente para validación")
except Exception as e:
    logger.error(f"Error al cargar modelo PyTorch: {e}")
    raise

# Exportar a ONNX
model.eval()
dummy_input = torch.randn(1, NUM_FRAMES, 3, IMG_SIZE, IMG_SIZE).to(device)
onnx_path = os.path.join(CHECKPOINT_PATH, f"final_model_{TIMESTAMP}.onnx")
try:
    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        export_params=True,
        opset_version=12,
        do_constant_folding=True,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
    )
    logger.info(f"Modelo exportado a ONNX: {onnx_path}")

    # Validar modelo ONNX
    ort_session = ort.InferenceSession(onnx_path)
    ort_inputs = {ort_session.get_inputs()[0].name: dummy_input.cpu().numpy()}
    ort_outs = ort_session.run(None, ort_inputs)
    logger.info("Modelo ONNX validado correctamente")
except Exception as e:
    logger.error(f"Error al exportar o validar ONNX: {e}")
    raise

In [None]:
# --- Sección 12: Prueba de Inferencia ---
logger.info("Realizando prueba de inferencia con modelo final")
test_video = test_dataset.videos[0]
cap = cv2.VideoCapture(test_video)
frames = []
while cap.isOpened() and len(frames) < NUM_FRAMES:
    ret, frame = cap.read()
    if not ret:
        break
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frames.append(frame)
cap.release()

if len(frames) < NUM_FRAMES:
    frames = frames + [frames[-1]] * (NUM_FRAMES - len(frames))
inputs = processor(frames, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

model.eval()
with torch.no_grad():
    outputs = model(**inputs)
    probs = torch.softmax(outputs.logits, dim=1)
    pred = torch.argmax(probs, dim=1).item()
    prob_violence = probs[0, 1].item()

logger.info(f"Prueba de inferencia: Video={test_video}, Predicción={LABELS[pred]}, Probabilidad de Violencia={prob_violence:.3f}")


In [None]:
# --- Sección 13: Conclusiones ---
logger.info("Entrenamiento completado")
logger.info(f"Mejor modelo TL: {os.path.join(CHECKPOINT_PATH, f'best_model_tl_{TIMESTAMP}.pt')}")
logger.info(f"Mejor modelo FT: {os.path.join(CHECKPOINT_PATH, f'best_model_ft_{TIMESTAMP}.pt')}")
logger.info(f"Modelo final para despliegue: {final_model_path} y {onnx_path}")
logger.info("Revisar WandB y Google Drive para análisis detallado")

wandb.finish()