In [28]:
import json
import os
from sklearn.model_selection import KFold
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from PIL import Image
import cv2

In [29]:
# Diccionario para mapear posiciones a índices
position_to_index = {
    '5050_guard': 0,
    'back1': 1,
    'back2': 2,
    'closed_guard1': 3,
    'closed_guard2': 4,
    'half_guard1': 5,
    'half_guard2': 6,
    'mount1': 7,
    'mount2': 8,
    'open_guard1': 9,
    'open_guard2': 10,
    'side_control1': 11,
    'side_control2': 12,
    'standing': 13,
    'takedown1': 14,
    'takedown2': 15,
    'turtle1': 16,
    'turtle2': 17
}

# Transformaciones para las imágenes
data_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])


In [30]:
def load_annotations(file_path):
    """
    Carga las anotaciones preprocesadas desde un archivo JSON.
    """
    with open(file_path, 'r') as f:
        return json.load(f)


In [31]:
class BJJDataset(Dataset):
    def __init__(self, annotations, image_dir, transform=None):
        """
        Inicializa el dataset con transformaciones y rutas a imágenes.
        """
        self.annotations = annotations
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        image_name = self.annotations[idx]['Image'] + '.jpg'
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path)
        label_str = self.annotations[idx]['Position']
        label = position_to_index[label_str]

        if self.transform:
            image = self.transform(image)

        label = torch.tensor(label).long()
        return image, label

In [32]:
def initialize_model(num_classes, device):
    """
    Inicializa el modelo ResNet-18 y lo prepara para entrenamiento.
    """
    model = models.resnet18(pretrained=True)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    return model.to(device)

In [33]:

def train_one_epoch(model, train_loader, criterion, optimizer, device):
    """
    Realiza una época de entrenamiento en el modelo.
    """
    model.train()
    running_loss = 0.0
    correct = 0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += torch.sum(preds == labels.data)

    return running_loss / len(train_loader.dataset), correct.double() / len(train_loader.dataset)

In [34]:
def validate_one_epoch(model, val_loader, criterion, device):
    """
    Realiza una época de validación en el modelo y calcula métricas adicionales.
    """
    model.eval()
    val_loss = 0.0
    val_correct = 0
    y_true = []
    y_pred = []

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            val_correct += torch.sum(preds == labels.data)

            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

    val_loss /= len(val_loader.dataset)
    val_acc = val_correct.double() / len(val_loader.dataset)

    # Calcular métricas
    mae = mean_absolute_error(y_true, y_pred)
    mse = mean_squared_error(y_true, y_pred)
    r2 = r2_score(y_true, y_pred)
    std_ae = np.std(np.abs(np.array(y_true) - np.array(y_pred)))

    print(f"Validation Metrics -> Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}, "
          f"MAE: {mae:.4f}, MSE: {mse:.4f}, R²: {r2:.4f}, Std_AE: {std_ae:.4f}")

    return val_loss, val_acc


