In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix, f1_score
import itertools

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", device)

LABELS = ["walk", "stand", "sit", "bend"] 


class PoseDataset(Dataset):
    def __init__(self, X, Y, label_to_idx=None):
        """
        X: numpy array of shape (N, 32, 4)
        Y: numpy array of strings (e.g., "walk", "sit", ...)
        """

       
        if isinstance(X, np.ndarray):
            X = X.astype(np.float32)
        else:
            raise ValueError("X must be a numpy array")

       
        if label_to_idx is None:
            # 自动生成标签映射
            uniq = sorted(list(set(Y.tolist())))
            label_to_idx = {lb: i for i, lb in enumerate(uniq)}

        # 保存映射
        self.label_to_idx = label_to_idx

        # 将字符串标签转换为 int
        Y_idx = np.array([label_to_idx[y] for y in Y], dtype=np.int64)

        # 转成 tensor 
        self.X = torch.tensor(X, dtype=torch.float32)   # (N, 32, 4)
        self.Y = torch.tensor(Y_idx, dtype=torch.long)  # (N,)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]


class STGCN(nn.Module):
    def __init__(self, num_class=4):
        super().__init__()
        self.gcn = nn.Sequential(
            nn.Conv2d(4, 64, kernel_size=(1,1)),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=(1,1)),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=(1,1)),
            nn.ReLU(),
        )
        self.fc = nn.Linear(256*32, num_class)

    def forward(self, x):  # x: (N,32,4)
        x = x.permute(0, 2, 1)        # -> (N,4,32)
        x = x.unsqueeze(-1)          # -> (N,4,32,1)
        x = self.gcn(x)              # -> (N,256,32,1)
        x = x.view(x.size(0), -1)    # -> (N,256*32)
        return self.fc(x)



# 加载 stgcn_dataset

data = np.load("data/stgcn_dataset.npz", allow_pickle=True)
trainX, trainY = data["trainX"], data["trainY"]
valX, valY     = data["valX"],   data["valY"]
testX, testY   = data["testX"],  data["testY"]

print("Shapes:")
print("Train:", trainX.shape, trainY.shape)
print("Val:  ", valX.shape, valY.shape)
print("Test: ", testX.shape, testY.shape)

train_loader = DataLoader(PoseDataset(trainX, trainY), batch_size=64, shuffle=True)
val_loader   = DataLoader(PoseDataset(valX,   valY),   batch_size=64)
test_loader  = DataLoader(PoseDataset(testX,  testY),  batch_size=64)



# Training Utils

def evaluate(model, loader, criterion):
    model.eval()
    total, correct, loss_sum = 0, 0, 0
    with torch.no_grad():
        for X, Y in loader:
            X, Y = X.to(device), Y.to(device)
            out = model(X)
            loss_sum += criterion(out, Y).item()
            pred = out.argmax(1)
            correct += (pred == Y).sum().item()
            total += len(Y)
    return loss_sum/len(loader), correct/total



#  Train

model = STGCN(num_class=4).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

train_loss_list, val_loss_list = [], []
train_acc_list,  val_acc_list  = [], []

EPOCH = 20

print("\n===== Training =====")
for epoch in range(1, EPOCH+1):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for X, Y in train_loader:
        X, Y = X.to(device), Y.to(device)
        optimizer.zero_grad()
        out = model(X)
        loss = criterion(out, Y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pred = out.argmax(1)
        correct += (pred == Y).sum().item()
        total += len(Y)

    # validation
    val_loss, val_acc = evaluate(model, val_loader, criterion)

    train_loss_list.append(total_loss/len(train_loader))
    val_loss_list.append(val_loss)
    train_acc_list.append(correct/total)
    val_acc_list.append(val_acc)

    print(f"Epoch {epoch} | Train Loss {train_loss_list[-1]:.4f} Acc {train_acc_list[-1]:.4f} "
          f"| Val Loss {val_loss:.4f} Acc {val_acc:.4f}")



#  Test Evaluation

model.eval()
all_pred, all_true = [], []

with torch.no_grad():
    for X, Y in test_loader:
        X = X.to(device)
        pred = model(X).argmax(1).cpu().numpy()
        all_pred.extend(pred)
        all_true.extend(Y.numpy())

print("\n===== TEST RESULTS =====")
print(classification_report(all_true, all_pred, target_names=LABELS))
print("Test Macro F1:", f1_score(all_true, all_pred, average="macro"))



#  Confusion Matrix

cm = confusion_matrix(all_true, all_pred)

plt.figure(figsize=(6,5))
plt.imshow(cm, cmap="Blues")
plt.title("Confusion Matrix")
plt.colorbar()
plt.xticks(range(4), LABELS)
plt.yticks(range(4), LABELS)

for i, j in itertools.product(range(4), range(4)):
    plt.text(j, i, cm[i, j], ha='center', va='center')

plt.show()



#  Loss & Accuracy Curves

plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(train_loss_list, label="Train Loss")
plt.plot(val_loss_list, label="Val Loss")
plt.legend(); plt.title("Loss Curve")

plt.subplot(1,2,2)
plt.plot(train_acc_list, label="Train Acc")
plt.plot(val_acc_list, label="Val Acc")
plt.legend(); plt.title("Accuracy Curve")
plt.show()