In [None]:
import json
import os
from sklearn.model_selection import KFold
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from PIL import Image
import cv2

In [11]:
# Diccionario para mapear posiciones a índices
position_to_index = {
    '5050_guard': 0,
    'back1': 1,
    'back2': 2,
    'closed_guard1': 3,
    'closed_guard2': 4,
    'half_guard1': 5,
    'half_guard2': 6,
    'mount1': 7,
    'mount2': 8,
    'open_guard1': 9,
    'open_guard2': 10,
    'side_control1': 11,
    'side_control2': 12,
    'standing': 13,
    'takedown1': 14,
    'takedown2': 15,
    'turtle1': 16,
    'turtle2': 17
}

# Transformaciones para las imágenes
data_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

def remove_background(image):
    """
    Quita el fondo de una imagen utilizando segmentación de contornos con OpenCV.
    """
    image_np = np.array(image)
    gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
    _, thresh = cv2.threshold(gray, 50, 255, cv2.THRESH_BINARY)
    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    mask = np.zeros_like(gray)
    if contours:
        largest_contour = max(contours, key=cv2.contourArea)
        cv2.drawContours(mask, [largest_contour], -1, 255, thickness=cv2.FILLED)
    result = cv2.bitwise_and(image_np, image_np, mask=mask)
    return Image.fromarray(result)

In [12]:
def load_annotations(file_path):
    """
    Carga las anotaciones preprocesadas desde un archivo JSON.
    """
    with open(file_path, 'r') as f:
        return json.load(f)


In [13]:

class BJJDataset(Dataset):
    def __init__(self, annotations, image_dir, transform=None):
        self.annotations = annotations
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        image_name = self.annotations[idx]['Image'] + '.jpg'
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path).convert('RGB')

        # Aplicar eliminación de fondo antes de transformaciones
        image = remove_background(image)

        if self.transform:
            image = self.transform(image)

        label_str = self.annotations[idx]['Position']
        label = position_to_index[label_str]
        label = torch.tensor(label).long()

        return image, label


In [14]:
def initialize_model(num_classes, device):
    """
    Inicializa el modelo ResNet-18 y lo prepara para entrenamiento.
    """
    model = models.resnet18(pretrained=True)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    return model.to(device)

In [15]:

def train_one_epoch(model, train_loader, criterion, optimizer, device):
    """
    Realiza una época de entrenamiento en el modelo.
    """
    model.train()
    running_loss = 0.0
    correct = 0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += torch.sum(preds == labels.data)

    return running_loss / len(train_loader.dataset), correct.double() / len(train_loader.dataset)

In [16]:
def validate_one_epoch(model, val_loader, criterion, device):
    """
    Realiza una época de validación en el modelo y calcula métricas adicionales.
    """
    model.eval()
    val_loss = 0.0
    val_correct = 0
    y_true = []
    y_pred = []

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            val_correct += torch.sum(preds == labels.data)

            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

    val_loss /= len(val_loader.dataset)
    val_acc = val_correct.double() / len(val_loader.dataset)

    # Calcular métricas
    mae = mean_absolute_error(y_true, y_pred)
    mse = mean_squared_error(y_true, y_pred)
    r2 = r2_score(y_true, y_pred)
    std_ae = np.std(np.abs(np.array(y_true) - np.array(y_pred)))

    print(f"Validation Metrics -> Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}, "
          f"MAE: {mae:.4f}, MSE: {mse:.4f}, R²: {r2:.4f}, Std_AE: {std_ae:.4f}")

    return val_loss, val_acc


