In [None]:
from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/MyDrive/IndividualProject/

In [None]:
import os
import glob
import numpy as np
import pickle as pkl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import (precision_score, recall_score, f1_score, 
                             accuracy_score, confusion_matrix, classification_report)

import matplotlib.pyplot as plt
import seaborn as sns

from ST_GCN.st_gcn import Model as STGCN
from WO_GMA.wogma import WOGMA
from SkateFormer.SkateFormer import SkateFormer_ as SkateFormer

In [None]:
def compute_effective_number_weights(labels, beta=0.9999):
    """
    Compute class weights based on the effective number of samples.

    Args:
        labels (array-like): Array or list of class labels.
        beta (float): Hyperparameter in [0, 1). Typically very close to 1.

    Returns:
        dict: Mapping from class label to computed weight.
    """
    classes, counts = np.unique(labels, return_counts=True)
    weights = {}
    for cls, count in zip(classes, counts):
        effective_num = 1.0 - beta ** count
        weights[cls] = (1.0 - beta) / effective_num if effective_num != 0 else 0.0
    # Optional: normalize weights so that their sum is 1
    weight_sum = sum(weights.values())
    weights = {cls: float(w / weight_sum) for cls, w in weights.items()}
    return weights


In [None]:
def plot_confuse_matrix(y_true, y_pred, classes):
    """
    Plot confusion matrix using seaborn.
    """
    cm = confusion_matrix(y_true=y_true, y_pred=y_pred)
    plt.figure(figsize=(5, 4))
    sns.heatmap(cm, annot=True, cmap="Blues", xticklabels=classes, yticklabels=classes)
    plt.xlabel('Predicted label')
    plt.ylabel('True label')
    plt.title('Confusion Matrix')
    plt.show()

In [None]:
class SkeletalDataset(Dataset):
    """
    A custom Dataset for 2D skeletal data.

    Each sample corresponds to one participant’s bag:
      - bag: Tensor of shape (num_segments, 6, 100, 5, 1)
      - label: Tensor of shape (7,) containing more information.
    """
    def __init__(self, bag_data_list, labels, transform=None):
        """
        Args:
            bag_data_list (list): List where each element is a bag (can be a numpy array or a torch.Tensor).
            labels (list): Each label is an array-like of shape (9,).
            transform (callable, optional): Optional transform to be applied on a bag.
        """
        assert len(bag_data_list) == len(labels), "Data and labels must have the same length."
        self.bag_data_list = bag_data_list
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.bag_data_list)

    def __getitem__(self, idx):
        bag = self.bag_data_list[idx]
        # Convert bag to torch tensor if it isn't already.
        if not isinstance(bag, torch.Tensor):
            bag = torch.tensor(bag, dtype=torch.float)
        label = self.labels[idx]
        # Convert label to tensor.
        label = torch.tensor(label, dtype=torch.float)
        if self.transform:
            bag = self.transform(bag)
        return bag, label

def collate_fn(batch):
    """
    Custom collate function for a batch_size of 1.
    It takes the single sample in the batch and ensures that the label has a batch dimension.
    """
    bag, label = batch[0]
    if label.dim() == 1:
        label = label.unsqueeze(0)
    return bag, label

In [None]:
def load_pretrained_model(model_type='st_gcn', checkpoint_path=None, device='cpu'):
    """
    Construct the ST-GCN architecture, then load your pretrained weights.
    """
    if model_type == 'st_gcn':
        model = STGCN(
            num_class=NUM_CLASS,
            in_channels=IN_CHANNEL,
            graph_args={'strategy': 'spatial'},
            dropout=DROPOUT,
            edge_importance_weighting=True
        )
    elif model_type == 'ctr_gcn':
        model = CTR_GCN(
            graph_args={'strategy': 'spatial'}, 
            drop_out=DROPOUT, 
            in_channels=IN_CHANNEL, 
            num_class=NUM_CLASS).to(device)
    elif model_type == 'skateformer':
        model = Skateformer(
            num_frames=64,
        )
    else:
        raise ValueError(f"Invalid model type: {model_type}")
            
    # Load checkpoint
    ckpt = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(ckpt)
    model.to(device)
    model.eval()
    return model

