In [28]:
import argparse
import pickle
import sys
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
from tqdm import tqdm

In [29]:
def parse_pose_filename(stem: str) -> Optional[Tuple[int, str, str]]:
    parts = stem.split('_')
    if len(parts) < 3:
        return None
    raw_subject = parts[0]
    digits = ''.join(ch for ch in raw_subject if ch.isdigit())
    if digits == '':
        return None
    try:
        subject_id = int(digits)
    except Exception:
        return None
    subject_name = parts[1]
    video_id = '_'.join(parts[2:])
    return subject_id, subject_name, video_id

In [30]:
def list_pose_files(pose_root: Path) -> List[Path]:
    if not pose_root.exists():
        logging.error(f'Pose root not found: {pose_root}')
        return []
    return sorted([p for p in pose_root.iterdir() if p.suffix == '.npy'])


def discover_subject_ids_from_pose_files(pose_files: List[Path]) -> List[int]:
    subject_ids = []
    for p in pose_files:
        parsed = parse_pose_filename(p.stem)
        if parsed is None:
            continue
        sid, _sname, _vid = parsed
        subject_ids.append(sid)
    unique = sorted(set(subject_ids))
    return unique

In [31]:
def read_ignored_list(path: Optional[Path]) -> set:
    if path is None or not path.exists():
        return set()
    ignored = set()
    with path.open('r', encoding='utf-8') as f:
        for line in f:
            token = line.strip()
            if token:
                ignored.add(token)
    return ignored

In [32]:
raw_root = Path(r'C:\AimCLR-v2\pose_new_v2')
ann_root = Path(r'C:\AimCLR-v2\Annotation_v4')
ignored = None
ignored_video_ids = read_ignored_list(ignored)
if ignored_video_ids:
    logging.info(f'Loaded ignored list with {len(ignored_video_ids)} entries')


In [33]:
pose_files = list_pose_files(raw_root)
print(pose_files)

