In [1]:
import os
import json

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F


In [2]:
# obtain skeletons + label

DSET_PATH = r"./"
TRAIN_PATH = os.path.join(DSET_PATH, "train_skeletons")
TEST_PATH = os.path.join(DSET_PATH, "test_skeletons")

def load_skel_data(path):
    skeletons = torch.load(os.path.join(path, "skeletons_tensor.pt"))  
    with open(os.path.join(path, "skeleton_annots.json"), "r") as f:
        metadata = json.load(f)
    return skeletons, metadata

train_skels, train_metadata = load_skel_data(TRAIN_PATH)
test_skels, test_metadata = load_skel_data(TEST_PATH)

# minus one to make 0-indexed
train_labels = torch.tensor(
    [sample["label_id"] - 1 for sample in train_metadata["samples"]],
    dtype=torch.long,
)
test_labels = torch.tensor(
    [sample["label_id"] - 1 for sample in test_metadata["samples"]],
    dtype=torch.long,
)


# filter out first 100 training datapoints since they aren't labelled well
train_skels = train_skels[100:]
train_labels = train_labels[100:]

num_classes = len(torch.unique(train_labels))

# sanity check stuff
print("Labels shape:", train_labels.shape)
print("Skeletons shape:", train_skels.shape)
print("Num classes:", num_classes)

Labels shape: torch.Size([3781])
Skeletons shape: torch.Size([3781, 210, 21, 3])
Num classes: 14


In [3]:
train_ds = TensorDataset(train_skels, train_labels)
test_ds = TensorDataset(test_skels, test_labels)

batch_size = 32
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

print("Train samples:", len(train_ds))
print("Test samples:", len(test_ds))

Train samples: 3781
Test samples: 1556


In [4]:
class MediapipeHandGestureCNN(nn.Module):
    def __init__(self, num_classes: int, in_channels: int = 3):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)

        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)

        self.dropout = nn.Dropout(p=0.5)
        self.fc = nn.Linear(256, num_classes)

        self.pool_time = nn.MaxPool2d(kernel_size=(1, 2))  # (joints stay, time halves)

    def forward(self, x):
        """
        x: (B, C, H, W) = (batch, channels, joints, time)
        """
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.pool_time(x)  # shrink time

        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.pool_time(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = F.relu(x)
        x = self.pool_time(x)

        x = self.conv4(x)
        x = self.bn4(x)
        x = F.relu(x)
        x = self.pool_time(x)

        x = F.adaptive_avg_pool2d(x, output_size=1) 
        x = x.view(x.size(0), -1)                   

        x = self.dropout(x)
        logits = self.fc(x)                       
        return logits

In [None]:
device = "cpu"
model = MediapipeHandGestureCNN(num_classes=14).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
epochs = 15

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        x = x.permute(0, 3, 2, 1)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)

        loss.backward()
        optimizer.step()

        # Metrics
        total_loss += loss.item() * x.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

    avg_loss = total_loss / total
    accuracy = correct / total

    return avg_loss, accuracy


def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss, correct, total = 0, 0, 0

    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            x = x.permute(0, 3, 2, 1)

            logits = model(x)
            loss = criterion(logits, y)

            total_loss += loss.item() * x.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)

    avg_loss = total_loss / total
    accuracy = correct / total

    return avg_loss, accuracy

train_accs = []
test_accs = []
for epoch in range(1, epochs + 1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    train_accs.append(train_acc)
    test_accs.append(test_acc)
    print(f"Epoch {epoch:02d}: "
          f"Train Loss={train_loss:.4f}, Train Acc={train_acc:.3f}, "
          f"Test Loss={test_loss:.4f}, Test Acc={test_acc:.3f}")

Epoch 01: Train Loss=1.7618, Train Acc=0.400, Test Loss=1.8588, Test Acc=0.389
Epoch 02: Train Loss=1.3985, Train Acc=0.490, Test Loss=1.5342, Test Acc=0.474
Epoch 03: Train Loss=1.1654, Train Acc=0.578, Test Loss=1.3143, Test Acc=0.525
Epoch 04: Train Loss=1.0201, Train Acc=0.630, Test Loss=1.1200, Test Acc=0.655
Epoch 05: Train Loss=0.8852, Train Acc=0.687, Test Loss=1.0844, Test Acc=0.645
Epoch 06: Train Loss=0.7511, Train Acc=0.742, Test Loss=1.0579, Test Acc=0.683
Epoch 07: Train Loss=0.6362, Train Acc=0.793, Test Loss=1.0920, Test Acc=0.679
Epoch 08: Train Loss=0.5339, Train Acc=0.820, Test Loss=1.0922, Test Acc=0.719
Epoch 09: Train Loss=0.5186, Train Acc=0.820, Test Loss=0.7163, Test Acc=0.801
Epoch 10: Train Loss=0.4485, Train Acc=0.848, Test Loss=0.6742, Test Acc=0.823


In [None]:
epochs = np.arange(epochs)
plt.figure(figsize=(8, 5))
plt.plot(epochs, train_accs, label='Train Accuracy')
plt.plot(epochs, test_accs, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training vs Testing Accuracy')
plt.legend()
plt.grid(True)
plt.show()