In [1]:
import os
import pandas as pd
import numpy as np
import random
from tqdm import tqdm

In [2]:
target_per_class=20
duration_range=(3, 10)
frame_interval=0.15
std=0.05
is_trainset_only=True

label2idx = {
    'crawl': 0, 'walk': 1,
    'sit-floor': 2, 'sit-high-chair': 3, 'sit-low-chair': 4, 'stand': 5, 
    'hold-horizontal': 6, 'hold-vertical': 7, 'piggyback': 8, 
    'baby-food': 9, 'bottle': 10, 'breast': 11, 
    'face-down': 12, 'face-side': 13, 'face-up':14, 'roll-over': 15
}


origin_dir = './data_origin/'
aug_dir = './data_aug/'
aug_method = 'GN'

sampling_interval = 0.15  # s/frame
tolerance = 0.15

def find_pair_dirs(base_dir: str, aug_method: str=None):
    sequence_dir = os.path.join(base_dir, aug_method, "sequence") if aug_method else os.path.join(base_dir, "sequence")
    label_dir = os.path.join(aug_dir, aug_method, "label") if aug_method else os.path.join(origin_dir, "label")
    os.makedirs(sequence_dir, exist_ok=True)
    os.makedirs(label_dir, exist_ok=True)
    return sequence_dir, label_dir

origin_sequence_dir, origin_label_dir = find_pair_dirs(origin_dir)
aug_sequence_dir, aug_label_dir = find_pair_dirs(aug_dir, aug_method)

In [3]:
action_to_samples = {action: [] for action in label2idx}

# 如果启用，仅保留在 train.txt 中的样本编号
valid_ids = set()
if is_trainset_only:
    train_txt_path = os.path.join(origin_dir, 'train.txt')
    with open(train_txt_path, 'r') as f:
        valid_ids = set(line.strip() for line in f)

for label_file in os.listdir(origin_label_dir):
    if not label_file.endswith('_label.csv'):
        continue

    seq_id = label_file.replace('_label.csv', '')
    if is_trainset_only and seq_id not in valid_ids:
        continue

    label_path = os.path.join(origin_label_dir, label_file)
    seq_path = os.path.join(origin_sequence_dir, f"{seq_id}.csv")

    if not os.path.exists(seq_path):
        continue

    label_df = pd.read_csv(label_path)
    action = label_df.iloc[0]['action']
    if action in action_to_samples:
        action_to_samples[action].append((seq_path, label_path))


In [4]:
sample_id = 0
for action, samples in tqdm(action_to_samples.items()):
    total_needed = target_per_class
    num_available = len(samples)
    if num_available == 0:
        continue

    while total_needed > 0:
        batch = samples if total_needed >= num_available else random.sample(samples, total_needed)
        for seq_path, label_path in batch:
            df = pd.read_csv(seq_path)

            # 随机裁剪一段长度
            min_len = int(duration_range[0] / frame_interval)
            max_len = int(duration_range[1] / frame_interval)
            crop_len = random.randint(min_len, max_len)
            if len(df) <= crop_len:
                cropped = df.copy()
            else:
                start_idx = random.randint(0, len(df) - crop_len)
                cropped = df.iloc[start_idx:start_idx + crop_len].copy()

            # 添加高斯噪声
            for col in ['accel_x', 'accel_y', 'accel_z']:
                cropped[col] += np.random.normal(0, std, size=len(cropped))

            # 保留 9 位小数
            cropped = cropped.round(9)

            # 保存增强后的数据
            file_id = f"A{sample_id:05d}"
            cropped.to_csv(os.path.join(aug_sequence_dir, f"{file_id}.csv"), index=False)

            # 保存 label 的指定字段
            label_df = pd.read_csv(label_path)
            label_df = label_df[['gender', 'age', 'dur', 'action']]
            label_df.to_csv(os.path.join(aug_label_dir, f"{file_id}_label.csv"), index=False)

            sample_id += 1
            total_needed -= 1

print(f"Done: {sample_id} augmented samples saved.")

  0%|          | 0/16 [00:00<?, ?it/s]

100%|██████████| 16/16 [00:02<00:00,  6.04it/s]

Done: 300 augmented samples saved.



