In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [1]:
import os
import json
import math
from pathlib import Path
from typing import List, Dict, Tuple, Optional

import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from glob import glob
from tqdm import tqdm
import math
import time
from collections import OrderedDict
import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.nn import functional as F
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix,precision_recall_fscore_support

from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR, CosineAnnealingLR

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [2]:
class BVHParser:

    def __init__(self, filepath: str):
        self.filepath = filepath
        self.joints = []
        self.hierarchy = {}
        self.joint_parents = {}
        self.channels_list = []
        self.motion_data = None
        self.frame_time = 0.0
        self.num_frames = 0
        self._joint_channel_counts = {}
        self.ordered_joints = []

    def parse(self):
        with open(self.filepath, 'r') as f:
            lines = f.readlines()

        # Find MOTION index
        motion_idx = None
        for i, line in enumerate(lines):
            if 'MOTION' in line.upper():
                motion_idx = i
                break

        self._parse_hierarchy(lines[:motion_idx])
        self._parse_motion(lines[motion_idx:])

        return self

    def _parse_hierarchy(self, lines):
      """Extract joint hierarchy"""
      joint_stack = []
      current_joint = None
      self.joints = []
      self.hierarchy = {}
      self.joint_parents = {}
      self.channels_list = []
      self.ordered_joints = []

      for line in lines:
          stripped = line.strip()
          if stripped.startswith("ROOT") or stripped.startswith("JOINT"):
              parts = stripped.split()
              if len(parts) >= 2:
                  current_joint = parts[1]
                  self.joints.append(current_joint)
                  self.ordered_joints.append(current_joint)
                  self.hierarchy[current_joint] = {
                      'channels': [],
                      'offset': np.array([0.0, 0.0, 0.0], dtype=np.float32),
                      'children': []
                  }
                  if joint_stack:
                      parent = joint_stack[-1]
                      self.joint_parents[current_joint] = parent
                      self.hierarchy[parent]['children'].append(current_joint)
                  else:
                      self.joint_parents[current_joint] = None
                  joint_stack.append(current_joint)

          elif stripped.startswith("End Site"):
              # End sites don't have channels, skip them
              pass

          elif stripped.startswith("OFFSET") and current_joint:
              parts = stripped.split()
              try:
                  ox, oy, oz = float(parts[1]), float(parts[2]), float(parts[3])
                  self.hierarchy[current_joint]['offset'] = np.array([ox, oy, oz], dtype=np.float32)
              except:
                  pass

          elif stripped.startswith("CHANNELS") and current_joint:
              parts = stripped.split()
              try:
                  num_ch = int(parts[1])
                  chs = parts[2:2+num_ch]
                  self.hierarchy[current_joint]['channels'] = chs
                  self._joint_channel_counts[current_joint] = len(chs)
                  for ch in chs:
                      self.channels_list.append((current_joint, ch))
              except:
                  pass

          elif '}' in stripped:
              if joint_stack:
                  joint_stack.pop()
                  current_joint = joint_stack[-1] if joint_stack else None

    def _parse_motion(self, lines):
        """Extract motion data"""
        for line in lines[:20]:
            if "Frames:" in line:
                self.num_frames = int(line.split(":")[1].strip())
            if "Frame Time:" in line:
                self.frame_time = float(line.split(":")[1].strip())


        data_start = None
        for i, line in enumerate(lines):
            s = line.strip()
            if not s:
                continue
            parts = s.split()
            try:
                _ = float(parts[0])
                data_start = i
                break
            except:
                continue

        data = []
        for line in lines[data_start:]:
            s = line.strip()
            if not s:
                continue
            parts = s.split()
            try:
                nums = [float(p) for p in parts]
                data.append(nums)
            except:
                continue

        self.motion_data = np.array(data, dtype=np.float32)
        if self.num_frames == 0:
            self.num_frames = len(data)


    def compute_joint_positions(self):

        T = self.motion_data.shape[0]
        V = len(self.ordered_joints)
        positions = np.zeros((T, V, 3), dtype=np.float32)

        joint_col_idx = {}
        col = 0
        for j in self.ordered_joints:
            chs = self.hierarchy[j]['channels']
            cnt = len(chs)
            joint_col_idx[j] = (col, col + cnt)
            col += cnt

        for t in range(T):
            global_rot = {}
            global_pos = {}

            for v_idx, joint in enumerate(self.ordered_joints):
                parent = self.joint_parents.get(joint, None)
                offset = self.hierarchy[joint]['offset']

                local_rot = np.eye(3, dtype=np.float32)
                local_pos = np.zeros(3, dtype=np.float32)

                start, end = joint_col_idx[joint]
                if end > start:
                    ch_names = self.hierarchy[joint]['channels']
                    values = self.motion_data[t, start:end]
                    pos_vals = []
                    rot_vals = []
                    pos_order = []
                    rot_order = []
                    for i, ch in enumerate(ch_names):
                        lower = ch.lower()
                        if 'position' in lower:
                            pos_vals.append(values[i])
                            pos_order.append(lower[0])
                        elif 'rotation' in lower:
                            rot_vals.append(values[i])
                            rot_order.append(ch)

                    if len(pos_vals) > 0:
                        lp = np.zeros(3, dtype=np.float32)
                        for i, ch in enumerate(ch_names):
                            if 'position' in ch.lower():
                                axis = ch.lower()[0]  # x/y/z
                                if axis == 'x':
                                    lp[0] = values[i]
                                elif axis == 'y':
                                    lp[1] = values[i]
                                elif axis == 'z':
                                    lp[2] = values[i]
                        local_pos = lp
                    if len(rot_vals) > 0:
                        rot_mat = np.eye(3, dtype=np.float32)
                        for i, ch in enumerate(ch_names):
                            if 'rotation' in ch.lower():
                                angle_deg = values[i]
                                angle = math.radians(angle_deg)
                                axis = ch.lower()[0]  # x/y/z
                                if axis == 'x':
                                    R = np.array([[1,0,0],
                                                  [0, math.cos(angle), -math.sin(angle)],
                                                  [0, math.sin(angle),  math.cos(angle)]], dtype=np.float32)
                                elif axis == 'y':
                                    R = np.array([[ math.cos(angle), 0, math.sin(angle)],
                                                  [0,1,0],
                                                  [-math.sin(angle),0, math.cos(angle)]], dtype=np.float32)
                                else:
                                    R = np.array([[math.cos(angle), -math.sin(angle), 0],
                                                  [math.sin(angle),  math.cos(angle), 0],
                                                  [0,0,1]], dtype=np.float32)
                                rot_mat = rot_mat @ R
                        local_rot = rot_mat

                if parent is None:
                    if np.any(local_pos != 0):
                        root_pos = local_pos
                    else:
                        root_pos = offset
                    global_pos[joint] = root_pos
                    global_rot[joint] = local_rot
                else:
                    p_rot = global_rot[parent]
                    p_pos = global_pos[parent]
                    joint_pos = p_pos + p_rot.dot(offset) + p_rot.dot(local_pos)
                    global_pos[joint] = joint_pos
                    global_rot[joint] = p_rot.dot(local_rot)

                positions[t, v_idx, :] = global_pos[joint]

        return positions