In [None]:
def train_mil(epoch, model, loader, device, optimizer, scheduler=None, disease='dystonia', verbose=True):
    """
    Run one epoch of MIL training. Return (avg_loss, accuracy).
    
    label_type: 'dystonia' => label[:,3], 'choreoathetosis' => label[:,4].
    """
    model.train()
    if scheduler and verbose:
        print("Learning Rate:", scheduler.get_last_lr()[0])
    else:
        # If no scheduler was given, we can directly query the optimizer:
        for param_group in optimizer.param_groups:
            if verbose:
                print("Learning Rate:", param_group['lr'])
    
    running_loss = 0.0
    total_count = 0
    correct = 0

    for bag, label in loader:
        bag = bag.to(device)
        if disease == 'dystonia_duration':
            video_label = label[:, 3].long().item()
        elif disease == 'dystonia_amplitude':
            video_label = label[:, 4].long().item()
        elif disease == 'choreoathetosis_duration':
            video_label = label[:, 5].long().item()
        elif disease == 'choreoathetosis_amplitude':
            video_label = label[:, 6].long().item()
        elif disease == 'dystonia':
            video_label = label[:, 7].long().item()
        elif disease == 'choreoathetosis':
            video_label = label[:, 8].long().item()
        else:
            raise ValueError("Invalid label_type")

        out = model(bag, gt_label=video_label)  # returns dict with "loss_total" etc.
        loss = out["loss_total"]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        total_count += 1

        # For accuracy, look at out["video_scores_oamb"] (or whichever aggregator you prefer)
        scores = out["video_scores_oamb"]  # shape (2,)
        probs = F.softmax(scores, dim=0)
        pred = 1 if probs[1] > 0.5 else 0
        if pred == video_label:
            correct += 1

    if scheduler:
        scheduler.step()

    avg_loss = running_loss / total_count if total_count else 0
    accuracy = correct / total_count if total_count else 0
    if verbose:
        print(f"Epoch {epoch+1} => Train Loss: {avg_loss:.4f}, Train Acc: {accuracy:.4f}")
    return avg_loss, accuracy

def test_mil(epoch, model, loader, device, disease='dystonia', verbose=True):
    """
    Run validation/test pass. Return (avg_loss, accuracy, f1, recall, precision, preds, labels).
    """
    model.eval()
    running_loss = 0.0
    total_count = 0
    correct = 0

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for bag, label in loader:
            bag = bag.to(device)
            if disease == 'dystonia_duration':
                video_label = label[:, 3].long().item()
            elif disease == 'dystonia_amplitude':
                video_label = label[:, 4].long().item()
            elif disease == 'choreoathetosis_duration':
                video_label = label[:, 5].long().item()
            elif disease == 'choreoathetosis_amplitude':
                video_label = label[:, 6].long().item()
            elif disease == 'dystonia':
                video_label = label[:, 7].long().item()
            elif disease == 'choreoathetosis':
                video_label = label[:, 8].long().item()
            else:
                raise ValueError("Invalid label_type")

            out = model(bag, gt_label=video_label)
            loss = out["loss_total"]
            running_loss += loss.item()
            total_count += 1

            scores = out["video_scores_oamb"]  # shape (2,)
            probs = F.softmax(scores, dim=0)
            pred = 1 if probs[1] > 0.5 else 0

            all_preds.append(pred)
            all_labels.append(video_label)
            if pred == video_label:
                correct += 1

    avg_loss = running_loss / total_count if total_count else 0
    accuracy = correct / total_count if total_count else 0

    # Additional metrics
    precision = precision_score(all_labels, all_preds, average='macro', zero_division=True)
    recall    = recall_score(all_labels, all_preds, average='macro', zero_division=True)
    f1        = f1_score(all_labels, all_preds, average='macro', zero_division=True)
    if verbose:
        print(f"Epoch {epoch+1} => Val Loss: {avg_loss:.4f}, Val Acc: {accuracy:.4f}, F1: {f1:.4f}")
    return avg_loss, accuracy, f1, recall, precision, all_preds, all_labels