[WindowsPath('C:/AimCLR-v2/pose_new_v2/s10_PhamQuangDai_1.npy'), WindowsPath('C:/AimCLR-v2/pose_new_v2/s10_PhamQuangDai_2.npy'), WindowsPath('C:/AimCLR-v2/pose_new_v2/s10_PhamQuangDai_3.npy'), WindowsPath('C:/AimCLR-v2/pose_new_v2/s11_NguyenTheThao_1.npy'), WindowsPath('C:/AimCLR-v2/pose_new_v2/s11_NguyenTheThao_2.npy'), WindowsPath('C:/AimCLR-v2/pose_new_v2/s11_NguyenTheThao_3.npy'), WindowsPath('C:/AimCLR-v2/pose_new_v2/s12_NguyenXuanHieu_1.npy'), WindowsPath('C:/AimCLR-v2/pose_new_v2/s12_NguyenXuanHieu_2.npy'), WindowsPath('C:/AimCLR-v2/pose_new_v2/s12_NguyenXuanHieu_3.npy'), WindowsPath('C:/AimCLR-v2/pose_new_v2/s13_TranNhatNam_1.npy'), WindowsPath('C:/AimCLR-v2/pose_new_v2/s13_TranNhatNam_2.npy'), WindowsPath('C:/AimCLR-v2/pose_new_v2/s13_TranNhatNam_3.npy'), WindowsPath('C:/AimCLR-v2/pose_new_v2/s14_Alexandre_1.npy'), WindowsPath('C:/AimCLR-v2/pose_new_v2/s14_Alexandre_2.npy'), WindowsPath('C:/AimCLR-v2/pose_new_v2/s14_Alexandre_3.npy'), WindowsPath('C:/AimCLR-v2/pose_new_v2/s15_

In [34]:
train_subjects = discover_subject_ids_from_pose_files(pose_files)
print(train_subjects)


[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52]


In [35]:
def build_annotation_path(ann_root: Path, subject_id: int, subject_name: str, video_id: str) -> Path:
    # find CSV inside per-subject subfolder
    base = ann_root / 'Annotation_v4'
    scan_root = base if base.exists() else ann_root
    candidates = [f's{subject_id}_{subject_name}', f'{subject_id}_{subject_name}']
    for folder_name in candidates:
        folder = scan_root / folder_name
        if not (folder.exists() and folder.is_dir()):
            continue
        # Try matching filenames
        possibles = [
            folder / f'{folder_name}_{video_id}.csv',
        ]
        if folder_name.startswith('s'):
            possibles.append(folder / f'{folder_name[1:]}_{video_id}.csv')
        for p in possibles:
            if p.exists():
                return p
        # Fallback: any file ending with _<video_id>.csv
        try:
            for f in folder.iterdir():
                if f.is_file() and f.suffix.lower() == '.csv' and f.name.endswith(f'_{video_id}.csv'):
                    return f
        except Exception:
            pass
    # Default expected path (for logging)
    return scan_root / f's{subject_id}_{subject_name}' / f's{subject_id}_{subject_name}_{video_id}.csv'

In [36]:
def validate_and_warn_overlaps_gaps(actions: pd.DataFrame) -> None:
    if actions.empty:
        return
    actions_sorted = actions.sort_values('start').reset_index(drop=True)
    for i in range(len(actions_sorted) - 1):
        cur_end = int(actions_sorted.loc[i, 'stop'])
        next_start = int(actions_sorted.loc[i + 1, 'start'])
        if next_start <= cur_end:
            logging.warning('Detected overlapping actions in CSV (start <= previous stop).')
        elif next_start > cur_end + 1:
            logging.warning('Detected gap between actions in CSV (start > previous stop + 1).')

In [37]:
def scan_dataset(
    raw_root: Path,
    ann_root: Path,
    ignored_video_ids: set,
    train_subjects: List[int],
) -> Tuple[List[Dict], Dict[str, int]]:
    # Support both when raw_root is the base directory containing pose_new_v2/ and when it is pose_new_v2 itself
    pose_root_candidate = raw_root / 'pose_new_v2'
    pose_root = pose_root_candidate if pose_root_candidate.exists() else raw_root
    logging.info(f'Scanning pose root: {pose_root}')
    pose_files = list_pose_files(pose_root)

    samples: List[Dict] = []
    action_histogram: Dict[str, int] = {}

    for pose_path in tqdm(pose_files, desc='Scanning pose files'):
        parsed = parse_pose_filename(pose_path.stem)
        if parsed is None:
            logging.warning(f'Skipping unrecognized filename: {pose_path.name}')
            continue
        subject_id, subject_name, video_id = parsed

        if video_id in ignored_video_ids:
            continue

        csv_path = build_annotation_path(ann_root, subject_id, subject_name, video_id)
        if not csv_path.exists():
            logging.warning(f'Missing annotation CSV for {pose_path.name} -> {csv_path}')
            continue

        try:
            # Lightweight shape read
            skel = np.load(str(pose_path), mmap_mode='r')
        except Exception as e:
            logging.error(f'Failed to load skeleton: {pose_path} ({e})')
            continue

        if skel.ndim != 3 or skel.shape[1] != 48 or skel.shape[2] != 3:
            logging.warning(f'Unexpected skeleton shape {skel.shape} in {pose_path.name}, skipping')
            continue

        total_frames = skel.shape[0]

        try:
            df = pd.read_csv(csv_path)
        except Exception as e:
            logging.error(f'Failed to read CSV: {csv_path} ({e})')
            continue

        # Normalize column names
        df_columns = {c.strip().lower(): c for c in df.columns}
        required = ['id', 'start', 'stop']
        if not all(col in df_columns for col in required):
            logging.warning(f'CSV columns missing among {required} in {csv_path.name}, columns={list(df.columns)}')
            # Try best-effort mapping
            # If columns like 'label' exist, we keep but we need id,start,stop to continue
            possible = [c for c in df.columns if c.lower() in {'id', 'start', 'stop'}]
            if len(possible) < 3:
                continue

        id_col = df_columns.get('id', 'ID')
        start_col = df_columns.get('start', 'start')
        stop_col = df_columns.get('stop', 'stop')
        label_col = df_columns.get('label')

        # Warnings for overlap/gap
        try:
            tmp_df = pd.DataFrame({
                'start': df[start_col].astype(int),
                'stop': df[stop_col].astype(int),
            })
            validate_and_warn_overlaps_gaps(tmp_df)
        except Exception:
            pass

        for _, row in df.iterrows():
            try:
                action_id = int(row[id_col])
                start = int(row[start_col])
                end = int(row[stop_col])
            except Exception:
                continue

            if start > end:
                logging.warning(f'Invalid segment start>end in {csv_path.name}: {start}>{end}')
                continue
            if start < 0 or end >= total_frames:
                logging.warning(
                    f'Segment out of range in {csv_path.name}: [0,{total_frames-1}] vs [{start},{end}]'
                )
                continue

            length = end - start + 1
            split = 'train' if subject_id in train_subjects else 'val'

            sample_name = f'{subject_id}_{video_id}_A{action_id}_S{start}_E{end}'
            action_name = None
            if label_col is not None and label_col in df.columns:
                try:
                    action_name = str(row[label_col]).strip()
                except Exception:
                    action_name = None

            samples.append({
                'pose_path': pose_path,
                'subject_id': subject_id,
                'subject_name': subject_name,
                'video_id': video_id,
                'csv_path': csv_path,
                'action_id': action_id,
                'action_name': action_name,
                'start': start,
                'end': end,
                'length': length,
                'split': split,
                'sample_name': sample_name,
            })

            action_histogram[str(action_id)] = action_histogram.get(str(action_id), 0) + 1

    return samples, action_histogram

In [38]:
samples, action_hist = scan_dataset(
    raw_root=raw_root,
    ann_root=ann_root,
    ignored_video_ids=ignored_video_ids,
    train_subjects=train_subjects,
    )



Scanning pose files: 100%|██████████| 156/156 [00:05<00:00, 30.68it/s]


In [39]:
print(action_hist)

{'1': 152, '5': 118, '3': 152, '11': 156, '17': 154, '6': 152, '15': 152, '2': 152, '18': 150, '14': 152, '0': 186, '7': 152, '8': 150, '9': 155, '10': 160, '16': 152, '13': 151, '4': 152, '12': 136}


In [40]:
def load_labels(path: Path) -> Optional[Tuple[List[str], List[int]]]:
    if not path.exists():
        return None
    try:
        with path.open('rb') as f:
            names, labels = pickle.load(f)
        return names, labels
    except Exception:
        return None

In [None]:
def open_memmap_writer(path: Path, shape: Tuple[int, int, int, int, int]) -> np.memmap:
    # Creates/overwrites using numpy.lib.format.open_memmap
    from numpy.lib.format import open_memmap as np_open_memmap

    path.parent.mkdir(parents=True, exist_ok=True)
    mem = np_open_memmap(filename=str(path), mode='w+', dtype='float32', shape=shape)
    return mem


In [None]:
def to_ntu_format(segment_xyz: np.ndarray) -> np.ndarray:
    # Input: (L, 48, 3) -> Output: (3, L, 48, 1)
    data = np.transpose(segment_xyz, (2, 0, 1)).astype(np.float32, copy=False)
    data = data[..., np.newaxis]
    return data

In [None]:
def center_crop_indices(length: int, target: int) -> np.ndarray:
    if length <= target:
        # Will be padded elsewhere
        return np.arange(length, dtype=np.int64)
    start_index = (length - target) // 2
    return np.arange(start_index, start_index + target, dtype=np.int64)

In [None]:
def uniform_sample_indices(length: int, target: int) -> np.ndarray:
    if length <= 0:
        return np.zeros((target,), dtype=np.int64)
    indices = np.linspace(0, max(0, length - 1), num=target)
    indices = np.clip(indices.round().astype(np.int64), 0, length - 1)
    return indices


In [41]:
def fit_to_length(data: np.ndarray, max_frame: int, policy: str) -> Tuple[np.ndarray, bool, bool]:
    # data: (3, L, 48, 1)
    length = data.shape[1]
    did_pad = False
    did_resample = False

    if length == max_frame:
        return data, did_pad, did_resample

    if length < max_frame:
        out = np.empty((3, max_frame, 48, 1), dtype=np.float32)
        out[:, :length] = data
        last = data[:, length - 1 : length]
        if max_frame > length:
            out[:, length:] = last
        did_pad = True
        return out, did_pad, did_resample

    # length > max_frame
    if policy == 'uniform-sample':
        indices = uniform_sample_indices(length, max_frame)
        out = data[:, indices]
        did_resample = True
        return out, did_pad, did_resample
    elif policy == 'center-crop':
        indices = center_crop_indices(length, max_frame)
        if indices.shape[0] == max_frame:
            out = data[:, indices]
        else:
            # length < max_frame path handled above; here we still may need padding (edge case)
            out = np.empty((3, max_frame, 48, 1), dtype=np.float32)
            out[:, : indices.shape[0]] = data[:, indices]
            last = out[:, indices.shape[0] - 1 : indices.shape[0]]
            out[:, indices.shape[0] :] = last
        did_resample = True
        return out, did_pad, did_resample
    elif policy == 'pad':
        # Truncate tail to max_frame
        out = data[:, :max_frame]
        did_resample = True
        return out, did_pad, did_resample
    else:
        raise ValueError(f'Unknown resample policy: {policy}')

In [None]:
def write_memmaps(
    samples: List[Dict],
    out_root: Path,
    max_frame: int,
    resample_policy: str,
    emit_clips: bool,
    clip_len: int,
) -> Dict[str, Dict[str, int]]:
    xsub_dir = out_root / 'xsub'
    xsub_dir.mkdir(parents=True, exist_ok=True)

    # Prepare splits
    train_samples = [s for s in samples if s['split'] == 'train']
    val_samples = [s for s in samples if s['split'] == 'val']

    # Resume/idempotency handling
    train_label_path = xsub_dir / 'train_label.pkl'
    val_label_path = xsub_dir / 'val_label.pkl'
    train_data_path = xsub_dir / 'train_data.npy'
    val_data_path = xsub_dir / 'val_data.npy'

    expected_train_names = [s['sample_name'] for s in train_samples]
    expected_train_labels = [int(s['action_id']) for s in train_samples]
    expected_val_names = [s['sample_name'] for s in val_samples]
    expected_val_labels = [int(s['action_id']) for s in val_samples]

    def memmap_shape_ok(path: Path, expected_n: int) -> bool:
        try:
            arr = np.load(str(path), mmap_mode='r')
            return (
                arr.shape == (expected_n, 3, max_frame, 48, 1)
                and arr.dtype == np.float32
            )
        except Exception:
            return False

    existing_train = load_labels(train_label_path)
    existing_val = load_labels(val_label_path)

    # Case 1: perfect match -> skip all writing
    if (
        existing_train is not None and existing_val is not None and
        existing_train[0] == expected_train_names and existing_train[1] == expected_train_labels and
        existing_val[0] == expected_val_names and existing_val[1] == expected_val_labels and
        train_data_path.exists() and val_data_path.exists() and
        memmap_shape_ok(train_data_path, len(train_samples)) and memmap_shape_ok(val_data_path, len(val_samples))
    ):
        logging.info('Outputs already exist with matching sample names and shapes; skipping write.')
        return {
            'train': {'count': len(train_samples)},
            'val': {'count': len(val_samples)},
            'padded': {'count': 0},
            'resampled': {'count': 0},
        }

    # Case 2: partial resume -> we will copy rows by name if possible
    old_train_map: Dict[str, int] = {}
    old_val_map: Dict[str, int] = {}
    old_train_mem = None
    old_val_mem = None

    if existing_train is not None and train_data_path.exists():
        try:
            old_train_names, _old_train_labels = existing_train
            old_train_mem = np.load(str(train_data_path), mmap_mode='r')
            if old_train_mem.shape[1:] == (3, max_frame, 48, 1):
                old_train_map = {name: i for i, name in enumerate(old_train_names)}
            else:
                old_train_mem = None
        except Exception:
            old_train_mem = None
            old_train_map = {}

    if existing_val is not None and val_data_path.exists():
        try:
            old_val_names, _old_val_labels = existing_val
            old_val_mem = np.load(str(val_data_path), mmap_mode='r')
            if old_val_mem.shape[1:] == (3, max_frame, 48, 1):
                old_val_map = {name: i for i, name in enumerate(old_val_names)}
            else:
                old_val_mem = None
        except Exception:
            old_val_mem = None
            old_val_map = {}

    # Create new memmaps
    train_mem = open_memmap_writer(train_data_path, (len(train_samples), 3, max_frame, 48, 1))
    val_mem = open_memmap_writer(val_data_path, (len(val_samples), 3, max_frame, 48, 1))

    # Optional clips
    if emit_clips:
        clips_dir = out_root / 'clips' / 'xsub'
        clips_dir.mkdir(parents=True, exist_ok=True)
        train_clips_path = clips_dir / 'train_data.npy'
        val_clips_path = clips_dir / 'val_data.npy'
        train_clips_mem = open_memmap_writer(train_clips_path, (len(train_samples), 3, clip_len, 48, 1))
        val_clips_mem = open_memmap_writer(val_clips_path, (len(val_samples), 3, clip_len, 48, 1))
    else:
        train_clips_mem = None
        val_clips_mem = None

    # For progress and stats
    padded_count = 0
    resampled_count = 0
    copied_count = 0

    # Group samples by pose file to minimize reloads
    by_pose: Dict[Path, List[Tuple[int, Dict]]] = {}
    for idx, s in enumerate(train_samples):
        by_pose.setdefault(s['pose_path'], []).append((('train', idx), s))
    for idx, s in enumerate(val_samples):
        by_pose.setdefault(s['pose_path'], []).append((('val', idx), s))

    for pose_path, entries in tqdm(by_pose.items(), desc='Writing samples'):
        # Determine if all entries can be copied from previous run
        all_copyable = True
        for (split_tag, index_in_split), s in entries:
            sname = s['sample_name']
            if split_tag == 'train':
                if old_train_mem is None or sname not in old_train_map:
                    all_copyable = False
                    break
            else:
                if old_val_mem is None or sname not in old_val_map:
                    all_copyable = False
                    break

        skel = None
        if not all_copyable:
            try:
                skel = np.load(str(pose_path))  # load fully for slicing speed
            except Exception as e:
                logging.error(f'Failed to load skeleton: {pose_path} ({e})')
                # Even if load fails, still try to copy any copyable ones
                skel = None

        for (split_tag, index_in_split), s in entries:
            sname = s['sample_name']
            # Copy path
            if split_tag == 'train' and old_train_mem is not None and sname in old_train_map:
                train_mem[index_in_split] = old_train_mem[old_train_map[sname]]
                copied_count += 1
                continue
            if split_tag == 'val' and old_val_mem is not None and sname in old_val_map:
                val_mem[index_in_split] = old_val_mem[old_val_map[sname]]
                copied_count += 1
                continue

            # Compute path (needs skel)
            if skel is None:
                # Cannot compute; skip
                logging.error(f'Cannot compute sample {sname} due to missing skeleton load.')
                continue

            start = s['start']
            end = s['end']
            segment = skel[start : end + 1]  # (L, 48, 3)
            ntu = to_ntu_format(segment)     # (3, L, 48, 1)
            fitted, did_pad, did_resample = fit_to_length(ntu, max_frame, resample_policy)
            if did_pad:
                padded_count += 1
            if did_resample:
                resampled_count += 1

            if split_tag == 'train':
                train_mem[index_in_split] = fitted
                if emit_clips and train_clips_mem is not None:
                    if fitted.shape[1] >= clip_len:
                        cc_idx = center_crop_indices(fitted.shape[1], clip_len)
                        train_clips_mem[index_in_split] = fitted[:, cc_idx]
                    else:
                        # Already padded to max_frame >= clip_len by design
                        cc_idx = center_crop_indices(fitted.shape[1], fitted.shape[1])
                        tmp = np.empty((3, clip_len, 48, 1), dtype=np.float32)
                        tmp[:, : cc_idx.shape[0]] = fitted[:, cc_idx]
                        last = tmp[:, cc_idx.shape[0] - 1 : cc_idx.shape[0]]
                        tmp[:, cc_idx.shape[0] :] = last
                        train_clips_mem[index_in_split] = tmp
            else:
                val_mem[index_in_split] = fitted
                if emit_clips and val_clips_mem is not None:
                    if fitted.shape[1] >= clip_len:
                        cc_idx = center_crop_indices(fitted.shape[1], clip_len)
                        val_clips_mem[index_in_split] = fitted[:, cc_idx]
                    else:
                        cc_idx = center_crop_indices(fitted.shape[1], fitted.shape[1])
                        tmp = np.empty((3, clip_len, 48, 1), dtype=np.float32)
                        tmp[:, : cc_idx.shape[0]] = fitted[:, cc_idx]
                        last = tmp[:, cc_idx.shape[0] - 1 : cc_idx.shape[0]]
                        tmp[:, cc_idx.shape[0] :] = last
                        val_clips_mem[index_in_split] = tmp

    # Flush memmaps
    del train_mem
    del val_mem
    if emit_clips and train_clips_mem is not None and val_clips_mem is not None:
        del train_clips_mem
        del val_clips_mem

    # Write labels at the end to reflect final expected sets
    write_labels(train_label_path, expected_train_names, expected_train_labels)
    write_labels(val_label_path, expected_val_names, expected_val_labels)

    return {
        'train': {'count': len(train_samples)},
        'val': {'count': len(val_samples)},
        'padded': {'count': padded_count},
        'resampled': {'count': resampled_count},
        'copied': {'count': copied_count},
    }