In [4]:
import torch
from typing import List, Tuple
from torch.utils.data import DataLoader, TensorDataset


def get_dataloaders(test_set, batch_size=32, data_type='imu'):
    X_imu_test, X_kp_test, y_test = test_set

    if data_type == 'imu':
        test_loader = DataLoader(TensorDataset(X_imu_test, y_test), batch_size=batch_size, shuffle=False)
    elif data_type == 'kp':
        test_loader = DataLoader(TensorDataset(X_kp_test, y_test), batch_size=batch_size, shuffle=False)
    elif data_type == 'all':
        test_loader = DataLoader(TensorDataset(X_imu_test, X_kp_test, y_test), batch_size=batch_size, shuffle=False)

    print(f"Test set size: {len(test_loader)}")

    return test_loader

train_dataset = torch.load('./data/train_dataset.pt')
train_dataloader = get_dataloaders(train_dataset, 128, "all")
test_dataset = torch.load('./data/test_dataset.pt')
test_dataloader = get_dataloaders(test_dataset, 128, "all")

In [8]:
import numpy as np
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, average_precision_score, matthews_corrcoef
)


def model_test(dl_test, imu_model, kp_model, mode, device):
    imu_model.eval()
    kp_model.eval()

    # 记录所有数据
    all_labels = []
    imu_all_preds = []
    imu_all_probs = []
    kp_all_preds = []
    kp_all_probs = []

    with torch.no_grad():
        for imu, kp, labels in dl_test:
            imu = imu.to(device).float()
            kp = kp.to(device).float()
            labels = labels.to(device).long()

            if mode == 'stg1':
                imu_pred = imu_model(imu, 'stg1')
                kp_pred = kp_model(kp, 'stg1')
            elif mode == 'stg4':
                imu_pred = imu_model(imu, 'stg3')
                kp_pred = kp_model(kp, 'stg3')

            # detach转numpy
            labels_np = labels.cpu().numpy()
            imu_probs_np = torch.softmax(imu_pred, dim=-1).cpu().numpy()
            imu_preds_np = np.argmax(imu_probs_np, axis=-1)
            kp_probs_np = torch.softmax(kp_pred, dim=-1).cpu().numpy()
            kp_preds_np = np.argmax(kp_probs_np, axis=-1)

            all_labels.append(labels_np)
            imu_all_preds.append(imu_preds_np)
            imu_all_probs.append(imu_probs_np)
            kp_all_preds.append(kp_preds_np)
            kp_all_probs.append(kp_probs_np)

    # 合并全部batch
    all_labels = np.concatenate(all_labels)  # shape [N]
    imu_all_preds = np.concatenate(imu_all_preds)  # shape [N]
    imu_all_probs = np.concatenate(imu_all_probs)  # shape [N, num_classes]
    kp_all_preds = np.concatenate(kp_all_preds)
    kp_all_probs = np.concatenate(kp_all_probs)

    num_classes = imu_all_probs.shape[1]

    # acc
    imu_acc = accuracy_score(all_labels, imu_all_preds)
    kp_acc = accuracy_score(all_labels, kp_all_preds)
    # precision, recall, f1
    imu_precision = precision_score(all_labels, imu_all_preds, average="macro" if num_classes > 2 else "binary",
                                    zero_division=0)
    kp_precision = precision_score(all_labels, kp_all_preds, average="macro" if num_classes > 2 else "binary",
                                   zero_division=0)
    imu_recall = recall_score(all_labels, imu_all_preds, average="macro" if num_classes > 2 else "binary",
                              zero_division=0)
    kp_recall = recall_score(all_labels, kp_all_preds, average="macro" if num_classes > 2 else "binary",
                             zero_division=0)
    imu_f1 = f1_score(all_labels, imu_all_preds, average="macro" if num_classes > 2 else "binary", zero_division=0)
    kp_f1 = f1_score(all_labels, kp_all_preds, average="macro" if num_classes > 2 else "binary", zero_division=0)
    # mcc
    imu_mcc = matthews_corrcoef(all_labels, imu_all_preds)
    kp_mcc = matthews_corrcoef(all_labels, kp_all_preds)

    # AUC & AUPRC
    if num_classes == 2:
        imu_auc = roc_auc_score(all_labels, imu_all_probs[:, 1])
        kp_auc = roc_auc_score(all_labels, kp_all_probs[:, 1])
        imu_auprc = average_precision_score(all_labels, imu_all_probs[:, 1])
        kp_auprc = average_precision_score(all_labels, kp_all_probs[:, 1])
    else:
        # one-vs-rest
        imu_auc = roc_auc_score(all_labels, imu_all_probs, multi_class="ovr", average="macro")
        kp_auc = roc_auc_score(all_labels, kp_all_probs, multi_class="ovr", average="macro")
        imu_auprc = average_precision_score(all_labels, imu_all_probs, average="macro")
        kp_auprc = average_precision_score(all_labels, kp_all_probs, average="macro")

    imu_result = {"acc": imu_acc, "precision": imu_precision, "recall": imu_recall, "f1": imu_f1, "auc": imu_auc,
                  "auprc": imu_auprc, "mcc": imu_mcc}
    kp_result = {"acc": kp_acc, "precision": kp_precision, "recall": kp_recall, "f1": kp_f1, "auc": kp_auc,
                 "auprc": kp_auprc, "mcc": kp_mcc}

    return imu_result, kp_result

In [9]:
from imunet.stft import TFusion
from kpnet.mstgcn import MSTGCN