In [None]:
def main_mil(fold, subject_id, train_loader, val_loader, pretrained_model_path, model_type='st_gcn', disease='dystonia', verbose=True):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 1) Load ST-GCN
    pretrained = load_pretrained_model(model_type=model_type, pretrained_model_path=pretrained_model_path, device=device)
    # Freeze ST-GCN
    for param in pretrained.parameters():
        param.requires_grad = False

    # 2) Compute class weights for MIL
    #    Assume we use the same approach as your ST-GCN code
    train_labels = np.array(train_loader.dataset.labels)
    if disease == 'dystonia_duration':
        ld = train_labels[:, 3]
    elif disease == 'dystonia_amplitude':
        ld = train_labels[:, 4]
    elif disease == 'choreoathetosis_duration':
        ld = train_labels[:, 5]
    elif disease == 'choreoathetosis_amplitude':
        ld = train_labels[:, 6]
    elif disease == 'dystonia':
        ld = train_labels[:, 7]
    elif disease == 'choreoathetosis':
        ld = train_labels[:, 8]
        raise ValueError(f"Invalid disease type {disease}")
    
    weights_dict = compute_effective_number_weights(ld)
    # Typically, class 0 => weights_dict[0], class 1 => weights_dict[1]
    weight_vec = [weights_dict.get(0, 1.0), weights_dict.get(1, 1.0)]
    print("Class weights:", weight_vec)

    # 3) Build WOGMA
    model = WOGMA(
        pretrained_model=pretrained,
        feature_dim=32,
        hidden_dim=128,
        num_classes=2,
        top_k_ratio=8,
        theta_class=0.4,
        theta_score=0.3,
        class_weights=weight_vec
    ).to(device)

    optimizer = optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=1e-4,
        weight_decay=1e-4
    )

    # 4) Train loop
    num_epochs = 100
    best_score = -1
    best_model_weights = None
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-7)

    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    val_f1s = []

    for epoch in range(num_epochs):
        # ---- Train ----
        tr_loss, tr_acc = train_mil(
            epoch, model, train_loader, device, optimizer, scheduler, disease=disease, verbose=verbose
        )
        train_losses.append(tr_loss)
        train_accs.append(tr_acc)

        # ---- Validate ----
        val_loss, val_acc, f1, recall, precision, preds, labels = test_mil(
            epoch, model, val_loader, device, disease=disease, verbose=verbose
        )
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        val_f1s.append(f1)

        # Weighted metric for best model
        # e.g., 0.6 * f1 + 0.4 * val_acc
        val_metric = 0.6 * f1 + 0.4 * val_acc
        if val_metric > best_score:
            best_score = val_metric
            best_model_weights = model.state_dict().copy()

    # 5) Save best model
    out_dir = f"./Data/6-leave_one_out/subject{subject_id}/fold_{fold}/{disease}/{model_type}/mil"
    os.makedirs(out_dir, exist_ok=True)
    best_model_path = os.path.join(out_dir, "best_mil_model.pth")

    if best_model_weights is not None:
        torch.save(best_model_weights, best_model_path)
        print(f"[INFO] Best MIL model saved => {best_model_path}")

    # Plot training curves (optional)
    plt.figure()
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.title(f"Subject {subject_id} - Fold {fold} MIL Training Loss")
    plt.legend()
    plt.show()

    plt.figure()
    plt.plot(train_accs, label='Train Acc')
    plt.plot(val_accs, label='Val Acc')
    plt.title(f"Subject {subject_id} - Fold {fold} MIL Accuracy")
    plt.legend()
    plt.show()

    return best_model_path, weight_vec

In [None]:
def inference_mil(best_mil_model_path, stgcn_checkpoint_path, test_loader, device, weight_vec, model_type='st_gcn', disease='dystonia'):
    """
    Load best model + pretrained ST-GCN, run inference on test_loader.
    Returns (all_preds, all_labels).
    """
    # 1) Load ST-GCN
    stgcn_pretrained = load_pretrained_model(model_type, stgcn_checkpoint_path, device=device)
    for param in stgcn_pretrained.parameters():
        param.requires_grad = False

    # 2) We'll assume you stored the class weights somewhere or you can pass [1,1] if unknown
    #    Or if you want, read from file. For simplicity, let's do uniform here:

    model = WOGMA(
        stgcn_pretrained_model=stgcn_pretrained,
        feature_dim=32,
        hidden_dim=128,
        num_classes=2,
        top_k_ratio=8,
        theta_class=0.4,
        theta_score=0.3,
        class_weights=weight_vec
    ).to(device)

    # Load best weights
    ckpt = torch.load(best_mil_model_path, map_location=device)
    model.load_state_dict(ckpt)
    model.eval()

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for bag, label in test_loader:
            bag = bag.to(device)
            if disease == 'dystonia_duration':
                video_label = label[:, 3].long().item()
            elif disease == 'dystonia_amplitude':
                video_label = label[:, 4].long().item()
            elif disease == 'choreoathetosis_duration':
                video_label = label[:, 5].long().item()
            elif disease == 'choreoathetosis_amplitude':
                video_label = label[:, 6].long().item()
            elif disease == 'dystonia':
                video_label = label[:, 7].long().item()
            elif disease == 'choreoathetosis':
                video_label = label[:, 8].long().item()
            else:
                raise ValueError("Invalid label_type")

            out = model(bag, gt_label=video_label)
            scores = out["video_scores_oamb"]
            probs = F.softmax(scores, dim=0)
            pred = 1 if probs[1] > 0.5 else 0

            all_preds.append(pred)
            all_labels.append(video_label)

    return np.array(all_preds), np.array(all_labels)