def compute_motion_features(positions):

  velocity = np.zeros_like(positions)
  velocity[1:] = positions[1:] - positions[:-1]

  acceleration = np.zeros_like(positions)
  acceleration[1:] = velocity[1:] - velocity[:-1]

  motion_features = np.concatenate([positions, velocity, acceleration], axis=2)

  return motion_features


def resample_time_sequence(seq, target_frames):
    T, V, C = seq.shape
    if T == target_frames:
        return seq
    if T < 2:
        out = np.tile(seq, (target_frames, 1, 1))
        return out[:target_frames]
    old_idx = np.linspace(0, T-1, T)
    new_idx = np.linspace(0, T-1, target_frames)
    out = np.zeros((target_frames, V, C), dtype=seq.dtype)
    for v in range(V):
        for c in range(C):
            out[:, v, c] = np.interp(new_idx, old_idx, seq[:, v, c])
    return out

def center_on_root(seq, root_index = 0, keep_z = True):
    T, V, C = seq.shape
    root_pos = seq[:, root_index:root_index+1, :]
    centered = seq - root_pos
    if not keep_z:
        centered[..., 2] = 0.0
    return centered

def normalize_sequence(seq, scale = None):
    if scale is not None:
        return seq / scale, scale
    T, V, C = seq.shape
    dists = np.linalg.norm(seq, axis=2)
    mean_dist = float(np.mean(dists))
    if mean_dist == 0:
        mean_dist = 1.0
    return seq / mean_dist, mean_dist


def random_rotation_xyz(sample, angle_range=np.pi/12):
    C, T, V = sample.shape
    if C % 3 != 0:
        return sample
    theta = np.random.uniform(-angle_range, angle_range)
    c, s = np.cos(theta), np.sin(theta)
    R = np.array([[c, 0, -s],
                  [0, 1,  0],
                  [s, 0,  c]], dtype=np.float32)
    out = sample.copy()
    for v in range(V):
        xyz = sample[:, :, v].reshape(3, -1)
        xyz_rot = R @ xyz
        out[0:3, :, v] = xyz_rot
    return out

def temporal_jitter(sample, max_shift=8):

    C, T, V = sample.shape
    shift = np.random.randint(-max_shift, max_shift + 1)
    if shift == 0: return sample
    if shift > 0:

        idx = np.linspace(0, T-1, T-shift).astype(int)
        new = sample[:, idx, :]
        pad = np.repeat(new[:, -1:, :], shift, axis=1)
        out = np.concatenate([new, pad], axis=1)
    else:

        idx = np.linspace(0, T-1, min(T, T+(-shift))).astype(int)
        new = sample[:, idx, :]
        if new.shape[1] >= T:
            out = new[:, :T, :]
        else:
            pad = np.repeat(new[:, -1:, :], T - new.shape[1], axis=1)
            out = np.concatenate([new, pad], axis=1)
    return out

def add_gaussian_noise(sample, sigma=0.005):
    return sample + np.random.normal(scale=sigma, size=sample.shape).astype(np.float32)


def center_and_scale(sample, root_idx=0, scale_pair=(11,14)):

    C, T, V = sample.shape
    out = sample.copy().astype(np.float32)
    if C < 3:
        return out

    root_xyz = out[0:3, :, root_idx].mean(axis=1)  # (3,)
    out[0:3, :, :] = out[0:3, :, :] - root_xyz.reshape(3,1,1)

    p1 = out[0:3, :, scale_pair[0]].mean(axis=1)
    p2 = out[0:3, :, scale_pair[1]].mean(axis=1)
    scale = np.linalg.norm(p1 - p2) + 1e-6
    out[0:3, :, :] = out[0:3, :, :] / scale
    return out

def build_edge_list(ordered_joints, joint_parents):

    name2idx = {name: i for i, name in enumerate(ordered_joints)}
    edges = []
    for j, parent in joint_parents.items():
        if parent is None:
            continue
        u = name2idx[parent]
        v = name2idx[j]
        edges.append((u, v))
        edges.append((v, u))
    return edges

### STGCN Dataset Builder

