In [37]:
import numpy as np
from data_processing import process_data
from graph import Graph
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import random_split, DataLoader

In [38]:
data = np.load("dataset_padding_new.npy",allow_pickle=True)
dataset, label_encoder = process_data(data) # transfer our data to st-gcn data
size_total = len(dataset)
size_test = int(0.3 * size_total)
size_train = size_total - size_test
train_ds, test_ds = random_split(dataset, [size_train, size_test],
                                 generator=torch.Generator().manual_seed(42))
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
test_loader  = DataLoader(test_ds,  batch_size=32, shuffle=False)

In [39]:
graph = Graph(hop_size=2)
adjacency_matrix = graph.get_adjacency_matrix()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Number of unique connected joints: 50


In [40]:
class STGCNBlock(nn.Module):
    def __init__(self, in_channel, out_channel, A, stride = 1, residual = True):
        
        super().__init__()
        self.register_buffer("A", torch.tensor(A, dtype=torch.float32))
        self.gcn = nn.Conv2d(in_channel, out_channel, kernel_size=(1,1))
        self.tcn = nn.Sequential(
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=0.3),
            nn.Conv2d(out_channel, out_channel, kernel_size=(9, 1), 
                      stride=(stride,1), padding=(4,0)),
            nn.BatchNorm2d(out_channel)
        )
        

        if not residual:
            self.residual =None
        elif in_channel == out_channel and stride == 1:
            self.residual = nn.Identity() 
        else:
            self.residual = nn.Sequential(
                nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=(stride,1)),
                nn.BatchNorm2d(out_channel)
            )

        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        A = self.A.to(x.device, non_blocking=True) 
        x_gcn = self.gcn(x)
        x_gcn = torch.einsum('nctv,kvw->nctw', x_gcn, A)

        res = 0 if self.residual is None else self.residual(x)
        x_tcn = self.tcn(x_gcn) + res

        return self.relu(x_tcn)


In [47]:
class STGCN(nn.Module):
    def __init__(self, A, num_classes):
        super().__init__()
        self.data_bn = nn.BatchNorm1d(50 * 4)  # 50 joints * 4 channels (x,y,z,c)

        self.layer1 = STGCNBlock(4, 32, A, residual=False)
        self.layer2 = STGCNBlock(32, 64, A)
        self.layer3 = STGCNBlock(64, 128, A, stride=2)
        # self.layer4 = STGCNBlock(128, 256, A, stride=2)

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        n, c, t, v = x.shape
        x = x.permute(0, 3, 1, 2).contiguous().view(n, v * c, t)
        x = self.data_bn(x)
        x = x.view(n, v, c, t).permute(0, 2, 3, 1).contiguous()

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        # x = self.layer4(x)

        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

In [48]:


model = STGCN(adjacency_matrix, num_classes=len(label_encoder.classes_)).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)

In [None]:
num_epochs = 100

for epoch in range(num_epochs):
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for batch_x, batch_y in train_loader:
        batch_x, batch_y = batch_x.to(device), batch_y.long().to(device)

        optimizer.zero_grad()
        out = model(batch_x)
        loss = criterion(out, batch_y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        preds = out.argmax(dim=1)              # predicted class indices
        correct += (preds == batch_y).sum().item()
        total   += batch_y.size(0)

    acc = correct / total                      # epoch accuracy
    print(f"Epoch {epoch+1}/{num_epochs}  Loss: {running_loss:.4f}  Acc: {acc:.2%}")

    model.eval()
    test_correct, test_total = 0, 0
    with torch.no_grad():
        for x_test, y_test in test_loader:
            x_test, y_test = x_test.to(device), y_test.long().to(device)
            logits = model(x_test)
            preds  = logits.argmax(1)
            test_correct += (preds == y_test).sum().item()
            test_total   += y_test.size(0)

    acc_test = test_correct / test_total
    print(f"                >>> Test‑acc {acc_test:.2%}")
    model.train()            # switch back for next epoch


Epoch 1/100  Loss: 27.5815  Acc: 4.12%
                >>> Test‑acc 1.92%
Epoch 2/100  Loss: 27.0581  Acc: 3.70%
                >>> Test‑acc 1.92%
Epoch 3/100  Loss: 26.8140  Acc: 6.17%
                >>> Test‑acc 1.92%
Epoch 4/100  Loss: 26.6922  Acc: 4.94%
                >>> Test‑acc 1.92%
Epoch 5/100  Loss: 26.4827  Acc: 7.82%
                >>> Test‑acc 2.88%
Epoch 6/100  Loss: 26.2804  Acc: 4.94%
                >>> Test‑acc 3.85%
Epoch 7/100  Loss: 26.2277  Acc: 8.23%
                >>> Test‑acc 3.85%
Epoch 8/100  Loss: 26.0220  Acc: 10.70%
                >>> Test‑acc 3.85%
Epoch 9/100  Loss: 26.1302  Acc: 9.47%
                >>> Test‑acc 5.77%
Epoch 10/100  Loss: 25.9199  Acc: 8.64%
                >>> Test‑acc 3.85%
Epoch 11/100  Loss: 25.9527  Acc: 10.70%
                >>> Test‑acc 1.92%
Epoch 12/100  Loss: 25.7572  Acc: 8.64%
                >>> Test‑acc 2.88%
Epoch 13/100  Loss: 25.6311  Acc: 11.11%
                >>> Test‑acc 4.81%
Epoch 14/100  Loss: 25.6787  Ac