In [None]:
def run_loso_5_fold_training_mil(disease='dystonia', model_type='st_gcn', verbose=True):
    """
    Example structure that parallels ST-GCN's LOSO + 5-fold approach.
    Adjust file paths to match your directory layout.
    
    Data structure assumption:
      ./Data/6-leave_one_out/subjectX/
          test_data.npy
          test_labels.pkl
          fold_0/
             train_data.npy
             train_labels.pkl
             val_data.npy
             val_labels.pkl
             ...
    """
    base_loso_dir = "./Data/6-leave_one_out"
    subject_dirs = sorted(glob.glob(os.path.join(base_loso_dir, "subject*")))

    all_preds_global = []
    all_labels_global = []

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    for subj_path in subject_dirs:

        subject_id = os.path.basename(subj_path).replace("subject", "")
        print(f"\n====== LOSO for subject {subject_id} ======")

        pretrained_checkpoint_path = f"./Data/6-leave_one_out/subject{subject_id}/{disease}/{model_type}/best_model.pth"

        # 1) Test data
        test_data_path = os.path.join(subj_path, "test_data.npy")
        test_label_path = os.path.join(subj_path, "test_labels.pkl")
        if not (os.path.exists(test_data_path) and os.path.exists(test_label_path)):
            print(f"[WARN] Missing test data for subject {subject_id}. Skipping.")
            continue

        best_fold = None
        best_fold_score = -999
        best_fold_model_path = None
        best_fold_model = None
        best_fold_class_weights = None

        # 2) For each fold_0..4
        for fold_num in range(5):
            fold_dir = os.path.join(subj_path, f"fold_{fold_num}")
            if not os.path.isdir(fold_dir):
                print(f"[WARN] No fold_{fold_num} directory. Skipping.")
                continue

            # -- train/val data
            train_data_path = os.path.join(fold_dir, "train_data.npy")
            train_label_path = os.path.join(fold_dir, "train_labels.pkl")
            val_data_path   = os.path.join(fold_dir, "val_data.npy")
            val_label_path  = os.path.join(fold_dir, "val_labels.pkl")

            if not all(os.path.exists(p) for p in [train_data_path, train_label_path, val_data_path, val_label_path]):
                print(f"[WARN] Missing train/val for fold_{fold_num}. Skipping.")
                continue

            # --- Build MIL Datasets ---
            # Convert your data into "bags" if needed. 
            # If you already have them in the shape (num_samples, 6, T, V, 1), 
            # you may just do something simpler. 
            # Below is an example showing how you'd group data per ID to create bags.
            
            # Load raw arrays
            train_data_raw = np.load(train_data_path)  # shape depends on your saving format
            train_labels_raw = pkl.load(open(train_label_path, 'rb'))
            val_data_raw   = np.load(val_data_path)
            val_labels_raw = pkl.load(open(val_label_path, 'rb'))

            # Example splitting (for L/R). Adjust if needed.
            LEFT, RIGHT = 12, 11
            # Split
            left_indices_train  = np.where(train_labels_raw[:,1] == LEFT)
            right_indices_train = np.where(train_labels_raw[:,1] == RIGHT)
            left_data_train  = train_data_raw[left_indices_train]
            right_data_train = train_data_raw[right_indices_train]
            left_labels_train  = train_labels_raw[left_indices_train]
            right_labels_train = train_labels_raw[right_indices_train]

            # Make bag_data_list, labels
            # For each unique ID, gather all left segments => 1 bag, all right segments => 1 bag
            train_bag_list = []
            train_bag_labels = []
            ids_train = np.unique(train_labels_raw[:,0])
            for pid in ids_train:
                # left
                bag_left_data  = left_data_train[np.where(left_labels_train[:,0] == pid)]
                bag_left_label = left_labels_train[np.where(left_labels_train[:,0] == pid)]
                if len(bag_left_label) > 0:
                    train_bag_list.append(bag_left_data)
                    train_bag_labels.append(bag_left_label[0])

                # right
                bag_right_data  = right_data_train[np.where(right_labels_train[:,0] == pid)]
                bag_right_label = right_labels_train[np.where(right_labels_train[:,0] == pid)]
                if len(bag_right_label) > 0:
                    train_bag_list.append(bag_right_data)
                    train_bag_labels.append(bag_right_label[0])

            # Build dataset
            train_dataset = SkeletalDataset(train_bag_list, train_bag_labels)
            train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

            # --- same for val ---
            left_indices_val  = np.where(val_labels_raw[:,1] == LEFT)
            right_indices_val = np.where(val_labels_raw[:,1] == RIGHT)
            left_data_val  = val_data_raw[left_indices_val]
            right_data_val = val_data_raw[right_indices_val]
            left_labels_val  = val_labels_raw[left_indices_val]
            right_labels_val = val_labels_raw[right_indices_val]

            val_bag_list = []
            val_bag_labels = []
            ids_val = np.unique(val_labels_raw[:,0])
            for pid in ids_val:
                bag_left_data  = left_data_val[np.where(left_labels_val[:,0] == pid)]
                bag_left_label = left_labels_val[np.where(left_labels_val[:,0] == pid)]
                if len(bag_left_label) > 0:
                    val_bag_list.append(bag_left_data)
                    val_bag_labels.append(bag_left_label[0])

                bag_right_data  = right_data_val[np.where(right_labels_val[:,0] == pid)]
                bag_right_label = right_labels_val[np.where(right_labels_val[:,0] == pid)]
                if len(bag_right_label) > 0:
                    val_bag_list.append(bag_right_data)
                    val_bag_labels.append(bag_right_label[0])

            val_dataset = SkeletalDataset(val_bag_list, val_bag_labels)
            val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

            # --- Train & Evaluate model on this fold ---
            best_mil_model_path, class_weights = main_mil(
                fold=fold_num,
                subject_id=subject_id,
                train_loader=train_loader,
                val_loader=val_loader,
                stgcn_checkpoint_path=pretrained_checkpoint_path,
                model_type=model_type,
                disease=disease,
                verbose=verbose
            )

            # Evaluate on the validation set to pick the best fold
            model_temp = WOGMA(
                stgcn_pretrained_model=load_pretrained_model(model_type=model_type, pretrained_checkpoint_path=pretrained_checkpoint_path, device=device),
                feature_dim=32,
                hidden_dim=128,
                num_classes=2,
                top_k_ratio=8,
                theta_class=0.4,
                theta_score=0.3,
                class_weights=class_weights
            ).to(device)

            model_temp.load_state_dict(torch.load(best_mil_model_path, map_location=device))
            model_temp.eval()

            # Quick validation pass
            _, val_acc_temp, f1_temp, _, _, preds_temp, labels_temp = test_mil(
                epoch=999, 
                model=model_temp, 
                loader=val_loader, 
                device=device, 
                label_type=disease,
                verbose=verbose
            )
            fold_metric = 0.6 * f1_temp + 0.4 * val_acc_temp
            if fold_metric > best_fold_score:
                best_fold_score = fold_metric
                best_fold = fold_num
                best_fold_model_path = best_mil_model_path
                best_fold_model = model_temp
                best_fold_class_weights = class_weights

        # Once we've found the best fold, do test inference
        if best_fold_model_path is None:
            print(f"[WARN] No valid folds for subject {subject_id}. Skipping test inference.")
            continue

        troch.save(best_fold_model.state_dict(), f"./Data/6-leave_one_out/subject{subject_id}/{disease}/{model_type}/mil_best_model.pth")
        print(f"[Subject {subject_id}] Best fold {best_fold} => metric = {best_fold_score:.4f}")
        
        # 3) Test inference
        test_data = np.load(test_data_path)
        test_labels = pkl.load(open(test_label_path, 'rb'))

        # Build test 'bags' similarly to how we built train/val
        left_indices_test  = np.where(test_labels[:,1] == 12)
        right_indices_test = np.where(test_labels[:,1] == 11)

        left_data_test  = test_data[left_indices_test]
        right_data_test = test_data[right_indices_test]
        left_labels_test  = test_labels[left_indices_test]
        right_labels_test = test_labels[right_indices_test]

        test_bag_list = []
        test_bag_labels = []
        ids_test = np.unique(test_labels[:,0])
        for pid in ids_test:
            bag_left_data  = left_data_test[np.where(left_labels_test[:,0] == pid)]
            bag_left_label = left_labels_test[np.where(left_labels_test[:,0] == pid)]
            if len(bag_left_label) > 0:
                test_bag_list.append(bag_left_data)
                test_bag_labels.append(bag_left_label[0])

            bag_right_data  = right_data_test[np.where(right_labels_test[:,0] == pid)]
            bag_right_label = right_labels_test[np.where(right_labels_test[:,0] == pid)]
            if len(bag_right_label) > 0:
                test_bag_list.append(bag_right_data)
                test_bag_labels.append(bag_right_label[0])

        test_dataset = SkeletalDataset(test_bag_list, test_bag_labels)
        test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

        preds_subj, labels_subj = inference_mil(
            best_mil_model_path=best_fold_model_path,
            stgcn_checkpoint_path=pretrained_checkpoint_path,
            test_loader=test_loader,
            device=device,
            model_type=model_type,
            disease=disease,
            weight_vec=best_fold_class_weights
        )

        all_preds_global.append(preds_subj)
        all_labels_global.append(labels_subj)

    # Global results
    if len(all_preds_global) == 0:
        print("\nNo predictions collected. Check your data/folder structure.")
        return

    all_preds_global = np.concatenate(all_preds_global, axis=0)
    all_labels_global = np.concatenate(all_labels_global, axis=0)
    
    print("\n=========== Global MIL LOSO Test Results ===========")
    if disease == 'dystonia' or disease == 'choreoathetosis':
        classes = [0, 1]
    else:
        classes = [0, 1, 2, 3, 4]

    plot_confuse_matrix(all_labels_global, all_preds_global, classes=classes)
    print("\nClassification Report:\n", classification_report(all_labels_global, all_preds_global))
    print("Done!")

