imuonly-dataprocoss

In [None]:
# ===================================================================
#  数据处理脚本
#  CMI Behavior Detection - 精简特征版
#
#  排除已知问题序列: SEQ_011975 (存在测量问题)
#
#  训练特征: rot, linear_acc, linear_acc_mag, angular_vel, angular_distance, acc_mag
# ===================================================================

import os
import json
import joblib
import numpy as np
import pandas as pd
from scipy.spatial.transform import Rotation as R
from tqdm import tqdm

# Suppress noisy warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# ===================================================================
# 1. Configuration & Global Settings
# ===================================================================

DATA_ROOT = '/kaggle/input/cmi-detect-behavior-with-sensor-data'
OUTPUT_DIR = './processed_data_selected_features_v1'
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

# 训练使用的特征列表
SELECTED_FEATURES = [
    'rot_w', 'rot_x', 'rot_y', 'rot_z',           # 四元数
    'linear_acc_x', 'linear_acc_y', 'linear_acc_z', # 线性加速度
    'linear_acc_mag',                               # 线性加速度模长
    'angular_vel_x', 'angular_vel_y', 'angular_vel_z', # 角速度
    'angular_distance',                             # 角距离
    'acc_mag'                                       # 加速度模长
]

# Labels
LABEL_NAMES = [
    'Forehead - pull hairline', 'Neck - pinch skin', 'Forehead - scratch',
    'Eyelash - pull hair', 'Text on phone', 'Eyebrow - pull hair',
    'Neck - scratch', 'Above ear - pull hair', 'Cheek - pinch skin',
    'Wave hello', 'Write name in air', 'Pull air toward your face',
    'Feel around in tray and pull out an object', 'Write name on leg',
    'Pinch knee/leg skin', 'Scratch knee/leg skin', 'Drink from bottle/cup',
    'Glasses on/off'
]
LABEL2IDX = {x: i for i, x in enumerate(LABEL_NAMES)}

# ===================================================================
# 2. 核心特征工程函数
# ===================================================================

