In [None]:
import json
import os
from sklearn.model_selection import KFold
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from PIL import Image

In [None]:
# Diccionario para mapear posiciones a índices
position_to_index = {
    '5050_guard': 0,
    'back1': 1,
    'back2': 2,
    'closed_guard1': 3,
    'closed_guard2': 4,
    'half_guard1': 5,
    'half_guard2': 6,
    'mount1': 7,
    'mount2': 8,
    'open_guard1': 9,
    'open_guard2': 10,
    'side_control1': 11,
    'side_control2': 12,
    'standing': 13,
    'takedown1': 14,
    'takedown2': 15,
    'turtle1': 16,
    'turtle2': 17
}

# Transformaciones para las imágenes
data_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])


In [None]:
def load_annotations(file_path):
    """
    Carga las anotaciones preprocesadas desde un archivo JSON.
    """
    with open(file_path, 'r') as f:
        return json.load(f)


In [None]:
class BJJDataset(Dataset):
    def __init__(self, annotations, image_dir, transform=None):
        """
        Inicializa el dataset con transformaciones y rutas a imágenes.
        """
        self.annotations = annotations
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        image_name = self.annotations[idx]['Image'] + '.jpg'
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path)
        label_str = self.annotations[idx]['Position']
        label = position_to_index[label_str]

        if self.transform:
            image = self.transform(image)

        label = torch.tensor(label).long()
        return image, label

In [None]:
def initialize_model(num_classes, device):
    """
    Inicializa el modelo ResNet-18 y lo prepara para entrenamiento.
    """
    model = models.resnet18(pretrained=True)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    return model.to(device)

In [None]:

def train_one_epoch(model, train_loader, criterion, optimizer, device):
    """
    Realiza una época de entrenamiento en el modelo.
    """
    model.train()
    running_loss = 0.0
    correct = 0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_acc = correct.double() / len(train_loader.dataset)
    return epoch_loss, epoch_acc

In [None]:
def validate_one_epoch(model, val_loader, criterion, device):
    """
    Realiza una época de validación en el modelo.
    """
    model.eval()
    val_loss = 0.0
    val_correct = 0

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            val_correct += torch.sum(preds == labels.data)

    val_loss /= len(val_loader.dataset)
    val_acc = val_correct.double() / len(val_loader.dataset)
    return val_loss, val_acc

In [None]:
def train_model_kfold(dataset, num_classes, k_folds=5, num_epochs=25, batch_size=32, early_stopping_patience=5):
    """
    Entrena el modelo ResNet-18 utilizando validación cruzada con K-Fold.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    kf = KFold(n_splits=k_folds, shuffle=True, random_state=42)

    for fold, (train_idx, val_idx) in enumerate(kf.split(dataset)):
        print(f'Fold {fold+1}/{k_folds}')

        # Dividir dataset en subconjuntos de entrenamiento y validación
        train_subset = torch.utils.data.Subset(dataset, train_idx)
        val_subset = torch.utils.data.Subset(dataset, val_idx)

        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_subset, batch_size=batch_size)

        # Inicializar modelo, criterio, optimizador y scheduler
        model = initialize_model(num_classes, device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        scheduler = StepLR(optimizer, step_size=7, gamma=0.1)

        best_val_loss = float('inf')
        early_stopping_counter = 0

        for epoch in range(num_epochs):
            # Entrenamiento
            epoch_loss, epoch_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
            print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

            # Validación
            val_loss, val_acc = validate_one_epoch(model, val_loader, criterion, device)
            print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}')

            # Early Stopping
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                early_stopping_counter = 0
                torch.save(model.state_dict(), f'model_fold_{fold+1}.pth')
            else:
                early_stopping_counter += 1
                if early_stopping_counter >= early_stopping_patience:
                    print("Early stopping")
                    break

        torch.cuda.empty_cache()
        

In [None]:
def main():
    """
    Función principal para cargar el dataset y entrenar el modelo con validación cruzada.
    """
    annotations_path = '../mnt/V3/annotations/annotations_preprocessed.json'
    image_dir = '../mnt/V3/images'

    # Cargar anotaciones
    annotations = load_annotations(annotations_path)

    # Cargar el dataset completo
    dataset = BJJDataset(annotations, image_dir, transform=data_transforms)

    # Entrenar modelo con validación cruzada
    train_model_kfold(dataset, num_classes=len(position_to_index))

if __name__ == '__main__':
    main()