In [None]:
# plot confuse matrix
import seaborn as sns

plt.figure(figsize=(10, 7))
sns.heatmap(over_all_cm, annot=True, fmt='d')
plt.xlabel('Predicted')
plt.ylabel('Truth')
plt.show()

# Choreoathetosis

In [None]:
def train_one_epoch(model, train_loader, device, optimizer, scheduler):
    """
    Train WO-GMA model for one epoch.
    Returns average loss and classification accuracy across the training set.
    """
    model.train()  # set to train mode
    running_loss = 0.0
    total_count = 0  # how many videos
    correct = 0      # how many correct

    for bag, label in train_loader:
        # bag.shape:   (L, C, T, V, M)   L=clip_num
        # label.shape: (1, 7)           pick label[:,3] => 0 or 1
        bag = bag.to(device)
        video_label = label[:, -1].long().item()  # integer 0 or 1

        # Forward pass
        out = model(bag, gt_label=video_label)  # returns dict
        loss = out["loss_total"]                # combined MIL loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        #scheduler.step()

        running_loss += loss.item()
        total_count += 1

        # For training accuracy, use the OAMB aggregator
        #   (or CPGB aggregator, whichever you prefer for final classification).
        with torch.no_grad():
            scores = out["video_scores_oamb"]  # shape (num_classes,) => (2,)
            # Convert to probability
            probs = F.softmax(scores, dim=0)   # e.g. [p_neg, p_pos]
            pred = 1 if probs[1] > 0.5 else 0
            if pred == video_label:
                correct += 1

    avg_loss = running_loss / total_count if total_count else 0
    accuracy = correct / total_count if total_count else 0

    return avg_loss, accuracy


