# OLD

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.loader import DataLoader
from torch_geometric.data import Dataset, Data
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, roc_auc_score, confusion_matrix
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# Define joints and edges
joints = [
    'PELVIS', 'SPINE_NAVAL', 'SPINE_CHEST', 'NECK', 'CLAVICLE_LEFT', 'SHOULDER_LEFT',
    'ELBOW_LEFT', 'WRIST_LEFT', 'HAND_LEFT', 'HANDTIP_LEFT', 'THUMB_LEFT',
    'CLAVICLE_RIGHT', 'SHOULDER_RIGHT', 'ELBOW_RIGHT', 'WRIST_RIGHT', 'HAND_RIGHT',
    'HANDTIP_RIGHT', 'THUMB_RIGHT', 'HIP_LEFT', 'KNEE_LEFT', 'ANKLE_LEFT',
    'FOOT_LEFT', 'HIP_RIGHT', 'KNEE_RIGHT', 'ANKLE_RIGHT', 'FOOT_RIGHT',
    'HEAD', 'NOSE', 'EYE_LEFT', 'EAR_LEFT', 'EYE_RIGHT', 'EAR_RIGHT'
]

edges = [
    ('PELVIS', 'SPINE_NAVAL'), ('SPINE_NAVAL', 'SPINE_CHEST'), ('SPINE_CHEST', 'NECK'),
    ('NECK', 'HEAD'), ('SPINE_CHEST', 'CLAVICLE_LEFT'), ('CLAVICLE_LEFT', 'SHOULDER_LEFT'),
    ('SHOULDER_LEFT', 'ELBOW_LEFT'), ('ELBOW_LEFT', 'WRIST_LEFT'), ('WRIST_LEFT', 'HAND_LEFT'),
    ('HAND_LEFT', 'HANDTIP_LEFT'), ('WRIST_LEFT', 'THUMB_LEFT'), ('SPINE_CHEST', 'CLAVICLE_RIGHT'),
    ('CLAVICLE_RIGHT', 'SHOULDER_RIGHT'), ('SHOULDER_RIGHT', 'ELBOW_RIGHT'), ('ELBOW_RIGHT', 'WRIST_RIGHT'),
    ('WRIST_RIGHT', 'HAND_RIGHT'), ('HAND_RIGHT', 'HANDTIP_RIGHT'), ('WRIST_RIGHT', 'THUMB_RIGHT'),
    ('PELVIS', 'HIP_LEFT'), ('HIP_LEFT', 'KNEE_LEFT'), ('KNEE_LEFT', 'ANKLE_LEFT'),
    ('ANKLE_LEFT', 'FOOT_LEFT'), ('PELVIS', 'HIP_RIGHT'), ('HIP_RIGHT', 'KNEE_RIGHT'),
    ('KNEE_RIGHT', 'ANKLE_RIGHT'), ('ANKLE_RIGHT', 'FOOT_RIGHT'),
    ('HEAD', 'NOSE'), ('HEAD', 'EYE_LEFT'), ('HEAD', 'EYE_RIGHT'), ('HEAD', 'EAR_LEFT'), ('HEAD', 'EAR_RIGHT')
]

joint_to_idx = {joint: idx for idx, joint in enumerate(joints)}
edge_index = torch.tensor(
    [[joint_to_idx[src], joint_to_idx[dst]] for src, dst in edges] +
    [[joint_to_idx[dst], joint_to_idx[src]] for src, dst in edges],
    dtype=torch.long
).t()

class SubjectSequenceDataset(Dataset):
    def __init__(self, dataframe):
        self.samples = []
        for _, patient_df in dataframe.groupby('patientID'):
            label = patient_df['QoR_class'].iloc[0]
            frames = []
            for _, row in patient_df.iterrows():
                node_features = [
                    [row[f'{joint}_X'], row[f'{joint}_Y'], row[f'{joint}_Z'], row['t_uniform']]
                    for joint in joints
                ]
                x = torch.tensor(node_features, dtype=torch.float)
                graph = Data(x=x, edge_index=edge_index)
                frames.append(graph)
            self.samples.append((frames, torch.tensor(label, dtype=torch.float)))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

def collate_subjects(batch):
    sequences, labels = zip(*batch)
    return list(sequences), torch.tensor(labels, dtype=torch.float)