In [None]:
class STGCNDatasetBuilder:

    def __init__(self, data_dir, fileinfo_csv, output_dir, target_frames = 100, center_root = True,normalize = True, root_name = None):
        self.data_dir = Path(data_dir)
        self.fileinfo = pd.read_csv(fileinfo_csv)
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.target_frames = target_frames
        self.center_root = center_root
        self.normalize = normalize
        self.root_name = root_name

        self.sample_paths = []
        self.labels = []
        self.failed = []
        self.label2idx = {}

    def _find_bvh(self, filename: str):
        # Try direct and recursive search
        candidates = [
            self.data_dir / f"{filename}.bvh",
            self.data_dir / "bvh" / f"{filename}.bvh",
            self.data_dir / "data" / f"{filename}.bvh",
            self.data_dir / filename / f"{filename}.bvh",
        ]
        for p in candidates:
            if p.exists():
                return p
        # recursive search
        found = list(self.data_dir.rglob(f"{filename}.bvh"))
        if found:
            return found[0]
        return None

    def build(self, save_as_single_array: bool = True):
        print("Processing dataset...")
        data_list = []
        label_list = []
        global_joint_order = None
        global_edge_list = None

        for idx, row in self.fileinfo.iterrows():
            filename = row['filename']
            emotion = row.get('emotion', None)
            bvh_path = self._find_bvh(filename)
            if bvh_path is None:
                self.failed.append((filename, "not found"))
                continue

            parser = BVHParser(str(bvh_path)).parse()
            positions = parser.compute_joint_positions()  # (T, V, 3)

            if global_joint_order is None:
                global_joint_order = parser.ordered_joints
                global_edge_list = build_edge_list(parser.ordered_joints, parser.joint_parents)

                if self.root_name and self.root_name in global_joint_order:
                    root_idx = global_joint_order.index(self.root_name)
                else:
                    root_idx = 0


            positions = resample_time_sequence(positions, self.target_frames)  # (T, V, 3)
            positions_with_motion = compute_motion_features(positions)


            if self.center_root:
                positions_with_motion = center_on_root(positions_with_motion, root_index=root_idx)


            if self.normalize:
                positions_with_motion, used_scale = normalize_sequence(positions_with_motion, scale=None)


            if parser.ordered_joints != global_joint_order:

                name2idx_local = {name: i for i, name in enumerate(parser.ordered_joints)}
                reordered = np.zeros((self.target_frames, len(global_joint_order), 3), dtype=positions.dtype)
                for i, name in enumerate(global_joint_order):
                    if name in name2idx_local:
                        reordered[:, i, :] = positions[:, name2idx_local[name], :]
                    else:

                        reordered[:, i, :] = 0.0
                positions = reordered

            seq = np.transpose(positions_with_motion, (2, 0, 1))
            data_list.append(seq.astype(np.float32))


            if emotion not in self.label2idx:
                self.label2idx[emotion] = len(self.label2idx)
            label_list.append(self.label2idx[emotion])

            sample_out = self.output_dir / f"{filename}.npy"
            np.save(sample_out, seq.astype(np.float32))
            self.sample_paths.append(str(sample_out))

            if len(data_list) % 50 == 0:
                print(f"Processed {len(data_list)} samples...")


        data_arr = np.stack(data_list, axis=0)
        labels_arr = np.array(label_list, dtype=np.int64)

        if save_as_single_array:
            np.save(self.output_dir / "data.npy", data_arr)
            np.save(self.output_dir / "labels.npy", labels_arr)
            meta = {
                'joint_names': global_joint_order,
                'edge_list': global_edge_list,
                'label2idx': self.label2idx,
                'target_frames': self.target_frames,
                'center_root': self.center_root,
                'normalize': self.normalize
            }
            with open(self.output_dir / "meta.json", "w") as f:
                json.dump(meta, f, indent=2)
            print(f"Saved data.npy ({data_arr.shape}), labels.npy ({labels_arr.shape}), meta.json")

        return data_arr, labels_arr, global_joint_order, global_edge_list


if __name__ == "__main__":
    data_dir = "/content/drive/MyDrive/kinematic_dataset_final/BVH/"
    fileinfo_csv = "/content/drive/MyDrive/kinematic_dataset_final/file-info.csv"
    output_dir = "/content/drive/MyDrive/kinematic_dataset_final/stgcn_input_final/"
    target_frames = 100
    center_root = True
    normalize = True
    sample_save_single_array = True

    builder = STGCNDatasetBuilder(data_dir, fileinfo_csv, output_dir,
                                  target_frames=target_frames,
                                  center_root=center_root,
                                  normalize=normalize)
    data_arr, labels_arr, joints, edges = builder.build(save_as_single_array=sample_save_single_array)

    print("Data shape:", data_arr.shape)
    print("Label distribution:", np.bincount(labels_arr))
    print("Joints:", joints)
    print("Edges:", edges)

Processing dataset...
Processed 50 samples...
Processed 100 samples...
Processed 150 samples...
Processed 200 samples...
Processed 250 samples...
Processed 300 samples...
Processed 350 samples...
Processed 400 samples...
Processed 450 samples...
Processed 500 samples...
Processed 550 samples...
Processed 600 samples...
Processed 650 samples...
Processed 700 samples...
Processed 750 samples...
Processed 800 samples...
Processed 850 samples...
Processed 900 samples...
Processed 950 samples...
Processed 1000 samples...
Processed 1050 samples...
Processed 1100 samples...
Processed 1150 samples...
Processed 1200 samples...
Processed 1250 samples...
Processed 1300 samples...
Processed 1350 samples...
Processed 1400 samples...
Saved data.npy ((1401, 9, 100, 59)), labels.npy ((1401,)), meta.json
Data shape: (1401, 9, 100, 59)
Label distribution: [200 210 216 216 145 202 212]
Joints: ['Hips', 'RightUpLeg', 'RightLeg', 'RightFoot', 'LeftUpLeg', 'LeftLeg', 'LeftFoot', 'Spine', 'Spine1', 'Spine2',

In [6]:
### Uncomment it when using colab
X_path = "/content/drive/MyDrive/kinematic_dataset_final/stgcn_input_final/data.npy"
y_path = "/content/drive/MyDrive/kinematic_dataset_final/stgcn_input_final/labels.npy"

# X_path = "data.npy"
# y_path = "labels.npy"


X = np.load(X_path)
y = np.load(y_path)

print("Loaded X", X.shape, "y", y.shape)

Loaded X (1401, 9, 100, 59) y (1401,)


In [7]:
if X.ndim == 5 and X.shape[-1] == 1:
    X = X[..., 0]
print("Using X shape:", X.shape)

N, C, T, V = X.shape
print("N,C,T,V =", N, C, T, V)

Using X shape: (1401, 9, 100, 59)
N,C,T,V = 1401 9 100 59


In [8]:
class STGCNDataset(Dataset):
    def __init__(self, X, y, augment=False,train =True, fixed_T=100):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.train = train
        self.augment = augment
        self.fixed_T = fixed_T

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        x = self.X[idx]  # (C, T, V)
        if self.augment:
            x = x + 0.001 * np.random.randn(*x.shape).astype(np.float32)
        return x, self.y[idx]


train_idx, val_idx = train_test_split(np.arange(N), test_size=0.2, stratify=y, random_state=42)

X_train, y_train = X[train_idx], y[train_idx]
X_val, y_val = X[val_idx], y[val_idx]
print("Train/Val sizes:", X_train.shape[0], X_val.shape[0])

fixed_T = 100
train_ds = STGCNDataset(X_train, y_train, train=True, augment=True, fixed_T=fixed_T)
val_ds = STGCNDataset(X_val, y_val, train=False, augment=False, fixed_T=fixed_T)

batch_size = 8
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)