def validate_one_epoch(model, val_loader, device):
    """
    Validate WO-GMA model for one epoch.
    Returns average loss, classification accuracy, and also all predictions/labels
    for further metrics if you want (F1, confusion matrix, etc.).
    """
    model.eval()  # set to eval mode
    running_loss = 0.0
    total_count = 0
    correct = 0

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for bag, label in val_loader:
            bag = bag.to(device)
            video_label = label[:, -1].long().item()

            # If you want a validation loss consistent with training,
            # provide gt_label here:
            out = model(bag, gt_label=video_label)
            loss = out["loss_total"]

            running_loss += loss.item()
            total_count += 1

            # Classification:
            scores = out["video_scores_oamb"]  # shape (2,)
            probs = F.softmax(scores, dim=0)
            pred = 1 if probs[1] > 0.5 else 0

            all_preds.append(pred)
            all_labels.append(video_label)
            if pred == video_label:
                correct += 1

    avg_loss = running_loss / total_count if total_count else 0
    accuracy = correct / total_count if total_count else 0

    return avg_loss, accuracy, all_preds, all_labels


In [None]:
def main(fold, mil_train_loader, mil_val_loader, disease='dystonia', model_type='st_gcn', checkpoint_path = './Data/9-kfold/3/best_model.pth'):
    # ----------------------------------------------------
    # 1) Setup your environment, device, data loaders
    # ----------------------------------------------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    stgcn_pretrained = load_pretrained_model(model_type=model, checkpoint_path=checkpoint_path, device=device)

    # Suppose you have:
    # train_loader, val_loader => Dataloaders that yield (bag, label)
    # stgcn_pretrained => your pretrained ST-GCN (frozen or not)

    stgcn_pretrained.eval()  # freeze if desired
    for param in stgcn_pretrained.parameters():
        param.requires_grad = False

    train_losses, train_accs = [], []
    val_losses, val_accs = [], []
    best_preds = []

    # Compute class weights for MIL
    if disease == 'dystonia_duration':
        ld = np.array(mil_train_loader.dataset.labels)[:, 3]
    elif disease == 'dystonia_amplitude':
        ld = np.array(mil_train_loader.dataset.labels)[:, 4]
    elif disease == 'choreoathetosis_duration':
        ld = np.array(mil_train_loader.dataset.labels)[:, 5]
    elif disease == 'choreoathetosis_amplitude':
        ld = np.array(mil_train_loader.dataset.labels)[:, 6]
    elif disease == 'dystonia':
        ld = np.array(mil_train_loader.dataset.labels)[:, 7]
    elif disease == 'choreoathetosis':
        ld = np.array(mil_train_loader.dataset.labels)[:, 8]
    else:
        raise ValueError(f"Invalid disease type: {disease}")
    

    weights = compute_effective_number_weights(ld)
    print(weights)
    weight_vec = list(weights.values())[::-1]


    # Initialize the WOGMA model
    model = WOGMA(
        stgcn_pretrained_model=stgcn_pretrained,
        feature_dim=32,     # must match ST-GCN extract_feature dimension
        hidden_dim=128,
        num_classes=NUM_CLASSES,      # F+ or F-
        top_k_ratio=8,
        theta_class=0.4,
        theta_score=0.3,
        class_weights=weight_vec
    ).to(device)

    # ----------------------------------------------------
    # 2) Set your optimizer
    # ----------------------------------------------------
    optimizer = optim.Adam(
        [p for p in model.parameters() if p.requires_grad],
        lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
    )

    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)

    # ----------------------------------------------------
    # 3) Training loop
    # ----------------------------------------------------
    num_epochs = 100
    best_val_f1 = 0.0
    best_mil_model = None

    val_accs, val_losses = [], []
    train_accs, train_losses = [], []
    val_f1, val_recall, val_precision = [], [], []
    best_preds, best_labels = [], []

    for epoch in range(num_epochs):
        train_loss, train_acc = train_one_epoch(model, mil_train_loader, device, optimizer, scheduler)
        val_loss, val_acc, preds, labels = validate_one_epoch(model, mil_val_loader, device)

        train_losses.append(train_loss)
        train_accs.append(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)

        # Metrics
        precision = precision_score(labels, preds, average='macro', zero_division=True)
        recall = recall_score(labels, preds, average='macro', zero_division=True)
        f1 = f1_score(labels, preds, average='macro', zero_division=True)
        acc = accuracy_score(labels, preds)

        val_f1.append(f1)
        val_recall.append(recall)
        val_precision.append(precision)

        if f1 >= best_val_f1:
            best_val_f1 = val_f1
            best_preds = preds
            best_labels = labels
            best_mil_model = model


        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    print("Done training!")

    torch.save(best_mil_model.state_dict(), f'./Data/9-kfold/{fold}/mil_best_model.pth')

    val_metrics = {
        'val_loss': val_losses,
        'val_acc': val_accs,
        'val_f1': val_f1,
        'val_recall': val_recall,
        'val_precision': val_precision}

    return best_mil_model, train_losses, train_accs, val_metrics, best_preds, best_labels