In [35]:
def train_model_kfold(dataset, num_classes, k_folds=5, num_epochs=25, batch_size=32, early_stopping_patience=5):
    """
    Entrena el modelo ResNet-18 utilizando validación cruzada con K-Fold.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    kf = KFold(n_splits=k_folds, shuffle=True, random_state=42)

    for fold, (train_idx, val_idx) in enumerate(kf.split(dataset)):
        print(f'Fold {fold+1}/{k_folds}')

        # Dividir dataset en subconjuntos de entrenamiento y validación
        train_subset = torch.utils.data.Subset(dataset, train_idx)
        val_subset = torch.utils.data.Subset(dataset, val_idx)

        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_subset, batch_size=batch_size)

        # Inicializar modelo, criterio, optimizador y scheduler
        model = initialize_model(num_classes, device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        scheduler = StepLR(optimizer, step_size=7, gamma=0.1)

        best_val_loss = float('inf')
        early_stopping_counter = 0

        for epoch in range(num_epochs):
            # Entrenamiento
            epoch_loss, epoch_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
            print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

            # Validación
            val_loss, val_acc = validate_one_epoch(model, val_loader, criterion, device)
            print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}')

            # Early Stopping
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                early_stopping_counter = 0
                torch.save(model.state_dict(), f'model_fold_{fold+1}.pth')
            else:
                early_stopping_counter += 1
                if early_stopping_counter >= early_stopping_patience:
                    print("Early stopping")
                    break

        torch.cuda.empty_cache()
        

In [36]:
def main():
    """
    Función principal para cargar el dataset y entrenar el modelo con validación cruzada.
    """
    annotations_path = '../mnt/V1/annotations/annotations_preprocessed.json'
    image_dir = '../mnt/V1/images'

    # Cargar anotaciones
    annotations = load_annotations(annotations_path)

    # Cargar el dataset completo
    dataset = BJJDataset(annotations, image_dir, transform=data_transforms)

    # Entrenar modelo con validación cruzada
    train_model_kfold(dataset, num_classes=len(position_to_index))

if __name__ == '__main__':
    main()

Fold 1/5
Epoch 1/25, Loss: 0.3802, Accuracy: 0.8794
Validation Metrics -> Loss: 0.8345, Accuracy: 0.7652, MAE: 1.7837, MSE: 16.8840, R²: 0.3633, Std_AE: 3.7017
Validation Loss: 0.8345, Validation Accuracy: 0.7652
Epoch 2/25, Loss: 0.1091, Accuracy: 0.9657
Validation Metrics -> Loss: 0.0773, Accuracy: 0.9733, MAE: 0.0980, MSE: 0.6213, R²: 0.9766, Std_AE: 0.7821
Validation Loss: 0.0773, Validation Accuracy: 0.9733
Epoch 3/25, Loss: 0.0677, Accuracy: 0.9788
Validation Metrics -> Loss: 0.2737, Accuracy: 0.9250, MAE: 0.4398, MSE: 3.8565, R²: 0.8546, Std_AE: 1.9139
Validation Loss: 0.2737, Validation Accuracy: 0.9250
Epoch 4/25, Loss: 0.0501, Accuracy: 0.9838
Validation Metrics -> Loss: 0.0718, Accuracy: 0.9760, MAE: 0.1082, MSE: 0.8388, R²: 0.9684, Std_AE: 0.9095
Validation Loss: 0.0718, Validation Accuracy: 0.9760
Epoch 5/25, Loss: 0.0454, Accuracy: 0.9860
Validation Metrics -> Loss: 0.0427, Accuracy: 0.9868, MAE: 0.0558, MSE: 0.4768, R²: 0.9820, Std_AE: 0.6883
Validation Loss: 0.0427, Val



Epoch 1/25, Loss: 0.3510, Accuracy: 0.8871
Validation Metrics -> Loss: 0.3882, Accuracy: 0.8880, MAE: 0.6490, MSE: 6.0250, R²: 0.7749, Std_AE: 2.3672
Validation Loss: 0.3882, Validation Accuracy: 0.8880
Epoch 2/25, Loss: 0.0947, Accuracy: 0.9706
Validation Metrics -> Loss: 0.2508, Accuracy: 0.9133, MAE: 0.4293, MSE: 3.2140, R²: 0.8799, Std_AE: 1.7406
Validation Loss: 0.2508, Validation Accuracy: 0.9133
Epoch 3/25, Loss: 0.0736, Accuracy: 0.9758
Validation Metrics -> Loss: 0.0542, Accuracy: 0.9858, MAE: 0.0542, MSE: 0.3322, R²: 0.9876, Std_AE: 0.5738
Validation Loss: 0.0542, Validation Accuracy: 0.9858
Epoch 4/25, Loss: 0.0587, Accuracy: 0.9808
Validation Metrics -> Loss: 0.0727, Accuracy: 0.9772, MAE: 0.0958, MSE: 0.6428, R²: 0.9760, Std_AE: 0.7960
Validation Loss: 0.0727, Validation Accuracy: 0.9772
Epoch 5/25, Loss: 0.0259, Accuracy: 0.9916
Validation Metrics -> Loss: 0.0459, Accuracy: 0.9863, MAE: 0.0563, MSE: 0.4170, R²: 0.9844, Std_AE: 0.6433
Validation Loss: 0.0459, Validation Ac



Epoch 1/25, Loss: 0.3468, Accuracy: 0.8912
Validation Metrics -> Loss: 0.2263, Accuracy: 0.9330, MAE: 0.2852, MSE: 1.9418, R²: 0.9276, Std_AE: 1.3640
Validation Loss: 0.2263, Validation Accuracy: 0.9330
Epoch 2/25, Loss: 0.0995, Accuracy: 0.9683
Validation Metrics -> Loss: 0.1488, Accuracy: 0.9498, MAE: 0.2522, MSE: 2.1248, R²: 0.9207, Std_AE: 1.4357
Validation Loss: 0.1488, Validation Accuracy: 0.9498
Epoch 3/25, Loss: 0.0601, Accuracy: 0.9808
Validation Metrics -> Loss: 0.0609, Accuracy: 0.9828, MAE: 0.0672, MSE: 0.4898, R²: 0.9817, Std_AE: 0.6967
Validation Loss: 0.0609, Validation Accuracy: 0.9828
Epoch 4/25, Loss: 0.0588, Accuracy: 0.9809
Validation Metrics -> Loss: 0.0533, Accuracy: 0.9827, MAE: 0.0655, MSE: 0.4908, R²: 0.9817, Std_AE: 0.6975
Validation Loss: 0.0533, Validation Accuracy: 0.9827
Epoch 5/25, Loss: 0.0361, Accuracy: 0.9885
Validation Metrics -> Loss: 0.0616, Accuracy: 0.9828, MAE: 0.0630, MSE: 0.5163, R²: 0.9807, Std_AE: 0.7158
Validation Loss: 0.0616, Validation Ac



Epoch 1/25, Loss: 0.3516, Accuracy: 0.8896
Validation Metrics -> Loss: 0.1756, Accuracy: 0.9420, MAE: 0.2192, MSE: 1.4265, R²: 0.9463, Std_AE: 1.1741
Validation Loss: 0.1756, Validation Accuracy: 0.9420
Epoch 2/25, Loss: 0.1029, Accuracy: 0.9670
Validation Metrics -> Loss: 0.1670, Accuracy: 0.9475, MAE: 0.2368, MSE: 2.0968, R²: 0.9211, Std_AE: 1.4285
Validation Loss: 0.1670, Validation Accuracy: 0.9475
Epoch 3/25, Loss: 0.0621, Accuracy: 0.9797
Validation Metrics -> Loss: 0.1659, Accuracy: 0.9438, MAE: 0.2160, MSE: 1.4780, R²: 0.9444, Std_AE: 1.1964
Validation Loss: 0.1659, Validation Accuracy: 0.9438
Epoch 4/25, Loss: 0.0524, Accuracy: 0.9835
Validation Metrics -> Loss: 0.1199, Accuracy: 0.9695, MAE: 0.1225, MSE: 0.9178, R²: 0.9654, Std_AE: 0.9502
Validation Loss: 0.1199, Validation Accuracy: 0.9695
Epoch 5/25, Loss: 0.0402, Accuracy: 0.9880
Validation Metrics -> Loss: 0.0328, Accuracy: 0.9892, MAE: 0.0418, MSE: 0.2852, R²: 0.9893, Std_AE: 0.5324
Validation Loss: 0.0328, Validation Ac



Epoch 1/25, Loss: 0.3596, Accuracy: 0.8847
Validation Metrics -> Loss: 0.3725, Accuracy: 0.8935, MAE: 0.4788, MSE: 3.4622, R²: 0.8698, Std_AE: 1.7980
Validation Loss: 0.3725, Validation Accuracy: 0.8935
Epoch 2/25, Loss: 0.1024, Accuracy: 0.9664
Validation Metrics -> Loss: 0.1023, Accuracy: 0.9638, MAE: 0.1685, MSE: 1.3062, R²: 0.9509, Std_AE: 1.1304
Validation Loss: 0.1023, Validation Accuracy: 0.9638
Epoch 3/25, Loss: 0.0712, Accuracy: 0.9768
Validation Metrics -> Loss: 0.2320, Accuracy: 0.9192, MAE: 0.3530, MSE: 2.2050, R²: 0.9171, Std_AE: 1.4424
Validation Loss: 0.2320, Validation Accuracy: 0.9192
Epoch 4/25, Loss: 0.0500, Accuracy: 0.9845
Validation Metrics -> Loss: 0.1011, Accuracy: 0.9653, MAE: 0.1380, MSE: 0.8983, R²: 0.9662, Std_AE: 0.9377
Validation Loss: 0.1011, Validation Accuracy: 0.9653
Epoch 5/25, Loss: 0.0299, Accuracy: 0.9906
Validation Metrics -> Loss: 0.2323, Accuracy: 0.9403, MAE: 0.2848, MSE: 2.3545, R²: 0.9115, Std_AE: 1.5078
Validation Loss: 0.2323, Validation Ac