# All sensor model

all sensor model1

In [None]:
import os
import torch
import kagglehub
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, Subset
from tqdm.notebook import tqdm
from torch.amp import autocast
import pandas as pd
import polars as pl
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from scipy.spatial.transform import Rotation as R
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from collections import defaultdict
from sklearn.utils.class_weight import compute_class_weight
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import StratifiedGroupKFold

def remove_gravity_from_acc(acc_data, rot_data):
    if isinstance(acc_data, pd.DataFrame):
        acc_values = acc_data[['acc_x', 'acc_y', 'acc_z']].values
    else:
        acc_values = acc_data
    if isinstance(rot_data, pd.DataFrame):
        quat_values = rot_data[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
    else:
        quat_values = rot_data
    num_samples = acc_values.shape[0]
    linear_accel = np.zeros_like(acc_values)
    gravity_world = np.array([0, 0, 9.81])
    for i in range(num_samples):
        if np.all(np.isnan(quat_values[i])) or np.all(np.isclose(quat_values[i], 0)):
            linear_accel[i, :] = acc_values[i, :] 
            continue
        try:
            rotation = R.from_quat(quat_values[i])
            gravity_sensor_frame = rotation.apply(gravity_world, inverse=True)
            linear_accel[i, :] = acc_values[i, :] - gravity_sensor_frame
        except ValueError:
             linear_accel[i, :] = acc_values[i, :]
    return linear_accel

def calculate_angular_velocity_from_quat(rot_data, time_delta=1/200): # Assuming 200Hz sampling rate
    if isinstance(rot_data, pd.DataFrame):
        quat_values = rot_data[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
    else:
        quat_values = rot_data
    num_samples = quat_values.shape[0]
    angular_vel = np.zeros((num_samples, 3))
    for i in range(num_samples - 1):
        q_t = quat_values[i]
        q_t_plus_dt = quat_values[i+1]
        if np.all(np.isnan(q_t)) or np.all(np.isclose(q_t, 0)) or \
           np.all(np.isnan(q_t_plus_dt)) or np.all(np.isclose(q_t_plus_dt, 0)):
            continue
        try:
            rot_t = R.from_quat(q_t)
            rot_t_plus_dt = R.from_quat(q_t_plus_dt)
            delta_rot = rot_t.inv() * rot_t_plus_dt
            angular_vel[i, :] = delta_rot.as_rotvec() / time_delta
        except ValueError:
            pass
    return angular_vel

def calculate_angular_distance(rot_data):
    if isinstance(rot_data, pd.DataFrame):
        quat_values = rot_data[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
    else:
        quat_values = rot_data
    num_samples = quat_values.shape[0]
    angular_dist = np.zeros(num_samples)
    for i in range(num_samples - 1):
        q1 = quat_values[i]
        q2 = quat_values[i+1]
        if np.all(np.isnan(q1)) or np.all(np.isclose(q1, 0)) or \
           np.all(np.isnan(q2)) or np.all(np.isclose(q2, 0)):
            angular_dist[i] = 0
            continue
        try:
            r1 = R.from_quat(q1)
            r2 = R.from_quat(q2)
            relative_rotation = r1.inv() * r2
            angle = np.linalg.norm(relative_rotation.as_rotvec())
            angular_dist[i] = angle
        except ValueError:
            angular_dist[i] = 0 # В случае недействительных кватернионов
            pass
    return angular_dist

class CMIFeDataset(Dataset):
    def __init__(self, data_path, config):
        self.config = config
        self.init_feature_names(data_path)
        df = self.generate_features(pd.read_csv(data_path, usecols=set(self.use_cols) & set(self.raw_columns)))
        self.generate_dataset(df)

    def init_feature_names(self, data_path):
        self.target_gestures = [
            'Above ear - pull hair',
            'Cheek - pinch skin',
            'Eyebrow - pull hair',
            'Eyelash - pull hair',
            'Forehead - pull hairline',
            'Forehead - scratch',
            'Neck - pinch skin',
            'Neck - scratch',
        ]
        self.non_target_gestures = [
            'Write name on leg',
            'Wave hello',
            'Glasses on/off',
            'Text on phone',
            'Write name in air',
            'Feel around in tray and pull out an object',
            'Scratch knee/leg skin',
            'Pull air toward your face',
            'Drink from bottle/cup',
            'Pinch knee/leg skin'
        ]

        self.acc_features = ['acc_mag', 'acc_mag_jerk', 'linear_acc_mag', 'linear_acc_mag_jerk']
        self.rot_features = ['rot_angle', 'rot_angle_vel', 'angular_vel_x', 'angular_vel_y', 'angular_vel_z', 'angular_distance']
        self.old_imu_features = [
            'acc_mag', 'rot_angle','acc_mag_jerk', 'rot_angle_vel',
            'linear_acc_mag', 'linear_acc_mag_jerk',
            'angular_vel_x', 'angular_vel_y', 'angular_vel_z', 'angular_distance'
        ]

        self.extra_imu_features = self.config.get("imu_feats", [])
        self.imu_features = self.extra_imu_features.copy()
        if self.config.get("add_imu_feat_default", True):
            if self.config.get("old_imu_feat", True):
                self.imu_features.extend(self.old_imu_features)
            else:
                self.imu_features.extend(self.acc_features)
                self.imu_features.extend(self.rot_features)
        self.er1_fearues = ["er_x", "er_y", "er_z"]
        self.er2_fearues = ['er_r_xy', 'er_r_xz', 'er_r_yz', 'er_c_xy', 'er_c_xz', 'er_c_yz']
        self.er_fearues = self.er1_fearues + self.er2_fearues
        self.tof_mode = self.config.get("tof_mode", "stats")
        self.tof_region_stats = ['mean', 'std', 'min', 'max']
        self.tof_cols = self.generate_tof_feature_names()

        self.raw_columns = pd.read_csv(data_path, nrows=0).columns.tolist()
        self.imu_acc_cols_base = ['acc_x', 'acc_y', 'acc_z', 'linear_acc_x', 'linear_acc_y', 'linear_acc_z'] if self.config.get("add_raw_acc", False) else ['linear_acc_x', 'linear_acc_y', 'linear_acc_z']
        self.imu_rot_cols_base = ['rot_w', 'rot_x', 'rot_y', 'rot_z']
        self.imu_cols_base = self.imu_acc_cols_base + self.imu_rot_cols_base
        self.imu_cols = list()
        self.imu_channel_keys = defaultdict(list)
        if self.config.get("add_imu_base", True): 
            self.imu_cols.extend(self.imu_cols_base)
            self.imu_channel_keys["acc"] = self.imu_acc_cols_base
            self.imu_channel_keys["rot"] = self.imu_rot_cols_base
        if self.config.get("add_imu_feats", True): 
            self.imu_cols.extend(self.imu_features)
            if self.config.get("split_imu_feat", False):
                if self.config.get("old_imu_feat", True):
                    assert False, "split_imu_feat=True and old_imu_feat=True not supported"
                self.imu_channel_keys["acc_feat"] = self.acc_features
                self.imu_channel_keys["rot_feat"] = self.rot_features
            else:
                if self.config.get("old_imu_feat", True):
                    self.imu_channel_keys["other"].extend(self.old_imu_features)
                else:
                    self.imu_channel_keys["other"].extend(self.acc_features)
                    self.imu_channel_keys["other"].extend(self.rot_features)
        if self.config.get("add_imu_er_feats", False): 
            self.imu_cols.extend(self.er_fearues)
            if self.config.get("split_imu_feat", False):
                self.imu_channel_keys["er1_feat"] = self.er1_fearues
                self.imu_channel_keys["er2_feat"] = self.er2_fearues
            else:
                self.imu_channel_keys["other"].extend(self.er1_fearues)
                self.imu_channel_keys["other"].extend(self.er2_fearues)
        self.flip_imu_cols = [f"{col}_flip" for col in self.imu_cols]
        self.imu_channel_keys = {k: sorted(v) for k, v in self.imu_channel_keys.items()}
        self.thm_cols = [c for c in self.raw_columns if c.startswith('thm_')]
        self.thm_channel_keys = {k: [f"thm_{k}"] for k in range(1, 6)}
        self.feature_cols = self.imu_cols + self.thm_cols + self.tof_cols
        self.imu_dim = len(self.imu_cols)
        self.thm_dim = len(self.thm_cols)
        self.tof_dim = len(self.tof_cols)
        self.base_cols = ['acc_x', 'acc_y', 'acc_z',
                          'rot_x', 'rot_y', 'rot_z', 'rot_w',
                          'sequence_id', 'subject', 
                          'sequence_type', 'gesture', 'orientation'] + [c for c in self.raw_columns if c.startswith('thm_')] + [f"tof_{i}_v{p}" for i in range(1, 6) for p in range(64)]
        self.use_cols = self.base_cols + self.feature_cols
        if self.config.get("return_flip_imu", False):
            self.use_cols.extend(self.flip_imu_cols)
        self.fold_cols = ['subject', 'sequence_type', 'gesture', 'orientation', 'sequence_id']
        if self.config.get("use_dg", False):
            self.dg_cols = ['adult_child', 'age', 'sex', 'handedness', 'shoulder_to_wrist_height', 'elbow_to_wrist_height']
        self.global_imu_indices = {k: sorted([self.imu_cols.index(feat) for feat in feats]) for k, feats in self.imu_channel_keys.items()}
        self.global_thm_indices = {k: sorted([self.thm_cols.index(key) for key in self.thm_channel_keys[k]]) for k in range(1, 6)}
        self.global_tof_indices = {k: sorted([self.tof_cols.index(key) for key in self.tof_channel_keys[k]]) for k in range(1, 6)}
            
    def generate_tof_feature_names(self):
        features = list()
        self.tof_channel_keys = defaultdict(list)
        if self.config.get("tof_raw", False):
            for i in range(1, 6):
                features.extend([f"tof_{i}_v{p}" for p in range(64)])
                self.tof_channel_keys[i].extend([f"tof_{i}_v{p}" for p in range(64)])
        for i in range(1, 6):
            if self.tof_mode != 0:
                for stat in self.tof_region_stats:
                    features.append(f'tof_{i}_{stat}')
                    self.tof_channel_keys[i].append(f'tof_{i}_{stat}')
                if self.tof_mode > 1:
                    for r in range(self.tof_mode):
                        for stat in self.tof_region_stats:
                            features.append(f'tof{self.tof_mode}_{i}_region_{r}_{stat}')
                            self.tof_channel_keys[i].append(f'tof{self.tof_mode}_{i}_region_{r}_{stat}')
                if self.tof_mode == -1:
                    for mode in [2, 4, 8, 16, 32]:
                        for r in range(mode):
                            for stat in self.tof_region_stats:
                                features.append(f'tof{mode}_{i}_region_{r}_{stat}')
                                self.tof_channel_keys[i].append(f'tof{mode}_{i}_region_{r}_{stat}')
        return features

    def compute_cross_axis_energy(self, df):
        axes=['x', 'y', 'z']
        features = {}
        for axis in axes:
            fft_result = fft(df[f'acc_{axis}'].values)
            energy = np.sum(np.abs(fft_result)**2)
            features[f"er_{axis}"] = energy
        for i, axis1 in enumerate(axes):
            for axis2 in axes[i+1:]:
                features[f'er_r_{axis1}{axis2}'] = features[f'er_{axis1}'] / (features[f'er_{axis2}'] + 1e-6)
        for i, axis1 in enumerate(axes):
            for axis2 in axes[i+1:]:
                features[f'er_c_{axis1}{axis2}'] = np.corrcoef(np.abs(fft(df[f'acc_{axis1}'].values)), np.abs(fft(df[f'acc_{axis2}'].values)))[0, 1]
        return {k: v for k, v in features.items() if k in self.er_fearues}

    def compute_imu_features(self, df):
        if self.config.get("rot_fillna", False):
            df['rot_w'] = df['rot_w'].fillna(1)
            df[['rot_x', 'rot_y', 'rot_z']] = df[['rot_x', 'rot_y', 'rot_z']].fillna(0)
        df['acc_mag'] = np.sqrt(df['acc_x']**2 + df['acc_y']**2 + df['acc_z']**2)
        df['rot_angle'] = 2 * np.arccos(df['rot_w'].clip(-1, 1))
        df['acc_mag_jerk'] = df.groupby('sequence_id')['acc_mag'].diff().fillna(0)
        df['rot_angle_vel'] = df.groupby('sequence_id')['rot_angle'].diff().fillna(0)
            
        linear_accel_list = []
        for _, group in df.groupby('sequence_id'):
            acc_data_group = group[['acc_x', 'acc_y', 'acc_z']]
            rot_data_group = group[['rot_x', 'rot_y', 'rot_z', 'rot_w']]
            linear_accel_group = remove_gravity_from_acc(acc_data_group, rot_data_group)
            linear_accel_list.append(pd.DataFrame(linear_accel_group, columns=['linear_acc_x', 'linear_acc_y', 'linear_acc_z'], index=group.index))
        df_linear_accel = pd.concat(linear_accel_list)
        df = pd.concat([df, df_linear_accel], axis=1)
        df['linear_acc_mag'] = np.sqrt(df['linear_acc_x']**2 + df['linear_acc_y']**2 + df['linear_acc_z']**2)
        df['linear_acc_mag_jerk'] = df.groupby('sequence_id')['linear_acc_mag'].diff().fillna(0)
    
        angular_vel_list = []
        for _, group in df.groupby('sequence_id'):
            rot_data_group = group[['rot_x', 'rot_y', 'rot_z', 'rot_w']]
            angular_vel_group = calculate_angular_velocity_from_quat(rot_data_group)
            angular_vel_list.append(pd.DataFrame(angular_vel_group, columns=['angular_vel_x', 'angular_vel_y', 'angular_vel_z'], index=group.index))
        df_angular_vel = pd.concat(angular_vel_list)
        df = pd.concat([df, df_angular_vel], axis=1)
    
        angular_distance_list = []
        for _, group in df.groupby('sequence_id'):
            rot_data_group = group[['rot_x', 'rot_y', 'rot_z', 'rot_w']]
            angular_dist_group = calculate_angular_distance(rot_data_group)
            angular_distance_list.append(pd.DataFrame(angular_dist_group, columns=['angular_distance'], index=group.index))
        df_angular_distance = pd.concat(angular_distance_list)
        df = pd.concat([df, df_angular_distance], axis=1)
        return df

    def compute_flip_features(self, df):
        flip_df = df[['sequence_id', 'acc_x', 'acc_y', 'acc_z', 'rot_x', 'rot_y', 'rot_z', 'rot_w']].copy()
        flip_df[['acc_x', 'acc_y', 'rot_x', 'rot_y']] *= -1
        flip_df = self.compute_imu_features(flip_df)
        for col in flip_df.columns:
            if col != 'sequence_id':
                df[f"{col}_flip"] = flip_df[col]
        return df

    def compute_features(self, df):
        df = self.compute_imu_features(df)
        if self.tof_mode != 0:
            new_columns = {}
            for i in range(1, 6):
                pixel_cols = [f"tof_{i}_v{p}" for p in range(64)]
                tof_data = df[pixel_cols].replace(-1, np.nan)
                new_columns.update({
                    f'tof_{i}_mean': tof_data.mean(axis=1),
                    f'tof_{i}_std': tof_data.std(axis=1),
                    f'tof_{i}_min': tof_data.min(axis=1),
                    f'tof_{i}_max': tof_data.max(axis=1)
                })
                if self.tof_mode > 1:
                    region_size = 64 // self.tof_mode
                    for r in range(self.tof_mode):
                        region_data = tof_data.iloc[:, r*region_size : (r+1)*region_size]
                        new_columns.update({
                            f'tof{self.tof_mode}_{i}_region_{r}_mean': region_data.mean(axis=1),
                            f'tof{self.tof_mode}_{i}_region_{r}_std': region_data.std(axis=1),
                            f'tof{self.tof_mode}_{i}_region_{r}_min': region_data.min(axis=1),
                            f'tof{self.tof_mode}_{i}_region_{r}_max': region_data.max(axis=1)
                        })
                if self.tof_mode == -1:
                    for mode in [2, 4, 8, 16, 32]:
                        region_size = 64 // mode
                        for r in range(mode):
                            region_data = tof_data.iloc[:, r*region_size : (r+1)*region_size]
                            new_columns.update({
                                f'tof{mode}_{i}_region_{r}_mean': region_data.mean(axis=1),
                                f'tof{mode}_{i}_region_{r}_std': region_data.std(axis=1),
                                f'tof{mode}_{i}_region_{r}_min': region_data.min(axis=1),
                                f'tof{mode}_{i}_region_{r}_max': region_data.max(axis=1)
                            })
            df = pd.concat([df, pd.DataFrame(new_columns)], axis=1)
            
        def _calc_features(group):
            return pd.DataFrame(self.compute_cross_axis_energy(group), index=[group.index[0]])
        features_df = df.groupby('sequence_id', group_keys=False).apply(_calc_features)
        df = df.join(features_df, how='left')
        df[features_df.columns] = df.groupby('sequence_id')[features_df.columns].ffill()
        
        return df
        
    def generate_features(self, df):
        self.le = LabelEncoder()
        if self.config.get("one_neg", False):
            neg_other = "Write name on leg"
            df['gesture'] = df['gesture'].apply(lambda x: x if x in self.target_gestures else neg_other)
        df['gesture_int'] = self.le.fit_transform(df['gesture'])
        self.class_num = len(self.le.classes_)
        self.target_ints = np.array([self.le.classes_.tolist().index(name) for name in self.target_gestures])
        self.non_target_ints = np.array([self.le.classes_.tolist().index(name) for name in self.non_target_gestures])
        
        if all(c in df.columns for c in self.feature_cols):
            print("Features have precomputed, skip compute.")
        else:
            print("Features not precomputed, do compute.")
            df = self.compute_features(df)

        if self.config.get("return_flip_imu", False):
            if all(c in df.columns for c in self.flip_imu_cols):
                print("Flip have precomputed, skip compute.")
            else:
                print("Flip not precomputed, do compute.")
                df = self.compute_flip_features(df)

        if self.config.get("use_dg", False):
            dg_df = pd.read_csv(self.config["dg_path"])
            df = pd.merge(df, dg_df, how='left', on='subject')
            df['age'] /= 100
            df['shoulder_to_wrist_height'] = df['shoulder_to_wrist_cm'] / df['height_cm']
            df['elbow_to_wrist_height'] = df['elbow_to_wrist_cm'] / df['height_cm']
        
        if self.config.get("save_precompute", False):
            df.to_csv(self.config.get("save_filename", "train.csv"))
        return df

    def scale(self, data_unscaled):
        scaler_function = self.config.get("scaler_function", StandardScaler())
        scaler = scaler_function.fit(np.concatenate(data_unscaled, axis=0))
        return [scaler.transform(x) for x in data_unscaled], scaler

    def pad(self, data_scaled, cols):
        pad_data = np.zeros((len(data_scaled), self.pad_len, len(cols)), dtype='float32')
        for i, seq in enumerate(data_scaled):
            seq_len = min(len(seq), self.pad_len)
            pad_data[i, :seq_len] = seq[:seq_len]
        return pad_data

    def get_nan_value(self, data, ratio):
        max_value = data.max().max()
        nan_value = -max_value * ratio
        print(f"Max: {max_value}, set nan to {nan_value}")
        return nan_value

    def generate_dataset(self, df):
        seq_gp = df.groupby('sequence_id') 
        imu_unscaled, thm_unscaled, tof_unscaled = list(), list(), list()
        if self.config.get("return_flip_imu", False): flip_imu_unscaled = list()
        classes, lens = list(), list()
        self.imu_nan_value = self.get_nan_value(df[self.imu_cols], self.config["nan_ratio"]["imu"])
        self.thm_nan_value = self.get_nan_value(df[self.thm_cols], self.config["nan_ratio"]["thm"])
        self.tof_nan_value = self.get_nan_value(df[self.tof_cols], self.config["nan_ratio"]["tof"])
        if self.config.get("use_dg", False):
            self.dg = list()

        self.fold_feats = defaultdict(list)
        for seq_id, seq_df in seq_gp:
            imu_data = seq_df[self.imu_cols]
            if self.config["fbfill"]["imu"]:
                imu_data = imu_data.ffill().bfill()
            imu_unscaled.append(imu_data.fillna(self.imu_nan_value).values.astype('float32'))

            if self.config.get("return_flip_imu", False):
                flip_imu_data = seq_df[self.flip_imu_cols]
                if self.config["fbfill"]["imu"]:
                    flip_imu_data = flip_imu_data.ffill().bfill()
                flip_imu_unscaled.append(flip_imu_data.fillna(self.imu_nan_value).values.astype('float32'))

            thm_data = seq_df[self.thm_cols]
            if self.config["fbfill"]["thm"]:
                thm_data = thm_data.ffill().bfill()
            thm_unscaled.append(thm_data.fillna(self.thm_nan_value).values.astype('float32'))

            tof_data = seq_df[self.tof_cols]
            if self.config["fbfill"]["tof"]:
                tof_data = tof_data.ffill().bfill()
            tof_unscaled.append(tof_data.fillna(self.tof_nan_value).values.astype('float32'))
            
            classes.append(seq_df['gesture_int'].iloc[0])
            lens.append(len(imu_data))

            for col in self.fold_cols:
                self.fold_feats[col].append(seq_df[col].iloc[0])

            if self.config.get("use_dg", False):
                self.dg.append(seq_df[self.dg_cols].iloc[0].values.astype('float32'))
            
        self.dataset_indices = classes
        self.pad_len = int(np.percentile(lens, self.config.get("percent", 95)))
        if self.config.get("one_scale", True):
            x_unscaled = [np.concatenate([imu, thm, tof], axis=1) for imu, thm, tof in zip(imu_unscaled, thm_unscaled, tof_unscaled)]
            x_scaled, self.x_scaler = self.scale(x_unscaled)
            x = self.pad(x_scaled, self.imu_cols+self.thm_cols+self.tof_cols)
            self.imu = x[..., :self.imu_dim]
            self.thm = x[..., self.imu_dim:self.imu_dim+self.thm_dim]
            self.tof = x[..., self.imu_dim+self.thm_dim:self.imu_dim+self.thm_dim+self.tof_dim]

            if self.config.get("return_flip_imu", False):
                flip_x_unscaled = [np.concatenate([flip_imu, thm, tof], axis=1) for flip_imu, thm, tof in zip(flip_imu_unscaled, thm_unscaled, tof_unscaled)]
                flip_x_scaled = [self.x_scaler.transform(x) for x in flip_x_unscaled]
                flip_x = self.pad(flip_x_scaled, self.imu_cols+self.thm_cols+self.tof_cols)
                self.flip_imu = flip_x[..., :self.imu_dim]
        else:
            imu_scaled, self.imu_scaler = self.scale(imu_unscaled)
            thm_scaled, self.thm_scaler = self.scale(thm_unscaled)
            tof_scaled, self.tof_scaler = self.scale(tof_unscaled)
            self.imu = self.pad(imu_scaled, self.imu_cols)
            self.thm = self.pad(thm_scaled, self.thm_cols)
            self.tof = self.pad(tof_scaled, self.tof_cols)

            if self.config.get("return_flip_imu", False):
                flip_imu_scaled = [self.imu_scaler.transform(x) for x in flip_imu_unscaled]
                self.flip_imu = self.pad(flip_imu_scaled, self.imu_cols)
        self.precompute_scaled_nan_values()
        self.class_ = F.one_hot(torch.from_numpy(np.array(classes)).long(), num_classes=len(self.le.classes_)).float().numpy()
        self.binary_class_ = np.isin(np.array(classes), self.target_ints).astype(np.float32)
        self.class_weight = torch.FloatTensor(compute_class_weight('balanced', classes=np.arange(len(self.le.classes_)), y=classes))

    def precompute_scaled_nan_values(self):
        dummy_df = pd.DataFrame(
            np.array([[self.imu_nan_value]*len(self.imu_cols) + 
                     [self.thm_nan_value]*len(self.thm_cols) +
                     [self.tof_nan_value]*len(self.tof_cols)]),
            columns=self.imu_cols + self.thm_cols + self.tof_cols
        )
        
        if self.config.get("one_scale", True):
            scaled = self.x_scaler.transform(dummy_df)
            self.imu_scaled_nan = scaled[0, :self.imu_dim].mean()
            self.thm_scaled_nan = scaled[0, self.imu_dim:self.imu_dim+self.thm_dim].mean()
            self.tof_scaled_nan = scaled[0, self.imu_dim+self.thm_dim:self.imu_dim+self.thm_dim+self.tof_dim].mean()
        else:
            self.imu_scaled_nan = self.imu_scaler.transform(dummy_df[self.imu_cols])[0].mean()
            self.thm_scaled_nan = self.thm_scaler.transform(dummy_df[self.thm_cols])[0].mean()
            self.tof_scaled_nan = self.tof_scaler.transform(dummy_df[self.tof_cols])[0].mean()

    def get_scaled_nan_tensors(self, imu, thm, tof):
        return torch.full(imu.shape, self.imu_scaled_nan, device=imu.device), \
            torch.full(thm.shape, self.thm_scaled_nan, device=thm.device), \
            torch.full(tof.shape, self.tof_scaled_nan, device=tof.device)

    def inference_process(self, sequence, demographics=None, reverse=False):
        if self.config.get("use_dg", False):
            assert demographics is not None, "Demographics needed"
            df_dg = demographics.to_pandas().copy()
            df_dg['age'] /= 100
            df_dg['shoulder_to_wrist_height'] = df_dg['shoulder_to_wrist_cm'] / df_dg['height_cm']
            df_dg['elbow_to_wrist_height'] = df_dg['elbow_to_wrist_cm'] / df_dg['height_cm']
        df_seq = sequence.to_pandas().copy()
        if reverse:
            df_seq[['acc_x', 'acc_y', 'rot_x', 'rot_y']] *= -1
        if self.config.get("rot_fillna", False):
            df_seq['rot_w'] = df_seq['rot_w'].fillna(1)
            df_seq[['rot_x', 'rot_y', 'rot_z']] = df_seq[['rot_x', 'rot_y', 'rot_z']].fillna(0)
        if not all(c in df_seq.columns for c in self.imu_features):
            df_seq['acc_mag'] = np.sqrt(df_seq['acc_x']**2 + df_seq['acc_y']**2 + df_seq['acc_z']**2)
            df_seq['rot_angle'] = 2 * np.arccos(df_seq['rot_w'].clip(-1, 1))
            df_seq['acc_mag_jerk'] = df_seq['acc_mag'].diff().fillna(0)
            df_seq['rot_angle_vel'] = df_seq['rot_angle'].diff().fillna(0)
            if all(col in df_seq.columns for col in ['acc_x', 'acc_y', 'acc_z', 'rot_x', 'rot_y', 'rot_z', 'rot_w']):
                linear_accel = remove_gravity_from_acc(
                    df_seq[['acc_x', 'acc_y', 'acc_z']], 
                    df_seq[['rot_x', 'rot_y', 'rot_z', 'rot_w']]
                )
                df_seq[['linear_acc_x', 'linear_acc_y', 'linear_acc_z']] = linear_accel
            else:
                df_seq['linear_acc_x'] = df_seq.get('acc_x', 0)
                df_seq['linear_acc_y'] = df_seq.get('acc_y', 0)
                df_seq['linear_acc_z'] = df_seq.get('acc_z', 0)
            df_seq['linear_acc_mag'] = np.sqrt(df_seq['linear_acc_x']**2 + df_seq['linear_acc_y']**2 + df_seq['linear_acc_z']**2)
            df_seq['linear_acc_mag_jerk'] = df_seq['linear_acc_mag'].diff().fillna(0)
            if all(col in df_seq.columns for col in ['rot_x', 'rot_y', 'rot_z', 'rot_w']):
                angular_vel = calculate_angular_velocity_from_quat(df_seq[['rot_x', 'rot_y', 'rot_z', 'rot_w']])
                df_seq[['angular_vel_x', 'angular_vel_y', 'angular_vel_z']] = angular_vel
            else:
                df_seq[['angular_vel_x', 'angular_vel_y', 'angular_vel_z']] = 0
            if all(col in df_seq.columns for col in ['rot_x', 'rot_y', 'rot_z', 'rot_w']):
                df_seq['angular_distance'] = calculate_angular_distance(df_seq[['rot_x', 'rot_y', 'rot_z', 'rot_w']])
            else:
                df_seq['angular_distance'] = 0

        if self.tof_mode != 0:
            new_columns = {} 
            for i in range(1, 6):
                pixel_cols = [f"tof_{i}_v{p}" for p in range(64)]
                tof_data = df_seq[pixel_cols].replace(-1, np.nan)
                new_columns.update({
                    f'tof_{i}_mean': tof_data.mean(axis=1),
                    f'tof_{i}_std': tof_data.std(axis=1),
                    f'tof_{i}_min': tof_data.min(axis=1),
                    f'tof_{i}_max': tof_data.max(axis=1)
                })
                if self.tof_mode > 1:
                    region_size = 64 // self.tof_mode
                    for r in range(self.tof_mode):
                        region_data = tof_data.iloc[:, r*region_size : (r+1)*region_size]
                        new_columns.update({
                            f'tof{self.tof_mode}_{i}_region_{r}_mean': region_data.mean(axis=1),
                            f'tof{self.tof_mode}_{i}_region_{r}_std': region_data.std(axis=1),
                            f'tof{self.tof_mode}_{i}_region_{r}_min': region_data.min(axis=1),
                            f'tof{self.tof_mode}_{i}_region_{r}_max': region_data.max(axis=1)
                        })
                if self.tof_mode == -1:
                    for mode in [2, 4, 8, 16, 32]:
                        region_size = 64 // mode
                        for r in range(mode):
                            region_data = tof_data.iloc[:, r*region_size : (r+1)*region_size]
                            new_columns.update({
                                f'tof{mode}_{i}_region_{r}_mean': region_data.mean(axis=1),
                                f'tof{mode}_{i}_region_{r}_std': region_data.std(axis=1),
                                f'tof{mode}_{i}_region_{r}_min': region_data.min(axis=1),
                                f'tof{mode}_{i}_region_{r}_max': region_data.max(axis=1)
                            })
            df_seq = pd.concat([df_seq, pd.DataFrame(new_columns)], axis=1)
        
        imu_unscaled = df_seq[self.imu_cols]
        if self.config["fbfill"]["imu"]:
            imu_unscaled = imu_unscaled.ffill().bfill()
        imu_unscaled = imu_unscaled.fillna(self.imu_nan_value).values.astype('float32')

        thm_unscaled = df_seq[self.thm_cols]
        if self.config["fbfill"]["thm"]:
            thm_unscaled = thm_unscaled.ffill().bfill()
        thm_unscaled = thm_unscaled.fillna(self.thm_nan_value).values.astype('float32')

        tof_unscaled = df_seq[self.tof_cols]
        if self.config["fbfill"]["tof"]:
            tof_unscaled = tof_unscaled.ffill().bfill()
        tof_unscaled = tof_unscaled.fillna(self.tof_nan_value).values.astype('float32')
        
        if self.config.get("one_scale", True):
            x_unscaled = np.concatenate([imu_unscaled, thm_unscaled, tof_unscaled], axis=1)
            x_scaled = self.x_scaler.transform(x_unscaled)
            imu_scaled = x_scaled[..., :self.imu_dim]
            thm_scaled = x_scaled[..., self.imu_dim:self.imu_dim+self.thm_dim]
            tof_scaled = x_scaled[..., self.imu_dim+self.thm_dim:self.imu_dim+self.thm_dim+self.tof_dim]
        else:
            imu_scaled = self.imu_scaler.transform(imu_unscaled)
            thm_scaled = self.thm_scaler.transform(thm_unscaled)
            tof_scaled = self.tof_scaler.transform(tof_unscaled)

        combined = np.concatenate([imu_scaled, thm_scaled, tof_scaled], axis=1)
        padded = np.zeros((self.pad_len, combined.shape[1]), dtype='float32')
        seq_len = min(combined.shape[0], self.pad_len)
        padded[:seq_len] = combined[:seq_len]
        imu = padded[..., :self.imu_dim]
        thm = padded[..., self.imu_dim:self.imu_dim+self.thm_dim]
        tof = padded[..., self.imu_dim+self.thm_dim:self.imu_dim+self.thm_dim+self.tof_dim]

        ret = [torch.from_numpy(imu).float().unsqueeze(0), torch.from_numpy(thm).float().unsqueeze(0), torch.from_numpy(tof).float().unsqueeze(0)]
        if self.config.get("use_dg", False):
            dg = df_dg[self.dg_cols].values.astype('float32')
            ret.append(torch.from_numpy(dg).float())
        return ret

    def split5(self, imu, thm, tof):
        imus = [imu[:, :, self.global_imu_indices[k]] for k in self.global_imu_indices]
        thms = [thm[:, :, self.global_thm_indices[k]] for k in range(1, 6)]
        tofs = [tof[:, :, self.global_tof_indices[k]] for k in range(1, 6)]
        return imus, thms, tofs

    def slide(self, imu, thm, tof, ratio=1.0):
        def slide_tensor(tensor, nan_value, ratio):
            b, l, d = tensor.shape
            length = int(l * ratio)
            if length > l:
                pad = torch.full((b, length-l, d), nan_value, device=tensor.device)
                tensor = torch.cat([tensor, pad], dim=1)
            elif length < l:
                tensor = tensor[:, :length, :] 
            return tensor
        return slide_tensor(imu, self.imu_scaled_nan, ratio), slide_tensor(thm, self.thm_scaled_nan, ratio), slide_tensor(tof, self.tof_scaled_nan, ratio)

    def __getitem__(self, idx):
        ret = [self.imu[idx], self.thm[idx], self.tof[idx], self.class_[idx], self.binary_class_[idx]]
        if self.config.get("return_extra", False):
            fold_feat_info = [self.fold_feats[col][idx] for col in self.fold_cols]
            ret.append((idx, fold_feat_info))
        if self.config.get("use_dg", False):
            ret.append(self.dg[idx])
        if self.config.get("return_flip_imu", False):
            ret.append(self.flip_imu[idx])
        return ret

    def __len__(self):
        return len(self.class_)

class CMIFoldDataset:
    def __init__(self, data_path, config, full_dataset_function, n_folds=5, random_seed=0):
        self.full_dataset = full_dataset_function(data_path=data_path, config=config)
        self.imu_dim = self.full_dataset.imu_dim
        self.thm_dim = self.full_dataset.thm_dim
        self.tof_dim = self.full_dataset.tof_dim
        self.le = self.full_dataset.le
        self.class_names = self.full_dataset.le.classes_
        self.class_weight = self.full_dataset.class_weight
        self.n_folds = n_folds
        self.sgkf = StratifiedGroupKFold(n_splits=n_folds, shuffle=True, random_state=random_seed)
        self.fold_y = np.array(self.full_dataset.fold_feats[config.get("fold_y", "sequence_type")])
        self.fold_groups = np.array(self.full_dataset.fold_feats[config.get("fold_groups", "subject")])
        self.folds = list(self.sgkf.split(X=np.arange(len(self.full_dataset)), y=self.fold_y, groups=self.fold_groups))
        self.exclude_subjects = set(config.get("exclude_subjects", []))
    
    def get_fold_datasets(self, fold_idx):
        if self.folds is None or fold_idx >= self.n_folds: return None, None
        fold_train_idx, fold_valid_idx = self.folds[fold_idx]
        subjects = np.array(self.full_dataset.fold_feats["subject"])
        train_subjects, valid_subjects = subjects[fold_train_idx], subjects[fold_valid_idx]
        train_mask, valid_mask = ~np.isin(train_subjects, list(self.exclude_subjects)), ~np.isin(valid_subjects, list(self.exclude_subjects))
        return Subset(self.full_dataset, np.array(fold_train_idx)[train_mask].tolist()), Subset(self.full_dataset, np.array(fold_valid_idx)[valid_mask].tolist())

    def print_fold_stats(self):
        def get_label_counts(subset):
            counts = {name: 0 for name in self.class_names}
            if subset is None: return counts
            for idx in subset.indices:
                label_idx = self.full_dataset.dataset_indices[idx]
                counts[self.class_names[label_idx]] += 1
            return counts
        
        print("\n交叉验证折叠统计:")
        for fold_idx in range(self.n_folds):
            train_fold, valid_fold = self.get_fold_datasets(fold_idx)
            train_counts = get_label_counts(train_fold)
            valid_counts = get_label_counts(valid_fold)
            print(f"\nFold {fold_idx + 1}:")
            print(f"{'类别':<50} {'训练集':<10} {'验证集':<10}")
            for name in self.class_names:
                print(f"{name:<50} {train_counts[name]:<10} {valid_counts[name]:<10}")

        for fold_idx, (train_idx, val_idx) in enumerate(self.folds):
            train_subjects = set(self.fold_groups[train_idx])
            val_subjects = set(self.fold_groups[val_idx])
            print(f"\nFold {fold_idx + 1}:")
            print("训练集受试者:", train_subjects)
            print("验证集受试者:", val_subjects)

        self.print_filtered_stats()

    def print_filtered_stats(self):
        original_counts = defaultdict(int)
        filtered_counts = defaultdict(int)
        
        for fold_idx in range(self.n_folds):
            train_idx, val_idx = self.folds[fold_idx]
            for idx in train_idx:
                original_counts['train'] += 1
            for idx in val_idx:
                original_counts['valid'] += 1
            train_set, val_set = self.get_fold_datasets(fold_idx)
            filtered_counts['train'] += len(train_set)
            filtered_counts['valid'] += len(val_set)
        
        print(f"\n排除subject {self.exclude_subjects} 后的数据量变化:")
        print(f"原始训练集样本: {original_counts['train']}")
        print(f"过滤后训练集样本: {filtered_counts['train']}")
        print(f"原始验证集样本: {original_counts['valid']}") 
        print(f"过滤后验证集样本: {filtered_counts['valid']}")

class SEBlock(nn.Module):
    def __init__(self, channels, reduction = 8):
        super().__init__()
        self.fc1 = nn.Linear(channels, channels // reduction, bias=True)
        self.fc2 = nn.Linear(channels // reduction, channels, bias=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # x: (B, C, L)
        se = F.adaptive_avg_pool1d(x, 1).squeeze(-1)      # -> (B, C)
        se = F.relu(self.fc1(se), inplace=True)          # -> (B, C//r)
        se = self.sigmoid(self.fc2(se)).unsqueeze(-1)    # -> (B, C, 1)
        return x * se                

class ResNetSEBlock(nn.Module):
    def __init__(self, in_channels, out_channels, wd = 1e-4):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels,
                               kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels,
                               kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(out_channels)
        # SE
        self.se = SEBlock(out_channels)
        
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1,
                          padding=0, bias=False),
                nn.BatchNorm1d(out_channels)
            )
        else:
            self.shortcut = nn.Identity()

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x) :
        identity = self.shortcut(x)              # (B, out, L)
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.se(out)                       # (B, out, L)
        out = out + identity
        return self.relu(out)

class AttentionLayer(nn.Module):
    def __init__(self, feature_dim):
        super().__init__()
        self.score_fn = nn.Linear(feature_dim, 1, bias=True)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        # x: (B, L, F)
        score = torch.tanh(self.score_fn(x))     # (B, L, 1)
        weights = self.softmax(score.squeeze(-1))# (B, L)
        weights = weights.unsqueeze(-1)          # (B, L, 1)
        context = x * weights                    # (B, L, F)
        return context.sum(dim=1)                # (B, F)

class GaussianNoise(nn.Module):
    """Add Gaussian noise to input tensor"""
    def __init__(self, stddev):
        super().__init__()
        self.stddev = stddev
    
    def forward(self, x):
        if self.training:
            noise = torch.randn_like(x) * self.stddev
            return x + noise
        return x

class CMIBackbone(nn.Module):
    def __init__(self, imu_dim, thm_dim, tof_dim, **kwargs):
        super().__init__()
        self.imu_acc_branch = nn.Sequential(
            self.residual_feature_block(3, kwargs["imu1_channels"], kwargs["imu1_layers"], drop=kwargs["imu1_dropout"]),
            self.residual_feature_block(kwargs["imu1_channels"], kwargs["imu2_channels"], kwargs["imu2_layers"], drop=kwargs["imu2_dropout"])
        )
        self.imu_rot_branch = nn.Sequential(
            self.residual_feature_block(4, kwargs["imu1_channels"], kwargs["imu1_layers"], drop=kwargs["imu1_dropout"]),
            self.residual_feature_block(kwargs["imu1_channels"], kwargs["imu2_channels"], kwargs["imu2_layers"], drop=kwargs["imu2_dropout"])
        )
        self.imu_other_branch = nn.Sequential(
            self.residual_feature_block(imu_dim-7, kwargs["imu1_channels"], kwargs["imu1_layers"], drop=kwargs["imu1_dropout"]),
            self.residual_feature_block(kwargs["imu1_channels"], kwargs["imu2_channels"], kwargs["imu2_layers"], drop=kwargs["imu2_dropout"])
        )

        self.thm_branch1, self.tof_branch1 = self.init_thm_tof_branch(thm_dim//5, tof_dim//5, **kwargs)
        self.thm_branch2, self.tof_branch2 = self.init_thm_tof_branch(thm_dim//5, tof_dim//5, **kwargs)
        self.thm_branch3, self.tof_branch3 = self.init_thm_tof_branch(thm_dim//5, tof_dim//5, **kwargs)
        self.thm_branch4, self.tof_branch4 = self.init_thm_tof_branch(thm_dim//5, tof_dim//5, **kwargs)
        self.thm_branch5, self.tof_branch5 = self.init_thm_tof_branch(thm_dim//5, tof_dim//5, **kwargs)

        self.imu_proj = ResNetSEBlock(in_channels=3*kwargs["imu2_channels"], out_channels=kwargs["imu2_channels"])
        self.thm_proj = ResNetSEBlock(in_channels=5*kwargs["thm2_channels"], out_channels=kwargs["thm2_channels"])
        self.tof_proj = ResNetSEBlock(in_channels=5*kwargs["tof2_channels"], out_channels=kwargs["tof2_channels"])

        self.lstm = nn.LSTM(
            input_size=kwargs['imu2_channels']+kwargs['thm2_channels']+kwargs['tof2_channels'],
            hidden_size=kwargs['lstm_hidden_size'],
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )
        self.gru = nn.GRU(
            input_size=kwargs['imu2_channels']+kwargs['thm2_channels']+kwargs['tof2_channels'],
            hidden_size=kwargs['gru_hidden_size'],
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )
        
        self.noise = GaussianNoise(kwargs['gaussian_noise_rate'])
        self.dense = nn.Sequential(
            nn.Linear(kwargs['imu2_channels']+kwargs['thm2_channels']+kwargs['tof2_channels'], kwargs['dense_channels']),
            nn.ELU()
        )
        
        self.attn = AttentionLayer(feature_dim=(kwargs['lstm_hidden_size']+kwargs['gru_hidden_size'])*2+kwargs['dense_channels'])  # lstm + gru + dense

    def feature_block(self, in_channels, out_channels, num_layers, pool_size=2, drop=0.3):
        return nn.Sequential(
            *[ResNetSEBlock(in_channels=in_channels, out_channels=in_channels) for i in range(num_layers)],
            nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(pool_size, ceil_mode=True),
            nn.Dropout(drop)
        )

    def residual_feature_block(self, in_channels, out_channels, num_layers, pool_size=2, drop=0.3):
        return nn.Sequential(
            *[ResNetSEBlock(in_channels=in_channels, out_channels=in_channels) for i in range(num_layers)],
            ResNetSEBlock(in_channels, out_channels, wd=1e-4),
            nn.MaxPool1d(pool_size, ceil_mode=True),
            nn.Dropout(drop)
        )

    def init_thm_tof_branch(self, thm_dim, tof_dim, **kwargs):
        thm_branch = nn.Sequential(
            self.feature_block(thm_dim, kwargs["thm1_channels"], kwargs["thm1_layers"], drop=kwargs["thm1_dropout"]),
            self.feature_block(kwargs["thm1_channels"], kwargs["thm2_channels"], kwargs["thm2_layers"], drop=kwargs["thm2_dropout"]),
        )
        tof_branch = nn.Sequential(
            self.feature_block(tof_dim, kwargs["tof1_channels"], kwargs["tof1_layers"], drop=kwargs["tof1_dropout"]),
            self.feature_block(kwargs["tof1_channels"], kwargs["tof2_channels"], kwargs["tof2_layers"], drop=kwargs["tof2_dropout"]),
        )
        return thm_branch, tof_branch
    
    def forward(self, imus, thms, tofs):
        imu_acc, imu_rot, imu_other = imus
        imu_acc_feat = self.imu_acc_branch(imu_acc.permute(0, 2, 1))
        imu_rot_feat = self.imu_rot_branch(imu_rot.permute(0, 2, 1))
        imu_other_feat = self.imu_other_branch(imu_other.permute(0, 2, 1))
        imu_feat = self.imu_proj(torch.cat([imu_acc_feat, imu_rot_feat, imu_other_feat], dim=1))
        
        thm1, thm2, thm3, thm4, thm5 = thms
        tof1, tof2, tof3, tof4, tof5 = tofs
        
        thm1_feat = self.thm_branch1(thm1.permute(0, 2, 1))
        thm2_feat = self.thm_branch2(thm2.permute(0, 2, 1))
        thm3_feat = self.thm_branch3(thm3.permute(0, 2, 1))
        thm4_feat = self.thm_branch4(thm4.permute(0, 2, 1))
        thm5_feat = self.thm_branch5(thm5.permute(0, 2, 1))
        thm_feat = self.thm_proj(torch.cat([thm1_feat, thm2_feat, thm3_feat, thm4_feat, thm5_feat], dim=1))
        
        tof1_feat = self.tof_branch1(tof1.permute(0, 2, 1))
        tof2_feat = self.tof_branch2(tof2.permute(0, 2, 1))
        tof3_feat = self.tof_branch3(tof3.permute(0, 2, 1))
        tof4_feat = self.tof_branch4(tof4.permute(0, 2, 1))
        tof5_feat = self.tof_branch5(tof5.permute(0, 2, 1))
        tof_feat = self.tof_proj(torch.cat([tof1_feat, tof2_feat, tof3_feat, tof4_feat, tof5_feat], dim=1))
        
        feat = torch.cat([imu_feat, thm_feat, tof_feat], dim=1).permute(0, 2, 1)
        lstm_out, _ = self.lstm(feat)
        gru_out, _ = self.gru(feat)
        dense_out = self.dense(self.noise(feat))
        
        return self.attn(torch.cat([lstm_out, gru_out, dense_out], dim=-1))


CUDA0 = "cuda:0"
seed = 0
batch_size = 64
num_workers = 4
n_folds = 5

root_dir = Path("/kaggle/input/cmi-detect-behavior-with-sensor-data")
universe_csv_path = Path("/kaggle/input/cmi-precompute/pytorch/all/1/tof-1_raw.csv")

imu_only = False

deterministic = kagglehub.package_import('wasupandceacar/deterministic').deterministic
deterministic.init_all(seed)


def init_dataset():
    dataset_config = {
        "percent": 99,
        "scaler_config": StandardScaler(),
        "nan_ratio": {
            "imu": 0,
            "thm": 0,
            "tof": 0,
        },
        "fbfill": {
            "imu": True,
            "thm": True,
            "tof": True,
        },
        "one_scale": False,
        "tof_raw": True,
        "tof_mode": 16,
        "save_precompute": False,
        "fold_y": "gesture",
        "fold_groups": "subject",
    }

    dataset = CMIFoldDataset(universe_csv_path, dataset_config, full_dataset_function=CMIFeDataset, n_folds=n_folds, random_seed=seed)
    dataset.print_fold_stats()
    return dataset

def get_fold_dataset(dataset, fold):
    _, valid_dataset = dataset.get_fold_datasets(fold)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
    return valid_loader

dataset = init_dataset()

class CMIModel(nn.Module):
    def __init__(self, target_classes_num, non_target_classes_num, **kwargs):
        super().__init__()
        self.backbone = CMIBackbone(dataset.imu_dim, dataset.thm_dim, dataset.tof_dim, **kwargs)
        self.target_classifier = nn.Sequential(
            nn.Linear((kwargs['lstm_hidden_size']+kwargs['gru_hidden_size'])*2+kwargs['dense_channels'], kwargs["cls_channels1"]),
            nn.BatchNorm1d(kwargs["cls_channels1"]),
            nn.ReLU(),
            nn.Dropout(kwargs["cls_dropout1"]),
            nn.Linear(kwargs["cls_channels1"], kwargs["cls_channels2"]),
            nn.BatchNorm1d(kwargs["cls_channels2"]),
            nn.ReLU(),
            nn.Dropout(kwargs["cls_dropout2"]),
            nn.Linear(kwargs["cls_channels2"], target_classes_num)
        )
        self.non_target_classifier = nn.Sequential(
            nn.Linear((kwargs['lstm_hidden_size']+kwargs['gru_hidden_size'])*2+kwargs['dense_channels'], kwargs["cls_channels1"]),
            nn.BatchNorm1d(kwargs["cls_channels1"]),
            nn.ReLU(),
            nn.Dropout(kwargs["cls_dropout1"]),
            nn.Linear(kwargs["cls_channels1"], kwargs["cls_channels2"]),
            nn.BatchNorm1d(kwargs["cls_channels2"]),
            nn.ReLU(),
            nn.Dropout(kwargs["cls_dropout2"]),
            nn.Linear(kwargs["cls_channels2"], non_target_classes_num)
        )
    
    def forward(self, imu, thm, tof):
        feat = self.backbone(imu, thm, tof)
        targets_y = self.target_classifier(feat)
        non_targets_y = self.non_target_classifier(feat)
        return torch.cat([targets_y, non_targets_y], dim=1)

model_function = CMIModel
model_args = {"imu1_channels": 128, "imu2_channels": 256, "imu1_dropout": 0.3, "imu2_dropout": 0.25,
              "imu1_layers": 0, "imu2_layers": 0, 
              "thm1_channels": 32, "thm2_channels": 64, "thm1_dropout": 0.25, "thm2_dropout": 0.2,
              "thm1_layers": 0, "thm2_layers": 0, 
              "tof1_channels": 256, "tof2_channels": 512, "tof1_dropout": 0.4, "tof2_dropout": 0.3,
              "tof1_layers": 0, "tof2_layers": 0, 
              "lstm_hidden_size": 128, "gru_hidden_size": 128, "gaussian_noise_rate": 0.1, "dense_channels": 32,
              "cls_channels1": 256, "cls_dropout1": 0.2, "cls_channels2": 128, "cls_dropout2": 0.2,
              "target_classes_num": 8, "non_target_classes_num": 10,}
model_dir = Path("/kaggle/input/cmi-models-public/pytorch/base04/1")

model_dicts = [
    {
        "model_function": model_function,
        "model_args": model_args,
        "model_path": model_dir / f"fold{fold}/best_ema.pt",
    } for fold in range(n_folds)
]

def replace(k):
    k = k.replace("_orig_mod.", "")
    return k

models = list()
for model_dict in model_dicts:
    model_function = model_dict["model_function"]
    model_args = model_dict["model_args"]
    model_path = model_dict["model_path"]
    model = model_function(**model_args).to(CUDA0)
    state_dict = {replace(k): v for k,v in torch.load(model_path).items()}
    model.load_state_dict(state_dict)
    model = model.eval()
    models.append(model)


metric_package = kagglehub.package_import('wasupandceacar/cmi-metric')

metric = metric_package.Metric()
imu_only_metric = metric_package.Metric()

def to_cuda(*tensors):
    return [tensor.to(CUDA0) for tensor in tensors]

def inference(model, imu, thm, tof):
    imus, thms, tofs = dataset.full_dataset.split5(imu, thm, tof)
    with autocast(device_type='cuda'):
        pred_y = model(imus, thms, tofs)
    return pred_y

def valid(model, valid_bar):
    with torch.no_grad():
        for imu, thm, tof, y, b in valid_bar:
            imu, thm, tof, y = to_cuda(imu, thm, tof, y)
            pred_y = inference(model, imu, thm, tof)
            metric.add(dataset.le.classes_[y.argmax(dim=1).cpu()], dataset.le.classes_[pred_y.argmax(dim=1).cpu()])
            _, thm, tof = dataset.full_dataset.get_scaled_nan_tensors(imu, thm, tof)
            pred_y = inference(model, imu, thm, tof)
            imu_only_metric.add(dataset.le.classes_[y.argmax(dim=1).cpu()], dataset.le.classes_[pred_y.argmax(dim=1).cpu()])

for fold, model in enumerate(models):
    valid_loader = get_fold_dataset(dataset, fold)
    valid_bar = tqdm(valid_loader, desc=f"Valid", leave=False)
    valid(model, valid_bar)

print(f"""
Normal score: {metric.score()}
IMU only score: {imu_only_metric.score()}
""")



In [None]:
print("PyTorch model's label classes:")
print(dataset.le.classes_)

In [None]:
# ======== BEGIN: drop-in replacement for avg_predict / predict2 / predict3 ========
import torch
import torch.nn.functional as F
from contextlib import nullcontext


AMP_CTX = (torch.amp.autocast('cuda') if torch.cuda.is_available() else nullcontext())

def _short_err(e, k=800):
    s = f"{type(e).__name__}: {str(e)}"
    return s[:k]

def avg_predict(models, imu, thm, tof):
    """
    安全的集成平均：
    - 不再硬编码 cuda；有 GPU 时自动使用 AMP；
    - 统一每个子模型输出为 (1, C)；
    - 子模型出错只打印短日志并用零 logits 兜底，不向外 raise（防止 gRPC metadata 超限）。
    """
    outputs = []
    try:
        with AMP_CTX:
            for model in models:
                try:
                    y = inference(model, imu, thm, tof)   # 你的原有前向函数
                    # 兼容多输出/列表输出
                    if isinstance(y, (list, tuple)):
                        y = y[0]
                    # 非 tensor -> tensor
                    if not torch.is_tensor(y):
                        y = torch.as_tensor(y, device=imu.device)
                    # (C,) -> (1, C)
                    if y.ndim == 1:
                        y = y.unsqueeze(0)
                    # 设备对齐
                    if y.device != imu.device:
                        y = y.to(imu.device, non_blocking=True)
                    outputs.append(y)
                except Exception as e:
                    print(f"[WARN] submodel failed: {_short_err(e)}")
                    C = len(getattr(dataset.le, 'classes_', [])) if hasattr(dataset, 'le') and hasattr(dataset.le, 'classes_') else 18
                    outputs.append(torch.zeros((1, C), device=imu.device))
        if not outputs:
            C = len(getattr(dataset.le, 'classes_', [])) if hasattr(dataset, 'le') and hasattr(dataset.le, 'classes_') else 18
            return torch.zeros((1, C), device=imu.device)
        return torch.mean(torch.stack(outputs, dim=0), dim=0)  # (1, C)
    except Exception as e:
        print(f"[ERROR] avg_predict failed: {_short_err(e)}")
        C = len(getattr(dataset.le, 'classes_', [])) if hasattr(dataset, 'le') and hasattr(dataset.le, 'classes_') else 18
        return torch.zeros((1, C), device=imu.device)

def predict2(sequence: pl.DataFrame, demographics: pl.DataFrame):
    """第二个模型的预测函数（安全版）"""
    try:
        imu, thm, tof = dataset.full_dataset.inference_process(sequence)
        with torch.no_grad():
            # 仅在可用时移动到 GPU，避免 CPU 网关报错
            if torch.cuda.is_available():
                imu = imu.cuda(non_blocking=True); thm = thm.cuda(non_blocking=True); tof = tof.cuda(non_blocking=True)
            # 可选 IMU-only 开关
            if 'imu_only' in globals() and imu_only:
                try:
                    _, thm, tof = dataset.full_dataset.get_scaled_nan_tensors(imu, thm, tof)
                except Exception as e:
                    print(f"[WARN] get_scaled_nan_tensors: {_short_err(e)}")
            # 集成预测
            pred_y = avg_predict(models, imu, thm, tof)                  # (1, C)
            probabilities = F.softmax(pred_y, dim=1).detach().cpu().numpy()
        return probabilities  # (1, C)
    except Exception as e:
        print(f"[ERROR] predict2 failed: {_short_err(e)}")
        C = len(getattr(dataset.le, 'classes_', [])) if hasattr(dataset, 'le') and hasattr(dataset.le, 'classes_') else 18
        return np.ones((1, C), dtype=np.float32) / C

def predict3(sequence: pl.DataFrame, demographics: pl.DataFrame):
    """第三个模型：IMU-only 推理（安全版）"""
    try:
        imu, thm, tof = dataset.full_dataset.inference_process(sequence)
        with torch.no_grad():
            if torch.cuda.is_available():
                imu = imu.cuda(non_blocking=True); thm = thm.cuda(non_blocking=True); tof = tof.cuda(non_blocking=True)
            # 强制 IMU-only：替换 THM/TOF 为占位 NaN 张量
            try:
                _, thm, tof = dataset.full_dataset.get_scaled_nan_tensors(imu, thm, tof)
            except Exception as e:
                print(f"[WARN] get_scaled_nan_tensors: {_short_err(e)}")
            pred_y = avg_predict(models, imu, thm, tof)                  # (1, C)
            probabilities = F.softmax(pred_y, dim=1).detach().cpu().numpy()
        return probabilities  # (1, C)
    except Exception as e:
        print(f"[ERROR] predict3 failed: {_short_err(e)}")
        C = len(getattr(dataset.le, 'classes_', [])) if hasattr(dataset, 'le') and hasattr(dataset.le, 'classes_') else 18
        return np.ones((1, C), dtype=np.float32) / C
# ======== END: drop-in replacement ========


all sensor model2

In [None]:
"""gated-gru-hybrid-ensemble-v02.ipynb

    https://colab.research.google.com/drive/15f-PUIU6Tc6qYWYP6g7trekz1LypFFwW
"""

import os
import json
import joblib
import numpy as np
import pandas as pd
from pathlib import Path
import warnings
import random
import math
import matplotlib.pyplot as plt
import polars as pl
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import (
    Input, Conv1D, BatchNormalization, Activation, add, MaxPooling1D, Dropout,
    Bidirectional, GRU, GlobalAveragePooling1D, Dense, Multiply, Reshape,
    Lambda, Concatenate
)
from tensorflow.keras.optimizers import Adam as AdamTF
from tensorflow.keras.regularizers import l2
from tensorflow.keras.utils import Sequence, to_categorical, pad_sequences
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers.schedules import CosineDecay

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam as AdamTorch
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from sklearn.model_selection import StratifiedGroupKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from scipy.spatial.transform import Rotation as R
from scipy.signal import firwin

# 評価メトリクスはローカル検証/学習時にのみインポート
try:
    from cmi_2025_metric_copy_for_import import CompetitionMetric
except ImportError:
    CompetitionMetric = None
    print("CompetitionMetric could not be imported. OOF/CV score will not be calculated.")

def seed_everything(seed=42):
    """
    実行環境の乱数シードを統一的に設定する関数。
    """
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(2025)
    tf.random.set_seed(seed)
    tf.experimental.numpy.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
    os.environ['TF_DETERMINISTIC_OPS'] = '1'
    # torch.backends.cudnn.deterministic = True # パフォーマンスが低下する可能性があるためコメントアウト
    # torch.backends.cudnn.benchmark = False

seed_everything(seed=42)
warnings.filterwarnings("ignore")

TRAIN = False

# --- パス設定 ---
RAW_DIR = Path("/kaggle/input/cmi-detect-behavior-with-sensor-data")
# YOUR_MODELS_DIRは自分の学習済みモデルが格納されているKaggleデータセットのパスに設定してください
YOUR_MODELS_DIR = Path("/kaggle/input/cmi-data-gated-gru") # ★★★ 自分のモデルパスに変更 ★★★
PUBLIC_TF_MODEL_DIR = Path("/kaggle/input/lb-0-78-quaternions-tf-bilstm-gru-attention")
PUBLIC_PT_MODEL_DIR = Path("/kaggle/input/cmi3-models-p")
EXPORT_DIR = Path("./") # 学習済みモデルやアーティファクトの保存先

# --- モデル学習ハイパーパラメータ ---
BATCH_SIZE = 64          # バッチサイズ
PAD_PERCENTILE = 95      # シーケンス長のパディングを決めるためのパーセンタイル値
LR_INIT = 4e-4           # 学習率の初期値 (微調整)
WD = 3e-3                # Weight Decay（L2正則化）の係数
MIXUP_ALPHA = 0.4        # Mixupのα値
EPOCHS = 360             # 最大エポック数 (増加)
PATIENCE = 50            # EarlyStoppingのpatience (増加)
N_SPLITS = 10             # クロスバリデーションの分割数
MASKING_PROB = 0.25      # 学習時にTOF/THMデータをマスクする確率
GATE_LOSS_WEIGHT = 0.2   # Gatedモデルのゲート損失に対する重み

print(f"▶ ライブラリのインポート完了")
print(f"  - TensorFlow: {tf.__version__}")
print(f"  - PyTorch: {torch.__version__}")
print(f"▶ TRAINモード: {TRAIN}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# PyTorchモデル用の標準化パラメータ
mean_pt = torch.tensor([
    0, 0, 0, 0, 0, 0, 9.0319e-03, 1.0849e+00, -2.6186e-03, 3.7651e-03,
    -5.3660e-03, -2.8177e-03, 1.3318e-03, -1.5876e-04, 6.3495e-01,
    6.2877e-01, 6.0607e-01, 6.2142e-01, 6.3808e-01, 6.5420e-01,
    7.4102e-03, -3.4159e-03, -7.5237e-03, -2.6034e-02, 2.9704e-02,
    -3.1546e-02, -2.0610e-03, -4.6986e-03, -4.7216e-03, -2.6281e-02,
    1.5799e-02, 1.0016e-02
], dtype=torch.float32).view(1, -1, 1).to(device)

std_pt = torch.tensor([
    1, 1, 1, 1, 1, 1, 0.2067, 0.8583, 0.3162,
    0.2668, 0.2917, 0.2341, 0.3023, 0.3281, 1.0264, 0.8838, 0.8686, 1.0973,
    1.0267, 0.9018, 0.4658, 0.2009, 0.2057, 1.2240, 0.9535, 0.6655, 0.2941,
    0.3421, 0.8156, 0.6565, 1.1034, 1.5577
], dtype=torch.float32).view(1, -1, 1).to(device) + 1e-8

class ImuFeatureExtractor(nn.Module):
    """
    ★★★ PyTorchモデル用の特徴量抽出器 ★★★
    公開モデルの重みと一致させるため、元の正しい定義に修正。
    """
    def __init__(self, fs=100., add_quaternion=False):
        super().__init__()
        self.fs = fs
        self.add_quaternion = add_quaternion

        k = 15

        # ▼▼▼【ここが修正点】▼▼▼
        # 公開モデルの重みファイルに存在する 'self.lpf' 層を再度追加する
        self.lpf = nn.Conv1d(6, 6, kernel_size=k, padding=k//2,
                                 groups=6, bias=False)
        nn.init.kaiming_uniform_(self.lpf.weight, a=math.sqrt(5))
        # ▲▲▲【ここまでが修正点】▲▲▲

        self.lpf_acc  = nn.Conv1d(3, 3, k, padding=k//2, groups=3, bias=False)
        self.lpf_gyro = nn.Conv1d(3, 3, k, padding=k//2, groups=3, bias=False)

    def forward(self, imu):
        acc  = imu[:, 0:3, :]
        gyro = imu[:, 3:6, :]

        # 1) magnitude
        acc_mag  = torch.norm(acc,  dim=1, keepdim=True)
        gyro_mag = torch.norm(gyro, dim=1, keepdim=True)

        # 2) jerk
        jerk = F.pad(acc[:, :, 1:] - acc[:, :, :-1], (1,0))
        gyro_delta = F.pad(gyro[:, :, 1:] - gyro[:, :, :-1], (1,0))

        # 3) energy
        acc_pow  = acc ** 2
        gyro_pow = gyro ** 2

        # 4) LPF / HPF
        # self.lpf は forwardパスでは使われていないが、重み読み込みのために定義が必要
        acc_lpf  = self.lpf_acc(acc)
        acc_hpf  = acc - acc_lpf
        gyro_lpf = self.lpf_gyro(gyro)
        gyro_hpf = gyro - gyro_lpf

        features = [
            acc, gyro,
            acc_mag, gyro_mag,
            jerk, gyro_delta,
            acc_pow, gyro_pow,
            acc_lpf, acc_hpf,
            gyro_lpf, gyro_hpf,
        ]
        return torch.cat(features, dim=1)

class SEBlock(nn.Module):
    def __init__(self, channels, reduction=8):
        super().__init__()
        self.squeeze = nn.AdaptiveAvgPool1d(1)
        self.excitation = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False), nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False), nn.Sigmoid()
        )
    def forward(self, x):
        b, c, _ = x.size()
        y = self.squeeze(x).view(b, c)
        y = self.excitation(y).view(b, c, 1)
        return x * y.expand_as(x)

class ResidualSECNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, pool_size=2, dropout=0.3):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, padding=kernel_size//2, bias=False)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, padding=kernel_size//2, bias=False)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.se = SEBlock(out_channels)
        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(nn.Conv1d(in_channels, out_channels, 1, bias=False), nn.BatchNorm1d(out_channels))
        self.pool = nn.MaxPool1d(pool_size)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.se(out)
        out += self.shortcut(x)
        return self.dropout(self.pool(F.relu(out)))

class AttentionLayer(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.attention = nn.Linear(hidden_dim, 1)
    def forward(self, x):
        scores = torch.tanh(self.attention(x))
        weights = F.softmax(scores.squeeze(-1), dim=1)
        return torch.sum(x * weights.unsqueeze(-1), dim=1)

class TwoBranchModel(nn.Module):
    def __init__(self, pad_len, imu_dim_raw, tof_dim, n_classes, dropouts=[0.3, 0.3, 0.3, 0.3, 0.4, 0.5, 0.3], feature_engineering=True, **kwargs):
        super().__init__()
        self.feature_engineering = feature_engineering
        imu_dim = 32 if feature_engineering else imu_dim_raw
        self.imu_fe = ImuFeatureExtractor(**kwargs) if feature_engineering else nn.Identity()
        self.fir_nchan = 7
        numtaps = 33
        fir_kernel = torch.tensor(firwin(numtaps, cutoff=1.0, fs=10.0, pass_zero=False), dtype=torch.float32).view(1, 1, -1).repeat(self.fir_nchan, 1, 1)
        self.register_buffer("fir_kernel", fir_kernel)
        self.imu_block1 = ResidualSECNNBlock(imu_dim, 64, 3, dropout=dropouts[0])
        self.imu_block2 = ResidualSECNNBlock(64, 128, 5, dropout=dropouts[1])
        self.tof_conv1 = nn.Conv1d(tof_dim, 64, 3, padding=1, bias=False)
        self.tof_bn1, self.tof_pool1, self.tof_drop1 = nn.BatchNorm1d(64), nn.MaxPool1d(2), nn.Dropout(dropouts[2])
        self.tof_conv2 = nn.Conv1d(64, 128, 3, padding=1, bias=False)
        self.tof_bn2, self.tof_pool2, self.tof_drop2 = nn.BatchNorm1d(128), nn.MaxPool1d(2), nn.Dropout(dropouts[3])
        self.bilstm = nn.LSTM(256, 128, bidirectional=True, batch_first=True)
        self.lstm_dropout = nn.Dropout(dropouts[4])
        self.attention = AttentionLayer(256)
        self.dense1, self.bn_dense1, self.drop1 = nn.Linear(256, 256, bias=False), nn.BatchNorm1d(256), nn.Dropout(dropouts[5])
        self.dense2, self.bn_dense2, self.drop2 = nn.Linear(256, 128, bias=False), nn.BatchNorm1d(128), nn.Dropout(dropouts[6])
        self.classifier = nn.Linear(128, n_classes)

    def forward(self, x):
        imu_raw = x[:, :, :self.fir_nchan].transpose(1, 2)
        tof = x[:, :, self.fir_nchan:].transpose(1, 2)
        imu_fe = self.imu_fe(imu_raw)
        filtered = F.conv1d(imu_fe[:, :self.fir_nchan, :], self.fir_kernel, padding=self.fir_kernel.shape[-1] // 2, groups=self.fir_nchan)
        imu = (torch.cat([filtered, imu_fe[:, self.fir_nchan:, :]], dim=1) - mean_pt) / std_pt
        x1 = self.imu_block1(imu); x1 = self.imu_block2(x1)
        x2 = self.tof_drop1(self.tof_pool1(F.relu(self.tof_bn1(self.tof_conv1(tof)))))
        x2 = self.tof_drop2(self.tof_pool2(F.relu(self.tof_bn2(self.tof_conv2(x2)))))
        merged = torch.cat([x1, x2], dim=1).transpose(1, 2)
        lstm_out, _ = self.bilstm(merged); lstm_out = self.lstm_dropout(lstm_out)
        attended = self.attention(lstm_out)
        x = self.drop1(F.relu(self.bn_dense1(self.dense1(attended))))
        x = self.drop2(F.relu(self.bn_dense2(self.dense2(x))))
        return self.classifier(x)

class PublicTwoBranchModel(nn.Module):
    """
    ★★★ 公開されているPyTorchモデル（モデル群C）を読み込むための、元のアーキテクチャを持つクラス ★★★
    """
    def __init__(self, pad_len, imu_dim_raw, tof_dim, n_classes, dropouts=[0.3, 0.3, 0.3, 0.3, 0.4, 0.5, 0.3], feature_engineering=True, **kwargs):
        super().__init__()
        self.feature_engineering = feature_engineering
        imu_dim = 32 if feature_engineering else imu_dim_raw
        self.imu_fe = ImuFeatureExtractor(**kwargs) if feature_engineering else nn.Identity()
        self.fir_nchan = 7
        numtaps = 33
        fir_kernel = torch.tensor(firwin(numtaps, cutoff=1.0, fs=10.0, pass_zero=False), dtype=torch.float32).view(1, 1, -1).repeat(self.fir_nchan, 1, 1)
        self.register_buffer("fir_kernel", fir_kernel)
        self.imu_block1 = ResidualSECNNBlock(imu_dim, 64, 3, dropout=dropouts[0])
        self.imu_block2 = ResidualSECNNBlock(64, 128, 5, dropout=dropouts[1])
        self.tof_conv1 = nn.Conv1d(tof_dim, 64, 3, padding=1, bias=False)
        self.tof_bn1, self.tof_pool1, self.tof_drop1 = nn.BatchNorm1d(64), nn.MaxPool1d(2), nn.Dropout(dropouts[2])
        self.tof_conv2 = nn.Conv1d(64, 128, 3, padding=1, bias=False)
        self.tof_bn2, self.tof_pool2, self.tof_drop2 = nn.BatchNorm1d(128), nn.MaxPool1d(2), nn.Dropout(dropouts[3])
        self.bilstm = nn.LSTM(256, 128, bidirectional=True, batch_first=True) # GRUではなくLSTM
        self.lstm_dropout = nn.Dropout(dropouts[4])
        self.attention = AttentionLayer(256) # 128*2 for bidirectional
        self.dense1, self.bn_dense1, self.drop1 = nn.Linear(256, 256, bias=False), nn.BatchNorm1d(256), nn.Dropout(dropouts[5])
        self.dense2, self.bn_dense2, self.drop2 = nn.Linear(256, 128, bias=False), nn.BatchNorm1d(128), nn.Dropout(dropouts[6])
        self.classifier = nn.Linear(128, n_classes)

    def forward(self, x):
        imu_raw = x[:, :, :self.fir_nchan].transpose(1, 2)
        tof = x[:, :, self.fir_nchan:].transpose(1, 2)
        imu_fe = self.imu_fe(imu_raw)
        filtered = F.conv1d(imu_fe[:, :self.fir_nchan, :], self.fir_kernel, padding=self.fir_kernel.shape[-1] // 2, groups=self.fir_nchan)
        # mean_pt, std_pt は事前に定義されているグローバル変数
        imu = (torch.cat([filtered, imu_fe[:, self.fir_nchan:, :]], dim=1) - mean_pt) / std_pt
        x1 = self.imu_block1(imu); x1 = self.imu_block2(x1)
        x2 = self.tof_drop1(self.tof_pool1(F.relu(self.tof_bn1(self.tof_conv1(tof)))))
        x2 = self.tof_drop2(self.tof_pool2(F.relu(self.tof_bn2(self.tof_conv2(x2)))))
        merged = torch.cat([x1, x2], dim=1).transpose(1, 2)
        lstm_out, _ = self.bilstm(merged); lstm_out = self.lstm_dropout(lstm_out)
        attended = self.attention(lstm_out)
        x = self.drop1(F.relu(self.bn_dense1(self.dense1(attended))))
        x = self.drop2(F.relu(self.bn_dense2(self.dense2(x))))
        return self.classifier(x)

def pad_sequences_torch3(sequences, maxlen, padding='post', truncating='post', value=0.0):
    result = []
    for seq in sequences:
        if len(seq) >= maxlen: seq = seq[:maxlen] if truncating == 'post' else seq[-maxlen:]
        else:
            pad_len = maxlen - len(seq)
            pad_array = np.full((pad_len, seq.shape[1]), value)
            seq = np.concatenate([seq, pad_array]) if padding == 'post' else np.concatenate([pad_array, seq])
        result.append(seq)
    return np.array(result, dtype=np.float32)

# =============================================================================
# ## 特徴量エンジニアリング関数
# =============================================================================
def remove_gravity_from_acc3(acc_data, rot_data):
    """加速度データから重力成分を除去する"""
    acc_values = acc_data[['acc_x', 'acc_y', 'acc_z']].values
    quat_values = rot_data[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
    linear_accel = np.zeros_like(acc_values)
    gravity_world = np.array([0, 0, 9.81])
    for i in range(len(acc_values)):
        if np.all(np.isnan(quat_values[i])):
            linear_accel[i, :] = acc_values[i, :]
            continue
        try:
            rotation = R.from_quat(quat_values[i])
            gravity_sensor_frame = rotation.apply(gravity_world, inverse=True)
            linear_accel[i, :] = acc_values[i, :] - gravity_sensor_frame
        except (ValueError, IndexError):
            linear_accel[i, :] = acc_values[i, :]
    return linear_accel

def calculate_angular_velocity_from_quat3(rot_data, time_delta=1/200):
    """クォータニオンから角速度を計算する"""
    quat_values = rot_data[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
    angular_vel = np.zeros((len(quat_values), 3))
    for i in range(len(quat_values) - 1):
        q_t, q_t_plus_dt = quat_values[i], quat_values[i+1]
        if np.all(np.isnan(q_t)) or np.all(np.isnan(q_t_plus_dt)): continue
        try:
            rot_t = R.from_quat(q_t)
            rot_t_plus_dt = R.from_quat(q_t_plus_dt)
            delta_rot = rot_t.inv() * rot_t_plus_dt
            angular_vel[i, :] = delta_rot.as_rotvec() / time_delta
        except (ValueError, IndexError): pass
    return angular_vel

def calculate_angular_distance3(rot_data):
    """クォータニオンから角距離を計算する"""
    quat_values = rot_data[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
    angular_dist = np.zeros(len(quat_values))
    for i in range(len(quat_values) - 1):
        q1, q2 = quat_values[i], quat_values[i+1]
        if np.all(np.isnan(q1)) or np.all(np.isnan(q2)): continue
        try:
            r1 = R.from_quat(q1)
            r2 = R.from_quat(q2)
            relative_rotation = r1.inv() * r2
            angular_dist[i] = np.linalg.norm(relative_rotation.as_rotvec())
        except (ValueError, IndexError): pass
    return angular_dist

def time_sum(x): return K.sum(x, axis=1)
def squeeze_last_axis(x): return tf.squeeze(x, axis=-1)
def expand_last_axis(x): return tf.expand_dims(x, axis=-1)

def se_block(x, reduction=8):
    """Squeeze-and-Excitationブロック"""
    ch = x.shape[-1]
    se = GlobalAveragePooling1D()(x)
    se = Dense(ch // reduction, activation='relu')(se)
    se = Dense(ch, activation='sigmoid')(se)
    se = Reshape((1, ch))(se)
    return Multiply()([x, se])

def residual_se_cnn_block(x, filters, kernel_size, pool_size=2, drop=0.3, wd=1e-4):
    """Residual SE-CNNブロック"""
    shortcut = x
    # 2層のConv1D
    for _ in range(2):
        x = Conv1D(filters, kernel_size, padding='same', use_bias=False, kernel_regularizer=l2(wd))(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
    # SEブロック
    x = se_block(x)
    # ショートカット接続
    if shortcut.shape[-1] != filters:
        shortcut = Conv1D(filters, 1, padding='same', use_bias=False, kernel_regularizer=l2(wd))(shortcut)
        shortcut = BatchNormalization()(shortcut)
    x = add([x, shortcut])
    x = Activation('relu')(x)
    x = MaxPooling1D(pool_size)(x)
    x = Dropout(drop)(x)
    return x

def attention_layer(inputs):
    """アテンション層"""
    score = Dense(1, activation='tanh')(inputs)
    score = Lambda(squeeze_last_axis)(score)
    weights = Activation('softmax')(score)
    weights = Lambda(expand_last_axis)(weights)
    context = Multiply()([inputs, weights])
    context = Lambda(time_sum)(context)
    return context

class GatedMixupGenerator(Sequence):
    """Mixupとセンサーマスキングを適用するデータジェネレータ"""
    def __init__(self, X, y, batch_size, imu_dim, class_weight=None, alpha=0.2, masking_prob=0.0):
        self.X, self.y, self.batch, self.imu_dim = X, y, batch_size, imu_dim
        self.class_weight, self.alpha, self.masking_prob = class_weight, alpha, masking_prob
        self.indices = np.arange(len(X))

    def __len__(self):
        return int(np.ceil(len(self.X) / self.batch))

    def __getitem__(self, i):
        idx = self.indices[i*self.batch:(i+1)*self.batch]
        Xb, yb = self.X[idx].copy(), self.y[idx].copy()

        sample_weights = np.ones(len(Xb), dtype='float32')
        if self.class_weight:
            sample_weights = np.array([self.class_weight.get(i, 1.0) for i in yb.argmax(axis=1)])

        gate_target = np.ones(len(Xb), dtype='float32')
        if self.masking_prob > 0:
            for j in range(len(Xb)):
                if np.random.rand() < self.masking_prob:
                    Xb[j, :, self.imu_dim:] = 0
                    gate_target[j] = 0.0

        if self.alpha > 0:
            lam = np.random.beta(self.alpha, self.alpha)
            perm = np.random.permutation(len(Xb))
            X_mix = lam * Xb + (1 - lam) * Xb[perm]
            y_mix = lam * yb + (1 - lam) * yb[perm]
            gate_target_mix = lam * gate_target + (1 - lam) * gate_target[perm]
            sample_weights_mix = lam * sample_weights + (1 - lam) * sample_weights[perm]
            return X_mix, {'main_output': y_mix, 'tof_gate': gate_target_mix}, sample_weights_mix

        return Xb, {'main_output': yb, 'tof_gate': gate_target}, sample_weights

    def on_epoch_end(self):
        np.random.shuffle(self.indices)

def build_gated_two_branch_model(pad_len, imu_dim, tof_dim, n_classes, wd=1e-4):
    """
    自作のGated Two-Branchモデルを構築する関数。
    [改良点] LSTMをGRUに変更、全結合層を1層追加。
    """
    inp = Input(shape=(pad_len, imu_dim + tof_dim))
    imu = Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # IMUブランチ (Deep)
    x1 = residual_se_cnn_block(imu, 64, 3, drop=0.1, wd=wd)
    x1 = residual_se_cnn_block(x1, 128, 5, drop=0.1, wd=wd)

    # TOF/THMブランチ (Light) with Gating
    x2_base = Conv1D(64, 3, padding='same', use_bias=False, kernel_regularizer=l2(wd))(tof)
    x2_base = BatchNormalization()(x2_base); x2_base = Activation('relu')(x2_base)
    x2_base = MaxPooling1D(2)(x2_base); x2_base = Dropout(0.2)(x2_base)
    x2_base = Conv1D(128, 3, padding='same', use_bias=False, kernel_regularizer=l2(wd))(x2_base)
    x2_base = BatchNormalization()(x2_base); x2_base = Activation('relu')(x2_base)
    x2_base = MaxPooling1D(2)(x2_base); x2_base = Dropout(0.2)(x2_base)

    # Gating機構
    gate_input = GlobalAveragePooling1D()(tof)
    gate_input = Dense(16, activation='relu')(gate_input)
    gate = Dense(1, activation='sigmoid', name='tof_gate')(gate_input)
    x2 = Multiply()([x2_base, gate])

    # ブランチのマージと後続層
    merged = Concatenate()([x1, x2])
    # ★改良点: LSTM -> GRU
    x = Bidirectional(GRU(256, return_sequences=True, kernel_regularizer=l2(wd)))(merged)
    x = Dropout(0.45)(x)
    x = attention_layer(x)

    # ★改良点: 全結合層を1層追加して表現力を向上
    for units, drop in [(512, 0.5), (256, 0.4), (128, 0.3)]:
        x = Dense(units, use_bias=False, kernel_regularizer=l2(wd))(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        x = Dropout(drop)(x)

    out = Dense(n_classes, activation='softmax', name='main_output', kernel_regularizer=l2(wd))(x)

    return Model(inputs=inp, outputs=[out, gate])

# -----------------------------------------------------------------------------
# ### 推論モード (`TRAIN = False`)
# -----------------------------------------------------------------------------

print("▶ 推論モード開始 – 学習済みモデルとアーティファクトを読み込みます...")

# --- モデル群A (自作TF/Kerasモデル) の読み込み ---
print("  モデル群A (自作5-Fold Gated GRUモデル) を読み込み中...")
final_feature_cols_A = np.load(YOUR_MODELS_DIR / "final_feature_cols.npy", allow_pickle=True).tolist()
pad_len_A = int(np.load(YOUR_MODELS_DIR / "sequence_maxlen.npy"))
scaler_A = joblib.load(YOUR_MODELS_DIR / "scaler.pkl")
gesture_classes = np.load(YOUR_MODELS_DIR / "gesture_classes.npy", allow_pickle=True)
custom_objs_A = {'time_sum': time_sum, 'squeeze_last_axis': squeeze_last_axis, 'expand_last_axis': expand_last_axis,
                 'se_block': se_block, 'residual_se_cnn_block': residual_se_cnn_block, 'attention_layer': attention_layer}
models_A = [load_model(YOUR_MODELS_DIR / f"final_model_fold_{f}.h5", compile=False, custom_objects=custom_objs_A) for f in range(N_SPLITS)]
print(f"  > {len(models_A)}個のモデルを正常に読み込みました。")

# --- モデル群B (公開TF/Kerasモデル) の読み込み ---
print("\n  モデル群B (公開TF/Kerasモデル) を読み込み中...")
final_feature_cols_B = np.load(PUBLIC_TF_MODEL_DIR / "feature_cols.npy", allow_pickle=True).tolist()
pad_len_B = int(np.load(PUBLIC_TF_MODEL_DIR / "sequence_maxlen.npy"))
scaler_B = joblib.load(PUBLIC_TF_MODEL_DIR / "scaler.pkl")
custom_objs_B = custom_objs_A # public modelも同じカスタムオブジェクトを使用
model_B = load_model(PUBLIC_TF_MODEL_DIR / "gesture_two_branch_mixup.h5", compile=False, custom_objects=custom_objs_B)
print("  > 1個のモデルを正常に読み込みました。")

# --- モデル群C (公開PyTorchモデル) の読み込み ---
print("\n  モデル群C (公開PyTorchモデル) を読み込み中...")
final_feature_cols_C = np.load(PUBLIC_PT_MODEL_DIR / "feature_cols.npy", allow_pickle=True).tolist()
pad_len_C = int(np.load(PUBLIC_PT_MODEL_DIR / "sequence_maxlen.npy"))
scaler_C = joblib.load(PUBLIC_PT_MODEL_DIR / "scaler.pkl")

pt_models = []
for f in range(5):
    checkpoint = torch.load(PUBLIC_PT_MODEL_DIR / f"gesture_two_branch_fold{f}.pth", map_location=device)
    cfg = {'pad_len': checkpoint['pad_len'], 'imu_dim_raw': checkpoint['imu_dim'],
           'tof_dim': checkpoint['tof_dim'], 'n_classes': checkpoint['n_classes']}
    m = PublicTwoBranchModel(**cfg).to(device)
    m.load_state_dict(checkpoint['model_state_dict'])
    m.eval()
    pt_models.append(m)
print(f"  > {len(pt_models)}個のモデルを正常に読み込みました。")

# predict_4

# --- `predict`関数の定義 ---
def predict4(sequence: pl.DataFrame, demographics: pl.DataFrame) -> str:
    df_seq_orig = sequence.to_pandas()
    df_seq_A = df_seq_orig.copy()
    
    linear_accel_A = remove_gravity_from_acc3(df_seq_A[['acc_x','acc_y','acc_z']], df_seq_A[['rot_x','rot_y','rot_z','rot_w']])
    df_seq_A['linear_acc_x'], df_seq_A['linear_acc_y'], df_seq_A['linear_acc_z'] = linear_accel_A[:,0], linear_accel_A[:,1], linear_accel_A[:,2]
    df_seq_A['linear_acc_mag'] = np.linalg.norm(linear_accel_A, axis=1)
    df_seq_A['linear_acc_mag_jerk'] = df_seq_A['linear_acc_mag'].diff().fillna(0)
    angular_vel_A = calculate_angular_velocity_from_quat3(df_seq_A[['rot_x','rot_y','rot_z','rot_w']])
    df_seq_A['angular_vel_x'], df_seq_A['angular_vel_y'], df_seq_A['angular_vel_z'] = angular_vel_A[:,0], angular_vel_A[:,1], angular_vel_A[:,2]
    df_seq_A['angular_distance'] = calculate_angular_distance3(df_seq_A[['rot_x','rot_y','rot_z','rot_w']])
    for col in ['rot_x', 'rot_y', 'rot_z', 'rot_w']:
        df_seq_A[f'{col}_diff'] = df_seq_A[col].diff().fillna(0)
    cols_for_stats=['linear_acc_mag','linear_acc_mag_jerk','angular_distance']
    for col in cols_for_stats:
        df_seq_A[f'{col}_skew'], df_seq_A[f'{col}_kurt'] = df_seq_A[col].skew(), df_seq_A[col].kurtosis()
    for i in range(1,6):
        if f'tof_{i}_v0' in df_seq_A.columns:
            pixel_cols=[f"tof_{i}_v{p}" for p in range(64)]; tof_data=df_seq_A[pixel_cols].replace(-1,np.nan)
            df_seq_A[f'tof_{i}_mean'], df_seq_A[f'tof_{i}_std'], df_seq_A[f'tof_{i}_min'], df_seq_A[f'tof_{i}_max'] = tof_data.mean(axis=1),tof_data.std(axis=1),tof_data.min(axis=1),tof_data.max(axis=1)
    tof_mean_cols=[f'tof_{i}_mean' for i in range(1,6) if f'tof_{i}_mean' in df_seq_A.columns]
    if tof_mean_cols:
        df_seq_A['tof_std_across_sensors']=df_seq_A[tof_mean_cols].std(axis=1)
        df_seq_A['tof_range_across_sensors']=df_seq_A[tof_mean_cols].max(axis=1)-df_seq_A[tof_mean_cols].min(axis=1)
    thm_cols=[f'thm_{i}' for i in range(1,6) if f'thm_{i}' in df_seq_A.columns]
    if thm_cols:
        df_seq_A['thm_std_across_sensors']=df_seq_A[thm_cols].std(axis=1)
        df_seq_A['thm_range_across_sensors']=df_seq_A[thm_cols].max(axis=1)-df_seq_A[thm_cols].min(axis=1)
    # (推論 A)
    mat_A = df_seq_A[final_feature_cols_A].ffill().bfill().fillna(0).values.astype('float32')
    mat_A = scaler_A.transform(mat_A)
    pad_input_A = pad_sequences([mat_A], maxlen=pad_len_A, padding='post', dtype='float32')
    preds_A_folds = [model.predict(pad_input_A, verbose=0)[0] for model in models_A]
    avg_pred_A = np.mean(preds_A_folds, axis=0)

    # --- 2. モデル群B (公開TFモデル) の予測 ---
    df_seq_B = df_seq_orig.copy()
    # (特徴量生成 B)
    df_seq_B['acc_mag']=np.sqrt(df_seq_B['acc_x']**2+df_seq_B['acc_y']**2+df_seq_B['acc_z']**2)
    df_seq_B['rot_angle']=2*np.arccos(df_seq_B['rot_w'].clip(-1,1))
    df_seq_B['acc_mag_jerk']=df_seq_B['acc_mag'].diff().fillna(0)
    df_seq_B['rot_angle_vel']=df_seq_B['rot_angle'].diff().fillna(0)
    linear_accel_B=remove_gravity_from_acc3(df_seq_B,df_seq_B)
    df_seq_B['linear_acc_x'],df_seq_B['linear_acc_y'],df_seq_B['linear_acc_z']=linear_accel_B[:,0],linear_accel_B[:,1],linear_accel_B[:,2]
    df_seq_B['linear_acc_mag']=np.sqrt(df_seq_B['linear_acc_x']**2+df_seq_B['linear_acc_y']**2+df_seq_B['linear_acc_z']**2)
    df_seq_B['linear_acc_mag_jerk']=df_seq_B['linear_acc_mag'].diff().fillna(0)
    angular_vel_B=calculate_angular_velocity_from_quat3(df_seq_B)
    df_seq_B['angular_vel_x'],df_seq_B['angular_vel_y'],df_seq_B['angular_vel_z']=angular_vel_B[:,0],angular_vel_B[:,1],angular_vel_B[:,2]
    df_seq_B['angular_distance']=calculate_angular_distance3(df_seq_B)
    for i in range(1,6):
        if f'tof_{i}_v0' in df_seq_B.columns:
            pixel_cols=[f"tof_{i}_v{p}" for p in range(64)]; tof_data=df_seq_B[pixel_cols].replace(-1,np.nan)
            df_seq_B[f"tof_{i}_mean"],df_seq_B[f"tof_{i}_std"],df_seq_B[f"tof_{i}_min"],df_seq_B[f"tof_{i}_max"]=tof_data.mean(axis=1),tof_data.std(axis=1),tof_data.min(axis=1),tof_data.max(axis=1)
    # (推論 B)
    mat_B = df_seq_B[final_feature_cols_B].ffill().bfill().fillna(0).values.astype('float32')
    mat_B = scaler_B.transform(mat_B)
    pad_input_B = pad_sequences([mat_B], maxlen=pad_len_B, padding='post', dtype='float32')
    pred_B = model_B.predict(pad_input_B, verbose=0)
    if isinstance(pred_B, list): pred_B = pred_B[0]

    # --- 3. モデル群C (公開PyTorchモデル) の予測 ---
    df_seq_C = df_seq_orig.copy() # Cは特徴量生成が不要なため、コピーのみ
    mat_C = df_seq_C[final_feature_cols_C].ffill().bfill().fillna(0).values.astype('float32')
    mat_C = scaler_C.transform(mat_C)
    pad_input_C = pad_sequences_torch3([mat_C], maxlen=pad_len_C, padding='pre', truncating='pre')
    with torch.no_grad():
        pt_input = torch.from_numpy(pad_input_C).to(device)
        preds_C_folds = [model(pt_input) for model in pt_models]
        avg_pred_C_logits = torch.mean(torch.stack(preds_C_folds), dim=0)
        avg_pred_C = torch.softmax(avg_pred_C_logits, dim=1).cpu().numpy()

    # --- 4. 加重平均による最終決定 ---

    weights = {'A': 0.50, 'B': 0.20, 'C': 0.30}

    final_pred_proba = (weights['A'] * avg_pred_A + weights['B'] * pred_B + weights['C'] * avg_pred_C)

    return final_pred_proba

all sensor model3

In [None]:
import os, random, math
import pandas as pd, numpy as np, polars as pl
from pathlib import Path
from scipy.spatial.transform import Rotation as R
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import f1_score
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # Suppress ALL TensorFlow logs
import tensorflow as tf
tf.get_logger().setLevel('ERROR')  # Only show errors 
from tensorflow.keras.utils import to_categorical, pad_sequences
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import (
    Input, Conv1D, BatchNormalization, Activation, add, MaxPooling1D, 
    Dropout, Bidirectional, LSTM, GlobalAveragePooling1D, Dense, Multiply,
    Reshape, Lambda, Concatenate, GRU, GaussianNoise
)
from tensorflow.keras.optimizers.schedules import CosineDecay, CosineDecayRestarts, ExponentialDecay
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
import joblib
import warnings
warnings.filterwarnings('ignore')

TRAIN = False
BASE_DIR = Path("/kaggle/input/cmi-detect-behavior-with-sensor-data")
PRETRAINED_DIR = Path("/kaggle/input/artifact0")
EXPORT_DIR = Path("/kaggle/working")
SEED = 42
BATCH_SIZE = 64
PAD_PERCENTILE = 95
LR_INIT = 5e-4
WD = 3e-3
MIXUP_ALPHA = 0.4
EPOCHS = 160
PATIENCE = 40
N_SPLITS = 10
MASKING_PROB = 0.35
GATE_LOSS_WEIGHT = 0.2
USE_LR_SCHEDULER = True
LR_SCHEDULE_TYPE = "cosine_decay_restarts"  # Options: "cosine_decay", "cosine_decay_restarts", "exponential_decay"


def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    tf.experimental.numpy.random.seed(seed)
    os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
    os.environ['TF_DETERMINISTIC_OPS'] = '1'

seed_everything(SEED)

def remove_gravity_from_acc(acc_data, rot_data):
    acc_values = acc_data[['acc_x', 'acc_y', 'acc_z']].values
    quat_values = rot_data[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
    linear_accel = np.zeros_like(acc_values)
    gravity_world = np.array([0, 0, 9.81])
    
    for i in range(len(acc_values)):
        if np.all(np.isnan(quat_values[i])) or np.all(np.isclose(quat_values[i], 0)):
            linear_accel[i, :] = acc_values[i, :]
            continue
        try:
            rotation = R.from_quat(quat_values[i])
            gravity_sensor_frame = rotation.apply(gravity_world, inverse=True)
            linear_accel[i, :] = acc_values[i, :] - gravity_sensor_frame
        except ValueError:
            linear_accel[i, :] = acc_values[i, :]
    return linear_accel

def calculate_angular_velocity_from_quat(rot_data, time_delta=1/200):
    quat_values = rot_data[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
    angular_vel = np.zeros((len(quat_values), 3))
    
    for i in range(len(quat_values) - 1):
        q_t, q_t_plus_dt = quat_values[i], quat_values[i+1]
        if np.all(np.isnan(q_t)) or np.all(np.isnan(q_t_plus_dt)):
            continue
        try:
            rot_t = R.from_quat(q_t)
            rot_t_plus_dt = R.from_quat(q_t_plus_dt)
            delta_rot = rot_t.inv() * rot_t_plus_dt
            angular_vel[i, :] = delta_rot.as_rotvec() / time_delta
        except ValueError:
            pass
    return angular_vel

def calculate_angular_distance(rot_data):
    quat_values = rot_data[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
    angular_dist = np.zeros(len(quat_values))
    
    for i in range(len(quat_values) - 1):
        q1, q2 = quat_values[i], quat_values[i+1]
        if np.all(np.isnan(q1)) or np.all(np.isnan(q2)):
            continue
        try:
            r1, r2 = R.from_quat(q1), R.from_quat(q2)
            relative_rotation = r1.inv() * r2
            angular_dist[i] = np.linalg.norm(relative_rotation.as_rotvec())
        except ValueError:
            pass
    return angular_dist

def cmi_metric(y_true_gestures, y_pred_gestures, bfrb_gestures=None):
    y_true_gestures = np.array(y_true_gestures)
    y_pred_gestures = np.array(y_pred_gestures)
    
    y_true_binary = np.array(['target' if gesture in bfrb_gestures else 'non_target' 
                             for gesture in y_true_gestures])
    y_pred_binary = np.array(['target' if gesture in bfrb_gestures else 'non_target' 
                             for gesture in y_pred_gestures])
    
    binary_f1 = f1_score(y_true_binary, y_pred_binary, pos_label='target')
    
    y_true_collapsed = []
    y_pred_collapsed = []
    
    for true_gesture, pred_gesture in zip(y_true_gestures, y_pred_gestures):
        if true_gesture in bfrb_gestures:
            y_true_collapsed.append(true_gesture)
        else:
            y_true_collapsed.append('non_target')
            
        if pred_gesture in bfrb_gestures:
            y_pred_collapsed.append(pred_gesture)
        else:
            y_pred_collapsed.append('non_target')
    
    y_true_collapsed = np.array(y_true_collapsed)
    y_pred_collapsed = np.array(y_pred_collapsed)
    
    macro_f1 = f1_score(y_true_collapsed, y_pred_collapsed, average='macro')
    composite_score = (binary_f1 + macro_f1) / 2.0
    
    return {
        'binary_f1': binary_f1,
        'macro_f1': macro_f1, 
        'composite_score': composite_score
    }

def evaluate_with_cmi_metric(model, X_val, y_val_gestures, gesture_classes, bfrb_gestures):
    predictions = model.predict(X_val, verbose=0)[0]
    pred_gesture_indices = predictions.argmax(axis=1)
    pred_gestures = gesture_classes[pred_gesture_indices]
    scores = cmi_metric(y_val_gestures, pred_gestures, bfrb_gestures)
    return scores


def get_individual_gesture_scores(model, X_val, y_val_gestures, gesture_classes, bfrb_gestures):
    """Calculate F1 score for each individual gesture"""
    predictions = model.predict(X_val, verbose=0)[0]
    pred_gesture_indices = predictions.argmax(axis=1)
    pred_gestures = gesture_classes[pred_gesture_indices]
    
    gesture_scores = {}
    unique_gestures = np.unique(y_val_gestures)
    
    for gesture in unique_gestures:
        y_true_binary = (y_val_gestures == gesture).astype(int)
        y_pred_binary = (pred_gestures == gesture).astype(int)
        
        if y_true_binary.sum() > 0:
            f1 = f1_score(y_true_binary, y_pred_binary, zero_division=0)
            gesture_scores[gesture] = f1
    
    return gesture_scores

def evaluate_dual_cmi_metric(model, X_val, y_val_gestures, gesture_classes, bfrb_gestures, imu_dim):
    
    cmi_full = evaluate_with_cmi_metric(model, X_val, y_val_gestures, gesture_classes, bfrb_gestures)
    
    
    X_val_imu_only = X_val.copy()
    X_val_imu_only[:, :, imu_dim:] = 0.0
    
    cmi_imu = evaluate_with_cmi_metric(model, X_val_imu_only, y_val_gestures, gesture_classes, bfrb_gestures)
    
    realistic_composite = (cmi_full['composite_score'] + cmi_imu['composite_score']) / 2.0
    realistic_binary = (cmi_full['binary_f1'] + cmi_imu['binary_f1']) / 2.0
    realistic_macro = (cmi_full['macro_f1'] + cmi_imu['macro_f1']) / 2.0
    
    sensor_dependency = cmi_full['composite_score'] - cmi_imu['composite_score']
    
    return {
        'composite_score': realistic_composite,
        'binary_f1': realistic_binary, 
        'macro_f1': realistic_macro,
        'full_sensor_composite': cmi_full['composite_score'],
        'full_sensor_binary': cmi_full['binary_f1'],
        'full_sensor_macro': cmi_full['macro_f1'],
        'imu_only_composite': cmi_imu['composite_score'],
        'imu_only_binary': cmi_imu['binary_f1'],
        'imu_only_macro': cmi_imu['macro_f1'],
        'sensor_dependency': sensor_dependency,
        'performance_stability': 1.0 - (sensor_dependency / max(cmi_full['composite_score'], 0.01))
    }

class IMUSpecificScaler:
    def __init__(self):
         self.imu_scaler = StandardScaler()
         self.tof_scaler = StandardScaler()
         self.imu_dim = None
            
    def fit(self, X, imu_dim):
         self.imu_dim = imu_dim
         self.imu_scaler.fit(X[:, :imu_dim])
         self.tof_scaler.fit(X[:, imu_dim:])
         return self
            
    def transform(self, X):
         X_imu = self.imu_scaler.transform(X[:, :self.imu_dim])
         X_tof = self.tof_scaler.transform(X[:, self.imu_dim:])
         return np.concatenate([X_imu, X_tof], axis=1)
        
class EnhancedCMIMetricCallback(tf.keras.callbacks.Callback):
    def __init__(self, X_val, y_val_gestures, gesture_classes, bfrb_gestures, imu_dim, patience=40, verbose=1):
        super().__init__()
        self.X_val = X_val
        self.y_val_gestures = y_val_gestures
        self.gesture_classes = gesture_classes
        self.bfrb_gestures = bfrb_gestures
        self.imu_dim = imu_dim
        self.patience = patience
        self.verbose = verbose
        self.best_score = -np.inf
        self.wait = 0
        self.best_weights = None
        
    def on_epoch_end(self, epoch, logs=None):
        dual_scores = evaluate_dual_cmi_metric(
            self.model, self.X_val, self.y_val_gestures, 
            self.gesture_classes, self.bfrb_gestures, self.imu_dim
        )
        
        realistic_composite = dual_scores['composite_score']
        
        logs = logs or {}
        logs['val_realistic_composite'] = realistic_composite
        logs['val_full_sensor_composite'] = dual_scores['full_sensor_composite']
        logs['val_imu_only_composite'] = dual_scores['imu_only_composite']
        logs['val_sensor_dependency'] = dual_scores['sensor_dependency']
        
        # Silent progress
        if self.verbose > 0:
            print('.', end='', flush=True)
        
        if realistic_composite > self.best_score:
            self.best_score = realistic_composite
            self.wait = 0
            self.best_weights = self.model.get_weights()
        else:
            self.wait += 1
            
        if self.wait >= self.patience:
            self.model.stop_training = True
            
    def on_train_end(self, logs=None):
        if self.best_weights is not None:
            self.model.set_weights(self.best_weights)
        
def create_detailed_confusion_analysis(models, X_val_all, y_val_all, le_classes, bfrb_gestures):
    """Create detailed confusion matrix and misclassification analysis"""
    import matplotlib.pyplot as plt
    import seaborn as sns
    from sklearn.metrics import confusion_matrix
    
    # Get ensemble predictions
    all_predictions = []
    for model in models:
        pred = model.predict(X_val_all, verbose=0)[0]
        all_predictions.append(pred)
    
    ensemble_pred = np.mean(all_predictions, axis=0)
    pred_classes = ensemble_pred.argmax(axis=1)
    true_classes = y_val_all.argmax(axis=1)
    
    # Create confusion matrix
    cm = confusion_matrix(true_classes, pred_classes)
    
    print("\n" + "="*80)
    print("DETAILED CONFUSION MATRIX ANALYSIS")
    print("="*80)
    
    # 1. Overall confusion matrix visualization
    plt.figure(figsize=(15, 12))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=[cls[:20] for cls in le_classes],
                yticklabels=[cls[:20] for cls in le_classes])
    plt.title('Gesture Confusion Matrix')
    plt.ylabel('True Gesture')
    plt.xlabel('Predicted Gesture')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(EXPORT_DIR / 'confusion_matrix.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    # 2. Detailed misclassification analysis
    print("\n1. TOP MISCLASSIFICATIONS:")
    print("-" * 50)
    
    misclassifications = []
    for i, true_class in enumerate(le_classes):
        for j, pred_class in enumerate(le_classes):
            if i != j and cm[i][j] > 0:
                misclassifications.append({
                    'true': true_class,
                    'predicted': pred_class, 
                    'count': cm[i][j],
                    'true_is_bfrb': true_class in bfrb_gestures,
                    'pred_is_bfrb': pred_class in bfrb_gestures
                })
    
    # Sort by count
    misclassifications.sort(key=lambda x: x['count'], reverse=True)
    
    print("Most frequent misclassifications:")
    for i, mc in enumerate(misclassifications[:15]):
        true_marker = " " if mc['true_is_bfrb'] else "  "
        pred_marker = " " if mc['pred_is_bfrb'] else "  "
        print(f"{i+1:2d}. {true_marker} {mc['true'][:25]:25} → {pred_marker} {mc['predicted'][:25]:25} ({mc['count']} times)")
    
    # 3. BFRB-specific confusion analysis
    print("\n2. BFRB GESTURE CONFUSION PATTERNS:")
    print("-" * 50)
    
    bfrb_confusions = [mc for mc in misclassifications if mc['true_is_bfrb']]
    
    print("BFRB gestures confused with other BFRB gestures:")
    bfrb_to_bfrb = [mc for mc in bfrb_confusions if mc['pred_is_bfrb']]
    for mc in bfrb_to_bfrb[:10]:
        print(f"   {mc['true'][:30]:30} →  {mc['predicted'][:30]:30} ({mc['count']}x)")
    
    print(f"\nBFRB gestures confused with Non-BFRB gestures:")
    bfrb_to_non = [mc for mc in bfrb_confusions if not mc['pred_is_bfrb']]
    for mc in bfrb_to_non[:10]:
        print(f"   {mc['true'][:30]:30} →    {mc['predicted'][:30]:30} ({mc['count']}x)")
    
    # 4. Gesture-specific accuracy
    print("\n3. INDIVIDUAL GESTURE ACCURACY:")
    print("-" * 50)
    
    gesture_accuracy = {}
    for i, gesture in enumerate(le_classes):
        total_true = cm[i].sum()
        correct = cm[i][i]
        accuracy = correct / total_true if total_true > 0 else 0
        gesture_accuracy[gesture] = accuracy
    
    print("BFRB Gesture Accuracy:")
    bfrb_acc = {g: a for g, a in gesture_accuracy.items() if g in bfrb_gestures}
    for gesture, acc in sorted(bfrb_acc.items(), key=lambda x: x[1]):
        print(f"  {gesture[:35]:35} {acc:.3f}")
    
    print(f"\nNon-BFRB Gesture Accuracy (top 10):")
    non_bfrb_acc = {g: a for g, a in gesture_accuracy.items() if g not in bfrb_gestures}
    for gesture, acc in sorted(non_bfrb_acc.items(), key=lambda x: x[1], reverse=True)[:10]:
        print(f"  {gesture[:35]:35} {acc:.3f}")
    
    # 5. Actionable insights
    print("\n4. ACTIONABLE INSIGHTS:")
    print("-" * 50)
    
    worst_bfrb = sorted(bfrb_acc.items(), key=lambda x: x[1])[:3]
    print("PRIORITY: Worst performing BFRB gestures:")
    for gesture, acc in worst_bfrb:
        main_confusions = [mc for mc in misclassifications if mc['true'] == gesture][:3]
        print(f"\n  • {gesture} (accuracy: {acc:.3f})")
        print("    Most confused with:")
        for mc in main_confusions:
            conf_type = "BFRB" if mc['pred_is_bfrb'] else "Non-BFRB"
            print(f"      - {mc['predicted']} ({conf_type}, {mc['count']}x)")
    
    return cm, misclassifications, gesture_accuracy
def time_sum(x):
    return tf.reduce_sum(x, axis=1)

def squeeze_last_axis(x):
    return tf.squeeze(x, axis=-1)

def expand_last_axis(x):
    return tf.expand_dims(x, axis=-1)

def se_block(x, reduction=8):
    ch = x.shape[-1]
    se = GlobalAveragePooling1D()(x)
    se = Dense(ch // reduction, activation='relu')(se)
    se = Dense(ch, activation='sigmoid')(se)
    se = Reshape((1, ch))(se)
    return Multiply()([x, se])

def residual_se_cnn_block(x, filters, kernel_size, pool_size=2, drop=0.3, wd=1e-4):
    shortcut = x
    for _ in range(2):
        x = Conv1D(filters, kernel_size, padding='same', use_bias=False, kernel_regularizer=l2(wd))(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
    x = se_block(x)
    
    if shortcut.shape[-1] != filters:
        shortcut = Conv1D(filters, 1, padding='same', use_bias=False, kernel_regularizer=l2(wd))(shortcut)
        shortcut = BatchNormalization()(shortcut)
    
    x = add([x, shortcut])
    x = Activation('relu')(x)
    x = MaxPooling1D(pool_size)(x)
    x = Dropout(drop)(x)
    return x

def attention_layer(inputs):
    score = Dense(1, activation='tanh')(inputs)
    score = Lambda(squeeze_last_axis)(score)
    weights = Activation('softmax')(score)
    weights = Lambda(expand_last_axis)(weights)
    context = Multiply()([inputs, weights])
    context = Lambda(time_sum)(context)
    return context

class GatedMixupGenerator(tf.keras.utils.Sequence):
    def __init__(self, X, y, batch_size, imu_dim, class_weight=None, alpha=0.2, masking_prob=0.0):
        self.X, self.y = X, y
        self.batch = batch_size
        self.imu_dim = imu_dim
        self.class_weight = class_weight
        self.alpha = alpha
        self.masking_prob = masking_prob
        self.indices = np.arange(len(X))
        self.on_epoch_end()

    def __len__(self):
        return int(np.ceil(len(self.X) / self.batch))

    def __getitem__(self, i):
        idx = self.indices[i*self.batch:(i+1)*self.batch]
        Xb, yb = self.X[idx].copy(), self.y[idx].copy()
        
        sample_weights = np.ones(len(Xb), dtype='float32')
        if self.class_weight:
            y_integers = yb.argmax(axis=1)
            sample_weights = np.array([self.class_weight[i] for i in y_integers])
        
        gate_target = np.ones(len(Xb), dtype='float32')
        if self.masking_prob > 0:
            for i in range(len(Xb)):
                if np.random.rand() < self.masking_prob:
                    Xb[i, :, self.imu_dim:] = 0
                    gate_target[i] = 0.0
        
        if self.alpha > 0:
            lam = np.random.beta(self.alpha, self.alpha)
            perm = np.random.permutation(len(Xb))
            X_mix = lam * Xb + (1 - lam) * Xb[perm]
            y_mix = lam * yb + (1 - lam) * yb[perm]
            gate_target_mix = lam * gate_target + (1 - lam) * gate_target[perm]
            sample_weights_mix = lam * sample_weights + (1 - lam) * sample_weights[perm]
            return X_mix, {'main_output': y_mix, 'tof_gate': gate_target_mix}, sample_weights_mix
        
        return Xb, {'main_output': yb, 'tof_gate': gate_target}, sample_weights

    def on_epoch_end(self):
        np.random.shuffle(self.indices)

def build_gated_two_branch_model(pad_len, imu_dim, tof_dim, n_classes, wd=1e-4):
    inp = Input(shape=(pad_len, imu_dim+tof_dim))
    imu = Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = Lambda(lambda t: t[:, :, imu_dim:])(inp)
    
    x1 = residual_se_cnn_block(imu, 64, 3, drop=0.1, wd=wd)
    x1 = residual_se_cnn_block(x1, 128, 5, drop=0.1, wd=wd)
    
    x2_base = Conv1D(64, 3, padding='same', use_bias=False, kernel_regularizer=l2(wd))(tof)
    x2_base = BatchNormalization()(x2_base)
    x2_base = Activation('relu')(x2_base)
    x2_base = MaxPooling1D(2)(x2_base)
    x2_base = Dropout(0.2)(x2_base)
    
    x2_base = Conv1D(128, 3, padding='same', use_bias=False, kernel_regularizer=l2(wd))(x2_base)
    x2_base = BatchNormalization()(x2_base)
    x2_base = Activation('relu')(x2_base)
    x2_base = MaxPooling1D(2)(x2_base)
    x2_base = Dropout(0.2)(x2_base)
    
    gate_input = GlobalAveragePooling1D()(tof)
    gate_input = Dense(16, activation='relu')(gate_input)
    gate = Dense(1, activation='sigmoid', name='tof_gate')(gate_input)
    x2 = Multiply()([x2_base, gate])
    
    merged = Concatenate()([x1, x2])
    xa = Bidirectional(LSTM(128, return_sequences=True, kernel_regularizer=l2(wd)))(merged)
    xb = Bidirectional(GRU(128, return_sequences=True, kernel_regularizer=l2(wd)))(merged)
    xc = GaussianNoise(0.09)(merged)
    xc = Dense(16, activation='elu')(xc)
    x = Concatenate()([xa, xb, xc])
    x = Dropout(0.4)(x)
    x = attention_layer(x)
    
    for units, drop in [(256, 0.5), (128, 0.3)]:
        x = Dense(units, use_bias=False, kernel_regularizer=l2(wd))(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        x = Dropout(drop)(x)
    
    out = Dense(n_classes, activation='softmax', name='main_output', kernel_regularizer=l2(wd))(x)
    return Model(inputs=inp, outputs=[out, gate])
 
if TRAIN:
    print("----------TRAINING MODE---------")
    
    train = pd.read_csv(BASE_DIR / "train.csv")
    train_dem = pd.read_csv(BASE_DIR / "train_demographics.csv")
    df = pd.merge(train, train_dem, on='subject', how='left')
    
    le = LabelEncoder()
    df['gesture_int'] = le.fit_transform(df['gesture'])
    np.save(EXPORT_DIR / "gesture_classes.npy", le.classes_)
    print("Data loaded | Unique gestures:", len(le.classes_))
    
    bfrb_gestures = [
        'Above ear - pull hair',
        'Forehead - pull hairline', 
        'Forehead - scratch',
        'Eyebrow - pull hair',
        'Eyelash - pull hair',
        'Neck - pinch skin',
        'Neck - scratch',
        'Cheek - pinch skin'
    ]
    
    print("Calculating physics-based features with sequence grouping...")
    
    linear_accel_list = []
    for _, group in df.groupby('sequence_id'):
        linear_accel = remove_gravity_from_acc(
            group[['acc_x', 'acc_y', 'acc_z']], 
            group[['rot_x', 'rot_y', 'rot_z', 'rot_w']]
        )
        linear_accel_df = pd.DataFrame(
            linear_accel,
            columns=['linear_acc_x', 'linear_acc_y', 'linear_acc_z'],
            index=group.index
        )
        linear_accel_list.append(linear_accel_df)
    
    df = pd.concat([df, pd.concat(linear_accel_list)], axis=1)
    df['linear_acc_mag'] = np.sqrt(df['linear_acc_x']**2 + df['linear_acc_y']**2 + df['linear_acc_z']**2)
    df['linear_acc_mag_jerk'] = df.groupby('sequence_id')['linear_acc_mag'].diff().fillna(0)
    
    angular_vel_list = []
    for _, group in df.groupby('sequence_id'):
        angular_vel = calculate_angular_velocity_from_quat(
            group[['rot_x', 'rot_y', 'rot_z', 'rot_w']]
        )
        angular_vel_df = pd.DataFrame(
            angular_vel,
            columns=['angular_vel_x', 'angular_vel_y', 'angular_vel_z'],
            index=group.index
        )
        angular_vel_list.append(angular_vel_df)
    
    df = pd.concat([df, pd.concat(angular_vel_list)], axis=1)
    df['angular_vel_mag'] = np.sqrt(df['angular_vel_x']**2 + df['angular_vel_y']**2 + df['angular_vel_z']**2)
    df['angular_vel_mag_jerk'] = df.groupby('sequence_id')['angular_vel_mag'].diff().fillna(0)
    
    angular_dist_list = []
    for _, group in df.groupby('sequence_id'):
        angular_dist = calculate_angular_distance(
            group[['rot_x', 'rot_y', 'rot_z', 'rot_w']]
        )
        angular_dist_df = pd.DataFrame(
            angular_dist,
            columns=['angular_distance'],
            index=group.index
        )
        angular_dist_list.append(angular_dist_df)
    
    df = pd.concat([df, pd.concat(angular_dist_list)], axis=1)
    df['gesture_rhythm_signature'] = df.groupby('sequence_id')['linear_acc_mag'].transform(
        lambda x: x.rolling(5, min_periods=1).std() / (x.rolling(5, min_periods=1).mean() + 1e-6)
    )
    imu_cols_base = ['linear_acc_x', 'linear_acc_y', 'linear_acc_z'] + [c for c in df.columns if c.startswith('rot_')]
    imu_engineered = ['linear_acc_mag', 'linear_acc_mag_jerk', 'angular_vel_x', 'angular_vel_y', 'angular_vel_z', 'angular_distance','angular_vel_mag','angular_vel_mag_jerk','gesture_rhythm_signature']
    imu_cols = list(dict.fromkeys(imu_cols_base + imu_engineered))
    
    thm_cols = [c for c in df.columns if c.startswith('thm_')]
    tof_agg_cols = []
    for i in range(1, 6):
        tof_agg_cols.extend([f'tof_{i}_mean', f'tof_{i}_std', f'tof_{i}_min', f'tof_{i}_max'])
    
    final_feature_cols = imu_cols + thm_cols + tof_agg_cols
    imu_dim = len(imu_cols)
    tof_thm_dim = len(thm_cols) + len(tof_agg_cols)
    
    print(f"Feature dimensions: IMU={imu_dim} | TOF/THM={tof_thm_dim} | Total={len(final_feature_cols)}")
    np.save(EXPORT_DIR / "feature_cols.npy", np.array(final_feature_cols))
    
    print("Building sequences...")
    seq_gp = df.groupby('sequence_id')
    X_list_unscaled, y_list, groups_list, lens = [], [], [], []
    
    for seq_id, seq_df in seq_gp:
        seq_df_copy = seq_df.copy()
        for i in range(1, 6):
            pixel_cols = [f"tof_{i}_v{p}" for p in range(64)]
            tof_sensor_data = seq_df_copy[pixel_cols].replace(-1, np.nan)
            seq_df_copy[f'tof_{i}_mean'] = tof_sensor_data.mean(axis=1)
            seq_df_copy[f'tof_{i}_std'] = tof_sensor_data.std(axis=1)
            seq_df_copy[f'tof_{i}_min'] = tof_sensor_data.min(axis=1)
            seq_df_copy[f'tof_{i}_max'] = tof_sensor_data.max(axis=1)
        
        mat_unscaled = seq_df_copy[final_feature_cols].ffill().bfill().fillna(0).values.astype('float32')
        X_list_unscaled.append(mat_unscaled)
        y_list.append(seq_df_copy['gesture_int'].iloc[0])
        groups_list.append(seq_df_copy['subject'].iloc[0])
        lens.append(len(mat_unscaled))
    
    print("Fitting IMU-Specific StandardScalers...")
    all_steps_concatenated = np.concatenate(X_list_unscaled, axis=0)
    scaler = IMUSpecificScaler().fit(all_steps_concatenated, imu_dim)
    joblib.dump(scaler, EXPORT_DIR / "imu_specific_scaler.pkl")
    
    print("Scaling and padding sequences with IMU-specific normalization...")
    X_scaled_list = [scaler.transform(x_seq) for x_seq in X_list_unscaled]
    del X_list_unscaled
    
    pad_len = int(np.percentile(lens, PAD_PERCENTILE))
    np.save(EXPORT_DIR / "sequence_maxlen.npy", pad_len)
    
    X = pad_sequences(X_scaled_list, maxlen=pad_len, padding='post', truncating='post', dtype='float32')
    del X_scaled_list
    
    y_stratify = np.array(y_list)
    y = to_categorical(y_list, num_classes=len(le.classes_))
    groups = np.array(groups_list)
    
    print(f"Starting realistic {N_SPLITS}-fold training with dual evaluation...")
    sgkf = StratifiedGroupKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
    fold_realistic_scores = []
    fold_gesture_scores = []

    for fold, (train_idx, val_idx) in enumerate(sgkf.split(X, y_stratify, groups)):
        print(f"\n{'='*50}")
        print(f"FOLD {fold+1}/{N_SPLITS} - Training in progress", end='')
        
        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]
        y_val_gestures = le.classes_[y_stratify[val_idx]]

        model = build_gated_two_branch_model(pad_len, imu_dim, tof_thm_dim, len(le.classes_), wd=WD)
        # Learning rate scheduler setup
        if USE_LR_SCHEDULER:
            steps_per_epoch = len(X_train) // BATCH_SIZE
            total_steps = steps_per_epoch * EPOCHS
            
            if LR_SCHEDULE_TYPE == "cosine_decay":
                lr_schedule = CosineDecay(
                    initial_learning_rate=LR_INIT,
                    decay_steps=total_steps,
                    alpha=0.01  # End at 1% of initial LR
                )
            elif LR_SCHEDULE_TYPE == "cosine_decay_restarts":
                lr_schedule = CosineDecayRestarts(
                    initial_learning_rate=LR_INIT,
                    first_decay_steps=steps_per_epoch * 20,  # Restart every 20 epochs
                    t_mul=1.2,  # Increase restart period by 20% each time
                    m_mul=0.8,  # Reduce max LR by 20% each restart
                    alpha=0.01  # Minimum LR as fraction of initial
                )
            elif LR_SCHEDULE_TYPE == "exponential_decay":
                lr_schedule = ExponentialDecay(
                    initial_learning_rate=LR_INIT,
                    decay_steps=steps_per_epoch * 10,  # Decay every 10 epochs
                    decay_rate=0.9,
                    staircase=False
                )
            else:
                lr_schedule = LR_INIT  # No scheduling
            
            optimizer = Adam(learning_rate=lr_schedule)
        else:
            optimizer = Adam(LR_INIT)
        
        
        model.compile(
            optimizer = optimizer,
            loss={
                'main_output': tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),  
                'tof_gate': 'binary_crossentropy'
            },
            loss_weights={'main_output': 1.0, 'tof_gate': GATE_LOSS_WEIGHT},
            metrics={'main_output': 'accuracy'}
        )
        
        class_weights = compute_class_weight(
            'balanced', 
            classes=np.arange(len(le.classes_)), 
            y=y_train.argmax(1)
        )
        class_weight_dict = dict(enumerate(class_weights))
        
        train_gen = GatedMixupGenerator(
            X_train, y_train, BATCH_SIZE, imu_dim,
            class_weight=class_weight_dict, alpha=MIXUP_ALPHA, masking_prob=MASKING_PROB
        )
        val_gen = GatedMixupGenerator(X_val, y_val, BATCH_SIZE, imu_dim)
        enhanced_callback = EnhancedCMIMetricCallback(
            X_val, y_val_gestures, le.classes_, bfrb_gestures, imu_dim,
            patience=PATIENCE, verbose=1
        )
        
        model.fit(
            train_gen, validation_data=val_gen, epochs=EPOCHS,
            callbacks=[enhanced_callback], verbose=0
        )
        
        final_scores = evaluate_dual_cmi_metric(
            model, X_val, y_val_gestures, le.classes_, bfrb_gestures, imu_dim
        )
        
        gesture_scores = get_individual_gesture_scores(
            model, X_val, y_val_gestures, le.classes_, bfrb_gestures
        )
        
        print(f"\nFOLD {fold+1} COMPLETED ✓")
        print(f"Composite: {final_scores['composite_score']:.4f} | "
              f"Binary: {final_scores['binary_f1']:.4f} | "
              f"Macro: {final_scores['macro_f1']:.4f}")
        print(f"IMU-only: {final_scores['imu_only_composite']:.4f} | "
              f"Full: {final_scores['full_sensor_composite']:.4f} | "
              f"Gap: {final_scores['sensor_dependency']:.4f}")
        
        print("\nIndividual Gesture F1 Scores:")
        for gesture, score in sorted(gesture_scores.items(), key=lambda x: x[1], reverse=True):
            print(f"  {gesture[:30]:30} {score:.3f}")
        
        fold_realistic_scores.append(final_scores)
        fold_gesture_scores.append(gesture_scores)
        
        model.save(EXPORT_DIR / f"gesture_model_fold_{fold}.h5")
        print(f"Model saved: fold_{fold}.h5")

    print("\n----Training Complete----")
    print("\nAverage Results Across All Folds:")
    avg_scores = {
        'composite_score': np.mean([s['composite_score'] for s in fold_realistic_scores]),
        'macro_f1': np.mean([s['macro_f1'] for s in fold_realistic_scores]),
        'binary_f1': np.mean([s['binary_f1'] for s in fold_realistic_scores]),
        'imu_only_composite': np.mean([s['imu_only_composite'] for s in fold_realistic_scores]),
        'full_sensor_composite': np.mean([s['full_sensor_composite'] for s in fold_realistic_scores]),
        'sensor_dependency': np.mean([s['sensor_dependency'] for s in fold_realistic_scores])
    }
    print(f"Actual Composite: {avg_scores['composite_score']:.4f}")
    print(f"Macro: {avg_scores['macro_f1']:.4f}")
    print(f"Binary: {avg_scores['binary_f1']:.4f}")
    print(f"Imu: {avg_scores['imu_only_composite']:.4f}")
    print(f"Full: {avg_scores['full_sensor_composite']:.4f}")
    print(f"Sensor Gap: {avg_scores['sensor_dependency']:.4f}")
    print("\nGenerating detailed confusion analysis...")
    X_val_all = []
    y_val_all = []
    for fold, (train_idx, val_idx) in enumerate(sgkf.split(X, y_stratify, groups)):
        X_val_all.append(X[val_idx])
        y_val_all.append(y[val_idx])
    X_val_combined = np.concatenate(X_val_all, axis=0)
    y_val_combined = np.concatenate(y_val_all, axis=0)
    
    models = []
    for fold in range(N_SPLITS):
        model = load_model(EXPORT_DIR / f"gesture_model_fold_{fold}.h5", 
                           custom_objects={
                               'time_sum': time_sum,
                               'squeeze_last_axis': squeeze_last_axis,
                               'expand_last_axis': expand_last_axis
                           })
        models.append(model)
    cm, misclassifications, gesture_accuracy = create_detailed_confusion_analysis(
        models, X_val_combined, y_val_combined, le.classes_, bfrb_gestures
    )

else:
    print("▶ INFERENCE MODE – loading artifacts from", PRETRAINED_DIR)
    
    # Load all saved artifacts - UPDATED FOR NEW SCALER
    final_feature_cols = np.load(PRETRAINED_DIR / "feature_cols.npy", allow_pickle=True).tolist()
    pad_len = int(np.load(PRETRAINED_DIR / "sequence_maxlen.npy"))
    scaler = joblib.load(PRETRAINED_DIR / "imu_specific_scaler.pkl")  # Updated scaler
    gesture_classes = np.load(PRETRAINED_DIR / "gesture_classes.npy", allow_pickle=True)
    
    print(f"  Loaded feature columns: {len(final_feature_cols)}")
    print(f"  Sequence padding length: {pad_len}")
    print(f"  Gesture classes: {len(gesture_classes)}")
    
    # Define custom objects for model loading
    custom_objects = {
        'time_sum': time_sum,
        'squeeze_last_axis': squeeze_last_axis, 
        'expand_last_axis': expand_last_axis,
        'se_block': se_block,
        'residual_se_cnn_block': residual_se_cnn_block,
        'attention_layer': attention_layer,
    }
    
    # Load ensemble of models
    # Load ensemble of TF (Keras) models — keep them away from PyTorch `models`
    tf_models = []
    print(f"  Loading {N_SPLITS} TF models for ensemble inference...")
    for fold in range(N_SPLITS):
        model_path = PRETRAINED_DIR / f"gesture_model_fold_{fold}.h5"
        if model_path.exists():
            m = load_model(model_path, compile=False, custom_objects=custom_objects)
            tf_models.append(m)                                   # ← 只放进 tf_models
            print(f"    ✓ Loaded TF fold {fold} model")
        else:
            print(f"    ✗ TF model fold {fold} not found at {model_path}")

    print(f"  Successfully loaded {len(tf_models)} TF models")    # ← 统计 tf_models


def _get_main_logits(model, x, n_classes):
    """
    统一从单输出/多输出模型取出主 logits，并标准化为形状 (1, n_classes)。
    任何不匹配都抛出短错误，由上层 try/except 兜底。
    """
    out = model.predict(x, verbose=0)

    # 多输出模型：取第一个（main_output）
    if isinstance(out, (list, tuple)):
        out = out[0]

    out = np.asarray(out)

    # 单输出模型但被切成 (18,) 的情况，补回 batch 维
    if out.ndim == 1:
        out = out.reshape(1, -1)

    # 再保险：严格检查类别维度
    if out.shape[-1] != n_classes:
        raise ValueError(f"model logits dim {out.shape[-1]} != expected {n_classes}")

    # 现在一定是 (1, n_classes)
    return out

def predict_pf_allsensor(sequence: pl.DataFrame, demographics: pl.DataFrame) -> np.ndarray:
    """
    使用已加载的 Keras 折叠模型做集成预测。
    保证：任意异常仅打印短信息，不向外抛；返回 (1, C) 概率。
    """
    try:
        df_seq = sequence.to_pandas()

        # ----- feature engineering -----
        linear_accel = remove_gravity_from_acc(
            df_seq[['acc_x', 'acc_y', 'acc_z']],
            df_seq[['rot_x', 'rot_y', 'rot_z', 'rot_w']]
        )
        df_seq['linear_acc_x'] = linear_accel[:, 0]
        df_seq['linear_acc_y'] = linear_accel[:, 1]
        df_seq['linear_acc_z'] = linear_accel[:, 2]
        df_seq['linear_acc_mag'] = np.sqrt(
            df_seq['linear_acc_x']**2 + df_seq['linear_acc_y']**2 + df_seq['linear_acc_z']**2
        )
        df_seq['linear_acc_mag_jerk'] = df_seq['linear_acc_mag'].diff().fillna(0)

        angular_vel = calculate_angular_velocity_from_quat(
            df_seq[['rot_x', 'rot_y', 'rot_z', 'rot_w']]
        )
        df_seq['angular_vel_x'] = angular_vel[:, 0]
        df_seq['angular_vel_y'] = angular_vel[:, 1]
        df_seq['angular_vel_z'] = angular_vel[:, 2]
        df_seq['angular_vel_mag'] = np.sqrt(
            df_seq['angular_vel_x']**2 + df_seq['angular_vel_y']**2 + df_seq['angular_vel_z']**2
        )
        df_seq['angular_vel_mag_jerk'] = df_seq['angular_vel_mag'].diff().fillna(0)

        angular_dist = calculate_angular_distance(
            df_seq[['rot_x', 'rot_y', 'rot_z', 'rot_w']]
        )
        df_seq['angular_distance'] = angular_dist

        df_seq['gesture_rhythm_signature'] = (
            df_seq['linear_acc_mag'].rolling(5, min_periods=1).std()
            / (df_seq['linear_acc_mag'].rolling(5, min_periods=1).mean() + 1e-6)
        )

        for i in range(1, 6):
            pixel_cols = [f"tof_{i}_v{p}" for p in range(64)]
            tof_sensor_data = df_seq[pixel_cols].replace(-1, np.nan)
            df_seq[f'tof_{i}_mean'] = tof_sensor_data.mean(axis=1)
            df_seq[f'tof_{i}_std']  = tof_sensor_data.std(axis=1)
            df_seq[f'tof_{i}_min']  = tof_sensor_data.min(axis=1)
            df_seq[f'tof_{i}_max']  = tof_sensor_data.max(axis=1)

        # ----- scale & pad -----
        mat_unscaled = df_seq[final_feature_cols].ffill().bfill().fillna(0).values.astype('float32')
        mat_scaled = scaler.transform(mat_unscaled)
        padded_input = pad_sequences([mat_scaled], maxlen=pad_len, padding='post', truncating='post', dtype='float32')

        # ----- TF ensemble only -----
        C = len(gesture_classes)
        if not tf_models:
            return np.ones((1, C), dtype=np.float32) / C

        all_predictions = []
        for m in tf_models:
            try:
                pred_main = _get_main_logits(m, padded_input, n_classes=C)  # (1, C)
                all_predictions.append(pred_main.astype(np.float32, copy=False))
            except Exception as e:
                short = f"{type(e).__name__}: {str(e)[:200]}"
                print(f"[WARN] TF model {getattr(m,'name','?')} prediction failed: {short}")
                all_predictions.append(np.ones((1, C), dtype=np.float32) / C)

        ensemble_pred = np.mean(np.stack(all_predictions, axis=0), axis=0).astype(np.float32, copy=False)
        return ensemble_pred

    except Exception as e:
        short = f"{type(e).__name__}: {str(e)[:800]}"
        print(f"[ERROR] predict_pf_allsensor failed: {short}")
        C = len(gesture_classes) if 'gesture_classes' in globals() else 18
        return np.ones((1, C), dtype=np.float32) / C



# imu only model

In [None]:
# ==================== TensorFlow SE-1DCNN 模型集成模块 ====================
#
# 此模块提供 predict_tensorflow_se1dcnn 函数用于模型集成
# 使用类封装避免命名空间冲突
# 核心逻辑: SE-1DCNN + Attention模型, 带有四元数安全处理
#
# ------------------------------------------------------------------
import os
import joblib
import json
import numpy as np
import pandas as pd
import polars as pl
from pathlib import Path
from scipy.spatial.transform import Rotation as R
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
from tensorflow.keras.layers import Layer
import warnings
warnings.filterwarnings("ignore")

# 设置随机种子
def seed_everything(seed=42):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

seed_everything(42)

class TensorFlowSE1DCNNModel:
    """封装TensorFlow SE-1DCNN模型,避免命名空间冲突"""
    
    def __init__(self):
        # 配置路径 - 需要根据实际情况修改
        self.PREPROCESS_DIR = Path("/kaggle/input/imuonly-process-model/imuonly_porcess/kaggle/working/processed_data_selected_features_v1")  # 修改为实际路径
        self.MODEL_DIR = Path("/kaggle/input/imuonly-process-model/imuonly1_model/kaggle/working/saved_models_keras_fixed_test")  # 修改为实际路径
        self.N_FOLDS = 5
        self.MAX_SEQ_LENGTH = 128
        
        print(f"TensorFlow SE-1DCNN模块: 正在加载配置...")
        
        # 加载特征和标签配置
        with open(self.PREPROCESS_DIR / "feature_names.json", 'r') as f:
            feature_info = json.load(f)
            self.ALL_FEATURE_NAMES = feature_info['all_features']
        
        with open(self.PREPROCESS_DIR / "label_map.json", 'r') as f:
            self.LABEL2IDX = json.load(f)
            self.IDX2LABEL = {v: k for k, v in self.LABEL2IDX.items()}
            self.N_CLASSES = len(self.LABEL2IDX)
            
        # 统一的标签顺序(与所有模型对齐)
        self.UNIFIED_LABELS = [
            'Above ear - pull hair', 'Cheek - pinch skin', 'Drink from bottle/cup',
            'Eyebrow - pull hair', 'Eyelash - pull hair',
            'Feel around in tray and pull out an object', 'Forehead - pull hairline',
            'Forehead - scratch', 'Glasses on/off', 'Neck - pinch skin',
            'Neck - scratch', 'Pinch knee/leg skin', 'Pull air toward your face',
            'Scratch knee/leg skin', 'Text on phone', 'Wave hello',
            'Write name in air', 'Write name on leg'
        ]

        # 创建从本模型索引到统一索引的映射
        self.model_to_unified_indices = [0] * self.N_CLASSES
        for label, model_index in self.LABEL2IDX.items():
            try:
                unified_index = self.UNIFIED_LABELS.index(label)
                self.model_to_unified_indices[model_index] = unified_index
            except ValueError:
                print(f"警告: 标签 '{label}' 不在统一列表中。")

        # 选择的特征(与训练时完全一致)
        self.SELECTED_FEATURES = [
            'rot_w', 'rot_x', 'rot_y', 'rot_z',           # 四元数
            'linear_acc_x', 'linear_acc_y', 'linear_acc_z', # 线性加速度
            'linear_acc_mag',                               # 线性加速度模长
            'angular_vel_x', 'angular_vel_y', 'angular_vel_z', # 角速度
            'angular_distance',                             # 角距离
            'acc_mag'                                       # 加速度模长
        ]
        self.FEATURE_INDICES = [self.ALL_FEATURE_NAMES.index(f) for f in self.SELECTED_FEATURES]

        # 模型和缩放器
        self.models = []
        self.scalers = []
        self._loaded = False
        
    @staticmethod
    def remove_gravity_from_acc(acc_values, quat_values):
        """去除重力影响"""
        linear_accel = np.zeros_like(acc_values)
        gravity_world = np.array([0, 0, 9.81])
        for i in range(len(acc_values)):
            if np.all(np.isnan(quat_values[i])) or np.all(np.isclose(quat_values[i], 0)):
                linear_accel[i, :] = acc_values[i, :]
                continue
            try:
                rotation = R.from_quat(quat_values[i])
                gravity_sensor_frame = rotation.apply(gravity_world, inverse=True)
                linear_accel[i, :] = acc_values[i, :] - gravity_sensor_frame
            except ValueError:
                linear_accel[i, :] = acc_values[i, :]
        return linear_accel

    @staticmethod
    def calculate_angular_velocity_from_quat(quat_values, time_delta=1/200):
        """计算角速度"""
        angular_vel = np.zeros((len(quat_values), 3))
        for i in range(len(quat_values) - 1):
            q_t, q_t_plus_dt = quat_values[i], quat_values[i+1]
            if np.all(np.isnan(q_t)) or np.all(np.isclose(q_t, 0)) or \
               np.all(np.isnan(q_t_plus_dt)) or np.all(np.isclose(q_t_plus_dt, 0)):
                continue
            try:
                rot_t = R.from_quat(q_t)
                rot_t_plus_dt = R.from_quat(q_t_plus_dt)
                delta_rot = rot_t.inv() * rot_t_plus_dt
                angular_vel[i, :] = delta_rot.as_rotvec() / time_delta
            except ValueError:
                pass
        return angular_vel
    
    @staticmethod
    def compute_angular_distance_xyzw(quat_values):
        """计算相邻帧的角距离"""
        n = len(quat_values)
        ang = np.zeros(n, dtype=np.float32)
        if n <= 1:
            return ang
        # 归一化
        norm = np.linalg.norm(quat_values, axis=1, keepdims=True)
        mask = norm[:,0] > 1e-8
        quat_values[mask] = quat_values[mask] / norm[mask]
        # 计算角距离
        dot = np.sum(quat_values[:-1] * quat_values[1:], axis=1)
        dot = np.clip(np.abs(dot), -1.0, 1.0)
        ang[1:] = 2.0 * np.arccos(dot).astype(np.float32)
        return ang

    def fix_zero_quaternions(self, features):
        """修复零四元数"""
        quat_features = features[:, :4]
        quat_norm = np.linalg.norm(quat_features, axis=1)
        zero_mask = quat_norm < 1e-8
        
        if np.any(zero_mask):
            # 将零四元数替换为单位四元数 [1,0,0,0] (w,x,y,z格式)
            features[zero_mask, :4] = [1.0, 0.0, 0.0, 0.0]
        
        return features

    def feature_engineering(self, df):
        """特征工程"""
        # 加速度模长
        df['acc_mag'] = np.sqrt(df['acc_x']**2 + df['acc_y']**2 + df['acc_z']**2)
        
        # 线性加速度（去除重力）
        acc_values = df[['acc_x', 'acc_y', 'acc_z']].values
        quat_values = df[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
        
        linear_accel = self.remove_gravity_from_acc(acc_values, quat_values)
        df['linear_acc_x'] = linear_accel[:, 0]
        df['linear_acc_y'] = linear_accel[:, 1]
        df['linear_acc_z'] = linear_accel[:, 2]
        df['linear_acc_mag'] = np.sqrt(df['linear_acc_x']**2 + df['linear_acc_y']**2 + df['linear_acc_z']**2)
        
        # 角速度
        angular_vel = self.calculate_angular_velocity_from_quat(quat_values)
        df['angular_vel_x'] = angular_vel[:, 0]
        df['angular_vel_y'] = angular_vel[:, 1]
        df['angular_vel_z'] = angular_vel[:, 2]
        
        # 角距离
        df['angular_distance'] = self.compute_angular_distance_xyzw(quat_values)

        # 填充缺失值
        for feat in self.SELECTED_FEATURES:
            if feat in df.columns:
                df[feat] = df[feat].ffill().bfill().fillna(0.0).astype('float32')
        
        return df

    def load_models(self):
        """延迟加载模型"""
        if self._loaded:
            return
        
        print("TensorFlow SE-1DCNN模块: 正在加载模型和scalers...")
        
        # 启用不安全的反序列化
        tf.keras.config.enable_unsafe_deserialization()
        
        # 定义自定义层
        class SumPooling1D(Layer):
            def __init__(self, **kwargs):
                super(SumPooling1D, self).__init__(**kwargs)
            def call(self, inputs):
                return tf.reduce_sum(inputs, axis=1)
            def get_config(self):
                return super(SumPooling1D, self).get_config()

        class WarmupCosine(tf.keras.optimizers.schedules.LearningRateSchedule):
            def __init__(self, base_lr, warmup_steps, total_steps, min_lr=1e-5):
                super().__init__()
                self.base_lr = base_lr
                self.warmup_steps = warmup_steps
                self.total_steps = total_steps
                self.min_lr = min_lr
            def __call__(self, step):
                step = tf.cast(step, tf.float32)
                warm = self.base_lr * (step / tf.cast(self.warmup_steps, tf.float32))
                progress = (step - self.warmup_steps) / tf.maximum(1.0, self.total_steps - self.warmup_steps)
                cosine = self.min_lr + 0.5 * (self.base_lr - self.min_lr) * (1 + tf.cos(np.pi * tf.clip_by_value(progress, 0.0, 1.0)))
                return tf.where(step < self.warmup_steps, warm, cosine)
            def get_config(self):
                return {
                    "base_lr": self.base_lr,
                    "warmup_steps": self.warmup_steps,
                    "total_steps": self.total_steps,
                    "min_lr": self.min_lr
                }

        custom_objects = {
            'SumPooling1D': SumPooling1D,
            'WarmupCosine': WarmupCosine
        }
        
        # 加载模型和scaler
        for fold in range(1, self.N_FOLDS + 1):
            scaler_path = self.MODEL_DIR / f"fold_{fold}_scaler.joblib"
            model_path = self.MODEL_DIR / f"fold_{fold}_model.keras"
            
            if scaler_path.exists() and model_path.exists():
                try:
                    scaler = joblib.load(scaler_path)
                    model = tf.keras.models.load_model(model_path, custom_objects=custom_objects, compile=False)
                    self.scalers.append(scaler)
                    self.models.append(model)
                    print(f"  ✓ 加载 Fold {fold}")
                except Exception as e:
                    print(f"  ✗ 加载 Fold {fold} 失败: {e}")
            else:
                print(f"  ✗ 找不到 Fold {fold} 的文件")
            
        self._loaded = True
        print(f"TensorFlow SE-1DCNN模块: 成功加载 {len(self.models)}/{self.N_FOLDS} 个模型。")

    def _standardize_with_quaternion_simple(self, features, scaler):
        """四元数安全的标准化 - 简化版本"""
        quat_features = features[:, :4]
        other_features = features[:, 4:]
        
        # 四元数保持原值，不归一化
        # 只标准化非四元数特征
        if other_features.shape[1] > 0:
            other_scaled = scaler.transform(other_features)
            features_scaled = np.concatenate([quat_features, other_scaled], axis=1)
        else:
            features_scaled = quat_features
        
        return features_scaled

    def predict_proba(self, sequence: pl.DataFrame, demographics: pl.DataFrame) -> np.ndarray:
        """返回概率分布,已对齐到统一标签顺序"""
        self.load_models()
        
        try:
            # 数据预处理
            df = sequence.to_pandas()
            df = self.feature_engineering(df)
            
            # 提取选定特征
            features = df[self.SELECTED_FEATURES].values
            
            # 修复零四元数
            features = self.fix_zero_quaternions(features)
            
            # 集成预测
            all_predictions = []
            for model, scaler in zip(self.models, self.scalers):
                # 四元数安全的标准化
                features_scaled = self._standardize_with_quaternion_simple(features, scaler)
                
                # 截断或填充到max_seq_length
                if len(features_scaled) > self.MAX_SEQ_LENGTH:
                    features_scaled = features_scaled[:self.MAX_SEQ_LENGTH]
                else:
                    pad_length = self.MAX_SEQ_LENGTH - len(features_scaled)
                    # 填充：四元数用[1,0,0,0]，其他特征用0
                    quat_pad = np.array([[1.0, 0.0, 0.0, 0.0]] * pad_length)
                    other_pad = np.zeros((pad_length, features_scaled.shape[1] - 4))
                    pad_values = np.concatenate([quat_pad, other_pad], axis=1)
                    features_scaled = np.vstack([features_scaled, pad_values])

                # 预测
                features_scaled = features_scaled[np.newaxis, ...]
                pred = model.predict(features_scaled, verbose=0)[0]
                all_predictions.append(pred)
            
            # 平均集成并转换为概率
            ensemble_logits = np.mean(all_predictions, axis=0)
            
            # 使用softmax转换为概率
            exp_logits = np.exp(ensemble_logits - np.max(ensemble_logits))
            raw_probabilities = exp_logits / np.sum(exp_logits)
            
            # 重新排列概率以匹配统一标签顺序
            unified_probabilities = np.zeros((1, len(self.UNIFIED_LABELS)))
            for model_index, unified_index in enumerate(self.model_to_unified_indices):
                unified_probabilities[0, unified_index] = raw_probabilities[model_index]
                
            return unified_probabilities
            
        except Exception as e:
            print(f"TensorFlowSE1DCNNModel.predict_proba 出错: {str(e)}")
            import traceback
            traceback.print_exc()
            # 返回均匀分布作为默认
            return np.full((1, len(self.UNIFIED_LABELS)), 1.0 / len(self.UNIFIED_LABELS))

# ==================== 创建全局模型实例 ====================
_tensorflow_se1dcnn_model = TensorFlowSE1DCNNModel()

# ==================== 导出的预测函数 ====================
def predict_tensorflow_se1dcnn(sequence: pl.DataFrame, demographics: pl.DataFrame) -> np.ndarray:
    """
    TensorFlow SE-1DCNN模型预测函数(用于集成)
    
    参数:
        sequence: 传感器序列数据
        demographics: 人口统计数据 (未使用)
    
    返回:
        np.ndarray: 形状为 (1, 18) 的概率数组,已对齐到统一标签顺序
    """
    return _tensorflow_se1dcnn_model.predict_proba(sequence, demographics)


# ==================== 本地测试代码 ====================
if __name__ == "__main__":
    print("\n" + "="*60)
    print("         TensorFlow SE-1DCNN模型本地测试")
    print("="*60 + "\n")
    
    TEST_CSV = '/kaggle/input/cmi-detect-behavior-with-sensor-data/test.csv'
    TEST_DEMOGRAPHICS = '/kaggle/input/cmi-detect-behavior-with-sensor-data/test_demographics.csv'
    
    try:
        print(f"正在加载测试数据...")
        test_sequences = pl.read_csv(TEST_CSV)
        test_demographics = pl.read_csv(TEST_DEMOGRAPHICS)
        
        sequence_ids = test_sequences.get_column("sequence_id").unique().to_list()
        print(f"找到 {len(sequence_ids)} 个测试序列")
        
        test_count = min(3, len(sequence_ids))
        print(f"\n将测试前 {test_count} 个序列:")
        print("-" * 40)
        
        for i, seq_id in enumerate(sequence_ids[:test_count]):
            print(f"\n测试序列 {i+1}/{test_count}: {seq_id}")
            print("-" * 30)
            
            sequence = test_sequences.filter(pl.col("sequence_id") == seq_id)
            print(f"序列长度: {len(sequence)} 个时间步")
            
            print("正在进行预测...")
            probabilities = predict_tensorflow_se1dcnn(sequence, test_demographics)
            
            print(f"\n预测结果:")
            print(f"  - 概率数组形状: {probabilities.shape}")
            print(f"  - 概率和: {probabilities.sum():.6f} (应该 ≈ 1.0)")
            print(f"  - 最小概率: {probabilities.min():.6f}")
            print(f"  - 最大概率: {probabilities.max():.6f}")
            
            top3_indices = np.argsort(probabilities[0])[-3:][::-1]
            
            print(f"\nTop-3 预测:")
            for rank, idx in enumerate(top3_indices, 1):
                label = _tensorflow_se1dcnn_model.UNIFIED_LABELS[idx]
                prob = probabilities[0, idx]
                print(f"  {rank}. {label:<40} {prob:.4f}")
        
        print("\n" + "="*60)
        print("✅ 测试成功完成！")
        print("模块已准备好用于集成。")
        print("="*60)
        
    except FileNotFoundError as e:
        print(f"\n❌ 错误：找不到测试文件")
        print(f"请确保以下文件存在：")
        print(f"  - {TEST_CSV}")
        print(f"  - {TEST_DEMOGRAPHICS}")
        
    except Exception as e:
        print(f"\n❌ 测试过程中出错: {e}")
        import traceback
        traceback.print_exc()
        
    print("\n")

# Submit

In [None]:
import polars as pl
import numpy as np

# 更严格的数据质量门控

def predict(sequence: pl.DataFrame, demographics):
    """
    Evaluates sensor data quality (ToF & THM) to dynamically route to the best model.
    Uses a conservative QC for THM so that multi-sensor is used only when
    the extra sensors are truly informative.
    
    路由逻辑:
    - 如果 THM 和 TOF 都不可用: 0.4*IMU-THM-TOF + 0.6*IMU-only
    - 否则 (至少有一个可用): 0.7*IMU-THM-TOF + 0.3*IMU-only
    """

    # ---------------- Parameters ----------------
    # ToF (kept as your current settings)
    TOF_FRAME_VALID_RATIO_THRESHOLD = 0.25
    TOF_MIN_GOOD_SENSORS            = 2
    TOF_SPATIAL_COV_HIGH_THRESHOLD  = 0.50
    TOF_SPATIAL_COV_LOW_THRESHOLD   = 0.30
    TOF_TIME_COVERAGE_THRESHOLD     = 0.60

    # THM (strict gates; you can tune on CV)
    THM_FRAME_VALID_RATIO_THRESHOLD = 0.60   # was 0.25
    THM_MIN_ACTIVE_RATIO            = 0.20   # %frames with |diff| > eps
    THM_ACTIVE_DELTA_EPS            = 0.05   # activity threshold for diff
    THM_MIN_GOOD_SENSORS            = 2      # at least 2 THM channels pass
    THM_MIN_STD_SUM                 = 0.0    # keep 0 as you prefer (can later change to a quantile)

    CONFIDENCE_MARGIN = 0.0  # kept for future use if you compare confidences

    # ---------------- Column Identification ----------------
    all_tof_cols = [c for c in sequence.columns if c.startswith("tof_")]
    all_thm_cols = [c for c in sequence.columns if c.startswith("thm_")]

    # ---------------- Preprocessing for QC ----------------
    # For QC, treat sentinel -1 as null to compute valid coverage correctly.
    sequence_for_qc = sequence
    if all_tof_cols:
        sequence_for_qc = sequence_for_qc.with_columns(
            [pl.col(c).replace(-1, None).alias(c) for c in all_tof_cols]
        )
    # If your THM ever uses -1 as sentinel, uncomment the following:
    # if all_thm_cols:
    #     sequence_for_qc = sequence_for_qc.with_columns(
    #         [pl.col(c).replace(-1, None).alias(c) for c in all_thm_cols]
    #     )

    # ---------------- ToF Quality Assessment ----------------
    is_tof_system_ok = False
    if all_tof_cols:
        overall_tof_frame_ratio = sequence_for_qc.select(
            pl.any_horizontal([pl.col(c).is_not_null() for c in all_tof_cols]).cast(pl.Int8).mean()
        ).item()

        if overall_tof_frame_ratio is None:
            overall_tof_frame_ratio = 0.0

        if overall_tof_frame_ratio >= TOF_FRAME_VALID_RATIO_THRESHOLD:
            # group ToF pixels by sensor 1..5
            tof_sensor_groups: dict[int, list[str]] = {}
            for sensor_id in range(1, 6):
                cols_for_sensor = [c for c in all_tof_cols if c.startswith(f"tof_{sensor_id}_")]
                if cols_for_sensor:
                    tof_sensor_groups[sensor_id] = cols_for_sensor

            num_good_tof_sensors = 0
            for sensor_id, sensor_cols in tof_sensor_groups.items():
                n_pix = len(sensor_cols)
                spatial_cov_series = sequence_for_qc.select(
                    (pl.sum_horizontal([pl.col(c).is_not_null() for c in sensor_cols]) / n_pix).alias("cov")
                )["cov"]

                avg_spatial_coverage = float(spatial_cov_series.mean())
                time_coverage = float(
                    (spatial_cov_series >= TOF_SPATIAL_COV_LOW_THRESHOLD).cast(pl.Int8).mean()
                )

                is_high_quality = avg_spatial_coverage >= TOF_SPATIAL_COV_HIGH_THRESHOLD
                is_medium_stable = (
                    (avg_spatial_coverage >= TOF_SPATIAL_COV_LOW_THRESHOLD) and
                    (time_coverage >= TOF_TIME_COVERAGE_THRESHOLD)
                )

                if is_high_quality or is_medium_stable:
                    num_good_tof_sensors += 1

            is_tof_system_ok = (num_good_tof_sensors >= TOF_MIN_GOOD_SENSORS)

    # ---------------- THM Quality Assessment (strict gates) ----------------
    is_thm_system_ok = False
    if all_thm_cols:
        # Gate 1) Frame coverage across all THM columns
        thm_frame_ratio = sequence_for_qc.select(
            pl.any_horizontal([pl.col(c).is_not_null() for c in all_thm_cols]).cast(pl.Int8).mean()
        ).item()
        if thm_frame_ratio is None:
            thm_frame_ratio = 0.0

        if thm_frame_ratio >= THM_FRAME_VALID_RATIO_THRESHOLD:
            good_thm_flags = []
            thm_std_list = []

            for c in all_thm_cols:
                # per-channel frame coverage
                valid_ratio = sequence_for_qc.select(
                    pl.col(c).is_not_null().cast(pl.Int8).mean()
                ).item() or 0.0

                # per-channel activity ratio: fraction of frames with |diff| > eps
                active_ratio = sequence_for_qc.select(
                    (pl.col(c).diff().abs() > THM_ACTIVE_DELTA_EPS).cast(pl.Int8).mean()
                ).item() or 0.0

                # std for std_sum (not a hard gate since you chose to keep 0.0)
                ch_std = sequence_for_qc.select(pl.col(c).std()).item()
                ch_std = 0.0 if (ch_std is None or np.isnan(ch_std)) else float(ch_std)
                thm_std_list.append(ch_std)

                good_thm_flags.append(
                    (valid_ratio >= THM_FRAME_VALID_RATIO_THRESHOLD) and
                    (active_ratio >= THM_MIN_ACTIVE_RATIO)
                )

            num_good_thm = sum(1 for f in good_thm_flags if f)
            thm_std_sum  = float(np.sum(thm_std_list))

            # Gate 2 & 3: need >= 2 good channels AND (optionally) std_sum >= threshold (kept 0.0)
            is_thm_system_ok = (num_good_thm >= THM_MIN_GOOD_SENSORS) and (thm_std_sum >= THM_MIN_STD_SUM)

    # ---------------- 新的路由决策逻辑 ----------------
    # 判断是否至少有一个传感器系统可用
    at_least_one_sensor_ok = is_tof_system_ok or is_thm_system_ok
    
    # ---------------- 推理部分 ----------------
    # 获取两类模型的预测
    # Multi-sensor models (IMU-THM-TOF)
    probabilities_model1 = predict2(sequence, demographics)[0]
    probabilities_model2 = predict4(sequence, demographics)[0]
    probabilities_model3 = predict_pf_allsensor(sequence, demographics)[0]
    multi_sensor_probabilities = 0.4 * probabilities_model1 + 0.35 * probabilities_model2 + 0.25 * probabilities_model3
    
    # IMU-only models
    probabilities_imu1 = predict_tensorflow_se1dcnn(sequence, demographics)[0]
    probabilities_imu2 = predict3(sequence, demographics)[0]
    imu_only_probabilities = 0.6 * probabilities_imu1 + 0.4 * probabilities_imu2
    
    # 根据传感器可用性决定最终权重
    if at_least_one_sensor_ok:
        # 至少有一个传感器可用: 0.7*IMU-THM-TOF + 0.3*IMU-only
        final_probabilities = 0.7 * multi_sensor_probabilities + 0.3 * imu_only_probabilities
    else:
        # THM和TOF都不可用: 0.4*IMU-THM-TOF + 0.6*IMU-only
        final_probabilities = 0.4 * multi_sensor_probabilities + 0.6 * imu_only_probabilities

    predicted_index = int(np.argmax(final_probabilities))
    return dataset.le.classes_[predicted_index]

In [None]:

#imumonly model
    #1 secmm bert imuonly lb  = 0.810
    #2 secnn + attention imuonly lb = 0.813
    #1 and 2 average esamble = 0.814

#all sensor model
    #secnn bert allsensor lb = 0.841
    #secnn gated-gru-hybrid-ensemble esamble lb = 0.835
    #secnn + bigru,bilstm lb = 0.821
    #secnn+(bilstm,bigru),secnn+bert,secnn+
    

#esamble type
    #allsensor weight esamble lb = 0.845
    #imuonly esamble + allsensor esamble data quality Switching Model lb = 0.853
    #imuonly esamble + allsensor esamble data quality switching cross esamble lb = 0.855(now choice)

#performance improvement 
    #directional model integration
    #post-processing



In [None]:
import kaggle_evaluation.cmi_inference_server
inference_server = kaggle_evaluation.cmi_inference_server.CMIInferenceServer(predict)

if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    inference_server.serve()
else:
    inference_server.run_local_gateway(
        data_paths=(
            '/kaggle/input/cmi-detect-behavior-with-sensor-data/test.csv',
            '/kaggle/input/cmi-detect-behavior-with-sensor-data/test_demographics.csv',
        )
    )

if not os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    print(pd.read_parquet("submission.parquet"))