In [None]:
predictions, ground_true = [], []

for i in range(0, 6):
    checkpoint_path = f'./Data/9-kfold/{i}/choreoathetosis_best_model.pth'
    TRAIN_PARH = f'Data/9-kfold/{i}/train_data.npy'
    VAL_PARH = f'Data/9-kfold/{i}/validate_data.npy'
    TRAIN_LABEL_PATH = f'Data/9-kfold/{i}/train_binary_labels.pkl'
    VAL_LABEL_PATH = f'Data/9-kfold/{i}/validate_binary_labels.pkl'
    TEST_PARH = f'Data/9-kfold/{i}/test_data.npy'
    TEST_LABEL_PATH = f'Data/9-kfold/{i}/test_binary_labels.pkl'
    a1 = np.load(VAL_PARH)
    b1 = pkl.load(open(VAL_LABEL_PATH, 'rb'))

    LEFT = 12
    RIGHT = 11

    left_indices1 = np.where(b1[:, 1] == LEFT)
    right_indices1 = np.where(b1[:, 1] == RIGHT)

    left_data1 = a1[left_indices1]
    right_data1 = a1[right_indices1]

    left_labels1 = b1[left_indices1]
    right_labels1 = b1[right_indices1]

    ids1, _ = np.unique(b1[:, 0], return_counts=True)

    # Example usage:


    bag_data_list1 = []
    labels1 = []

    for id in ids1:
        id_left_data1 = left_data1[np.where(left_labels1[:, 0] == id)]
        id_right_data1 = right_data1[np.where(right_labels1[:, 0] == id)]

        id_left_labels1 = left_labels1[np.where(left_labels1[:, 0] == id)]
        id_right_labels1 = right_labels1[np.where(right_labels1[:, 0] == id)]

        if id_left_labels1.shape[0] !=0:
            labels1.append(id_left_labels1[0])
            bag_data_list1.append(id_left_data1)


        if id_right_labels1.shape[0] !=0:
            labels1.append(id_right_labels1[0])
            bag_data_list1.append(id_right_data1)


    val_dataset = SkeletalDataset(bag_data_list1, labels1)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

    a = np.load(TRAIN_PARH)
    b = pkl.load(open(TRAIN_LABEL_PATH, 'rb'))

    left_indices = np.where(b[:, 1] == LEFT)
    right_indices = np.where(b[:, 1] == RIGHT)

    left_data = a[left_indices]
    right_data = a[right_indices]

    left_labels = b[left_indices]
    right_labels = b[right_indices]

    ids, _ = np.unique(b[:, 0], return_counts=True)

    # Example usage:


    bag_data_list = []
    labels = []

    for id in ids:
        id_left_data = left_data[np.where(left_labels[:, 0] == id)]
        id_right_data = right_data[np.where(right_labels[:, 0] == id)]

        id_left_labels = left_labels[np.where(left_labels[:, 0] == id)]
        id_right_labels = right_labels[np.where(right_labels[:, 0] == id)]

        if id_left_labels.shape[0] !=0:
            labels.append(id_left_labels[0])
            bag_data_list.append(id_left_data)


        if id_right_labels.shape[0] !=0:
            labels.append(id_right_labels[0])
            bag_data_list.append(id_right_data)


    train_dataset = SkeletalDataset(bag_data_list, labels)
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

    print(f"Fold {i}")


    model, train_losses, train_accs, val_metrics, best_predictions, val_labels= main(fold=i, mil_train_loader=train_loader, mil_val_loader=val_loader, checkpoint_path=checkpoint_path)

    print(len(best_predictions))


    # plot losses and metrices
    # Plotting the loss
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_metrics['val_loss'], label='Validation Loss')
    plt.legend()
    plt.show()

    # Plotting the metrics by using subplot
    fig, axs = plt.subplots(2, 2)
    axs[0, 0].plot(val_metrics['val_precision'], label='Precision')
    axs[0, 0].legend()
    axs[0, 1].plot(val_metrics['val_recall'], label='Recall')
    axs[0, 1].legend()
    axs[1, 0].plot(val_metrics['val_f1'], label='F1 Score')
    axs[1, 0].legend()
    axs[1, 1].plot(val_metrics['val_acc'], label='Val Accuracy')
    axs[1, 1].plot(train_accs, label='Train Accuracy')
    axs[1, 1].legend()
    plt.show()

    # Plot confuse metrix for best prediction
    cm = confusion_matrix(val_labels, best_predictions)
    print(cm)
    print(classification_report(val_labels, best_predictions))

    predictions.extend(best_predictions)
    ground_true.extend(val_labels)