device = torch.device("cuda:1" if torch.cuda.is_available() else 'cpu')
seq_len = 200
n_fft = 64
hop_length = 8
in_channels = 9
patch_size = 16
stride = 16
depth = 12
num_classes = 10
out_channels = 32
graph_args = {"layout": "openpose", "strategy": "spatial"}
edge_importance_weighting = True

imu_model = TFusion(seq_len, n_fft, hop_length, device, in_channels, patch_size, stride, depth, num_classes).to(device)
state_dict = torch.load("./checkpoint/imu_best_test_clr.pt", map_location=device)
imu_model.load_state_dict(state_dict)
imu_model.eval()
kp_model = MSTGCN(2, out_channels, num_classes, graph_args, edge_importance_weighting).to(device)
state_dict = torch.load("./checkpoint/kp_best_test_clr.pt", map_location=device)
kp_model.load_state_dict(state_dict)
kp_model.eval()

imu_model.stft_encoder.init_stft_embedder(train_dataloader)



Teset set length: 14190
Test set size: 111


In [7]:
imu_result_tmp, kp_result_tmp = model_test(test_dataloader, imu_model, kp_model, 'stg3', device)
print(imu_result_tmp)
print(kp_result_tmp)

Test set size: 26
{'acc': 0.7305295950155763, 'precision': 0.7292001478095932, 'recall': 0.7305295950155763, 'f1': 0.7262006340674086, 'auc': 0.9616621862494865, 'auprc': 0.7993890599509568, 'mcc': 0.701388401748504}
{'acc': 0.8358255451713396, 'precision': 0.8378438911159444, 'recall': 0.8358255451713396, 'f1': 0.8345560688317605, 'auc': 0.9801992518619882, 'auprc': 0.8958213560645222, 'mcc': 0.818155914486705}


In [None]:
GROUP_DEF = {
    "env":   [0, 1, 2, 3],   # Environmental factors
    "viol":  [4, 5],         # Worker-violation factors
    "ergo":  [6, 7],         # Ergonomic-excess
    "emer":  [8],            # Immediate emergency response
    "base":  [9],            # Baseline activity
}

def eval_by_groups_from_labels(y_true, y_pred, group_def):
    """
    y_true, y_pred: 1D numpy array, 值为原始10类标签
    group_def: dict, {group_name: [class_ids]}
    return:
        group_metrics: {group_name: {"precision_g": .., "recall_g": .., "f1_g": ..}}
        macro_metrics: {"precision_g_macro": .., "recall_g_macro": .., "f1_g_macro": ..}
    """
    group_metrics = {}

    for g_name, class_ids in group_def.items():

        y_true_g = np.isin(y_true, class_ids).astype(int)
        y_pred_g = np.isin(y_pred, class_ids).astype(int)

        p = precision_score(y_true_g, y_pred_g, zero_division=0)
        r = recall_score(y_true_g, y_pred_g, zero_division=0)
        f = f1_score(y_true_g, y_pred_g, zero_division=0)

        group_metrics[g_name] = {
            "precision_g": float(p),
            "recall_g": float(r),
            "f1_g": float(f),
        }

    all_p = [m["precision_g"] for m in group_metrics.values()]
    all_r = [m["recall_g"] for m in group_metrics.values()]
    all_f = [m["f1_g"] for m in group_metrics.values()]

    macro_metrics = {
        "precision_g_macro": float(np.mean(all_p)),
        "recall_g_macro": float(np.mean(all_r)),
        "f1_g_macro": float(np.mean(all_f)),
    }

    return group_metrics, macro_metrics

def group_eval_test(dl_test, imu_model, kp_model, mode, device, group_def=None):

    if group_def is None:
        group_def = GROUP_DEF

    imu_model.eval()
    kp_model.eval()

    all_labels = []
    imu_all_preds = []
    kp_all_preds = []

    with torch.no_grad():
        for imu, kp, labels in dl_test:
            imu = imu.to(device).float()
            kp = kp.to(device).float()
            labels = labels.to(device).long()

            if mode == 'stg1':
                imu_pred = imu_model(imu, 'stg1')
                kp_pred = kp_model(kp, 'stg1')
            elif mode == 'stg3':
                imu_pred = imu_model(imu, 'stg3')
                kp_pred = kp_model(kp, 'stg3')
            else:
                raise ValueError(f"Unsupported mode: {mode}")

            imu_probs = torch.softmax(imu_pred, dim=-1)
            kp_probs = torch.softmax(kp_pred, dim=-1)

            imu_preds = torch.argmax(imu_probs, dim=-1)
            kp_preds = torch.argmax(kp_probs, dim=-1)

            all_labels.append(labels.cpu().numpy())
            imu_all_preds.append(imu_preds.cpu().numpy())
            kp_all_preds.append(kp_preds.cpu().numpy())

    all_labels = np.concatenate(all_labels)      # [N]
    imu_all_preds = np.concatenate(imu_all_preds)  # [N]
    kp_all_preds = np.concatenate(kp_all_preds)    # [N]

    imu_group_metrics, imu_group_macro = eval_by_groups_from_labels(
        all_labels, imu_all_preds, group_def
    )
    kp_group_metrics, kp_group_macro = eval_by_groups_from_labels(
        all_labels, kp_all_preds, group_def
    )

    imu_result = {
        "group_metrics": imu_group_metrics,
        "group_macro": imu_group_macro,
    }
    kp_result = {
        "group_metrics": kp_group_metrics,
        "group_macro": kp_group_macro,
    }

    return imu_result, kp_result

In [None]:
imu_g_res, kp_g_res = group_eval_test(
    dl_test=test_dataloader,
    imu_model=imu_model,
    kp_model=kp_model,
    mode='stg3',
    device=device
)
print(imu_g_res)
print(kp_g_res)