In [1]:
import os
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
from torch_geometric.nn import GCNConv
import torch.optim as optim
from torch_geometric.data import Data

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from Skeleton_EMG_Dataset import SkeletonEMGDataset

In [2]:
# dataset = SkeletonEMGDataset(csv_file="emg_skel.csv", window_size=10)

# skel, label = dataset.__getitem__(12)
# print(skel[0])

In [33]:
class Skeleton3DConvNet(nn.Module):
    def __init__(self):
        super(Skeleton3DConvNet, self).__init__()
        
        # Define the layers of the network
        self.conv1 = nn.Conv3d(in_channels=1, out_channels=32, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.bn1 = nn.BatchNorm3d(32)
        self.relu1 = nn.ReLU(inplace=True)
        
        self.conv2 = nn.Conv3d(in_channels=32, out_channels=64, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.bn2 = nn.BatchNorm3d(64)
        self.relu2 = nn.ReLU(inplace=True)
        
        self.conv3 = nn.Conv3d(in_channels=64, out_channels=128, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.bn3 = nn.BatchNorm3d(128)
        self.relu3 = nn.ReLU(inplace=True)
        
        self.pool = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
        
        # 15360
        self.fc1 = nn.LazyLinear(4096)
        self.relu4 = nn.ReLU(inplace=True)
        
        self.fc2 = nn.LazyLinear(1024)
        self.relu5 = nn.ReLU(inplace=True)
        
        self.fc2 = nn.LazyLinear(256)
        self.relu5 = nn.ReLU(inplace=True)
        
        self.fc3 = nn.LazyLinear(8)
        
    def forward(self, x):
        # Input shape: (batch_size, sequence_length, num_joints, 3)
        x = x.unsqueeze(1)  # Add a channel dimension
        
        # Apply convolutional layers
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        
        # Apply pooling layer
        x = self.pool(x)
        
        # Flatten the output for the fully connected layers
        x = x.view(x.size(0), -1)
        
        # Apply fully connected layers
        x = self.fc1(x)
        x = self.relu4(x)
        
        x = self.fc2(x)
        x = self.relu5(x)
        
        x = self.fc3(x)
        
        return x


def create_edge_index():
    coco_edges = [
        (0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6), (5, 6), 
        (5, 7), (7, 9), (6, 8), (8, 10), (5, 11), (6, 12), 
        (11, 12), (11, 13), (13, 15), (12, 14), (14, 16)
    ]
    row, col = torch.tensor([(i, j) for (i, j) in coco_edges]).t()
    return torch.stack([row, col], dim=0)


class GCNLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCNLayer, self).__init__()
        self.gcn = GCNConv(in_channels, out_channels)

    def forward(self, x, edge_index):
        return self.gcn(x, edge_index)
    
class GCNNSkeleton(torch.nn.Module):
    def __init__(self, num_features, hidden_channels, output_channels, num_keypoints):
        super(GCNNSkeleton, self).__init__()
        
        self.num_keypoints = num_keypoints
        
        self.gcn1 = GCNConv(num_features, hidden_channels)
        self.gcn2 = GCNConv(hidden_channels, hidden_channels)
        self.gcn3 = GCNConv(hidden_channels, hidden_channels)
        self.gcn4 = GCNConv(hidden_channels, output_channels)
        self.relu = nn.ReLU()
        
        
        self.fc = nn.Sequential(
            nn.LazyLinear(1024),
            nn.ReLU(),
            nn.LazyLinear(512),
            nn.ReLU(),
            nn.LazyLinear(128),
            nn.ReLU(),
            nn.LazyLinear(8),
        )

    def forward(self, x, edge_index):
        x = self.gcn1(x, edge_index)
        x = self.relu(x)
        x = self.gcn2(x, edge_index)
        x = self.relu(x)
        x = self.gcn3(x, edge_index)
        x = self.relu(x)
        x = self.gcn4(x, edge_index)
        
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        
        return x

In [34]:
# Create dataset and data loader
dataset = SkeletonEMGDataset(csv_file="emg_skel.csv", window_size=10)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2)

edge_index = torch.tensor([
    [0, 1], [0, 2], [1, 3],
    [2, 4], [5, 6], [5, 7],
    [6, 8], [7, 9], [8, 10],
    [11, 12], [11, 13], [12, 14],
    [13, 15], [14, 16], [5, 6], 
    [11, 12],
], dtype=torch.long).t().contiguous().to(device)

print(dataset.cumulative_num_samples)

[1789 3578 3882 5671 7460]


In [35]:
# Create model and optimizer
input_channels = 3
hidden_channels = 16
output_channels = 32
num_keypoints = 17

model = GCNNSkeleton(input_channels, hidden_channels, output_channels, num_keypoints).to(device)

# model = Skeleton3DConvNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Define loss function
criterion = nn.MSELoss()



In [36]:
x = None
for i, b in dataloader:
    x = i[0].view(1, dataset.window_size, 17, 3).to(device)
    break
    
model(x, edge_index)

tensor([[-0.0742,  0.0545,  0.0114, -0.0569,  0.0626, -0.0581,  0.0722, -0.0531]],
       device='cuda:0', grad_fn=<AddmmBackward0>)

In [37]:
# Train model
for epoch in range(30):
    running_loss = 0.0
    
    for data in tqdm(dataloader):
        
        # Get inputs and labels
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward + backward + optimize
        outputs = model(inputs, edge_index)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        
    print(f"Loss: {running_loss / len(dataloader)}")
    running_loss = 0.0

  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.015403517983599097


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.013779491520462891


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.013738997415918061


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.013710421287160143


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.013675392048162783


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.013641557576628322


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.013573401249372043


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.013413418546064287


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.013062400719485221


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.012522246131402815


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.012180238795013


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.011887759097620972


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.011597982544101711


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.011512998674606156


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.011417410161314357


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.011369898405849423


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.011255224896037681


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.011191919254950987


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.011050021801239405


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.011025519046582218


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.011050070349413615


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.010939356409267992


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.010896693875328598


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.010775449948433118


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.010764812967047477


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.010578376105707934


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.01057279157714966


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.01045784382467978


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.010415021644539049


  0%|          | 0/117 [00:00<?, ?it/s]

Loss: 0.01023094212779632


In [12]:
for data, emg in dataloader:
    
    data = data[0].view(1, dataset.window_size, 17, 3).to(device)
    emg = emg[0].view(1, 8).to(device)
    
    with torch.no_grad():
        outputs = model(data, edge_index)
        
    loss = criterion(outputs, labels)
        
    print(emg, "\n", outputs, "\n", loss.item())
    
    break

tensor([[0.0074, 0.2218, 0.2056, 0.0237, 0.0940, 0.0411, 0.2152, 0.2272]],
       device='cuda:0') 
 tensor([[0.0594, 0.1774, 0.1668, 0.2104, 0.1406, 0.1130, 0.2412, 0.2231]],
       device='cuda:0') 
 0.014454230666160583
