In [63]:
import sys
import os
sys.path.append(os.path.abspath(".."))

In [64]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as gnn
from core.dataset import YogaDataset
from torch.utils.data import DataLoader
from torch_geometric.data import Data
import torch.optim as optim


In [66]:
# ======================== MÔ HÌNH GCN ========================
import torch_geometric.nn as gnn

class YogaGCN(nn.Module):
    def __init__(self, in_channels=3, hidden_dim=64, num_classes=4):
        super(YogaGCN, self).__init__()
        self.conv1 = gnn.GCNConv(in_channels, hidden_dim)
        self.conv2 = gnn.GCNConv(hidden_dim, hidden_dim)
        self.conv3 = gnn.GCNConv(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = self.conv3(x, edge_index).relu()

        # Pooling giữ đúng batch_size
        x = gnn.global_mean_pool(x, batch)  # (batch_size, hidden_dim)

        x = self.fc(x)  # (batch_size, num_classes)
        return x


# ======================== EDGE INDEX (Mediapipe) ========================
def get_edge_index():
    """
    Trả về ma trận kề (edge_index) cho 33 keypoints của Mediapipe.
    """
    edges = [
        (0, 1), (1, 2), (2, 3), (3, 7),  # Tay trái
        (0, 4), (4, 5), (5, 6), (6, 8),  # Tay phải
        (9, 10), (11, 12),  # Hông
        (11, 13), (13, 15), (15, 17), (15, 19), (15, 21),  # Chân trái
        (12, 14), (14, 16), (16, 18), (16, 20), (16, 22),  # Chân phải
        (11, 23), (12, 24), (23, 24),  # Kết nối hông
        (23, 25), (25, 27), (27, 29), (29, 31),  # Chân trái
        (24, 26), (26, 28), (28, 30), (30, 32)   # Chân phải
    ]
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()  # (2, num_edges)
    return edge_index

# ======================== HÀM TRAINING ========================
def train_gcn(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    edge_index = get_edge_index().to(device)  # Edge index không thay đổi

    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)

        optimizer.zero_grad()

        # Giữ đúng batch size
        batch_size, num_frames, num_keypoints, keypoint_dim = X_batch.shape
        X_batch = X_batch.view(batch_size * num_frames * num_keypoints, keypoint_dim)  # (13200, 3)

        # Tạo batch index đúng (1 batch cho mỗi video)
        batch = torch.arange(batch_size, device=device).repeat_interleave(num_frames * num_keypoints)

        # Đưa vào mô hình GCN
        outputs = model(X_batch, edge_index, batch)  # (batch_size, num_classes)

        # Kiểm tra outputs.shape
#         print(f"DEBUG - outputs shape: {outputs.shape}, y_batch shape: {y_batch.shape}")

        # Tính loss
        loss = criterion(outputs, y_batch.long())  # Đảm bảo y_batch có dtype phù hợp
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == y_batch).sum().item()
        total += y_batch.size(0)

    acc = 100 * correct / total
    return total_loss / len(train_loader), acc

# ======================== TRAINING ========================
# Dataset và DataLoader
json_folder = "../data/keypoints/public_data"
dataset = YogaDataset(json_folder, max_frames=100)
batch_size = 4
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Khởi tạo mô hình và optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = YogaGCN(in_channels=3, hidden_dim=64, num_classes=4).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Train thử
num_epochs = 30
for epoch in range(num_epochs):
    loss, acc = train_gcn(model, dataloader, optimizer, criterion, device)
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {loss:.4f} - Accuracy: {acc:.2f}%")


Label map: {'Garland_Pose': 0, 'Happy_Baby_Pose': 1, 'Head_To_Knee_Pose': 2, 'Lunge_Pose': 3}
Epoch 1/30 - Loss: 1.3847 - Accuracy: 22.68%
Epoch 2/30 - Loss: 1.3686 - Accuracy: 43.30%
Epoch 3/30 - Loss: 1.3510 - Accuracy: 51.55%
Epoch 4/30 - Loss: 1.3122 - Accuracy: 45.36%
Epoch 5/30 - Loss: 1.2323 - Accuracy: 62.89%
Epoch 6/30 - Loss: 1.1451 - Accuracy: 62.89%
Epoch 7/30 - Loss: 1.0249 - Accuracy: 62.89%
Epoch 8/30 - Loss: 0.9020 - Accuracy: 70.10%
Epoch 9/30 - Loss: 0.7613 - Accuracy: 81.44%
Epoch 10/30 - Loss: 0.7253 - Accuracy: 78.35%
Epoch 11/30 - Loss: 0.6904 - Accuracy: 71.13%
Epoch 12/30 - Loss: 0.6374 - Accuracy: 81.44%
Epoch 13/30 - Loss: 0.6349 - Accuracy: 79.38%
Epoch 14/30 - Loss: 0.6104 - Accuracy: 78.35%
Epoch 15/30 - Loss: 0.5540 - Accuracy: 80.41%
Epoch 16/30 - Loss: 0.5277 - Accuracy: 80.41%
Epoch 17/30 - Loss: 0.5224 - Accuracy: 82.47%
Epoch 18/30 - Loss: 0.5803 - Accuracy: 80.41%
Epoch 19/30 - Loss: 0.5190 - Accuracy: 84.54%
Epoch 20/30 - Loss: 0.5523 - Accuracy: 82