In [None]:
import json
import os
from sklearn.model_selection import KFold
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from torch.cuda.amp import autocast, GradScaler
from torchvision.models.segmentation import deeplabv3_resnet50
import numpy as np

    # Transformaciones de datos
data_transforms = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

In [None]:
def load_annotations(file_path):
    """
    Carga las anotaciones preprocesadas desde un archivo JSON.
    """
    with open(file_path, 'r') as f:
        return json.load(f)

In [None]:
class SkeletonDataset(Dataset):
    def __init__(self, annotations, image_dir, transform=None):
        """
        Inicializa el dataset para análisis de poses esqueléticas.
        """
        self.annotations = annotations
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        """
        Devuelve el tamaño del dataset.
        """
        return len(self.annotations)

    def __getitem__(self, idx):
        """
        Devuelve una imagen y su correspondiente vector de poses concatenadas.
        """
        image_name = self.annotations[idx]['Image'] + '.jpg'
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path)

        pose1 = self.annotations[idx]['Pose1']
        pose2 = self.annotations[idx]['Pose2']

        if self.transform:
            image = self.transform(image)

        pose = pose1 + pose2
        pose = torch.tensor(pose).view(-1)  # Vectorizar las poses
        return image, pose

In [None]:
class HRNetForPose(nn.Module):
    def __init__(self, num_keypoints=102):
        """
        Modifica HRNet para predecir directamente puntos clave de poses (102 valores).
        """
        super(HRNetForPose, self).__init__()
        self.backbone = deeplabv3_resnet50(pretrained=True).backbone
        self.fc = nn.Sequential(
            nn.Conv2d(2048, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, num_keypoints, kernel_size=1, stride=1),
            nn.AdaptiveAvgPool2d((1, 1))
        )

    def forward(self, x):
        """
        Pasada hacia adelante del modelo.
        """
        x = self.backbone(x)['out']
        x = self.fc(x)
        return x.view(x.size(0), -1)

In [None]:
def train_one_epoch(model, train_loader, criterion, optimizer, scaler, device):
    """
    Entrena el modelo durante una época.
    """
    model.train()
    running_loss = 0.0

    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()

        # AMP para precisión mixta
        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * inputs.size(0)

    return running_loss / len(train_loader.dataset)

In [None]:
def validate_one_epoch(model, val_loader, criterion, scaler, device):
    """
    Valida el modelo durante una época.
    """
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item() * inputs.size(0)

    return val_loss / len(val_loader.dataset)

In [None]:
def train_and_validate(train_loader, val_loader, model, criterion, optimizer, scaler, device, num_epochs=25, patience=5):
    """
    Entrena y valida el modelo, implementando early stopping.
    """
    best_val_loss = float('inf')
    early_stopping_counter = 0

    for epoch in range(num_epochs):
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, scaler, device)
        print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}')

        val_loss = validate_one_epoch(model, val_loader, criterion, scaler, device)
        print(f'Epoch {epoch+1}, Validation Loss: {val_loss:.4f}')

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            early_stopping_counter = 0
            torch.save(model.state_dict(), 'hrnet_best_model.pth')
        else:
            early_stopping_counter += 1
            if early_stopping_counter >= patience:
                print("Early stopping triggered")
                break

        torch.cuda.empty_cache()

In [None]:
def cross_validation_kfold(dataset, k_folds=5, batch_size=16, num_epochs=25, patience=5):
    """
    Realiza validación cruzada con K-Fold en el dataset.
    """
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
        print(f'Fold {fold+1}/{k_folds}')

        train_subset = torch.utils.data.Subset(dataset, train_idx)
        val_subset = torch.utils.data.Subset(dataset, val_idx)

        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_subset, batch_size=batch_size)

        # Inicializar modelo, criterio, optimizador y escalador
        model = HRNetForPose(num_keypoints=102).to(device)
        criterion = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        scaler = GradScaler()

        train_and_validate(train_loader, val_loader, model, criterion, optimizer, scaler, device, num_epochs, patience)


In [None]:
def main():
    """
    Función principal para cargar datos y ejecutar validación cruzada.
    """
    annotations_path = '../mnt/V2/annotations/annotations_preprocessed.json'
    image_dir = '../mnt/V2/images'

    annotations = load_annotations(annotations_path)

    # Cargar dataset
    dataset = SkeletonDataset(annotations, image_dir=image_dir, transform=data_transforms)

    # Validación cruzada con early stopping
    cross_validation_kfold(dataset, k_folds=5, batch_size=16, num_epochs=25, patience=5)

if __name__ == '__main__':
    main()

Fold 1/5
Epoch 1/25, Loss: 0.022119741960739098
Validation Loss: 0.01614552407960097
Epoch 2/25, Loss: 0.015481623606756329
Validation Loss: 0.015117096982896329
Epoch 3/25, Loss: 0.013427072680244843
Validation Loss: 0.013808802790939809
Epoch 4/25, Loss: 0.011434545034853121
Validation Loss: 0.010966500379145145
Epoch 5/25, Loss: 0.009857095356409749
Validation Loss: 0.009805596155424912
Epoch 6/25, Loss: 0.008701335644349455
Validation Loss: 0.008630088952680429
Epoch 7/25, Loss: 0.007853433494456113
Validation Loss: 0.007675500318408013
Epoch 8/25, Loss: 0.007229135247568289
Validation Loss: 0.0070240427429477375
Epoch 9/25, Loss: 0.0066451052545259395
Validation Loss: 0.006758729490141074
Epoch 10/25, Loss: 0.006154507968264321
Validation Loss: 0.006063954694817464
Epoch 11/25, Loss: 0.0057156479628756645
Validation Loss: 0.005897498392810424
Epoch 12/25, Loss: 0.005293732135401418
Validation Loss: 0.005639238127196828
Epoch 13/25, Loss: 0.004948922306764871
Validation Loss: 0.005