over_all_cm = confusion_matrix(ground_true, predictions)
print(over_all_cm)
plot_confuse_matrix(ground_true, predictions) 
print(classification_report(ground_true, predictions))

# Dystonia Duration

## ST-GCN

In [None]:
if __name__ == "__main__":
    NUM_CLASSES = 5
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4
    run_loso_5_fold_training_mil(disease='dystonia_duration', model_type='st_gcn', verbose=False)

## CTR-GCN 

In [None]:
if __name__ == "__main__":
    NUM_CLASSES = 5
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4
    run_loso_5_fold_training_mil(disease='dystonia_duration', model_type='ctr_gcn', verbose=False)

## SkateFormer

In [None]:
if __name__ == "__main__":
    NUM_CLASSES = 5
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4
    run_loso_5_fold_training_mil(disease='dystonia_duration', model_type='skateformer', verbose=False)

# Dystonia Amplitude

## ST-GCN

In [None]:
if __name__ == "__main__":
    NUM_CLASSES = 5
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4
    run_loso_5_fold_training_mil(disease='dystonia_amplitude', model_type='st_gcn', verbose=False)

## CTR-GCN

In [None]:
if __name__ == "__main__":
    NUM_CLASSES = 5
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4
    run_loso_5_fold_training_mil(disease='dystonia_amplitude', model_type='ctr_gcn', verbose=False)

## SkateFormer

In [None]:
if __name__ == "__main__":
    NUM_CLASSES = 5
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4
    run_loso_5_fold_training_mil(disease='dystonia_amplitude', model_type='skateformer', verbose=False)

# Choreoathetosis Duration

## ST-GCN

In [None]:
if __name__ == "__main__":
    NUM_CLASSES = 5
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4
    run_loso_5_fold_training_mil(disease='choreoathetosis_duration', model_type='st_gcn', verbose=False)

## CTR-GCN

In [None]:
if __name__ == "__main__":
    NUM_CLASSES = 5
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4
    run_loso_5_fold_training_mil(disease='choreoathetosis_duration', model_type='ctr_gcn', verbose=False)

## SkateFormer

In [None]:
if __name__ == "__main__":
    NUM_CLASSES = 5
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4
    run_loso_5_fold_training_mil(disease='choreoathetosis_duration', model_type='skateformer', verbose=False)

# Choreoathetosis Amplitude

## ST-GCN

In [None]:
if __name__ == "__main__":
    NUM_CLASSES = 5
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4
    run_loso_5_fold_training_mil(disease='choreoathetosis_amplitude', model_type='st_gcn', verbose=False)

## CTR-GCN

In [None]:
if __name__ == "__main__":
    NUM_CLASSES = 5
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4
    run_loso_5_fold_training_mil(disease='choreoathetosis_amplitude', model_type='ctr_gcn', verbose=False)

## SkateFormer

In [None]:
if __name__ == "__main__":
    NUM_CLASSES = 5
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4
    run_loso_5_fold_training_mil(disease='choreoathetosis_amplitude', model_type='skateformer', verbose=False)

# Dystonia

## ST-GCN

In [None]:
if __name__ == "__main__":
    NUM_CLASSES = 2
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4
    run_loso_5_fold_training_mil(disease='dystonia', model_type='st_gcn', verbose=False)

## CTR-GCN

In [None]:
if __name__ == "__main__":
    NUM_CLASSES = 2
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4
    run_loso_5_fold_training_mil(disease='dystonia', model_type='ctr_gcn', verbose=False)

## SkateFormer

In [None]:
if __name__ == "__main__":
    NUM_CLASSES = 2
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4
    run_loso_5_fold_training_mil(disease='dystonia', model_type='skateformer', verbose=False)

# Choreoathetosis

## ST-GCN

In [None]:
if __name__ == "__main__":
    NUM_CLASSES = 2
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4
    run_loso_5_fold_training_mil(disease='choreoathetosis', model_type='st_gcn', verbose=False)

## CTR-GCN

In [None]:
if __name__ == "__main__":
    NUM_CLASSES = 2
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4
    run_loso_5_fold_training_mil(disease='choreoathetosis', model_type='ctr_gcn', verbose=False)

## SkateFormer

In [None]:
if __name__ == "__main__":
    NUM_CLASSES = 2
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4
    run_loso_5_fold_training_mil(disease='choreoathetosis', model_type='skateformer', verbose=False)