def remove_gravity_from_acc(acc_data, rot_data):
    """去除重力影响，计算线性加速度"""
    acc_values = acc_data[['acc_x', 'acc_y', 'acc_z']].values
    quat_values = rot_data[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
    linear_accel = np.zeros_like(acc_values)
    gravity_world = np.array([0, 0, 9.81])

    for i in range(len(acc_values)):
        if np.all(np.isnan(quat_values[i])) or np.all(np.isclose(quat_values[i], 0)):
            linear_accel[i, :] = acc_values[i, :]
            continue
        try:
            rotation = R.from_quat(quat_values[i])
            gravity_sensor_frame = rotation.apply(gravity_world, inverse=True)
            linear_accel[i, :] = acc_values[i, :] - gravity_sensor_frame
        except ValueError:
            linear_accel[i, :] = acc_values[i, :]
    return linear_accel

def calculate_angular_velocity_from_quat(rot_data, time_delta=1/200):
    """从四元数计算角速度"""
    quat_values = rot_data[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
    angular_vel = np.zeros((len(quat_values), 3))

    for i in range(len(quat_values) - 1):
        q_t, q_t_plus_dt = quat_values[i], quat_values[i+1]
        if np.all(np.isnan(q_t)) or np.all(np.isclose(q_t, 0)) or \
           np.all(np.isnan(q_t_plus_dt)) or np.all(np.isclose(q_t_plus_dt, 0)):
            continue
        try:
            rot_t = R.from_quat(q_t)
            rot_t_plus_dt = R.from_quat(q_t_plus_dt)
            delta_rot = rot_t.inv() * rot_t_plus_dt
            angular_vel[i, :] = delta_rot.as_rotvec() / time_delta
        except ValueError:
            pass
    return angular_vel

def compute_angular_distance_xyzw(rot_df):
    """计算相邻帧的角距离"""
    q = rot_df[['rot_x','rot_y','rot_z','rot_w']].values.astype(np.float32)
    n = len(q)
    ang = np.zeros(n, dtype=np.float32)
    if n <= 1:
        return ang
    # 归一化
    norm = np.linalg.norm(q, axis=1, keepdims=True)
    mask = norm[:,0] > 1e-8
    q[mask] = q[mask] / norm[mask]
    # 计算角距离
    dot = np.sum(q[:-1] * q[1:], axis=1)
    dot = np.clip(np.abs(dot), -1.0, 1.0)
    ang[1:] = 2.0 * np.arccos(dot).astype(np.float32)
    return ang

def feature_engineering(df):
    """精简版特征工程"""
    print("开始特征工程...")

    # 1. 加速度模长
    df['acc_mag'] = np.sqrt(df['acc_x']**2 + df['acc_y']**2 + df['acc_z']**2)

    # 2. 线性加速度（去除重力）
    print("计算线性加速度...")
    tqdm.pandas(desc="去除重力")
    linear_accel_df = df.groupby('sequence_id', group_keys=False).progress_apply(
        lambda g: pd.DataFrame(
            remove_gravity_from_acc(
                g[['acc_x', 'acc_y', 'acc_z']],
                g[['rot_x', 'rot_y', 'rot_z', 'rot_w']]
            ),
            columns=['linear_acc_x', 'linear_acc_y', 'linear_acc_z'],
            index=g.index
        )
    )
    df = df.join(linear_accel_df)

    # 线性加速度模长
    df['linear_acc_mag'] = np.sqrt(
        df['linear_acc_x']**2 + df['linear_acc_y']**2 + df['linear_acc_z']**2
    )

    # 3. 角速度
    print("计算角速度...")
    tqdm.pandas(desc="计算角速度")
    angular_velocity_df = df.groupby('sequence_id', group_keys=False).progress_apply(
        lambda g: pd.DataFrame(
            calculate_angular_velocity_from_quat(g[['rot_x', 'rot_y', 'rot_z', 'rot_w']]),
            columns=['angular_vel_x', 'angular_vel_y', 'angular_vel_z'],
            index=g.index
        )
    )
    df = df.join(angular_velocity_df)

    # 4. 角距离
    print("计算角距离...")
    tqdm.pandas(desc="计算角距离")
    angdist_df = df.groupby('sequence_id', group_keys=False).progress_apply(
        lambda g: pd.Series(compute_angular_distance_xyzw(g[['rot_x','rot_y','rot_z','rot_w']].reset_index(drop=True)),
                            index=g.index, name='angular_distance')
    ).to_frame()
    df = df.join(angdist_df)

    # 填充缺失值（按序列分组）
    print("填充缺失值...")
    df[SELECTED_FEATURES] = (
        df.groupby('sequence_id')[SELECTED_FEATURES]
        .apply(lambda g: g.ffill().bfill())
        .reset_index(level=0, drop=True)
        .fillna(0.0)
        .astype('float32')
    )

    print(f"特征工程完成。特征数: {len(SELECTED_FEATURES)}")
    return df

# ===================================================================
# 3. 主执行流程
# ===================================================================

PROBLEMATIC_SEQUENCES = ['SEQ_011975']

if __name__ == '__main__':
    print("加载原始数据...")
    train_df = pd.read_csv(f'{DATA_ROOT}/train.csv')
    train_demo_df = pd.read_csv(f'{DATA_ROOT}/train_demographics.csv')
    train_df = pd.merge(train_df, train_demo_df, how='left', on='subject')

    # 排除问题序列
    print(f"排除问题序列: {PROBLEMATIC_SEQUENCES}")
    original_count = len(train_df)
    train_df = train_df[~train_df['sequence_id'].isin(PROBLEMATIC_SEQUENCES)].copy()
    excluded_count = original_count - len(train_df)
    print(f"已排除 {excluded_count} 条数据记录")

    # 运行特征工程
    train_df = feature_engineering(train_df)

    print("聚合数据为序列...")
    agg_train_df = train_df.groupby(['sequence_id', 'subject', 'gesture']).apply(
        lambda df: df[SELECTED_FEATURES].values,
        include_groups=False,
    ).reset_index()
    agg_train_df.columns = ['sequence_id', 'subject', 'gesture', 'sequence']
    agg_train_df['label'] = agg_train_df.gesture.map(LABEL2IDX)

    # 保存处理后的数据
    output_path = f'{OUTPUT_DIR}/processed_train_data_raw.joblib'
    print(f"保存聚合数据到 {output_path}")
    joblib.dump(agg_train_df, output_path)

    # 保存特征配置
    print(f"保存特征配置...")
    feature_info = {
        'all_features': SELECTED_FEATURES,
        'time_features': SELECTED_FEATURES,
        'psd_features': [],
        'stat_features': [],
        'feature_count': len(SELECTED_FEATURES)
    }
    with open(f'{OUTPUT_DIR}/feature_names.json', 'w') as f:
        json.dump(feature_info, f, indent=2)

    with open(f'{OUTPUT_DIR}/label_map.json', 'w') as f:
        json.dump(LABEL2IDX, f, indent=2)

    print("\n✅ 数据预处理完成!")
    print(f"特征数: {len(SELECTED_FEATURES)}")
    print("\n特征列表:")
    for i, feat in enumerate(SELECTED_FEATURES, 1):
        print(f"  {i:2d}. {feat}")

# ===================================================================
# 4. 测试集预处理函数
# ===================================================================

def preprocess_test_data(test_csv_path, demographics_csv_path=None, output_dir=OUTPUT_DIR):
    """处理测试数据"""
    print("处理测试数据...")

    test_df = pd.read_csv(test_csv_path)
    if demographics_csv_path:
        test_demo_df = pd.read_csv(demographics_csv_path)
        test_df = pd.merge(test_df, test_demo_df, how='left', on='subject')

    # 排除问题序列
    if 'sequence_id' in test_df.columns:
        test_df = test_df[~test_df['sequence_id'].isin(PROBLEMATIC_SEQUENCES)].copy()

    # 特征工程
    test_df = feature_engineering(test_df)

    # 聚合数据
    agg_test_df = test_df.groupby(['sequence_id', 'subject']).apply(
        lambda df: df[SELECTED_FEATURES].values,
        include_groups=False,
    ).reset_index()
    agg_test_df.columns = ['sequence_id', 'subject', 'sequence']

    # 保存
    test_output_path = f'{output_dir}/processed_test_data_raw.joblib'
    print(f"保存测试数据到 {test_output_path}")
    joblib.dump(agg_test_df, test_output_path)

    print("✅ 测试数据预处理完成!")
    return agg_test_df

training-model

In [None]:
# ===================================================================
#
# 完整的 TensorFlow/Keras 训练代码 - CMI 行为检测
# se1dcnn+attention模型架构 - 四元数安全版本（简化版：仅修复零四元数）
# cv = 0.8094, lb = 0.812
# 过拟合已解决
#
# ===================================================================

import os
import sys
import json
import joblib
import random

import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_probability as tfp
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, ConfusionMatrixDisplay, classification_report
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings("ignore", category=UserWarning, message=".*Your input ran out of data.*")

# 抑制 TensorFlow 的一些日志输出
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# ===================================================================
# 1. 配置与全局设置
# ===================================================================
DEBUG = False
TRAIN = True
MAX_SEQ_LENGTH = 128
PROCESSED_DATA_DIR = '/kaggle/input/imuonly-process/kaggle/working/processed_data_selected_features_v1' # 请替换为您的实际路径
OUTPUT_DIR = './saved_models_keras_fixed_test'
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

# 从预处理阶段生成的文件中加载标签映射
with open(f'{PROCESSED_DATA_DIR}/label_map.json', 'r') as f:
    LABEL2IDX = json.load(f)
IDX2LABEL = {v: k for k, v in LABEL2IDX.items()}
N_CLASSES = len(LABEL2IDX)

# 加载特征配置
with open(f'{PROCESSED_DATA_DIR}/feature_names.json', 'r') as f:
    feature_info = json.load(f)
    ALL_FEATURE_NAMES = feature_info['all_features']
    TIME_FEATURE_NAMES = feature_info['time_features']
    PSD_FEATURE_NAMES = feature_info['psd_features']
    STAT_FEATURE_NAMES = feature_info['stat_features']

# ===================================================================
# 特征选择配置 - 注意前4个必须是四元数
# ===================================================================
SELECTED_FEATURES = [
    'rot_w', 'rot_x', 'rot_y', 'rot_z',           # 四元数
    'linear_acc_x', 'linear_acc_y', 'linear_acc_z', # 线性加速度
    'linear_acc_mag',                               # 线性加速度模长
    'angular_vel_x', 'angular_vel_y', 'angular_vel_z', # 角速度
    'angular_distance',                             # 角距离
    'acc_mag'                                       # 加速度模长
]

if SELECTED_FEATURES is not None:
    FEATURE_NAMES = SELECTED_FEATURES
else:
    FEATURE_NAMES = ALL_FEATURE_NAMES

FEATURE_INDICES = [ALL_FEATURE_NAMES.index(f) for f in FEATURE_NAMES]

print(f"使用的特征数量: {len(FEATURE_NAMES)}")
print(f"使用的特征: {FEATURE_NAMES}")
print(f"前4个特征是四元数: {FEATURE_NAMES[:4]}")

tf.keras.utils.set_random_seed(42)
np.random.seed(42)

# ===================================================================
# 2. 加载预处理好的数据（原始数据，未标准化）
# ===================================================================
print("\n加载预处理好的数据（原始特征）...")
agg_train_df = joblib.load(f'{PROCESSED_DATA_DIR}/processed_train_data_raw.joblib')

if DEBUG:
    agg_train_df = agg_train_df.head(2000)

sequences_full = agg_train_df['sequence'].tolist()
sequences = [seq[:, FEATURE_INDICES] for seq in sequences_full]

# ===================================================================
# 修复零四元数（缺失值填充导致的问题）- 保留
# ===================================================================
def fix_zero_quaternions(sequences):
    """
    修复零四元数，将其替换为单位四元数 [1,0,0,0] (w,x,y,z格式)
    """
    print("\n🔧 检查并修复零四元数...")
    total_frames = 0
    zero_frames = 0
    problematic_sequences = []

    for seq_idx, seq in enumerate(sequences):
        # 计算每帧四元数的范数
        quat_norms = np.linalg.norm(seq[:, :4], axis=1)
        total_frames += len(quat_norms)

        # 找出零四元数（范数接近0）
        zero_mask = quat_norms < 1e-8
        num_zeros = np.sum(zero_mask)

        if num_zeros > 0:
            # 修复：将零四元数替换为单位四元数 [1,0,0,0]
            seq[zero_mask, :4] = [1.0, 0.0, 0.0, 0.0]
            zero_frames += num_zeros
            problematic_sequences.append(seq_idx)

    # 输出统计信息
    if zero_frames > 0:
        print(f"   ✅ 修复了 {zero_frames}/{total_frames} 个零四元数帧 "
              f"({100*zero_frames/total_frames:.2f}%)")
        print(f"   📊 涉及 {len(problematic_sequences)} 个序列")
    else:
        print(f"   ✅ 未发现零四元数，数据质量良好")

    # 验证修复后的结果
    print("\n   验证修复后的四元数范数:")
    all_quats = np.vstack([seq[:, :4] for seq in sequences])
    fixed_norms = np.linalg.norm(all_quats, axis=1)
    print(f"     均值: {np.mean(fixed_norms):.6f}")
    print(f"     标准差: {np.std(fixed_norms):.6f}")
    print(f"     范围: [{np.min(fixed_norms):.6f}, {np.max(fixed_norms):.6f}]")

    # 检查是否还有问题
    remaining_zeros = np.sum(fixed_norms < 1e-8)
    if remaining_zeros > 0:
        print(f"   ⚠️ 警告: 仍有 {remaining_zeros} 个零四元数未修复")

    return sequences

# 执行修复
sequences = fix_zero_quaternions(sequences)

labels = agg_train_df['label'].values
groups = agg_train_df['subject'].values
print(f"\n成功加载并处理 {len(sequences)} 条序列数据。")

# ===================================================================
# 3. 数据标准化函数（简化版 - 不对四元数进行归一化）
# ===================================================================
def standardize_sequences_with_quaternion(sequences_train, sequences_val, fit_on_train=True):
    train_quat = [seq[:, :4] for seq in sequences_train]
    train_other = [seq[:, 4:] for seq in sequences_train]
    val_quat = [seq[:, :4] for seq in sequences_val]
    val_other = [seq[:, 4:] for seq in sequences_val]

    # 标准化非四元数特征
    train_other_flat = np.vstack(train_other)
    val_other_flat = np.vstack(val_other)
    scaler = StandardScaler()

    if fit_on_train:
        train_other_scaled = scaler.fit_transform(train_other_flat)
        val_other_scaled = scaler.transform(val_other_flat)
    else:
        all_other_flat = np.vstack([train_other_flat, val_other_flat])
        scaler.fit(all_other_flat)
        train_other_scaled = scaler.transform(train_other_flat)
        val_other_scaled = scaler.transform(val_other_flat)

    sequences_train_scaled, sequences_val_scaled = [], []

    # 训练集 - 直接使用原始四元数（不归一化）
    train_start = 0
    for i, seq in enumerate(sequences_train):
        seq_len = len(seq)
        quat = train_quat[i]  # 直接使用，不归一化
        other_scaled = train_other_scaled[train_start:train_start+seq_len]
        seq_scaled = np.concatenate([quat, other_scaled], axis=1)
        sequences_train_scaled.append(seq_scaled)
        train_start += seq_len

    # 验证集 - 同样处理
    val_start = 0
    for i, seq in enumerate(sequences_val):
        seq_len = len(seq)
        quat = val_quat[i]  # 直接使用，不归一化
        other_scaled = val_other_scaled[val_start:val_start+seq_len]
        seq_scaled = np.concatenate([quat, other_scaled], axis=1)
        sequences_val_scaled.append(seq_scaled)
        val_start += seq_len

    return sequences_train_scaled, sequences_val_scaled, scaler

# ===================================================================
# 4. 四元数工具函数
# ===================================================================
def normalize_quaternion(quat):
    norm = tf.norm(quat, axis=-1, keepdims=True)
    norm = tf.maximum(norm, 1e-8)
    return quat / norm

def quaternion_slerp(q1, q2, t):
    dot = tf.reduce_sum(q1 * q2, axis=-1, keepdims=True)
    q2 = tf.where(dot < 0, -q2, q2)
    dot = tf.abs(dot)
    dot = tf.clip_by_value(dot, -1.0, 1.0)
    theta = tf.acos(dot)
    sin_theta = tf.sin(theta)
    w1 = tf.where(sin_theta > 1e-4, tf.sin((1.0 - t) * theta) / sin_theta, 1.0 - t)
    w2 = tf.where(sin_theta > 1e-4, tf.sin(t * theta) / sin_theta, t)
    result = w1 * q1 + w2 * q2
    return normalize_quaternion(result)

# ===================================================================
# 5. 四元数安全的数据增强（保留增强后的归一化）
# ===================================================================
def safe_tf_time_stretch(sequence, stretch_range=(0.8, 1.2)):
    seq_len_float = tf.cast(tf.shape(sequence)[0], tf.float32)
    stretch_factor = tf.random.uniform(shape=(), minval=stretch_range[0], maxval=stretch_range[1])
    new_len = tf.cast(seq_len_float / stretch_factor, tf.int32)
    quat_features = sequence[:, :4]
    other_features = sequence[:, 4:]
    quat_reshaped = tf.reshape(quat_features, [1, tf.shape(quat_features)[0], 1, 4])
    quat_stretched = tf.image.resize(quat_reshaped, [new_len, 1], method=tf.image.ResizeMethod.BILINEAR)
    quat_stretched = tf.reshape(quat_stretched, [new_len, 4])
    quat_stretched = normalize_quaternion(quat_stretched)  # 保留：拉伸后需要归一化
    if tf.shape(other_features)[1] > 0:
        other_reshaped = tf.reshape(other_features, [1, tf.shape(other_features)[0], 1, tf.shape(other_features)[1]])
        other_stretched = tf.image.resize(other_reshaped, [new_len, 1], method=tf.image.ResizeMethod.BILINEAR)
        other_stretched = tf.reshape(other_stretched, [new_len, tf.shape(other_features)[1]])
        stretched_sequence = tf.concat([quat_stretched, other_stretched], axis=1)
    else:
        stretched_sequence = quat_stretched
    return stretched_sequence

def safe_tf_augment(sequence, label, aug_prob=0.5):
    if tf.random.uniform(()) < aug_prob:
        sequence = safe_tf_time_stretch(sequence)
        # 重算magnitude
        lin_acc_xyz = sequence[:, 4:7]
        sequence = tf.concat([
            sequence[:, :4],  # 四元数
            lin_acc_xyz,
            tf.norm(lin_acc_xyz, axis=1, keepdims=True),  # 重算lin_mag
            sequence[:, 8:]  # 保留其他特征不变
        ], axis=1)
    if tf.random.uniform(()) < aug_prob:
        seq_len = tf.shape(sequence)[0]
        max_shift = tf.cast(tf.cast(seq_len, tf.float32) * 0.1, tf.int32)
        shift = tf.random.uniform(shape=(), minval=-max_shift, maxval=max_shift, dtype=tf.int32)
        sequence = tf.roll(sequence, shift=shift, axis=0)
    if tf.random.uniform(()) < aug_prob:
        quat_features = sequence[:, :4]
        other_features = sequence[:, 4:]
        quat_noise = tf.random.normal(shape=tf.shape(quat_features), stddev=0.015)
        quat_features = normalize_quaternion(quat_features + quat_noise)  # 保留：加噪后需要归一化
        if tf.shape(other_features)[1] > 0:
            other_noise = tf.random.normal(shape=tf.shape(other_features), stddev=0.03)
            other_features = other_features + other_noise
            sequence = tf.concat([quat_features, other_features], axis=1)
        else:
            sequence = quat_features
    #if tf.random.uniform(()) < aug_prob:
        #quat_features = sequence[:, :4]
        #other_features = sequence[:, 4:]
        #if tf.shape(other_features)[1] > 0:
            #scale_factor = tf.random.uniform(shape=(), minval=0.9, maxval=1.1)
            #other_features = other_features * scale_factor
            #sequence = tf.concat([quat_features, other_features], axis=1)
    if tf.random.uniform(()) < aug_prob:
        seq_len = tf.shape(sequence)[0]
        mask_ratio = 0.15
        mask_length = tf.cast(tf.cast(seq_len, tf.float32) * mask_ratio, tf.int32)
        if mask_length > 0:
            start_idx = tf.random.uniform(shape=(), maxval=tf.maximum(1, seq_len - mask_length), dtype=tf.int32)
            quat_features = sequence[:, :4]
            other_features = sequence[:, 4:]
            if tf.shape(other_features)[1] > 0:
                mask = tf.concat([
                    tf.ones([start_idx, tf.shape(other_features)[1]]),
                    tf.zeros([mask_length, tf.shape(other_features)[1]]),
                    tf.ones([seq_len - start_idx - mask_length, tf.shape(other_features)[1]])
                ], axis=0)
                other_features = other_features * mask
                sequence = tf.concat([quat_features, other_features], axis=1)
    return sequence, label

# ===================================================================
# 6. TensorFlow 数据管道（四元数安全版本）
# ===================================================================
def create_tf_dataset(X, y, batch_size, is_training=True, use_augmentation=True, use_mixup=False, mixup_alpha=0.2):
    dataset = tf.data.Dataset.from_generator(
        lambda: ((seq, label) for seq, label in zip(X, y)),
        output_signature=(
            tf.TensorSpec(shape=(None, len(FEATURE_NAMES)), dtype=tf.float32),
            tf.TensorSpec(shape=(), dtype=tf.int32)
        )
    )
    if is_training:
        dataset = dataset.shuffle(buffer_size=len(X)).repeat()
        if use_augmentation:
            dataset = dataset.map(safe_tf_augment, num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.map(lambda seq, label: (seq[:MAX_SEQ_LENGTH], label), num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.padded_batch(batch_size, padded_shapes=([MAX_SEQ_LENGTH, len(FEATURE_NAMES)], []), padding_values=(0.0, 0), drop_remainder=True)
        if use_mixup:
            mixup_partner_ds = dataset.shuffle(buffer_size=len(X)//batch_size if batch_size > 0 else 1)
            dataset = tf.data.Dataset.zip((dataset, mixup_partner_ds))
            def safe_mixup_map(data1, data2):
                (seq1, label1), (seq2, label2) = data1, data2
                dist = tfp.distributions.Beta(mixup_alpha, mixup_alpha)
                lambda_ = dist.sample()
                quat1, quat2 = seq1[:, :, :4], seq2[:, :, :4]
                other1, other2 = seq1[:, :, 4:], seq2[:, :, 4:]
                mixed_quat = quaternion_slerp(quat1, quat2, 1.0 - lambda_)  # SLERP内部会归一化
                mixed_other = lambda_ * other1 + (1 - lambda_) * other2
                mixed_seq = tf.concat([mixed_quat, mixed_other], axis=-1)
                label1_oh = tf.one_hot(tf.cast(label1, tf.int32), N_CLASSES)
                label2_oh = tf.one_hot(tf.cast(label2, tf.int32), N_CLASSES)
                mixed_label = lambda_ * label1_oh + (1 - lambda_) * label2_oh
                return mixed_seq, mixed_label
            dataset = dataset.map(safe_mixup_map, num_parallel_calls=tf.data.AUTOTUNE)
        else:
            dataset = dataset.map(lambda seq, label: (seq, tf.one_hot(tf.cast(label, tf.int32), N_CLASSES)), num_parallel_calls=tf.data.AUTOTUNE)
    else:
        dataset = dataset.map(lambda seq, label: (seq[:MAX_SEQ_LENGTH], label), num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.padded_batch(batch_size, padded_shapes=([MAX_SEQ_LENGTH, len(FEATURE_NAMES)], []), padding_values=(0.0, 0), drop_remainder=False)
        dataset = dataset.map(lambda seq, label: (seq, tf.one_hot(tf.cast(label, tf.int32), N_CLASSES)), num_parallel_calls=tf.data.AUTOTUNE)
    return dataset.prefetch(tf.data.AUTOTUNE)

# ===================================================================
# 7. Keras 模型定义
# ===================================================================
from tensorflow.keras.layers import Layer

class SumPooling1D(Layer):
    def __init__(self, **kwargs):
        super(SumPooling1D, self).__init__(**kwargs)
    def call(self, inputs):
        return tf.reduce_sum(inputs, axis=1)
    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[2])
    def get_config(self):
        base_config = super(SumPooling1D, self).get_config()
        return base_config

def se_block(input_tensor, reduction=8):
    channels = input_tensor.shape[-1]
    x = tf.keras.layers.GlobalAveragePooling1D()(input_tensor)
    x = tf.keras.layers.Dense(channels // reduction, activation='relu', use_bias=False)(x)
    x = tf.keras.layers.Dense(channels, activation='sigmoid', use_bias=False)(x)
    x = tf.keras.layers.Reshape((1, channels))(x)
    return tf.keras.layers.Multiply()([input_tensor, x])

def residual_se_cnn_block(input_tensor, out_channels, kernel_size, pool_size=2, dropout=0.3, dilation_rate=1):
    in_channels = input_tensor.shape[-1]
    if in_channels != out_channels:
        shortcut = tf.keras.layers.Conv1D(out_channels, 1, use_bias=False)(input_tensor)
        shortcut = tf.keras.layers.BatchNormalization()(shortcut)
    else:
        shortcut = input_tensor
    x = tf.keras.layers.Conv1D(out_channels, kernel_size, padding='same', use_bias=False)(input_tensor)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.Conv1D(out_channels, kernel_size, padding='same', use_bias=False, dilation_rate=dilation_rate)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = se_block(x)
    x = tf.keras.layers.Add()([shortcut, x])
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.MaxPooling1D(pool_size)(x)
    x = tf.keras.layers.Dropout(dropout)(x)
    return x

def attention_layer(input_tensor):
    x = tf.keras.layers.LayerNormalization(axis=[1, 2])(input_tensor)
    scores = tf.keras.layers.Dense(1, activation='tanh', name='attention_scores')(x)
    weights = tf.keras.layers.Softmax(axis=1, name='attention_weights')(scores)
    context = tf.keras.layers.Multiply()([input_tensor, weights])  # 注意这里用原始输入
    context = SumPooling1D()(context)
    return context

def build_model(imu_dim=len(FEATURE_NAMES), n_classes=N_CLASSES, max_seq_len=MAX_SEQ_LENGTH):
    input_layer = tf.keras.layers.Input(shape=(max_seq_len, imu_dim), name='input_sequence')
    rot_features = tf.keras.layers.Lambda(lambda x: x[:, :, 0:4], name='rot_features')(input_layer)
    acc_features = tf.keras.layers.Lambda(lambda x: x[:, :, 4:8], name='acc_features')(input_layer)
    vel_features = tf.keras.layers.Lambda(lambda x: x[:, :, 8:12], name='vel_features')(input_layer)
    other_features = tf.keras.layers.Lambda(lambda x: x[:, :, 12:], name='other_features')(input_layer)
    rot_branch = residual_se_cnn_block(rot_features, out_channels=64, kernel_size=3, pool_size=2, dropout=0.25)
    rot_branch = residual_se_cnn_block(rot_branch, out_channels=128, kernel_size=3, pool_size=2, dropout=0.3)
    acc_branch = residual_se_cnn_block(acc_features, out_channels=64, kernel_size=3, pool_size=2, dropout=0.25)
    acc_branch = residual_se_cnn_block(acc_branch, out_channels=128, kernel_size=3, pool_size=2, dropout=0.3)
    vel_branch = residual_se_cnn_block(vel_features, out_channels=64, kernel_size=3, pool_size=2, dropout=0.25)
    vel_branch = residual_se_cnn_block(vel_branch, out_channels=128, kernel_size=3, pool_size=2, dropout=0.3)
    if other_features.shape[-1] > 0:
        other_branch = residual_se_cnn_block(other_features, out_channels=32, kernel_size=3, pool_size=2, dropout=0.2)
        other_branch = residual_se_cnn_block(other_branch, out_channels=64, kernel_size=3, pool_size=2, dropout=0.35)
        merged = tf.keras.layers.Concatenate(axis=-1)([rot_branch, acc_branch, vel_branch, other_branch])
    else:
        merged = tf.keras.layers.Concatenate(axis=-1)([rot_branch, acc_branch, vel_branch])
    x = residual_se_cnn_block(merged, out_channels=256, kernel_size=3, pool_size=1, dropout=0.4)
    x = residual_se_cnn_block(x, out_channels=512, kernel_size=5, pool_size=1, dropout=0.45)
    x = attention_layer(x)
    x = tf.keras.layers.Dense(128, use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.Dropout(0.45)(x)
    x = tf.keras.layers.Dense(64, use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.Dropout(0.25)(x)
    output_layer = tf.keras.layers.Dense(n_classes, activation='linear', name='output_logits')(x)
    return tf.keras.Model(inputs=input_layer, outputs=output_layer)

# ===================================================================
# 8. 训练、回调函数与评估
# ===================================================================
sys.path.append('/kaggle/usr/lib/cmi_2025_metric_copy_for_import')
try:
    import cmi_2025_metric_copy_for_import as metric
    print("成功导入本地评估指标文件。")
    def get_competition_score(true, pred):
        true_labels = [IDX2LABEL[x] for x in true]
        pred_labels = [IDX2LABEL[x] for x in pred]
        true_df = pd.DataFrame({'id': range(len(true_labels)), 'gesture': true_labels})
        pred_df = pd.DataFrame({'id': range(len(pred_labels)), 'gesture': pred_labels})
        return metric.score(true_df, pred_df, 'id')
except ImportError:
    print("无法导入本地评估指标文件，将使用 accuracy 作为替代。")
    def get_competition_score(true, pred):
        return accuracy_score(true, pred)

class CompetitionScoreCallback(tf.keras.callbacks.Callback):
    def __init__(self, validation_data):
        super().__init__()
        self.val_data = validation_data
        self.val_labels = np.concatenate([y for x, y in validation_data], axis=0)
        self.val_labels = np.argmax(self.val_labels, axis=1)
    def on_epoch_end(self, epoch, logs=None):
        val_preds = self.model.predict(self.val_data, verbose=0)
        val_preds = np.argmax(val_preds, axis=1)
        score = get_competition_score(self.val_labels, val_preds)
        print(f" - val_score: {score:.4f}", end="")
        logs['val_score'] = score

def plot_training_history(history, fold):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    ax1.plot(history.history['loss'], label='Train Loss')
    ax1.plot(history.history['val_loss'], label='Val Loss')
    ax1.set_title(f'Fold {fold} - Loss'); ax1.set_xlabel('Epoch'); ax1.legend()
    ax2.plot(history.history['categorical_accuracy'], label='Train Accuracy')
    ax2.plot(history.history['val_categorical_accuracy'], label='Val Accuracy')
    ax2.set_title(f'Fold {fold} - Accuracy'); ax2.set_xlabel('Epoch'); ax2.legend()
    plt.tight_layout()
    plt.savefig(f'{OUTPUT_DIR}/fold_{fold}_training_history.png')
    plt.show()

class WarmupCosine(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, base_lr, warmup_steps, total_steps, min_lr=1e-5):
        super().__init__()
        self.base_lr, self.warmup_steps, self.total_steps, self.min_lr = base_lr, warmup_steps, total_steps, min_lr
    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        warm = self.base_lr * (step / tf.cast(self.warmup_steps, tf.float32))
        progress = (step - self.warmup_steps) / tf.maximum(1.0, self.total_steps - self.warmup_steps)
        cosine = self.min_lr + 0.5 * (self.base_lr - self.min_lr) * (1 + tf.cos(np.pi * tf.clip_by_value(progress, 0.0, 1.0)))
        return tf.where(step < self.warmup_steps, warm, cosine)
    def get_config(self):
        return {"base_lr": self.base_lr, "warmup_steps": self.warmup_steps, "total_steps": self.total_steps, "min_lr": self.min_lr}

# ===================================================================
# 9. 数据验证函数
# ===================================================================
def check_quaternion_norm(sequences, tolerance=0.01):
    if isinstance(sequences, list):
        all_norms = np.concatenate([np.linalg.norm(seq[:, :4], axis=1) for seq in sequences])
    else:
        quat = sequences[:, :, :4] if len(sequences.shape) == 3 else sequences[:, :4]
        all_norms = np.linalg.norm(quat, axis=-1).flatten()
    stats = {'mean': np.mean(all_norms), 'std': np.std(all_norms), 'min': np.min(all_norms), 'max': np.max(all_norms), 'is_valid': np.all(np.abs(all_norms - 1.0) < tolerance)}
    return stats['is_valid'], stats

# ===================================================================
# 10. K-Fold 交叉验证训练流程
# ===================================================================
def run_kfold_training(sequences, labels, groups, n_folds=5, **kwargs):
    print("🚀 开始 K-Fold 训练（简化版 - 仅修复零四元数）...")
    print("=" * 60)
    print("\n📊 验证数据的四元数范数（已修复零四元数）...")
    is_valid, stats = check_quaternion_norm(sequences)
    print(f"   平均范数: {stats['mean']:.4f} (±{stats['std']:.4f})")
    print(f"   范数范围: [{stats['min']:.4f}, {stats['max']:.4f}]")

    kfold_results = []
    sgkf = StratifiedGroupKFold(n_splits=n_folds, shuffle=True, random_state=42)
    sequences_np = np.array(sequences, dtype=object)

    for fold_num, (train_idx, val_idx) in enumerate(sgkf.split(sequences_np, labels, groups), 1):
        print(f"\n{'='*60}\n🔄 训练 Fold {fold_num}/{n_folds}\n{'='*60}")
        X_train_raw, X_val_raw = sequences_np[train_idx], sequences_np[val_idx]
        y_train, y_val = labels[train_idx], labels[val_idx]

        print(f"📊 Fold {fold_num}: 标准化数据（四元数保持原始值）...")
        X_train_scaled, X_val_scaled, scaler = standardize_sequences_with_quaternion(X_train_raw.tolist(), X_val_raw.tolist(), fit_on_train=True)
        is_valid_train, stats_train = check_quaternion_norm(X_train_scaled)
        print(f"   训练集四元数范数: {stats_train['mean']:.4f} (±{stats_train['std']:.4f})")
        scaler_path = f'{OUTPUT_DIR}/fold_{fold_num}_scaler.joblib'
        joblib.dump(scaler, scaler_path)
        print(f"💾 保存 Fold {fold_num} 的 StandardScaler 到 {scaler_path}")

        train_ds = create_tf_dataset(X_train_scaled, y_train, kwargs['batch_size'], is_training=True, use_augmentation=True, use_mixup=kwargs['use_mixup'], mixup_alpha=kwargs.get('mixup_alpha', 0.2))
        val_ds = create_tf_dataset(X_val_scaled, y_val, kwargs['batch_size'], is_training=False)

        tf.keras.backend.clear_session()
        model = build_model()
        steps_per_epoch = len(X_train_scaled) // kwargs['batch_size']
        if steps_per_epoch == 0:
            print(f"错误: Fold {fold_num} 的训练样本数小于batch_size")
            continue

        total_steps = steps_per_epoch * kwargs['num_epochs']
        lr_sched = WarmupCosine(base_lr=kwargs['learning_rate'], warmup_steps=int(0.05 * total_steps), total_steps=total_steps, min_lr=1e-5)
        optimizer = tf.keras.optimizers.AdamW(learning_rate=lr_sched, weight_decay=5e-3)
        loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=kwargs['label_smoothing'])
        model.compile(optimizer=optimizer, loss=loss_fn, metrics=['categorical_accuracy'])

        model_save_path = f'{OUTPUT_DIR}/fold_{fold_num}_model.keras'
        callbacks = [
            CompetitionScoreCallback(validation_data=val_ds),
            tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path, save_best_only=True, monitor='val_score', mode='max', verbose=1),
            tf.keras.callbacks.EarlyStopping(monitor='val_score', patience=kwargs['patience'], mode='max', verbose=1, restore_best_weights=True),
        ]

        print(f"🏋️ 开始训练 Fold {fold_num}...")
        history = model.fit(train_ds, epochs=kwargs['num_epochs'], validation_data=val_ds, callbacks=callbacks, steps_per_epoch=steps_per_epoch, verbose=2)
        plot_training_history(history, fold_num)

        print(f"\n🎯 Fold {fold_num} 最终评估:")
        y_val_unbatched = np.concatenate([y for x, y in val_ds], axis=0)
        y_val_labels = np.argmax(y_val_unbatched, axis=1)
        val_preds_final = model.predict(val_ds)
        val_preds_final_labels = np.argmax(val_preds_final, axis=1)
        y_val_labels = y_val_labels[:len(val_preds_final_labels)]

        final_score = get_competition_score(y_val_labels, val_preds_final_labels)
        final_acc = np.mean(val_preds_final_labels == y_val_labels)
        print(f"       验证集 Accuracy: {final_acc:.4f}")
        print(f"       验证集 Score: {final_score:.4f}")

        print(f"\n📊 Fold {fold_num} - 分类报告:")
        print(classification_report(
            y_val_labels,
            val_preds_final_labels,
            target_names=list(IDX2LABEL.values()),
            digits=4
        ))

        kfold_results.append({'fold': fold_num, 'val_accuracy': final_acc, 'val_score': final_score})

    return kfold_results

# ===================================================================
# 11. 测试集预测函数（简化版 - 不归一化四元数）
# ===================================================================
def predict_test_data(test_sequences_raw, output_dir=OUTPUT_DIR, n_folds=5):
    print("🔮 开始测试集预测...")
    all_predictions = []
    for fold_num in range(1, n_folds + 1):
        print(f"\n处理 Fold {fold_num}...")
        scaler_path = f'{output_dir}/fold_{fold_num}_scaler.joblib'
        scaler = joblib.load(scaler_path)
        print(f"   加载scaler: {scaler_path}")
        test_quat = [seq[:, :4] for seq in test_sequences_raw]
        test_other = [seq[:, 4:] for seq in test_sequences_raw]
        test_other_flat = np.vstack(test_other)
        test_other_scaled = scaler.transform(test_other_flat)
        test_sequences_scaled = []
        start = 0
        for i, seq in enumerate(test_sequences_raw):
            seq_len = len(seq)
            quat = test_quat[i]  # 直接使用，不归一化
            other_scaled = test_other_scaled[start:start+seq_len]
            seq_scaled = np.concatenate([quat, other_scaled], axis=1)
            test_sequences_scaled.append(seq_scaled)
            start += seq_len
        model_path = f'{output_dir}/fold_{fold_num}_model.keras'
        custom_objects = {'WarmupCosine': WarmupCosine, 'SumPooling1D': SumPooling1D}
        model = tf.keras.models.load_model(model_path, custom_objects=custom_objects, compile=False)
        print(f"   加载模型: {model_path}")
        dummy_labels = np.zeros(len(test_sequences_scaled), dtype=np.int32)
        test_ds = create_tf_dataset(test_sequences_scaled, dummy_labels, batch_size=64, is_training=False)
        fold_predictions = model.predict(test_ds, verbose=0)
        all_predictions.append(fold_predictions)
        print(f"   完成Fold {fold_num}的预测")
    ensemble_predictions = np.mean(all_predictions, axis=0)
    final_predictions = np.argmax(ensemble_predictions, axis=1)
    print(f"\n✅ 测试集预测完成！")
    return final_predictions, ensemble_predictions

# ===================================================================
# 12. 主执行逻辑
# ===================================================================
if __name__ == '__main__':
    training_params = {
        'num_epochs': 3 if DEBUG else 100,
        'learning_rate': 0.001,
        'patience': 20,
        'batch_size': 64,
        'label_smoothing': 0.1,
        'use_mixup': True,
        'mixup_alpha': 0.4
    }

    if TRAIN:
        results = run_kfold_training(sequences, labels, groups, n_folds=5, **training_params)
        print("\n\n" + "="*60 + "\n🎉 K-Fold 训练总结\n" + "="*60)
        if results:
            val_scores = [r['val_score'] for r in results]
            val_accs = [r['val_accuracy'] for r in results]
            print(f"平均验证集 Score: {np.mean(val_scores):.4f} ± {np.std(val_scores):.4f}")
            print(f"平均验证集 Accuracy: {np.mean(val_accs):.4f} ± {np.std(val_accs):.4f}")
            feature_selection_info = {'all_features': ALL_FEATURE_NAMES, 'selected_features': FEATURE_NAMES, 'feature_indices': FEATURE_INDICES, 'num_features': len(FEATURE_NAMES), 'quaternion_safe': True}
            results_with_features = {'kfold_results': results, 'feature_selection': feature_selection_info, 'preprocessing_note': '简化版：仅修复零四元数，不进行阈值归一化', 'lr_schedule': 'WarmupCosine'}
            with open(f'{OUTPUT_DIR}/kfold_results.json', 'w') as f:
                json.dump(results_with_features, f, indent=2)

    print("\n📝 重要说明:")
    print("1. 自动检测并修复零四元数（填充为单位四元数）")
    print("2. 四元数在标准化时保持原始值，不进行归一化")
    print("3. 数据增强后对四元数进行必要的归一化")
    print("4. MixUp使用SLERP进行四元数插值")
    print("5. 标准化仅应用于非四元数特征")
    print("6. 每折训练后会打印详细的分类报告")

# ===================================================================
# 13. 生成提交文件
# ===================================================================
def generate_submission(submission_output_path='submission.csv'):
    print("\n" + "="*60 + "\n📄 生成提交文件" + "\n" + "="*60)
    print("加载预处理的测试数据...")
    test_data = joblib.load(f'{PROCESSED_DATA_DIR}/processed_test_data_raw.joblib')
    test_sequences_full = test_data['sequence'].tolist()
    test_sequences = [seq[:, FEATURE_INDICES] for seq in test_sequences_full]

    # 修复测试集的零四元数
    print("\n处理测试集数据...")
    test_sequences = fix_zero_quaternions(test_sequences)

    test_sequence_ids = test_data['sequence_id'].values
    predictions, _ = predict_test_data(test_sequences, OUTPUT_DIR, n_folds=5)
    predicted_labels = [IDX2LABEL[pred] for pred in predictions]
    submission_df = pd.DataFrame({'sequence_id': test_sequence_ids, 'gesture': predicted_labels})
    submission_df.to_csv(submission_output_path, index=False)
    print(f"✅ 提交文件已保存到: {submission_output_path}")
    print("\n📊 预测分布:")
    print(submission_df['gesture'].value_counts())
    return submission_df

In [None]:
#仔细检查修改数据增强和特征工程冲突，修复数据增强后破坏的数据物理特性

In [None]:
# ===================================================================
#
# 完整的 TensorFlow/Keras 训练代码 - CMI 行为检测
# se1dcnn+attention模型架构 - 四元数安全版本（简化版：仅修复零四元数）
# 10-FOLD交叉验证版本
# cv = 0.8104,lb = 0.813
# 过拟合已解决
#
# ===================================================================

import os
import sys
import json
import joblib
import random

import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_probability as tfp
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, ConfusionMatrixDisplay, classification_report
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings("ignore", category=UserWarning, message=".*Your input ran out of data.*")

# 抑制 TensorFlow 的一些日志输出
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# ===================================================================
# 1. 配置与全局设置
# ===================================================================
DEBUG = False
TRAIN = True
MAX_SEQ_LENGTH = 128
PROCESSED_DATA_DIR = '/kaggle/input/imuonly-process/kaggle/working/processed_data_selected_features_v1' # 请替换为您的实际路径
OUTPUT_DIR = './saved_models_keras_fixed_test'
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

# 从预处理阶段生成的文件中加载标签映射
with open(f'{PROCESSED_DATA_DIR}/label_map.json', 'r') as f:
    LABEL2IDX = json.load(f)
IDX2LABEL = {v: k for k, v in LABEL2IDX.items()}
N_CLASSES = len(LABEL2IDX)

# 加载特征配置
with open(f'{PROCESSED_DATA_DIR}/feature_names.json', 'r') as f:
    feature_info = json.load(f)
    ALL_FEATURE_NAMES = feature_info['all_features']
    TIME_FEATURE_NAMES = feature_info['time_features']
    PSD_FEATURE_NAMES = feature_info['psd_features']
    STAT_FEATURE_NAMES = feature_info['stat_features']

# ===================================================================
# 特征选择配置 - 注意前4个必须是四元数
# ===================================================================
SELECTED_FEATURES = [
    'rot_w', 'rot_x', 'rot_y', 'rot_z',           # 四元数
    'linear_acc_x', 'linear_acc_y', 'linear_acc_z', # 线性加速度
    'linear_acc_mag',                               # 线性加速度模长
    'angular_vel_x', 'angular_vel_y', 'angular_vel_z', # 角速度
    'angular_distance',                             # 角距离
    'acc_mag'                                       # 加速度模长
]

if SELECTED_FEATURES is not None:
    FEATURE_NAMES = SELECTED_FEATURES
else:
    FEATURE_NAMES = ALL_FEATURE_NAMES

FEATURE_INDICES = [ALL_FEATURE_NAMES.index(f) for f in FEATURE_NAMES]

print(f"使用的特征数量: {len(FEATURE_NAMES)}")
print(f"使用的特征: {FEATURE_NAMES}")
print(f"前4个特征是四元数: {FEATURE_NAMES[:4]}")

tf.keras.utils.set_random_seed(42)
np.random.seed(42)

# ===================================================================
# 2. 加载预处理好的数据（原始数据，未标准化）
# ===================================================================
print("\n加载预处理好的数据（原始特征）...")
agg_train_df = joblib.load(f'{PROCESSED_DATA_DIR}/processed_train_data_raw.joblib')

if DEBUG:
    agg_train_df = agg_train_df.head(2000)

sequences_full = agg_train_df['sequence'].tolist()
sequences = [seq[:, FEATURE_INDICES] for seq in sequences_full]

# ===================================================================
# 修复零四元数（缺失值填充导致的问题）- 保留
# ===================================================================
def fix_zero_quaternions(sequences):
    """
    修复零四元数，将其替换为单位四元数 [1,0,0,0] (w,x,y,z格式)
    """
    print("\n🔧 检查并修复零四元数...")
    total_frames = 0
    zero_frames = 0
    problematic_sequences = []

    for seq_idx, seq in enumerate(sequences):
        # 计算每帧四元数的范数
        quat_norms = np.linalg.norm(seq[:, :4], axis=1)
        total_frames += len(quat_norms)

        # 找出零四元数（范数接近0）
        zero_mask = quat_norms < 1e-8
        num_zeros = np.sum(zero_mask)

        if num_zeros > 0:
            # 修复：将零四元数替换为单位四元数 [1,0,0,0]
            seq[zero_mask, :4] = [1.0, 0.0, 0.0, 0.0]
            zero_frames += num_zeros
            problematic_sequences.append(seq_idx)

    # 输出统计信息
    if zero_frames > 0:
        print(f"   ✅ 修复了 {zero_frames}/{total_frames} 个零四元数帧 "
              f"({100*zero_frames/total_frames:.2f}%)")
        print(f"   📊 涉及 {len(problematic_sequences)} 个序列")
    else:
        print(f"   ✅ 未发现零四元数，数据质量良好")

    # 验证修复后的结果
    print("\n   验证修复后的四元数范数:")
    all_quats = np.vstack([seq[:, :4] for seq in sequences])
    fixed_norms = np.linalg.norm(all_quats, axis=1)
    print(f"     均值: {np.mean(fixed_norms):.6f}")
    print(f"     标准差: {np.std(fixed_norms):.6f}")
    print(f"     范围: [{np.min(fixed_norms):.6f}, {np.max(fixed_norms):.6f}]")

    # 检查是否还有问题
    remaining_zeros = np.sum(fixed_norms < 1e-8)
    if remaining_zeros > 0:
        print(f"   ⚠️ 警告: 仍有 {remaining_zeros} 个零四元数未修复")

    return sequences

# 执行修复
sequences = fix_zero_quaternions(sequences)

labels = agg_train_df['label'].values
groups = agg_train_df['subject'].values
print(f"\n成功加载并处理 {len(sequences)} 条序列数据。")

# ===================================================================
# 3. 数据标准化函数（简化版 - 不对四元数进行归一化）
# ===================================================================
def standardize_sequences_with_quaternion(sequences_train, sequences_val, fit_on_train=True):
    train_quat = [seq[:, :4] for seq in sequences_train]
    train_other = [seq[:, 4:] for seq in sequences_train]
    val_quat = [seq[:, :4] for seq in sequences_val]
    val_other = [seq[:, 4:] for seq in sequences_val]

    # 标准化非四元数特征
    train_other_flat = np.vstack(train_other)
    val_other_flat = np.vstack(val_other)
    scaler = StandardScaler()

    if fit_on_train:
        train_other_scaled = scaler.fit_transform(train_other_flat)
        val_other_scaled = scaler.transform(val_other_flat)
    else:
        all_other_flat = np.vstack([train_other_flat, val_other_flat])
        scaler.fit(all_other_flat)
        train_other_scaled = scaler.transform(train_other_flat)
        val_other_scaled = scaler.transform(val_other_flat)

    sequences_train_scaled, sequences_val_scaled = [], []

    # 训练集 - 直接使用原始四元数（不归一化）
    train_start = 0
    for i, seq in enumerate(sequences_train):
        seq_len = len(seq)
        quat = train_quat[i]  # 直接使用，不归一化
        other_scaled = train_other_scaled[train_start:train_start+seq_len]
        seq_scaled = np.concatenate([quat, other_scaled], axis=1)
        sequences_train_scaled.append(seq_scaled)
        train_start += seq_len

    # 验证集 - 同样处理
    val_start = 0
    for i, seq in enumerate(sequences_val):
        seq_len = len(seq)
        quat = val_quat[i]  # 直接使用，不归一化
        other_scaled = val_other_scaled[val_start:val_start+seq_len]
        seq_scaled = np.concatenate([quat, other_scaled], axis=1)
        sequences_val_scaled.append(seq_scaled)
        val_start += seq_len

    return sequences_train_scaled, sequences_val_scaled, scaler

# ===================================================================
# 4. 四元数工具函数
# ===================================================================
def normalize_quaternion(quat):
    norm = tf.norm(quat, axis=-1, keepdims=True)
    norm = tf.maximum(norm, 1e-8)
    return quat / norm

def quaternion_slerp(q1, q2, t):
    dot = tf.reduce_sum(q1 * q2, axis=-1, keepdims=True)
    q2 = tf.where(dot < 0, -q2, q2)
    dot = tf.abs(dot)
    dot = tf.clip_by_value(dot, -1.0, 1.0)
    theta = tf.acos(dot)
    sin_theta = tf.sin(theta)
    w1 = tf.where(sin_theta > 1e-4, tf.sin((1.0 - t) * theta) / sin_theta, 1.0 - t)
    w2 = tf.where(sin_theta > 1e-4, tf.sin(t * theta) / sin_theta, t)
    result = w1 * q1 + w2 * q2
    return normalize_quaternion(result)

# ===================================================================
# 5. 四元数安全的数据增强（保留增强后的归一化）
# ===================================================================
def safe_tf_time_stretch(sequence, stretch_range=(0.8, 1.2)):
    seq_len_float = tf.cast(tf.shape(sequence)[0], tf.float32)
    stretch_factor = tf.random.uniform(shape=(), minval=stretch_range[0], maxval=stretch_range[1])
    new_len = tf.cast(seq_len_float / stretch_factor, tf.int32)
    quat_features = sequence[:, :4]
    other_features = sequence[:, 4:]
    quat_reshaped = tf.reshape(quat_features, [1, tf.shape(quat_features)[0], 1, 4])
    quat_stretched = tf.image.resize(quat_reshaped, [new_len, 1], method=tf.image.ResizeMethod.BILINEAR)
    quat_stretched = tf.reshape(quat_stretched, [new_len, 4])
    quat_stretched = normalize_quaternion(quat_stretched)  # 保留：拉伸后需要归一化
    if tf.shape(other_features)[1] > 0:
        other_reshaped = tf.reshape(other_features, [1, tf.shape(other_features)[0], 1, tf.shape(other_features)[1]])
        other_stretched = tf.image.resize(other_reshaped, [new_len, 1], method=tf.image.ResizeMethod.BILINEAR)
        other_stretched = tf.reshape(other_stretched, [new_len, tf.shape(other_features)[1]])
        stretched_sequence = tf.concat([quat_stretched, other_stretched], axis=1)
    else:
        stretched_sequence = quat_stretched
    return stretched_sequence

def safe_tf_augment(sequence, label, aug_prob=0.5):
    if tf.random.uniform(()) < aug_prob:
        sequence = safe_tf_time_stretch(sequence)
        # 重算magnitude
        lin_acc_xyz = sequence[:, 4:7]
        sequence = tf.concat([
            sequence[:, :4],  # 四元数
            lin_acc_xyz,
            tf.norm(lin_acc_xyz, axis=1, keepdims=True),  # 重算lin_mag
            sequence[:, 8:]  # 保留其他特征不变
        ], axis=1)
    if tf.random.uniform(()) < aug_prob:
        seq_len = tf.shape(sequence)[0]
        max_shift = tf.cast(tf.cast(seq_len, tf.float32) * 0.1, tf.int32)
        shift = tf.random.uniform(shape=(), minval=-max_shift, maxval=max_shift, dtype=tf.int32)
        sequence = tf.roll(sequence, shift=shift, axis=0)
    if tf.random.uniform(()) < aug_prob:
        quat_features = sequence[:, :4]
        other_features = sequence[:, 4:]
        quat_noise = tf.random.normal(shape=tf.shape(quat_features), stddev=0.015)
        quat_features = normalize_quaternion(quat_features + quat_noise)  # 保留：加噪后需要归一化
        if tf.shape(other_features)[1] > 0:
            other_noise = tf.random.normal(shape=tf.shape(other_features), stddev=0.03)
            other_features = other_features + other_noise
            sequence = tf.concat([quat_features, other_features], axis=1)
        else:
            sequence = quat_features
    #if tf.random.uniform(()) < aug_prob:
        #quat_features = sequence[:, :4]
        #other_features = sequence[:, 4:]
        #if tf.shape(other_features)[1] > 0:
            #scale_factor = tf.random.uniform(shape=(), minval=0.9, maxval=1.1)
            #other_features = other_features * scale_factor
            #sequence = tf.concat([quat_features, other_features], axis=1)
    if tf.random.uniform(()) < aug_prob:
        seq_len = tf.shape(sequence)[0]
        mask_ratio = 0.15
        mask_length = tf.cast(tf.cast(seq_len, tf.float32) * mask_ratio, tf.int32)
        if mask_length > 0:
            start_idx = tf.random.uniform(shape=(), maxval=tf.maximum(1, seq_len - mask_length), dtype=tf.int32)
            quat_features = sequence[:, :4]
            other_features = sequence[:, 4:]
            if tf.shape(other_features)[1] > 0:
                mask = tf.concat([
                    tf.ones([start_idx, tf.shape(other_features)[1]]),
                    tf.zeros([mask_length, tf.shape(other_features)[1]]),
                    tf.ones([seq_len - start_idx - mask_length, tf.shape(other_features)[1]])
                ], axis=0)
                other_features = other_features * mask
                sequence = tf.concat([quat_features, other_features], axis=1)
    return sequence, label

# ===================================================================
# 6. TensorFlow 数据管道（四元数安全版本）
# ===================================================================
def create_tf_dataset(X, y, batch_size, is_training=True, use_augmentation=True, use_mixup=False, mixup_alpha=0.2):
    dataset = tf.data.Dataset.from_generator(
        lambda: ((seq, label) for seq, label in zip(X, y)),
        output_signature=(
            tf.TensorSpec(shape=(None, len(FEATURE_NAMES)), dtype=tf.float32),
            tf.TensorSpec(shape=(), dtype=tf.int32)
        )
    )
    if is_training:
        dataset = dataset.shuffle(buffer_size=len(X)).repeat()
        if use_augmentation:
            dataset = dataset.map(safe_tf_augment, num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.map(lambda seq, label: (seq[:MAX_SEQ_LENGTH], label), num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.padded_batch(batch_size, padded_shapes=([MAX_SEQ_LENGTH, len(FEATURE_NAMES)], []), padding_values=(0.0, 0), drop_remainder=True)
        if use_mixup:
            mixup_partner_ds = dataset.shuffle(buffer_size=len(X)//batch_size if batch_size > 0 else 1)
            dataset = tf.data.Dataset.zip((dataset, mixup_partner_ds))
            def safe_mixup_map(data1, data2):
                (seq1, label1), (seq2, label2) = data1, data2
                dist = tfp.distributions.Beta(mixup_alpha, mixup_alpha)
                lambda_ = dist.sample()
                quat1, quat2 = seq1[:, :, :4], seq2[:, :, :4]
                other1, other2 = seq1[:, :, 4:], seq2[:, :, 4:]
                mixed_quat = quaternion_slerp(quat1, quat2, 1.0 - lambda_)  # SLERP内部会归一化
                mixed_other = lambda_ * other1 + (1 - lambda_) * other2
                mixed_seq = tf.concat([mixed_quat, mixed_other], axis=-1)
                label1_oh = tf.one_hot(tf.cast(label1, tf.int32), N_CLASSES)
                label2_oh = tf.one_hot(tf.cast(label2, tf.int32), N_CLASSES)
                mixed_label = lambda_ * label1_oh + (1 - lambda_) * label2_oh
                return mixed_seq, mixed_label
            dataset = dataset.map(safe_mixup_map, num_parallel_calls=tf.data.AUTOTUNE)
        else:
            dataset = dataset.map(lambda seq, label: (seq, tf.one_hot(tf.cast(label, tf.int32), N_CLASSES)), num_parallel_calls=tf.data.AUTOTUNE)
    else:
        dataset = dataset.map(lambda seq, label: (seq[:MAX_SEQ_LENGTH], label), num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.padded_batch(batch_size, padded_shapes=([MAX_SEQ_LENGTH, len(FEATURE_NAMES)], []), padding_values=(0.0, 0), drop_remainder=False)
        dataset = dataset.map(lambda seq, label: (seq, tf.one_hot(tf.cast(label, tf.int32), N_CLASSES)), num_parallel_calls=tf.data.AUTOTUNE)
    return dataset.prefetch(tf.data.AUTOTUNE)

# ===================================================================
# 7. Keras 模型定义
# ===================================================================
from tensorflow.keras.layers import Layer

class SumPooling1D(Layer):
    def __init__(self, **kwargs):
        super(SumPooling1D, self).__init__(**kwargs)
    def call(self, inputs):
        return tf.reduce_sum(inputs, axis=1)
    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[2])
    def get_config(self):
        base_config = super(SumPooling1D, self).get_config()
        return base_config

def se_block(input_tensor, reduction=8):
    channels = input_tensor.shape[-1]
    x = tf.keras.layers.GlobalAveragePooling1D()(input_tensor)
    x = tf.keras.layers.Dense(channels // reduction, activation='relu', use_bias=False)(x)
    x = tf.keras.layers.Dense(channels, activation='sigmoid', use_bias=False)(x)
    x = tf.keras.layers.Reshape((1, channels))(x)
    return tf.keras.layers.Multiply()([input_tensor, x])

def residual_se_cnn_block(input_tensor, out_channels, kernel_size, pool_size=2, dropout=0.3, dilation_rate=1):
    in_channels = input_tensor.shape[-1]
    if in_channels != out_channels:
        shortcut = tf.keras.layers.Conv1D(out_channels, 1, use_bias=False)(input_tensor)
        shortcut = tf.keras.layers.BatchNormalization()(shortcut)
    else:
        shortcut = input_tensor
    x = tf.keras.layers.Conv1D(out_channels, kernel_size, padding='same', use_bias=False)(input_tensor)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.Conv1D(out_channels, kernel_size, padding='same', use_bias=False, dilation_rate=dilation_rate)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = se_block(x)
    x = tf.keras.layers.Add()([shortcut, x])
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.MaxPooling1D(pool_size)(x)
    x = tf.keras.layers.Dropout(dropout)(x)
    return x

def attention_layer(input_tensor):
    x = tf.keras.layers.LayerNormalization(axis=[1, 2])(input_tensor)
    scores = tf.keras.layers.Dense(1, activation='tanh', name='attention_scores')(x)
    weights = tf.keras.layers.Softmax(axis=1, name='attention_weights')(scores)
    context = tf.keras.layers.Multiply()([input_tensor, weights])  # 注意这里用原始输入
    context = SumPooling1D()(context)
    return context

def build_model(imu_dim=len(FEATURE_NAMES), n_classes=N_CLASSES, max_seq_len=MAX_SEQ_LENGTH):
    input_layer = tf.keras.layers.Input(shape=(max_seq_len, imu_dim), name='input_sequence')
    rot_features = tf.keras.layers.Lambda(lambda x: x[:, :, 0:4], name='rot_features')(input_layer)
    acc_features = tf.keras.layers.Lambda(lambda x: x[:, :, 4:8], name='acc_features')(input_layer)
    vel_features = tf.keras.layers.Lambda(lambda x: x[:, :, 8:12], name='vel_features')(input_layer)
    other_features = tf.keras.layers.Lambda(lambda x: x[:, :, 12:], name='other_features')(input_layer)
    rot_branch = residual_se_cnn_block(rot_features, out_channels=64, kernel_size=3, pool_size=2, dropout=0.25)
    rot_branch = residual_se_cnn_block(rot_branch, out_channels=128, kernel_size=3, pool_size=2, dropout=0.3)
    acc_branch = residual_se_cnn_block(acc_features, out_channels=64, kernel_size=3, pool_size=2, dropout=0.25)
    acc_branch = residual_se_cnn_block(acc_branch, out_channels=128, kernel_size=3, pool_size=2, dropout=0.3)
    vel_branch = residual_se_cnn_block(vel_features, out_channels=64, kernel_size=3, pool_size=2, dropout=0.25)
    vel_branch = residual_se_cnn_block(vel_branch, out_channels=128, kernel_size=3, pool_size=2, dropout=0.3)
    if other_features.shape[-1] > 0:
        other_branch = residual_se_cnn_block(other_features, out_channels=32, kernel_size=3, pool_size=2, dropout=0.2)
        other_branch = residual_se_cnn_block(other_branch, out_channels=64, kernel_size=3, pool_size=2, dropout=0.35)
        merged = tf.keras.layers.Concatenate(axis=-1)([rot_branch, acc_branch, vel_branch, other_branch])
    else:
        merged = tf.keras.layers.Concatenate(axis=-1)([rot_branch, acc_branch, vel_branch])
    x = residual_se_cnn_block(merged, out_channels=256, kernel_size=3, pool_size=1, dropout=0.4)
    x = residual_se_cnn_block(x, out_channels=512, kernel_size=5, pool_size=1, dropout=0.45)
    x = attention_layer(x)
    x = tf.keras.layers.Dense(128, use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.Dropout(0.45)(x)
    x = tf.keras.layers.Dense(64, use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.Dropout(0.25)(x)
    output_layer = tf.keras.layers.Dense(n_classes, activation='linear', name='output_logits')(x)
    return tf.keras.Model(inputs=input_layer, outputs=output_layer)

# ===================================================================
# 8. 训练、回调函数与评估
# ===================================================================
sys.path.append('/kaggle/usr/lib/cmi_2025_metric_copy_for_import')
try:
    import cmi_2025_metric_copy_for_import as metric
    print("成功导入本地评估指标文件。")
    def get_competition_score(true, pred):
        true_labels = [IDX2LABEL[x] for x in true]
        pred_labels = [IDX2LABEL[x] for x in pred]
        true_df = pd.DataFrame({'id': range(len(true_labels)), 'gesture': true_labels})
        pred_df = pd.DataFrame({'id': range(len(pred_labels)), 'gesture': pred_labels})
        return metric.score(true_df, pred_df, 'id')
except ImportError:
    print("无法导入本地评估指标文件，将使用 accuracy 作为替代。")
    def get_competition_score(true, pred):
        return accuracy_score(true, pred)

class CompetitionScoreCallback(tf.keras.callbacks.Callback):
    def __init__(self, validation_data):
        super().__init__()
        self.val_data = validation_data
        self.val_labels = np.concatenate([y for x, y in validation_data], axis=0)
        self.val_labels = np.argmax(self.val_labels, axis=1)
    def on_epoch_end(self, epoch, logs=None):
        val_preds = self.model.predict(self.val_data, verbose=0)
        val_preds = np.argmax(val_preds, axis=1)
        score = get_competition_score(self.val_labels, val_preds)
        print(f" - val_score: {score:.4f}", end="")
        logs['val_score'] = score

def plot_training_history(history, fold):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    ax1.plot(history.history['loss'], label='Train Loss')
    ax1.plot(history.history['val_loss'], label='Val Loss')
    ax1.set_title(f'Fold {fold} - Loss'); ax1.set_xlabel('Epoch'); ax1.legend()
    ax2.plot(history.history['categorical_accuracy'], label='Train Accuracy')
    ax2.plot(history.history['val_categorical_accuracy'], label='Val Accuracy')
    ax2.set_title(f'Fold {fold} - Accuracy'); ax2.set_xlabel('Epoch'); ax2.legend()
    plt.tight_layout()
    plt.savefig(f'{OUTPUT_DIR}/fold_{fold}_training_history.png')
    plt.show()

class WarmupCosine(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, base_lr, warmup_steps, total_steps, min_lr=1e-5):
        super().__init__()
        self.base_lr, self.warmup_steps, self.total_steps, self.min_lr = base_lr, warmup_steps, total_steps, min_lr
    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        warm = self.base_lr * (step / tf.cast(self.warmup_steps, tf.float32))
        progress = (step - self.warmup_steps) / tf.maximum(1.0, self.total_steps - self.warmup_steps)
        cosine = self.min_lr + 0.5 * (self.base_lr - self.min_lr) * (1 + tf.cos(np.pi * tf.clip_by_value(progress, 0.0, 1.0)))
        return tf.where(step < self.warmup_steps, warm, cosine)
    def get_config(self):
        return {"base_lr": self.base_lr, "warmup_steps": self.warmup_steps, "total_steps": self.total_steps, "min_lr": self.min_lr}

# ===================================================================
# 9. 数据验证函数
# ===================================================================
def check_quaternion_norm(sequences, tolerance=0.01):
    if isinstance(sequences, list):
        all_norms = np.concatenate([np.linalg.norm(seq[:, :4], axis=1) for seq in sequences])
    else:
        quat = sequences[:, :, :4] if len(sequences.shape) == 3 else sequences[:, :4]
        all_norms = np.linalg.norm(quat, axis=-1).flatten()
    stats = {'mean': np.mean(all_norms), 'std': np.std(all_norms), 'min': np.min(all_norms), 'max': np.max(all_norms), 'is_valid': np.all(np.abs(all_norms - 1.0) < tolerance)}
    return stats['is_valid'], stats

# ===================================================================
# 10. K-Fold 交叉验证训练流程
# ===================================================================
def run_kfold_training(sequences, labels, groups, n_folds=5, **kwargs):
    print(f"🚀 开始 {n_folds}-Fold 训练（简化版 - 仅修复零四元数）...")
    print("=" * 60)
    print("\n📊 验证数据的四元数范数（已修复零四元数）...")
    is_valid, stats = check_quaternion_norm(sequences)
    print(f"   平均范数: {stats['mean']:.4f} (±{stats['std']:.4f})")
    print(f"   范数范围: [{stats['min']:.4f}, {stats['max']:.4f}]")

    kfold_results = []
    sgkf = StratifiedGroupKFold(n_splits=n_folds, shuffle=True, random_state=42)
    sequences_np = np.array(sequences, dtype=object)

    for fold_num, (train_idx, val_idx) in enumerate(sgkf.split(sequences_np, labels, groups), 1):
        print(f"\n{'='*60}\n🔄 训练 Fold {fold_num}/{n_folds}\n{'='*60}")
        X_train_raw, X_val_raw = sequences_np[train_idx], sequences_np[val_idx]
        y_train, y_val = labels[train_idx], labels[val_idx]

        print(f"📊 Fold {fold_num}: 标准化数据（四元数保持原始值）...")
        X_train_scaled, X_val_scaled, scaler = standardize_sequences_with_quaternion(X_train_raw.tolist(), X_val_raw.tolist(), fit_on_train=True)
        is_valid_train, stats_train = check_quaternion_norm(X_train_scaled)
        print(f"   训练集四元数范数: {stats_train['mean']:.4f} (±{stats_train['std']:.4f})")
        scaler_path = f'{OUTPUT_DIR}/fold_{fold_num}_scaler.joblib'
        joblib.dump(scaler, scaler_path)
        print(f"💾 保存 Fold {fold_num} 的 StandardScaler 到 {scaler_path}")

        train_ds = create_tf_dataset(X_train_scaled, y_train, kwargs['batch_size'], is_training=True, use_augmentation=True, use_mixup=kwargs['use_mixup'], mixup_alpha=kwargs.get('mixup_alpha', 0.2))
        val_ds = create_tf_dataset(X_val_scaled, y_val, kwargs['batch_size'], is_training=False)

        tf.keras.backend.clear_session()
        model = build_model()
        steps_per_epoch = len(X_train_scaled) // kwargs['batch_size']
        if steps_per_epoch == 0:
            print(f"错误: Fold {fold_num} 的训练样本数小于batch_size")
            continue

        total_steps = steps_per_epoch * kwargs['num_epochs']
        lr_sched = WarmupCosine(base_lr=kwargs['learning_rate'], warmup_steps=int(0.05 * total_steps), total_steps=total_steps, min_lr=1e-5)
        optimizer = tf.keras.optimizers.AdamW(learning_rate=lr_sched, weight_decay=5e-3)
        loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=kwargs['label_smoothing'])
        model.compile(optimizer=optimizer, loss=loss_fn, metrics=['categorical_accuracy'])

        model_save_path = f'{OUTPUT_DIR}/fold_{fold_num}_model.keras'
        callbacks = [
            CompetitionScoreCallback(validation_data=val_ds),
            tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path, save_best_only=True, monitor='val_score', mode='max', verbose=1),
            tf.keras.callbacks.EarlyStopping(monitor='val_score', patience=kwargs['patience'], mode='max', verbose=1, restore_best_weights=True),
        ]

        print(f"🏋️ 开始训练 Fold {fold_num}...")
        history = model.fit(train_ds, epochs=kwargs['num_epochs'], validation_data=val_ds, callbacks=callbacks, steps_per_epoch=steps_per_epoch, verbose=2)
        plot_training_history(history, fold_num)

        print(f"\n🎯 Fold {fold_num} 最终评估:")
        y_val_unbatched = np.concatenate([y for x, y in val_ds], axis=0)
        y_val_labels = np.argmax(y_val_unbatched, axis=1)
        val_preds_final = model.predict(val_ds)
        val_preds_final_labels = np.argmax(val_preds_final, axis=1)
        y_val_labels = y_val_labels[:len(val_preds_final_labels)]

        final_score = get_competition_score(y_val_labels, val_preds_final_labels)
        final_acc = np.mean(val_preds_final_labels == y_val_labels)
        print(f"       验证集 Accuracy: {final_acc:.4f}")
        print(f"       验证集 Score: {final_score:.4f}")

        print(f"\n📊 Fold {fold_num} - 分类报告:")
        print(classification_report(
            y_val_labels,
            val_preds_final_labels,
            target_names=list(IDX2LABEL.values()),
            digits=4
        ))

        kfold_results.append({'fold': fold_num, 'val_accuracy': final_acc, 'val_score': final_score})

    return kfold_results

# ===================================================================
# 11. 测试集预测函数（简化版 - 不归一化四元数）
# ===================================================================
def predict_test_data(test_sequences_raw, output_dir=OUTPUT_DIR, n_folds=10):  # 改为10
    print(f"🔮 开始测试集预测 ({n_folds}-fold ensemble)...")
    all_predictions = []
    for fold_num in range(1, n_folds + 1):
        print(f"\n处理 Fold {fold_num}...")
        scaler_path = f'{output_dir}/fold_{fold_num}_scaler.joblib'
        scaler = joblib.load(scaler_path)
        print(f"   加载scaler: {scaler_path}")
        test_quat = [seq[:, :4] for seq in test_sequences_raw]
        test_other = [seq[:, 4:] for seq in test_sequences_raw]
        test_other_flat = np.vstack(test_other)
        test_other_scaled = scaler.transform(test_other_flat)
        test_sequences_scaled = []
        start = 0
        for i, seq in enumerate(test_sequences_raw):
            seq_len = len(seq)
            quat = test_quat[i]  # 直接使用，不归一化
            other_scaled = test_other_scaled[start:start+seq_len]
            seq_scaled = np.concatenate([quat, other_scaled], axis=1)
            test_sequences_scaled.append(seq_scaled)
            start += seq_len
        model_path = f'{output_dir}/fold_{fold_num}_model.keras'
        custom_objects = {'WarmupCosine': WarmupCosine, 'SumPooling1D': SumPooling1D}
        model = tf.keras.models.load_model(model_path, custom_objects=custom_objects, compile=False)
        print(f"   加载模型: {model_path}")
        dummy_labels = np.zeros(len(test_sequences_scaled), dtype=np.int32)
        test_ds = create_tf_dataset(test_sequences_scaled, dummy_labels, batch_size=64, is_training=False)
        fold_predictions = model.predict(test_ds, verbose=0)
        all_predictions.append(fold_predictions)
        print(f"   完成Fold {fold_num}的预测")
    ensemble_predictions = np.mean(all_predictions, axis=0)
    final_predictions = np.argmax(ensemble_predictions, axis=1)
    print(f"\n✅ 测试集预测完成！")
    return final_predictions, ensemble_predictions

# ===================================================================
# 12. 主执行逻辑 - 10-FOLD版本
# ===================================================================
if __name__ == '__main__':
    training_params = {
        'num_epochs': 3 if DEBUG else 100,
        'learning_rate': 0.001,
        'patience': 24,
        'batch_size': 64,
        'label_smoothing': 0.1,
        'use_mixup': True,
        'mixup_alpha': 0.4
    }

    if TRAIN:
        # 使用10-fold交叉验证
        results = run_kfold_training(sequences, labels, groups, n_folds=10, **training_params)
        print("\n\n" + "="*60 + "\n🎉 10-Fold 训练总结\n" + "="*60)
        if results:
            val_scores = [r['val_score'] for r in results]
            val_accs = [r['val_accuracy'] for r in results]
            print(f"平均验证集 Score: {np.mean(val_scores):.4f} ± {np.std(val_scores):.4f}")
            print(f"平均验证集 Accuracy: {np.mean(val_accs):.4f} ± {np.std(val_accs):.4f}")
            print(f"最佳折 Score: {np.max(val_scores):.4f} (Fold {np.argmax(val_scores)+1})")
            print(f"最差折 Score: {np.min(val_scores):.4f} (Fold {np.argmin(val_scores)+1})")

            feature_selection_info = {'all_features': ALL_FEATURE_NAMES, 'selected_features': FEATURE_NAMES, 'feature_indices': FEATURE_INDICES, 'num_features': len(FEATURE_NAMES), 'quaternion_safe': True, 'n_folds': 10}
            results_with_features = {'kfold_results': results, 'feature_selection': feature_selection_info, 'preprocessing_note': '10-fold CV版本：仅修复零四元数，不进行阈值归一化', 'lr_schedule': 'WarmupCosine'}
            with open(f'{OUTPUT_DIR}/kfold_results_10fold.json', 'w') as f:
                json.dump(results_with_features, f, indent=2)

    print("\n📝 重要说明:")
    print("1. 使用10-fold交叉验证提高模型稳定性")
    print("2. 自动检测并修复零四元数（填充为单位四元数）")
    print("3. 四元数在标准化时保持原始值，不进行归一化")
    print("4. 数据增强后对四元数进行必要的归一化")
    print("5. MixUp使用SLERP进行四元数插值")
    print("6. 标准化仅应用于非四元数特征")
    print("7. 每折训练后会打印详细的分类报告")

# ===================================================================
# 13. 生成提交文件 - 10-FOLD版本
# ===================================================================
def generate_submission(submission_output_path='submission.csv'):
    print("\n" + "="*60 + "\n📄 生成提交文件（10-fold ensemble）" + "\n" + "="*60)
    print("加载预处理的测试数据...")
    test_data = joblib.load(f'{PROCESSED_DATA_DIR}/processed_test_data_raw.joblib')
    test_sequences_full = test_data['sequence'].tolist()
    test_sequences = [seq[:, FEATURE_INDICES] for seq in test_sequences_full]

    # 修复测试集的零四元数
    print("\n处理测试集数据...")
    test_sequences = fix_zero_quaternions(test_sequences)

    test_sequence_ids = test_data['sequence_id'].values
    # 使用10-fold ensemble预测
    predictions, _ = predict_test_data(test_sequences, OUTPUT_DIR, n_folds=10)
    predicted_labels = [IDX2LABEL[pred] for pred in predictions]
    submission_df = pd.DataFrame({'sequence_id': test_sequence_ids, 'gesture': predicted_labels})
    submission_df.to_csv(submission_output_path, index=False)
    print(f"✅ 提交文件已保存到: {submission_output_path}")
    print("\n📊 预测分布:")
    print(submission_df['gesture'].value_counts())
    return submission_df