In [13]:
import os
import json
import torch
import numpy as np
import cv2
from PIL import Image
from torchvision import models, transforms
import torch.nn as nn
from torchvision.models.segmentation import deeplabv3_resnet50
import random
import matplotlib.pyplot as plt

# Diccionario de clases
position_to_index = {
    '5050_guard': 0,
    'back1': 1,
    'back2': 2,
    'closed_guard1': 3,
    'closed_guard2': 4,
    'half_guard1': 5,
    'half_guard2': 6,
    'mount1': 7,
    'mount2': 8,
    'open_guard1': 9,
    'open_guard2': 10,
    'side_control1': 11,
    'side_control2': 12,
    'standing': 13,
    'takedown1': 14,
    'takedown2': 15,
    'turtle1': 16,
    'turtle2': 17
}
index_to_position = {v: k for k, v in position_to_index.items()}

# Rutas de los modelos
hrnet_model_path = './V2/Models/Hrnet_Models/hrnet_best_model.pth'
lstm_model_path = './V2/Models/Lstm_Models/best_lstm_model.pth'
resnet_model_path = './V2/Models/Resnet_Models/model_fold_3.pth'
image_dir = './mnt/Dataset/images'

# Cargar modelos
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Modelo de predicción de pose (HRNet)
class HRNetForPose(nn.Module):
    def __init__(self, num_keypoints=102):
        super(HRNetForPose, self).__init__()
        self.backbone = deeplabv3_resnet50(pretrained=False).backbone
        self.fc = nn.Sequential(
            nn.Conv2d(2048, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, num_keypoints, kernel_size=1, stride=1),
            nn.AdaptiveAvgPool2d((1, 1))
        )

    def forward(self, x):
        x = self.backbone(x)['out']
        x = self.fc(x)
        x = x.view(x.size(0), -1)
        return x

hrnet_model = HRNetForPose(num_keypoints=102).to(device)
hrnet_model.load_state_dict(torch.load(hrnet_model_path, map_location=device))
hrnet_model.eval()

# Modelo de análisis de secuencia (LSTM)
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=2):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        final_output = lstm_out[:, -1, :]
        out = self.fc(final_output)
        return out

input_size = 102
hidden_size = 128
output_size = 18
lstm_model = LSTMModel(input_size, hidden_size, output_size).to(device)
lstm_model.load_state_dict(torch.load(lstm_model_path, map_location=device))
lstm_model.eval()

# Modelo de clasificación de imagen (ResNet)
resnet_model = models.resnet18(pretrained=False)
num_ftrs = resnet_model.fc.in_features
resnet_model.fc = nn.Linear(num_ftrs, 18)
resnet_model.load_state_dict(torch.load(resnet_model_path, map_location=device))
resnet_model = resnet_model.to(device)
resnet_model.eval()

# Transformaciones para el modelo ResNet
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

# Conexiones para dibujar esqueleto
skeleton_connections = [
    (0, 1), (0, 2), (1, 3), (2, 4),
    (5, 6), (5, 7), (7, 9),
    (6, 8), (8, 10), (11, 12),
    (11, 13), (13, 15), (12, 14),
    (14, 16)
]

# Función para dibujar esqueletos
def draw_skeleton(pose, img_shape, color='r'):
    pose = np.array(pose)
    img_h, img_w = img_shape[:2]
    
    for (x, y, c) in pose:
        if c > 0.5:
            plt.plot(x * img_w, y * img_h, 'o', color=color, markersize=5)

    for connection in skeleton_connections:
        p1, p2 = connection
        if pose[p1][2] > 0.5 and pose[p2][2] > 0.5:
            plt.plot([pose[p1][0] * img_w, pose[p2][0] * img_w],
                     [pose[p1][1] * img_h, pose[p2][1] * img_h], color=color, linewidth=2)

# Crear el video con predicciones y esqueleto dibujado
def create_video_with_predictions(image_dir, output_video_path):
    seq_length = 10
    images = sorted(os.listdir(image_dir))
    total_images = len(images)
    if total_images < 300:
        print("No hay suficientes imágenes en el directorio para seleccionar 100 imágenes consecutivas.")
        return

    max_start_index = total_images - 300
    start_index = random.randint(0, max_start_index)
    selected_images = images[start_index:start_index + 300]

    video_writer = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 10, (512, 512))
    
    sequence = []
    for image_name in selected_images:
        image_path = os.path.join(image_dir, image_name)
        image = Image.open(image_path).resize((256, 256))
        
        pose_input = transform(image).unsqueeze(0).to(device)
        with torch.no_grad():
            pose_output = hrnet_model(pose_input)
        pose_keypoints = pose_output.view(34, 3).cpu().numpy()
        
        with torch.no_grad():
            class_output = resnet_model(pose_input)
        predicted_class = index_to_position[class_output.argmax(dim=1).item()]
        
        sequence.append(pose_keypoints.flatten())
        if len(sequence) == seq_length:
            sequence_tensor = torch.tensor(sequence).unsqueeze(0).to(device)
            with torch.no_grad():
                sequence_output = lstm_model(sequence_tensor)
            sequence_class = index_to_position[sequence_output.argmax(dim=1).item()]
            sequence.clear()
        
        # Dibujar usando Matplotlib para esqueleto
        plt.figure(figsize=(5, 5))
        plt.imshow(np.array(image))
        draw_skeleton(pose_keypoints[:17], np.array(image).shape, color='r')  
        draw_skeleton(pose_keypoints[17:], np.array(image).shape, color='b')  
        plt.axis('off')
        
        # Convertir plt a imagen para OpenCV
        plt.savefig("temp_skeleton.png", bbox_inches='tight', pad_inches=0)
        plt.close()
        
        frame = cv2.imread("temp_skeleton.png")
        frame = cv2.resize(frame, (512, 512))
        
        # Añadir texto de predicciones
        cv2.putText(frame, f"Image Class: {predicted_class}", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
        if 'sequence_class' in locals():
            cv2.putText(frame, f"Sequence Class: {sequence_class}", (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
        
        video_writer.write(frame)
    
    video_writer.release()
    os.remove("temp_skeleton.png")

# Generar el video
output_video_path = './output_predictions_video_v2.mp4'
create_video_with_predictions(image_dir, output_video_path)
print(f"Video creado y guardado en: {output_video_path}")


Video creado y guardado en: ./output_predictions_video_v2.mp4