In [17]:
def train_model_kfold(dataset, num_classes, k_folds=5, num_epochs=25, batch_size=32, early_stopping_patience=5):
    """
    Entrena el modelo ResNet-18 utilizando validación cruzada con K-Fold.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    kf = KFold(n_splits=k_folds, shuffle=True, random_state=42)

    for fold, (train_idx, val_idx) in enumerate(kf.split(dataset)):
        print(f'Fold {fold+1}/{k_folds}')

        # Dividir dataset en subconjuntos de entrenamiento y validación
        train_subset = torch.utils.data.Subset(dataset, train_idx)
        val_subset = torch.utils.data.Subset(dataset, val_idx)

        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_subset, batch_size=batch_size)

        # Inicializar modelo, criterio, optimizador y scheduler
        model = initialize_model(num_classes, device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        scheduler = StepLR(optimizer, step_size=7, gamma=0.1)

        best_val_loss = float('inf')
        early_stopping_counter = 0

        for epoch in range(num_epochs):
            # Entrenamiento
            epoch_loss, epoch_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
            print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

            # Validación
            val_loss, val_acc = validate_one_epoch(model, val_loader, criterion, device)
            print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}')

            # Early Stopping
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                early_stopping_counter = 0
                torch.save(model.state_dict(), f'model_fold_{fold+1}.pth')
            else:
                early_stopping_counter += 1
                if early_stopping_counter >= early_stopping_patience:
                    print("Early stopping")
                    break

        torch.cuda.empty_cache()
        

In [18]:
def main():
    """
    Función principal para cargar el dataset y entrenar el modelo con validación cruzada.
    """
    annotations_path = '../mnt/V3/annotations/annotations_preprocessed.json'
    image_dir = '../mnt/V3/images'

    # Cargar anotaciones
    annotations = load_annotations(annotations_path)

    # Cargar el dataset completo
    dataset = BJJDataset(annotations, image_dir, transform=data_transforms)

    # Entrenar modelo con validación cruzada
    train_model_kfold(dataset, num_classes=len(position_to_index))

if __name__ == '__main__':
    main()

Fold 1/5
Epoch 1/25, Loss: 1.4633, Accuracy: 0.4981
Validation Metrics -> Loss: 1.1553, Accuracy: 0.6013, MAE: 1.8845, MSE: 15.4628, R²: 0.4183, Std_AE: 3.4513
Validation Loss: 1.1553, Validation Accuracy: 0.6013
Epoch 2/25, Loss: 0.9426, Accuracy: 0.6696
Validation Metrics -> Loss: 0.8491, Accuracy: 0.7112, MAE: 1.5359, MSE: 13.5476, R²: 0.4904, Std_AE: 3.3449
Validation Loss: 0.8491, Validation Accuracy: 0.7112
Epoch 3/25, Loss: 0.6865, Accuracy: 0.7605
Validation Metrics -> Loss: 0.6913, Accuracy: 0.7646, MAE: 1.1843, MSE: 10.0137, R²: 0.6233, Std_AE: 2.9345
Validation Loss: 0.6913, Validation Accuracy: 0.7646
Epoch 4/25, Loss: 0.5233, Accuracy: 0.8148
Validation Metrics -> Loss: 0.5613, Accuracy: 0.8123, MAE: 0.9074, MSE: 7.4245, R²: 0.7207, Std_AE: 2.5693
Validation Loss: 0.5613, Validation Accuracy: 0.8123
Epoch 5/25, Loss: 0.4055, Accuracy: 0.8556
Validation Metrics -> Loss: 0.5345, Accuracy: 0.8247, MAE: 0.8067, MSE: 6.4180, R²: 0.7586, Std_AE: 2.4015
Validation Loss: 0.5345, V



Epoch 1/25, Loss: 1.4334, Accuracy: 0.5116
Validation Metrics -> Loss: 1.1944, Accuracy: 0.5860, MAE: 1.8839, MSE: 14.9321, R²: 0.4415, Std_AE: 3.3738
Validation Loss: 1.1944, Validation Accuracy: 0.5860
Epoch 2/25, Loss: 0.9404, Accuracy: 0.6729
Validation Metrics -> Loss: 0.8029, Accuracy: 0.7216, MAE: 1.3411, MSE: 11.0276, R²: 0.5875, Std_AE: 3.0379
Validation Loss: 0.8029, Validation Accuracy: 0.7216
Epoch 3/25, Loss: 0.6841, Accuracy: 0.7606
Validation Metrics -> Loss: 0.6416, Accuracy: 0.7767, MAE: 1.0368, MSE: 8.2948, R²: 0.6898, Std_AE: 2.6870
Validation Loss: 0.6416, Validation Accuracy: 0.7767
Epoch 4/25, Loss: 0.5175, Accuracy: 0.8165
Validation Metrics -> Loss: 0.5977, Accuracy: 0.7979, MAE: 0.9775, MSE: 8.1225, R²: 0.6962, Std_AE: 2.6771
Validation Loss: 0.5977, Validation Accuracy: 0.7979
Epoch 5/25, Loss: 0.4100, Accuracy: 0.8543
Validation Metrics -> Loss: 0.5118, Accuracy: 0.8299, MAE: 0.8393, MSE: 7.0409, R²: 0.7367, Std_AE: 2.5172
Validation Loss: 0.5118, Validation 



Epoch 1/25, Loss: 1.4410, Accuracy: 0.5086
Validation Metrics -> Loss: 1.0736, Accuracy: 0.6284, MAE: 1.7248, MSE: 14.1760, R²: 0.4661, Std_AE: 3.3468
Validation Loss: 1.0736, Validation Accuracy: 0.6284
Epoch 2/25, Loss: 0.9265, Accuracy: 0.6791
Validation Metrics -> Loss: 0.8368, Accuracy: 0.7142, MAE: 1.3934, MSE: 11.8316, R²: 0.5544, Std_AE: 3.1448
Validation Loss: 0.8368, Validation Accuracy: 0.7142
Epoch 3/25, Loss: 0.6677, Accuracy: 0.7691
Validation Metrics -> Loss: 0.6642, Accuracy: 0.7701, MAE: 1.0825, MSE: 9.0004, R²: 0.6611, Std_AE: 2.7980
Validation Loss: 0.6642, Validation Accuracy: 0.7701
Epoch 4/25, Loss: 0.5093, Accuracy: 0.8240
Validation Metrics -> Loss: 0.6136, Accuracy: 0.7939, MAE: 0.9895, MSE: 8.2726, R²: 0.6885, Std_AE: 2.7006
Validation Loss: 0.6136, Validation Accuracy: 0.7939
Epoch 5/25, Loss: 0.3987, Accuracy: 0.8595
Validation Metrics -> Loss: 0.5721, Accuracy: 0.8144, MAE: 0.8502, MSE: 6.7695, R²: 0.7451, Std_AE: 2.4590
Validation Loss: 0.5721, Validation 



Epoch 1/25, Loss: 1.4437, Accuracy: 0.5052
Validation Metrics -> Loss: 1.1250, Accuracy: 0.6121, MAE: 1.8626, MSE: 15.3110, R²: 0.4406, Std_AE: 3.4412
Validation Loss: 1.1250, Validation Accuracy: 0.6121
Epoch 2/25, Loss: 0.9437, Accuracy: 0.6712
Validation Metrics -> Loss: 0.8040, Accuracy: 0.7186, MAE: 1.2706, MSE: 10.5152, R²: 0.6158, Std_AE: 2.9834
Validation Loss: 0.8040, Validation Accuracy: 0.7186
Epoch 3/25, Loss: 0.6853, Accuracy: 0.7642
Validation Metrics -> Loss: 0.6559, Accuracy: 0.7704, MAE: 1.1514, MSE: 9.9055, R²: 0.6381, Std_AE: 2.9291
Validation Loss: 0.6559, Validation Accuracy: 0.7704
Epoch 4/25, Loss: 0.5117, Accuracy: 0.8218
Validation Metrics -> Loss: 0.5344, Accuracy: 0.8173, MAE: 0.8721, MSE: 7.3270, R²: 0.7323, Std_AE: 2.5625
Validation Loss: 0.5344, Validation Accuracy: 0.8173
Epoch 5/25, Loss: 0.3969, Accuracy: 0.8581
Validation Metrics -> Loss: 0.5112, Accuracy: 0.8357, MAE: 0.7575, MSE: 6.0498, R²: 0.7790, Std_AE: 2.3401
Validation Loss: 0.5112, Validation 



Epoch 1/25, Loss: 1.4560, Accuracy: 0.5006
Validation Metrics -> Loss: 1.1830, Accuracy: 0.5924, MAE: 1.9608, MSE: 16.2488, R²: 0.4004, Std_AE: 3.5220
Validation Loss: 1.1830, Validation Accuracy: 0.5924
Epoch 2/25, Loss: 0.9643, Accuracy: 0.6673
Validation Metrics -> Loss: 0.7663, Accuracy: 0.7369, MAE: 1.2541, MSE: 10.2808, R²: 0.6206, Std_AE: 2.9510
Validation Loss: 0.7663, Validation Accuracy: 0.7369
Epoch 3/25, Loss: 0.7098, Accuracy: 0.7532
Validation Metrics -> Loss: 0.6654, Accuracy: 0.7727, MAE: 1.0358, MSE: 8.3741, R²: 0.6910, Std_AE: 2.7021
Validation Loss: 0.6654, Validation Accuracy: 0.7727
Epoch 4/25, Loss: 0.5419, Accuracy: 0.8095
Validation Metrics -> Loss: 0.6088, Accuracy: 0.7942, MAE: 0.9742, MSE: 8.2498, R²: 0.6956, Std_AE: 2.7020
Validation Loss: 0.6088, Validation Accuracy: 0.7942
Epoch 5/25, Loss: 0.4288, Accuracy: 0.8491
Validation Metrics -> Loss: 0.5777, Accuracy: 0.8078, MAE: 0.9150, MSE: 7.4995, R²: 0.7233, Std_AE: 2.5811
Validation Loss: 0.5777, Validation 