Train/Val sizes: 1120 281


In [9]:
edge_list = [(0, 1), (1, 0), (1, 2), (2, 1), (2, 3), (3, 2), (4, 5), (5, 4), (5, 6), (6, 5), (7, 8), (8, 7), (8, 9), (9, 8), (9, 10), (10, 9), (10, 11), (11, 10), (11, 12), (12, 11), (9, 13), (13, 9), (13, 14), (14, 13), (14, 15), (15, 14), (15, 16), (16, 15), (16, 17), (17, 16), (17, 18), (18, 17), (18, 19), (19, 18), (15, 20), (20, 15), (20, 21), (21, 20), (21, 22), (22, 21), (22, 23), (23, 22), (14, 24), (24, 14), (24, 25), (25, 24), (25, 26), (26, 25), (26, 27), (27, 26), (13, 28), (28, 13), (28, 29), (29, 28), (29, 30), (30, 29), (30, 31), (31, 30), (9, 32), (32, 9), (32, 33), (33, 32), (33, 34), (34, 33), (34, 35), (35, 34), (36, 37), (37, 36), (37, 38), (38, 37), (38, 39), (39, 38), (39, 40), (40, 39), (40, 41), (41, 40), (41, 42), (42, 41), (38, 43), (43, 38), (43, 44), (44, 43), (44, 45), (45, 44), (45, 46), (46, 45), (37, 47), (47, 37), (47, 48), (48, 47), (48, 49), (49, 48), (49, 50), (50, 49), (36, 51), (51, 36), (51, 52), (52, 51), (52, 53), (53, 52), (53, 54), (54, 53), (55, 56), (56, 55), (56, 57), (57, 56), (57, 58), (58, 57)]

def build_adj(V, edge_list, self_link=True):
    A = np.zeros((V, V), dtype=np.float32)
    for (u,v) in edge_list:
        A[u, v] = 1
        A[v, u] = 1
    if self_link:
        for i in range(V):
            A[i,i] = 1
    return A

def normalize_adj(A):
    # symmetric normalization D^-1/2 A D^-1/2
    D = np.sum(A, axis=1)
    D_inv_sqrt = np.power(D, -0.5)
    D_inv_sqrt[np.isinf(D_inv_sqrt)] = 0.0
    Dm = np.diag(D_inv_sqrt)
    return Dm @ A @ Dm

A = build_adj(V, edge_list, self_link=True)
A_norm = normalize_adj(A)
A_torch = torch.tensor(A_norm, dtype=torch.float32, device=device)  # (V,V)
print("A_norm shape:", A_torch.shape)


A_norm shape: torch.Size([59, 59])


In [10]:

class SpatialGraphConv(nn.Module):
    def __init__(self, in_channels, out_channels, bias=True):
        super().__init__()
        self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)

    def forward(self, x, A):
        x = self.conv1x1(x)
        N, outC, T, V = x.shape
        x = x.permute(0,2,1,3).contiguous().view(N*T, outC, V)
        x = torch.matmul(x, A.t())
        x = x.view(N, T, outC, V).permute(0,2,1,3).contiguous()
        return x