class TemporalGCNModel(nn.Module):
    def __init__(self, in_channels, gcn_hidden, lstm_hidden):
        super().__init__()
        self.gcn1 = GCNConv(in_channels, gcn_hidden)
        self.gcn2 = GCNConv(gcn_hidden, gcn_hidden)
        self.lstm = nn.LSTM(gcn_hidden, lstm_hidden, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(lstm_hidden * 2, 1)

    def forward(self, sequences):
        device = next(self.parameters()).device
        batch_embeddings = []

        for subject in sequences:
            frame_embeddings = []
            for graph in subject:
                if isinstance(graph, (tuple, list)):
                    graph = graph[0]
                if not isinstance(graph, Data):
                    continue
                if not hasattr(graph, 'x') or graph.x is None:
                    continue
                x = F.relu(self.gcn1(graph.x.to(device), graph.edge_index.to(device)))
                x = F.relu(self.gcn2(x, graph.edge_index.to(device)))
                pooled = global_mean_pool(x, torch.zeros(x.size(0), dtype=torch.long).to(device))
                frame_embeddings.append(pooled)
            
            if len(frame_embeddings) > 0:
                sequence_tensor = torch.stack(frame_embeddings)
                batch_embeddings.append(sequence_tensor)

        if len(batch_embeddings) == 0:
            return torch.empty((0,), device=device)  # Prevent model crash

        packed = torch.nn.utils.rnn.pad_sequence(batch_embeddings, batch_first=True)
        _, (h_n, _) = self.lstm(packed)
        h_final = torch.cat((h_n[-2], h_n[-1]), dim=-1)
        out = self.fc(h_final)
        return torch.sigmoid(out).squeeze()


csv_path = r'D:\Data\NYC\KINZ\Final_data_Balanced.csv'
df = pd.read_csv(csv_path)
df = df[df['walking_speed'] == "Fast"]

train_ids, test_ids = train_test_split(df['patientID'].unique(), test_size=0.2, random_state=42)
train_df = df[df['patientID'].isin(train_ids)]
test_df = df[df['patientID'].isin(test_ids)]

train_dataset = SubjectSequenceDataset(train_df)
test_dataset = SubjectSequenceDataset(test_df)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_subjects)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_subjects)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TemporalGCNModel(in_channels=4, gcn_hidden=64, lstm_hidden=128).to(device)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

class EarlyStopping:
    def __init__(self, patience=3, min_delta=0.01):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        self.best_model = None

    def __call__(self, val_loss, model):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.best_model = model.state_dict()
        else:
            self.counter += 1
            if self.counter >= self.patience:
                model.load_state_dict(self.best_model)
                return True
        return False

early_stopping = EarlyStopping()

for epoch in range(1, 101):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for sequences, labels in train_loader:
        out = model(sequences)
        if out.numel() == 0:  # Skip invalid batch
            continue
        optimizer.zero_grad()
        loss = criterion(out.to(device), labels.to(device))
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * len(labels)
        preds = (out.detach().cpu() >= 0.5).float()
        correct += (preds == labels.cpu()).sum().item()
        total += len(labels)

    if total == 0:
        print(f"Epoch {epoch}, No valid training samples found. Skipping epoch.")
        continue

    train_acc = correct / total
    print(f"Epoch {epoch}, Train Loss: {total_loss/total:.4f}, Acc: {train_acc:.4f}")

    model.eval()
    total_loss, correct, total = 0, 0, 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for sequences, labels in test_loader:
            out = model(sequences)
            if out.numel() == 0:
                continue
            loss = criterion(out.to(device), labels.to(device))
            total_loss += loss.item() * len(labels)
            preds = (out.detach().cpu() >= 0.5).float()
            correct += (preds == labels.cpu()).sum().item()
            total += len(labels)
            all_preds.extend(preds.numpy())
            all_labels.extend(labels.numpy())

    if total == 0:
        print(f"Epoch {epoch}, No valid validation samples found. Skipping epoch.")
        continue

    val_loss = total_loss / total
    val_acc = correct / total
    print(f"Epoch {epoch}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    if early_stopping(val_loss, model):
        print("Early stopping triggered.")
        break


accuracy = accuracy_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds)
recall = recall_score(all_labels, all_preds)
precision = precision_score(all_labels, all_preds)
roc_auc = roc_auc_score(all_labels, all_preds)
tn, fp, fn, tp = confusion_matrix(all_labels, all_preds).ravel()
specificity = tn / (tn + fp)

print(f'Final Accuracy: {accuracy:.4f}, F1: {f1:.4f}, Recall: {recall:.4f}, Precision: {precision:.4f}, Specificity: {specificity:.4f}, AUC: {roc_auc:.4f}')


Epoch 1, No valid training samples found. Skipping epoch.
Epoch 2, No valid training samples found. Skipping epoch.
Epoch 3, No valid training samples found. Skipping epoch.
Epoch 4, No valid training samples found. Skipping epoch.
Epoch 5, No valid training samples found. Skipping epoch.
Epoch 6, No valid training samples found. Skipping epoch.
Epoch 7, No valid training samples found. Skipping epoch.
Epoch 8, No valid training samples found. Skipping epoch.
Epoch 9, No valid training samples found. Skipping epoch.
Epoch 10, No valid training samples found. Skipping epoch.
Epoch 11, No valid training samples found. Skipping epoch.
Epoch 12, No valid training samples found. Skipping epoch.
Epoch 13, No valid training samples found. Skipping epoch.
Epoch 14, No valid training samples found. Skipping epoch.
Epoch 15, No valid training samples found. Skipping epoch.
Epoch 16, No valid training samples found. Skipping epoch.
Epoch 17, No valid training samples found. Skipping epoch.
Epoch 

NameError: name 'all_labels' is not defined