class TemporalAttention(nn.Module):
    """Attention across time dimension"""
    def __init__(self, channels, reduction=8):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d((None, 1))
        self.max_pool = nn.AdaptiveMaxPool2d((None, 1))

        self.fc = nn.Sequential(
            nn.Conv2d(channels*2, channels // reduction, kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction, channels, kernel_size=1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        avg_out = self.avg_pool(x)
        max_out = self.max_pool(x)
        concat = torch.cat([avg_out, max_out], dim=1)
        attention = self.fc(concat)
        return x * attention.expand_as(x)


class SpatialAttention(nn.Module):
    """Attention across joints (spatial dimension)"""
    def __init__(self, channels, reduction=8):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d((1, None))  # Pool over T dimension
        self.max_pool = nn.AdaptiveMaxPool2d((1, None))

        self.fc = nn.Sequential(
            nn.Conv2d(channels*2, channels // reduction, kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction, channels, kernel_size=1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        avg_out = self.avg_pool(x)
        max_out = self.max_pool(x)
        concat = torch.cat([avg_out, max_out], dim=1)
        attention = self.fc(concat)
        return x * attention.expand_as(x)


class ChannelAttention(nn.Module):
    """Attention across feature channels"""
    def __init__(self, channels, reduction=8):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc = nn.Sequential(
            nn.Conv2d(channels, channels // reduction, kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction, channels, kernel_size=1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        attention = avg_out + max_out
        return x * attention


class STGCNBlockWithAttention(nn.Module):
    def __init__(self, in_channels, out_channels, A, kernel_size=9, stride=1, use_attention=True):
        super().__init__()
        self.A = A
        padding = (kernel_size - 1) // 2

        self.gconv = SpatialGraphConv(in_channels, out_channels)

        self.tconv = nn.Conv2d(out_channels, out_channels,
                               kernel_size=(kernel_size, 1),
                               padding=(padding, 0),
                               stride=(stride, 1))

        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        self.use_attention = use_attention
        if use_attention:
            self.temporal_attn = TemporalAttention(out_channels)
            self.spatial_attn = SpatialAttention(out_channels)
            self.channel_attn = ChannelAttention(out_channels)

        self.dropout = nn.Dropout(0.3)

        if (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x
        else:
            self.residual = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=(stride, 1)),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):

        res = self.residual(x)
        x = self.gconv(x, self.A)
        x = self.tconv(x)
        x = self.bn(x)
        if self.use_attention:
            x = self.temporal_attn(x)
            x = self.spatial_attn(x)
            x = self.channel_attn(x)

        x = self.relu(x)
        x = self.dropout(x)
        x = x + res

        return x


In [12]:
class STGCN(nn.Module):
    def __init__(self, in_channels, num_class, A, base_channels=64, dropout=0.5):
        super().__init__()
        self.register_buffer('A', A)

        # in_channels are 9 (3 positions + 3 velocity + 3 acceleration)
        self.data_bn = nn.BatchNorm1d(in_channels * V)

        # Multiple blocks with attention
        self.layer1 = nn.Sequential(
            STGCNBlockWithAttention(in_channels, base_channels, A, kernel_size=9, stride=1, use_attention=True),
            STGCNBlockWithAttention(base_channels, base_channels, A, kernel_size=7, stride=1, use_attention=True),
        )

        # Downsample and extract higher-level features
        self.layer2 = nn.Sequential(
            STGCNBlockWithAttention(base_channels, base_channels*2, A, kernel_size=7, stride=2, use_attention=True),
            STGCNBlockWithAttention(base_channels*2, base_channels*2, A, kernel_size=5, stride=1, use_attention=True),
        )

        # Downsampling with attention
        self.layer3 = nn.Sequential(
            STGCNBlockWithAttention(base_channels*2, base_channels*4, A, kernel_size=5, stride=2, use_attention=True),
            STGCNBlockWithAttention(base_channels*4, base_channels*4, A, kernel_size=3, stride=1, use_attention=True),
        )

        # Global attention before pooling
        self.global_temporal_attn = TemporalAttention(base_channels*4)
        self.global_spatial_attn = SpatialAttention(base_channels*4)

        # Pooling
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(dropout)

        # Classifier
        self.fc = nn.Sequential(
            nn.Linear(base_channels*4, base_channels*2),
            nn.BatchNorm1d(base_channels*2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(base_channels*2, num_class)
        )

    def forward(self, x):
        N, C, T, V = x.shape

        # Data normalization
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(N, T, C*V).permute(0, 2, 1).contiguous()
        x = self.data_bn(x)
        x = x.permute(0, 2, 1).contiguous().view(N, T, C, V).permute(0, 2, 1, 3).contiguous()

        # Forward through layers
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        # global attention
        x = self.global_temporal_attn(x)
        x = self.global_spatial_attn(x)

        # Pool and classify
        x = self.pool(x).view(N, -1)
        x = self.dropout(x)
        x = self.fc(x)

        return x

num_classes = int(np.max(y) + 1)
model = STGCN(in_channels=C, num_class=num_classes, A=A_torch, base_channels=64, dropout=0.5).to(device)
print(model)


STGCN(
  (data_bn): BatchNorm1d(531, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): STGCNBlockWithAttention(
      (gconv): SpatialGraphConv(
        (conv1x1): Conv2d(9, 64, kernel_size=(1, 1), stride=(1, 1))
      )
      (tconv): Conv2d(64, 64, kernel_size=(9, 1), stride=(1, 1), padding=(4, 0))
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (temporal_attn): TemporalAttention(
        (avg_pool): AdaptiveAvgPool2d(output_size=(None, 1))
        (max_pool): AdaptiveMaxPool2d(output_size=(None, 1))
        (fc): Sequential(
          (0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(8, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (3): Sigmoid()
        )
      )
      (spatial_attn): SpatialAttention(
        (avg_pool): AdaptiveAvgPool2d(output_size=(1, None))
     

In [13]:
def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        _, pred = output.topk(maxk, 1, True, True) # (N, maxk)
        pred = pred.t()  # (maxk, N)
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append((correct_k.mul_(100.0 / batch_size)).item())
        return res

In [14]:
print("X.shape =", X.shape)
print("A_torch.shape =", A_torch.shape)

X.shape = (1401, 9, 100, 59)
A_torch.shape = torch.Size([59, 59])


In [15]:
### Final
### Batch Size 16
import os, random, numpy as np, time
import torch

random_seed = 42
n_splits = 5
patience = 8

# seed
torch.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)

skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_seed)
fold_results = []
all_preds = []
all_trues = []
best_models = []

for fold, (tr_idx, val_idx) in enumerate(skf.split(X, y)):
    print(f"\n=== Fold {fold+1}/{n_splits} ===")
    X_train, y_train = X[tr_idx], y[tr_idx]
    X_val, y_val = X[val_idx], y[val_idx]
    print("Train/Val sizes:", X_train.shape[0], X_val.shape[0])

    fixed_T = 100
    train_ds = STGCNDataset(X_train, y_train, train=True, augment=True, fixed_T=fixed_T)
    val_ds = STGCNDataset(X_val, y_val, train=False, augment=False, fixed_T=fixed_T)

    batch_size = 32
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=2, pin_memory=True, drop_last=True)  # ok
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
                            num_workers=2, pin_memory=True, drop_last=False)  # FIXED

    config = {
        'batch_size': 16,
        'lr': 0.001,
        'weight_decay': 1e-5,
        'dropout': 0.5,
        'epochs': 150,
        'base_channels': 64,
        'optimizer': 'AdamW',
        'scheduler': 'CosineAnnealingWarmRestarts',
    }

    model = STGCN(in_channels=C, num_class=num_classes, A=A_torch,
                  base_channels=config['base_channels'],
                  dropout=config['dropout']).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=config['lr'],
                            weight_decay=config['weight_decay'], betas=(0.9, 0.999))

    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)

    num_epochs = config['epochs']
    best_val_acc = 0.0
    save_path = f"stgcn_best_fold_{fold+1}_final.pth"
    epochs_no_improve = 0

    for epoch in range(1, num_epochs+1):
        t0 = time.time()
        model.train()
        running_loss = 0.0
        running_acc = 0.0
        cnt = 0
        for batch_idx, (xb, yb) in enumerate(train_loader):
            xb = xb.to(device)
            yb = yb.to(device)

            optimizer.zero_grad()
            logits = model(xb)
            loss = criterion(logits, yb)
            loss.backward()
            optimizer.step()



            acc1 = float(accuracy(logits, yb, topk=(1,))[0])

            running_loss += loss.item() * xb.size(0)
            running_acc += acc1 * xb.size(0)
            cnt += xb.size(0)

        epoch_loss = running_loss / cnt
        epoch_acc = running_acc / cnt


        model.eval()
        val_loss = 0.0
        val_acc = 0.0
        vcnt = 0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb = xb.to(device)
                yb = yb.to(device)
                logits = model(xb)
                loss = criterion(logits, yb)
                acc1 = float(accuracy(logits, yb, topk=(1,))[0])
                val_loss += loss.item() * xb.size(0)
                val_acc += acc1 * xb.size(0)
                vcnt += xb.size(0)

        val_loss /= max(1, vcnt)
        val_acc /= max(1, vcnt)


        scheduler.step()

        t1 = time.time()
        print(f"Epoch {epoch:02d}/{num_epochs} | "
              f"train_loss {epoch_loss:.4f} train_acc {epoch_acc:.2f}% | "
              f"val_loss {val_loss:.4f} val_acc {val_acc:.2f}% | time {(t1-t0):.1f}s")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'val_acc': val_acc},
                       save_path)
            epochs_no_improve = 0
            best_models.append(save_path)   # track saved model
            print("Saved best model at epoch", epoch, "val_acc", val_acc)
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print("Early stopping")
                break

    fold_results.append({'fold': fold+1, 'best_val_acc': best_val_acc})




=== Fold 1/5 ===
Train/Val sizes: 1120 281
Epoch 01/150 | train_loss 1.9216 train_acc 20.09% | val_loss 1.9593 val_acc 14.59% | time 14.8s
Saved best model at epoch 1 val_acc 14.590747330960854
Epoch 02/150 | train_loss 1.8038 train_acc 26.96% | val_loss 1.8977 val_acc 21.35% | time 13.2s
Saved best model at epoch 2 val_acc 21.352313167259787
Epoch 03/150 | train_loss 1.7490 train_acc 28.84% | val_loss 1.7174 val_acc 30.60% | time 13.3s
Saved best model at epoch 3 val_acc 30.604982206405694
Epoch 04/150 | train_loss 1.6628 train_acc 34.91% | val_loss 1.7121 val_acc 32.03% | time 13.5s
Saved best model at epoch 4 val_acc 32.02846975088968
Epoch 05/150 | train_loss 1.6203 train_acc 35.27% | val_loss 1.6488 val_acc 35.59% | time 13.7s
Saved best model at epoch 5 val_acc 35.587188612099645
Epoch 06/150 | train_loss 1.5328 train_acc 40.62% | val_loss 1.6116 val_acc 35.94% | time 13.9s
Saved best model at epoch 6 val_acc 35.94306049822064
Epoch 07/150 | train_loss 1.4592 train_acc 43.84% | 

### Preparation of Resuls

In [16]:

device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 32
fixed_T = 100

best_models =[]
best_models =["stgcn_best_fold_1_final.pth", "stgcn_best_fold_2_final.pth","stgcn_best_fold_3_final.pth", "stgcn_best_fold_4_final.pth", "stgcn_best_fold_5_final.pth"]

with open('best_models.json', 'w') as f:
    json.dump(best_models, f)


best_models = json.load(open('best_models.json'))


skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
fold_val_indices = [val_idx.copy() for _, val_idx in skf.split(X, y)]

per_fold_trues = []
per_fold_preds = []

for fold_idx, ckpt_path in enumerate(best_models):
    print(f"Evaluating fold {fold_idx+1}/{len(best_models)} -> {ckpt_path}")
    ckpt = torch.load(ckpt_path, map_location=device)


    model = STGCN(in_channels=C, num_class=num_classes, A=A_torch,
                  base_channels=64, dropout=0.5).to(device)
    model.load_state_dict(ckpt['model_state_dict'])
    model.eval()

    val_idx = np.asarray(fold_val_indices[fold_idx])
    X_val = X[val_idx]
    y_val = y[val_idx]

    val_ds = STGCNDataset(X_val, y_val, train=False, augment=False, fixed_T=fixed_T)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, drop_last=False)

    fold_preds = []
    fold_trues = []
    with torch.no_grad():
        for xb, yb in val_loader:
            xb = xb.to(device)
            logits = model(xb)
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            fold_preds.append(preds)
            fold_trues.append(yb.numpy())

    fold_preds = np.concatenate(fold_preds, axis=0)
    fold_trues = np.concatenate(fold_trues, axis=0)
    acc = accuracy_score(fold_trues, fold_preds)
    print(f"  fold acc: {acc:.4f}  (#val samples: {len(fold_trues)})")

    per_fold_trues.append(fold_trues)
    per_fold_preds.append(fold_preds)

model_name = 'STGCN_cv'
results_per_fold = {
    model_name: {
        'y_trues': per_fold_trues,
        'y_preds': per_fold_preds
    }
}


Evaluating fold 1/5 -> stgcn_best_fold_1_final.pth
  fold acc: 0.5302  (#val samples: 281)
Evaluating fold 2/5 -> stgcn_best_fold_2_final.pth
  fold acc: 0.5500  (#val samples: 280)
Evaluating fold 3/5 -> stgcn_best_fold_3_final.pth
  fold acc: 0.6286  (#val samples: 280)
Evaluating fold 4/5 -> stgcn_best_fold_4_final.pth
  fold acc: 0.5571  (#val samples: 280)
Evaluating fold 5/5 -> stgcn_best_fold_5_final.pth
  fold acc: 0.5679  (#val samples: 280)


In [17]:
np.save('results_per_fold_stgcn.npy', results_per_fold)

In [19]:
def plot_aggregated_confusion_matrix(results_dict, save_path='confusion_matrix_stgcn.png', normalize=True, cmap='Blues'):

    model_name = 'STGCN_cv'


    result = results_dict[model_name]
    y_true_per_fold = result['y_trues']
    y_pred_per_fold = result['y_preds']

    y_true_all = np.concatenate(y_true_per_fold, axis=0)
    y_pred_all = np.concatenate(y_pred_per_fold, axis=0)


    num_classes = int(np.max(y_true_all)) + 1
    emotion_dict = {
        0: 'Angry',
        1: 'Disgust',
        2: 'Fearful',
        3: 'Happy',
        4: 'Neutral',
        5: 'Sad',
        6: 'Surprise'
    }
    emotion_classes = [emotion_dict[i] for i in range(num_classes)]
    fig, ax = plt.subplots(1, 1, figsize=(7, 6))

    cm = confusion_matrix(y_true_all, y_pred_all, labels=np.arange(num_classes))


    if normalize:
        with np.errstate(all='ignore'):
            cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
            cm_normalized = np.nan_to_num(cm_normalized)
        display_matrix = cm_normalized
        fmt = '.2f'
        cbar_label = 'Proportion of Actual Class'
    else:
        display_matrix = cm
        fmt = 'd'
        cbar_label = 'Count'


    sns.heatmap(display_matrix, annot=True, fmt=fmt, cmap=cmap,
                xticklabels=emotion_classes, yticklabels=emotion_classes,
                ax=ax, cbar_kws={'label': cbar_label, 'orientation': 'vertical'})

    acc = np.mean(y_true_all == y_pred_all)
    acc_str = f"{acc:.3f}"

    ax.set_title(f'{model_name} (Aggregated Accuracy: {acc_str})', fontweight='bold', fontsize=14)
    ax.set_xlabel('Predicted Label', fontweight='bold')
    ax.set_ylabel('Actual Label', fontweight='bold')

    plt.suptitle('Confusion Matrix (Aggregated across 5 Folds)', fontsize=16, fontweight='bold', y=1.05)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

    print(f"Saved aggregated confusion matrix to {save_path}")
    print(f"Aggregated Accuracy: {acc_str}")
    return acc

results_dict = np.load('results_per_fold_stgcn.npy', allow_pickle=True).item()

accuracy = plot_aggregated_confusion_matrix(results_dict)
print(f"Overall Cross-Validation Accuracy: {accuracy:.4f}")

Saved aggregated confusion matrix to confusion_matrix_stgcn.png
Aggregated Accuracy: 0.567
Overall Cross-Validation Accuracy: 0.5667


In [20]:
def plot_per_fold_confusion_matrices(results_dict, normalize=True, cmap='Blues'):
    model_name = 'STGCN_cv'

    result = results_dict[model_name]
    y_true_per_fold = result['y_trues']
    y_pred_per_fold = result['y_preds']
    n_folds = len(y_true_per_fold)


    y_true_all = np.concatenate(y_true_per_fold, axis=0)
    num_classes = int(np.max(y_true_all)) + 1
    emotion_dict = {
        0: 'Angry',
        1: 'Disgust',
        2: 'Fearful',
        3: 'Happy',
        4: 'Neutral',
        5: 'Sad',
        6: 'Surprise'
    }
    emotion_classes = [emotion_dict[i] for i in range(num_classes)]

    print(f"Generating {n_folds} confusion matrices for model {model_name}...")


    fold_accuracies = []

    for fold_idx in range(n_folds):
        y_true_fold = y_true_per_fold[fold_idx]
        y_pred_fold = y_pred_per_fold[fold_idx]

        acc = np.mean(y_true_fold == y_pred_fold)
        fold_accuracies.append(acc)
        acc_str = f"{acc:.3f}"

        cm = confusion_matrix(y_true_fold, y_pred_fold, labels=np.arange(num_classes))

        if normalize:
            with np.errstate(all='ignore'):
                cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
                cm_normalized = np.nan_to_num(cm_normalized)
            display_matrix = cm_normalized
            fmt = '.2f'
            cbar_label = 'Proportion of Actual Class'
        else:
            display_matrix = cm
            fmt = 'd'
            cbar_label = 'Count'

        fig, ax = plt.subplots(1, 1, figsize=(7, 6))

        sns.heatmap(display_matrix, annot=True, fmt=fmt, cmap=cmap,
                    xticklabels=emotion_classes, yticklabels=emotion_classes,
                    ax=ax, cbar_kws={'label': cbar_label, 'orientation': 'vertical'})

        title = f'{model_name} - Fold {fold_idx + 1}\n(Accuracy: {acc_str})'
        ax.set_title(title, fontweight='bold', fontsize=14)
        ax.set_xlabel('Predicted Label', fontweight='bold')
        ax.set_ylabel('Actual Label', fontweight='bold')

        save_path = f'confusion_matrix_stgcn_fold_{fold_idx + 1}.png'
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close(fig)

        print(f"Saved confusion matrix for Fold {fold_idx + 1} to {save_path}")

    return fold_accuracies

results_dict = np.load('results_per_fold_stgcn.npy', allow_pickle=True).item()


accuracies = plot_per_fold_confusion_matrices(results_dict)

print("\n--- Summary of Fold Accuracies ---")
for i, acc in enumerate(accuracies):
    print(f"Fold {i+1}: {acc:.4f}")
print(f"Mean Accuracy: {np.mean(accuracies):.4f}")
print(f"Std Dev Accuracy: {np.std(accuracies):.4f}")

Generating 5 confusion matrices for model STGCN_cv...
Saved confusion matrix for Fold 1 to confusion_matrix_stgcn_fold_1.png
Saved confusion matrix for Fold 2 to confusion_matrix_stgcn_fold_2.png
Saved confusion matrix for Fold 3 to confusion_matrix_stgcn_fold_3.png
Saved confusion matrix for Fold 4 to confusion_matrix_stgcn_fold_4.png
Saved confusion matrix for Fold 5 to confusion_matrix_stgcn_fold_5.png

--- Summary of Fold Accuracies ---
Fold 1: 0.5302
Fold 2: 0.5500
Fold 3: 0.6286
Fold 4: 0.5571
Fold 5: 0.5679
Mean Accuracy: 0.5668
Std Dev Accuracy: 0.0333


In [21]:
def plot_overall_per_class_metrics(results_dict, save_path='per_class_metrics_overall_stgcn.png', cmap='Blues'):

    model_name = 'STGCN_cv'
    result = results_dict[model_name]
    y_true_per_fold = result['y_trues']
    y_pred_per_fold = result['y_preds']

    y_true_all = np.concatenate(y_true_per_fold, axis=0)
    num_classes = int(np.max(y_true_all)) + 1
    emotion_dict = {
        0: 'Angry',
        1: 'Disgust',
        2: 'Fearful',
        3: 'Happy',
        4: 'Neutral',
        5: 'Sad',
        6: 'Surprise'
    }
    emotion_classes = [emotion_dict[i] for i in range(num_classes)]


    all_precision, all_recall, all_f1 = [], [], []
    for y_true_f, y_pred_f in zip(y_true_per_fold, y_pred_per_fold):
        p, r, f1, _ = precision_recall_fscore_support(y_true_f, y_pred_f,
                                                      labels=np.arange(num_classes),
                                                      zero_division=0)
        all_precision.append(p)
        all_recall.append(r)
        all_f1.append(f1)


    metrics_data = {
        'Precision': (np.stack(all_precision, axis=0).mean(axis=0), np.stack(all_precision, axis=0).std(axis=0)),
        'Recall': (np.stack(all_recall, axis=0).mean(axis=0), np.stack(all_recall, axis=0).std(axis=0)),
        'F1-Score': (np.stack(all_f1, axis=0).mean(axis=0), np.stack(all_f1, axis=0).std(axis=0))
    }


    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    fig.suptitle('Per-Class Performance Metrics (Mean $\pm$ Std Dev across Folds)',
                 fontsize=16, fontweight='bold', y=1.02)

    metric_names = ['Precision', 'Recall', 'F1-Score']
    x = np.arange(num_classes)
    width = 0.5

    for ax, metric_name in zip(axes, metric_names):
        mean_scores, std_scores = metrics_data[metric_name]

        ax.bar(x, mean_scores, width, yerr=std_scores, capsize=5,
               color=plt.cm.get_cmap(cmap)(0.5), alpha=0.85)

        ax.set_xlabel('Class Label', fontsize=12, fontweight='bold')
        ax.set_ylabel(metric_name, fontsize=12, fontweight='bold')
        ax.set_title(f'{metric_name} by Class', fontsize=13, fontweight='bold')
        ax.set_xticks(x)
        ax.set_xticklabels(emotion_classes, rotation=45, ha='right')
        ax.grid(axis='y', alpha=0.3)
        ax.set_ylim([0, 1])

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved overall per-class metrics to {save_path}")

results_dict = np.load('results_per_fold_stgcn.npy', allow_pickle=True).item()

plot_overall_per_class_metrics(results_dict)

Saved overall per-class metrics to per_class_metrics_overall_stgcn.png


In [22]:
def plot_per_fold_individual_metrics(results_dict, save_path='per_class_metrics_per_fold.png', cmap='viridis'):

    model_name = 'STGCN_cv'
    result = results_dict[model_name]
    y_true_per_fold = result['y_trues']
    y_pred_per_fold = result['y_preds']
    n_folds = len(y_true_per_fold)

    y_true_all = np.concatenate(y_true_per_fold, axis=0)
    num_classes = int(np.max(y_true_all)) + 1
    emotion_dict = {
        0: 'Angry',
        1: 'Disgust',
        2: 'Fearful',
        3: 'Happy',
        4: 'Neutral',
        5: 'Sad',
        6: 'Surprise'
    }
    emotion_classes = [emotion_dict[i] for i in range(num_classes)]


    all_precision, all_recall, all_f1 = [], [], []
    for y_true_f, y_pred_f in zip(y_true_per_fold, y_pred_per_fold):
        p, r, f1, _ = precision_recall_fscore_support(y_true_f, y_pred_f,
                                                      labels=np.arange(num_classes),
                                                      zero_division=0)
        all_precision.append(p)
        all_recall.append(r)
        all_f1.append(f1)

    metrics_data = {
        'Precision': np.stack(all_precision, axis=0),
        'Recall': np.stack(all_recall, axis=0),
        'F1-Score': np.stack(all_f1, axis=0)
    }


    fig, axes = plt.subplots(1, 3, figsize=(21, 7))
    fig.suptitle('Per-Class Performance Metrics (Individual Folds)',
                 fontsize=18, fontweight='bold', y=1.02)

    metric_names = ['Precision', 'Recall', 'F1-Score']
    x = np.arange(num_classes)

    group_width = 0.8
    bar_width = group_width / n_folds

    colors = plt.cm.get_cmap(cmap, n_folds)

    for ax, metric_name in zip(axes, metric_names):
        scores_matrix = metrics_data[metric_name] # (n_folds, n_classes)

        for fold_idx in range(n_folds):
            offset = (fold_idx - (n_folds - 1) / 2) * bar_width

            ax.bar(x + offset, scores_matrix[fold_idx, :], bar_width,
                   label=f'Fold {fold_idx + 1}', color=colors(fold_idx), alpha=0.8)

        ax.set_xlabel('Class Label', fontsize=14, fontweight='bold')
        ax.set_ylabel(metric_name, fontsize=14, fontweight='bold')
        ax.set_title(f'{metric_name} by Class', fontsize=15, fontweight='bold')

        ax.set_xticks(x)
        ax.set_xticklabels(emotion_classes, rotation=45, ha='right', fontsize=12)
        ax.legend(fontsize=10)
        ax.grid(axis='y', alpha=0.3)
        ax.set_ylim([0, 1])

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved per-class metrics per fold to {save_path}")



results_dict = np.load('results_per_fold_stgcn.npy', allow_pickle=True).item()

plot_per_fold_individual_metrics(results_dict)

Saved per-class metrics per fold to per_class_metrics_per_fold.png
