In [None]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, models, callbacks
from scipy import signal
from scipy.interpolate import interp1d
from sklearn.model_selection import GroupKFold
from sklearn.metrics import accuracy_score
import mne

# Constants
BASE_PATH = '/kaggle/input/mtcaic3'
EEG_CHANNELS = ['C3', 'C4', 'CZ']
SAMPLING_RATE = 250
MI_TRIAL_LENGTH = 9  # seconds
MI_SAMPLES = 2250  # 9s * 250Hz
AMP_THRESHOLD = 500000  # nV
MAX_CONSECUTIVE_ARTIFACTS = 5
LOCAL_WINDOW_SIZE = 100  # samples
TARGET_CHANNELS = ['C3', 'C4', 'CZ']
NUM_CLASSES = 2  # For MI task (Left/Right)

# Load dataset indexes
train_df = pd.read_csv(os.path.join(BASE_PATH, 'train.csv'))
validation_df = pd.read_csv(os.path.join(BASE_PATH, 'validation.csv'))
test_df = pd.read_csv(os.path.join(BASE_PATH, 'test.csv'))

# Filter for MI tasks
mi_train_df = train_df[train_df['task'] == 'MI'].copy()
mi_validation_df = validation_df[validation_df['task'] == 'MI'].copy()
mi_test_df = test_df[test_df['task'] == 'MI'].copy()

# Subject-wise normalization statistics
subject_stats = {}
print("Calculating subject-wise normalization parameters...")
for subject_id in mi_train_df['subject_id'].unique():
    subject_data = []
    subject_df = mi_train_df[mi_train_df['subject_id'] == subject_id]
    
    for _, row in subject_df.iterrows():
        eeg_path = os.path.join(BASE_PATH, row['task'], 'train', 
                               row['subject_id'], str(row['trial_session']), 'EEGdata.csv')
        if not os.path.exists(eeg_path):
            continue
        eeg_data = pd.read_csv(eeg_path)
        
        # Extract all trials for the subject
        for trial in range(1, 11):
            start_idx = (trial - 1) * MI_SAMPLES
            end_idx = trial * MI_SAMPLES
            if end_idx > len(eeg_data):
                continue
            trial_data = eeg_data[TARGET_CHANNELS].iloc[start_idx:end_idx].values
            subject_data.append(trial_data)
    
    if subject_data:
        subject_data = np.vstack(subject_data)
        means = np.nanmean(subject_data, axis=0)
        stds = np.nanstd(subject_data, axis=0)
        stds[stds < 1e-6] = 1.0  # Avoid division by zero
        subject_stats[subject_id] = {'mean': means, 'std': stds}

# Calculate global stats for subjects not in training
all_train_data = []
for _, row in mi_train_df.iterrows():
    eeg_path = os.path.join(BASE_PATH, row['task'], 'train', 
                           row['subject_id'], str(row['trial_session']), 'EEGdata.csv')
    if not os.path.exists(eeg_path):
        continue
    eeg_data = pd.read_csv(eeg_path)
    start_idx = (row['trial'] - 1) * MI_SAMPLES
    end_idx = row['trial'] * MI_SAMPLES
    if end_idx > len(eeg_data):
        continue
    trial_data = eeg_data[TARGET_CHANNELS].iloc[start_idx:end_idx].values
    all_train_data.append(trial_data)

if all_train_data:
    all_train_data = np.vstack(all_train_data)
    global_mean = np.nanmean(all_train_data, axis=0)
    global_std = np.nanstd(all_train_data, axis=0)
    global_std[global_std < 1e-6] = 1.0
else:
    global_mean = np.zeros(len(TARGET_CHANNELS))
    global_std = np.ones(len(TARGET_CHANNELS))

# Preprocessing functions
def load_trial_data(row):
    """Load EEG data for a specific trial"""
    dataset = 'train' if row['id'] <= 4800 else 'validation' if row['id'] <= 4900 else 'test'
    eeg_path = os.path.join(BASE_PATH, row['task'], dataset, 
                           row['subject_id'], str(row['trial_session']), 'EEGdata.csv')
    if not os.path.exists(eeg_path):
        return pd.DataFrame()
    eeg_data = pd.read_csv(eeg_path)
    
    start_idx = (row['trial'] - 1) * MI_SAMPLES
    end_idx = row['trial'] * MI_SAMPLES
    if end_idx > len(eeg_data):
        return pd.DataFrame()
    return eeg_data.iloc[start_idx:end_idx]

def interpolate_invalid(data):
    """Improved interpolation with artifact prevention"""
    if 'Validation' not in data.columns or data.empty:
        return data
    
    valid_mask = data['Validation'] == 1
    invalid_mask = ~valid_mask
    
    if invalid_mask.sum() == 0 or valid_mask.sum() < 2:
        return data
    
    # Split into segments of consecutive invalid samples
    padded_mask = np.concatenate(([False], invalid_mask, [False]))
    diff_mask = np.diff(padded_mask.astype(int))
    
    # Find start and end indices of invalid segments
    starts = np.where(diff_mask == 1)[0]
    ends = np.where(diff_mask == -1)[0]
    
    # Handle case where number of starts and ends might not match
    n_pairs = min(len(starts), len(ends))
    invalid_groups = np.column_stack((starts[:n_pairs], ends[:n_pairs]))
    
    for channel in TARGET_CHANNELS:
        channel_data = data[channel].values.copy()
        valid_values = channel_data[valid_mask]
        valid_indices = np.where(valid_mask)[0]
        
        if len(valid_values) < 2:
            # Fallback if not enough valid points
            median_val = np.nanmedian(valid_values) if valid_values.size > 0 else 0
            channel_data[invalid_mask] = median_val
            data[channel] = channel_data
            continue
        
        # Create interpolation function
        interp_func = interp1d(
            valid_indices, 
            valid_values, 
            kind='linear',
            bounds_error=False,
            fill_value="extrapolate"
        )
        
        # Process each invalid segment separately
        for start, end in invalid_groups:
            gap_size = end - start
            gap_indices = np.arange(start, end)
            
            # Only interpolate small gaps
            if gap_size <= MAX_CONSECUTIVE_ARTIFACTS:
                # Check if gap indices are within valid range
                if start < 0 or end > len(channel_data):
                    continue
                channel_data[gap_indices] = interp_func(gap_indices)
            else:
                # For large gaps, use nearest valid value
                if start == 0:
                    fill_value = valid_values[0] if valid_values.size > 0 else 0
                elif end >= len(channel_data):
                    fill_value = valid_values[-1] if valid_values.size > 0 else 0
                else:
                    prev_idx = valid_indices[valid_indices < start]
                    next_idx = valid_indices[valid_indices >= end]
                    prev_val = channel_data[prev_idx[-1]] if len(prev_idx) > 0 else 0
                    next_val = channel_data[next_idx[0]] if len(next_idx) > 0 else 0
                    fill_value = (prev_val + next_val) / 2
                channel_data[gap_indices] = fill_value
        
        data[channel] = channel_data
    
    return data

def remove_high_amplitude_artifacts(data, amp_threshold=AMP_THRESHOLD):
    """Remove high-amplitude artifacts while preserving variance"""
    data = data.copy()
    for channel in TARGET_CHANNELS:
        channel_data = data[channel].values
        
        # Step 1: Clip extreme values to threshold
        clipped_data = np.clip(channel_data, -amp_threshold, amp_threshold)
        
        # Step 2: Identify remaining artifacts (clipped values)
        artifact_mask = (clipped_data == amp_threshold) | (clipped_data == -amp_threshold)
        
        # Step 3: Process artifact clusters
        if np.any(artifact_mask):
            # Find consecutive artifact segments
            padded_artifact = np.concatenate(([False], artifact_mask, [False]))
            diff_mask = np.diff(padded_artifact.astype(int))
            starts = np.where(diff_mask == 1)[0]
            ends = np.where(diff_mask == -1)[0]
            
            n_pairs = min(len(starts), len(ends))
            artifact_segments = np.column_stack((starts[:n_pairs], ends[:n_pairs]))
            
            # Process each artifact segment
            for start_idx, end_idx in artifact_segments:
                segment_length = end_idx - start_idx
                
                if segment_length <= MAX_CONSECUTIVE_ARTIFACTS:
                    # Small cluster: Replace with local median
                    local_start = max(0, start_idx - LOCAL_WINDOW_SIZE)
                    local_end = min(len(clipped_data), end_idx + LOCAL_WINDOW_SIZE)
                    
                    # Create local mask excluding current artifacts
                    local_mask = np.ones(local_end - local_start, dtype=bool)
                    local_start_idx = start_idx - local_start
                    local_end_idx = end_idx - local_start
                    if 0 <= local_start_idx < local_mask.size and 0 <= local_end_idx < local_mask.size:
                        local_mask[local_start_idx:local_end_idx] = False
                    
                    # Get clean local samples
                    clean_local = clipped_data[local_start:local_end][local_mask]
                    
                    if len(clean_local) > 0:
                        median_val = np.nanmedian(clean_local)
                        clipped_data[start_idx:end_idx] = median_val
                    else:
                        median_val = np.nanmedian(clipped_data[~artifact_mask])
                        clipped_data[start_idx:end_idx] = median_val
                else:
                    # Large cluster: Use linear interpolation between clean boundaries
                    prev_idx = start_idx - 1
                    next_idx = end_idx
                    
                    while prev_idx >= 0 and artifact_mask[prev_idx]:
                        prev_idx -= 1
                    while next_idx < len(clipped_data) and artifact_mask[next_idx]:
                        next_idx += 1
                    
                    # Get boundary values
                    if prev_idx >= 0 and next_idx < len(clipped_data):
                        prev_val = clipped_data[prev_idx]
                        next_val = clipped_data[next_idx]
                        interp_vals = np.linspace(prev_val, next_val, end_idx - start_idx + 2)[1:-1]
                        clipped_data[start_idx:end_idx] = interp_vals
                    else:
                        median_val = np.nanmedian(clipped_data[~artifact_mask])
                        clipped_data[start_idx:end_idx] = median_val
        
        data[channel] = clipped_data
    
    return data

def apply_filters(data):
    """Apply bandpass and notch filters to EEG data"""
    if data.empty:
        return data
        
    # Design bandpass filter (8-30 Hz for MI)
    nyquist = 0.5 * SAMPLING_RATE
    low = 8.0 / nyquist
    high = 30.0 / nyquist
    b, a = signal.butter(4, [low, high], btype='band')
    
    # Design notch filter (50 Hz)
    notch_freq = 50.0
    Q = 30.0
    b_notch, a_notch = signal.iirnotch(notch_freq, Q, SAMPLING_RATE)
    
    filtered_data = data.copy()
    for channel in TARGET_CHANNELS:
        # Apply bandpass
        channel_data = data[channel].values
        filtered = signal.filtfilt(b, a, channel_data)
        
        # Apply notch
        filtered = signal.filtfilt(b_notch, a_notch, filtered)
        
        filtered_data[channel] = filtered
    
    return filtered_data

def winsorize(data, lower=0.01, upper=0.99):
    """Apply winsorization to limit extreme values"""
    if data.empty:
        return data
        
    winsorized = data.copy()
    for channel in TARGET_CHANNELS:
        q_low = data[channel].quantile(lower)
        q_high = data[channel].quantile(upper)
        winsorized[channel] = data[channel].clip(lower=q_low, upper=q_high)
    return winsorized

def normalize_trial(data, subject_id):
    """Apply subject-wise normalization"""
    if data.empty:
        return data
        
    normalized = data.copy()
    
    if subject_id in subject_stats:
        stats = subject_stats[subject_id]
    else:
        stats = {'mean': global_mean, 'std': global_std}
    
    for i, channel in enumerate(TARGET_CHANNELS):
        mean_val = stats['mean'][i] if i < len(stats['mean']) else global_mean[i]
        std_val = stats['std'][i] if i < len(stats['std']) else global_std[i]
        normalized[channel] = (data[channel] - mean_val) / std_val
    
    return normalized

def preprocess_trial(row, augment=False):
    """Full preprocessing pipeline for a single trial"""
    # Load raw data
    data = load_trial_data(row)
    if data.empty:
        return np.zeros((MI_SAMPLES, len(TARGET_CHANNELS)))
    
    # Handle invalid samples
    data = interpolate_invalid(data)
    
    # Remove artifacts
    data = remove_high_amplitude_artifacts(data)
    
    # Handle NaNs and Infs
    for channel in TARGET_CHANNELS:
        data[channel] = data[channel].fillna(0)
        data[channel] = data[channel].replace([np.inf, -np.inf], 0)
    
    # Apply filters
    data = apply_filters(data)
    
    # Winsorize
    data = winsorize(data)
    
    # Normalize
    data = normalize_trial(data, row['subject_id'])
    
    # Extract only target channels
    eeg_data = data[TARGET_CHANNELS].values
    
    # Data augmentation
    if augment:
        # Time warping
        warp_factor = np.random.uniform(0.8, 1.2)
        warped_data = []
        for i in range(len(TARGET_CHANNELS)):
            warped = signal.resample(eeg_data[:, i], int(len(eeg_data) * warp_factor))
            if len(warped) > len(eeg_data):
                warped = warped[:len(eeg_data)]
            else:
                warped = np.pad(warped, (0, len(eeg_data) - len(warped)), 'edge')
            warped_data.append(warped)
        eeg_data = np.array(warped_data).T
        
        # Add Gaussian noise
        noise = np.random.normal(0, 0.1, eeg_data.shape)
        eeg_data += noise
    
    return eeg_data

# Model Architectures
def create_eegnet(input_shape, num_classes):
    """EEGNet with Domain Adaptation"""
    inputs = layers.Input(shape=input_shape)
    
    # Reshape to 4D for Conv2D: (batch, channels, time, 1)
    x = layers.Reshape((input_shape[1], input_shape[0], 1))(inputs)
    
    # Channel-wise normalization
    x = layers.Lambda(lambda x: (x - tf.reduce_mean(x, axis=[1, 2, 3], keepdims=True)) / 
                     tf.math.reduce_std(x, axis=[1, 2, 3], keepdims=True))(x)
    
    # Temporal convolution
    x = layers.Conv2D(8, (1, 64), use_bias=False, padding='same')(x)
    x = layers.BatchNormalization()(x)
    
    # Spatial convolution - changed to DepthwiseConv2D with padding
    x = layers.DepthwiseConv2D((len(TARGET_CHANNELS), 1), depth_multiplier=2, 
                              padding='valid',
                              depthwise_constraint=tf.keras.constraints.MaxNorm(1.))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('elu')(x)
    x = layers.AveragePooling2D((1, 4))(x)
    x = layers.Dropout(0.5)(x)
    
    # Domain adaptation through channel-wise statistics alignment
    mean = layers.GlobalAveragePooling1D()(layers.Reshape((-1, 8))(x))
    std = layers.Lambda(lambda x: tf.math.reduce_std(x, axis=[1, 2]))(x)
    domain_features = layers.Concatenate()([mean, std])
    
    # Additional convolution
    x = layers.SeparableConv2D(16, (1, 16), activation='elu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.AveragePooling2D((1, 8))(x)
    x = layers.Dropout(0.5)(x)
    
    # Combine with domain features
    x = layers.Flatten()(x)
    x = layers.Concatenate()([x, domain_features])
    
    # Classifier
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    return models.Model(inputs, outputs)

def create_ccnn_rnn(input_shape, num_classes):
    """CCNN-RNN with Cross-Channel Attention"""
    inputs = layers.Input(shape=input_shape)
    
    # Channel-wise normalization
    x = layers.BatchNormalization()(inputs)
    
    # Parallel temporal processing per channel
    channel_outputs = []
    for i in range(input_shape[1]):
        # Extract single channel
        ch = layers.Lambda(lambda x: x[:, :, i:i+1])(x)
        
        # Temporal convolution
        ch = layers.Conv1D(8, 64, activation='elu')(ch)
        ch = layers.MaxPooling1D(4)(ch)
        channel_outputs.append(ch)
    
    # Cross-channel attention
    merged = layers.Concatenate(axis=1)(channel_outputs)
    attention = layers.Conv1D(1, 1, activation='sigmoid')(merged)
    attention = layers.Multiply()([merged, attention])
    
    # Temporal modeling
    x = layers.Bidirectional(layers.LSTM(32, return_sequences=True))(attention)
    x = layers.TimeDistributed(layers.Dense(16))(x)
    
    # Output
    x = layers.GlobalAveragePooling1D()(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    return models.Model(inputs, outputs)

def create_csp_tsnet(input_shape, num_classes):
    """CSP-TSConvNet Hybrid Model"""
    inputs = layers.Input(shape=input_shape)
    
    # CSP-like spatial filtering
    x = layers.Conv1D(6, 1, activation='linear', use_bias=False)(inputs)
    x = layers.BatchNormalization()(x)
    
    # Temporal separable convolution
    x = layers.Conv1D(12, 128, groups=3, activation='swish')(x)
    x = layers.BatchNormalization()(x)
    x = layers.SeparableConv1D(24, 32, activation='swish')(x)
    x = layers.MaxPooling1D(8)(x)
    
    # Multi-scale features
    branch1 = layers.Conv1D(16, 8, activation='swish')(x)
    branch2 = layers.Conv1D(16, 16, activation='swish')(x)
    branch3 = layers.Conv1D(16, 32, activation='swish')(x)
    merged = layers.Concatenate()([branch1, branch2, branch3])
    
    # Classifier
    x = layers.GRU(32)(merged)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    return models.Model(inputs, outputs)

# Data Generator
class EEGDataGenerator(tf.keras.utils.Sequence):
    def __init__(self, df, batch_size=32, shuffle=True, augment=False):
        self.df = df
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.augment = augment
        self.on_epoch_end()
        self.label_map = {'Left': 0, 'Right': 1}
        
    def __len__(self):
        return int(np.ceil(len(self.df) / self.batch_size))
    
    def __getitem__(self, index):
        batch_indices = self.indices[index*self.batch_size:(index+1)*self.batch_size]
        batch_df = self.df.iloc[batch_indices]
        
        X = []
        y = []
        
        for _, row in batch_df.iterrows():
            # Preprocess trial with optional augmentation
            eeg_data = preprocess_trial(row, augment=self.augment)
            X.append(eeg_data)
            
            # Get label if available
            if 'label' in row:
                y.append(self.label_map[row['label']])
        
        X = np.array(X)
        y = np.array(y)
        
        return X, y
    
    def on_epoch_end(self):
        self.indices = np.arange(len(self.df))
        if self.shuffle:
            np.random.shuffle(self.indices)

# Subject-wise windowing
def create_subject_windows(df, window_size=2, overlap=0.5):
    """Create windowed segments from trials with subject-wise grouping"""
    new_rows = []
    samples_per_window = int(window_size * SAMPLING_RATE)
    step = int(samples_per_window * (1 - overlap))
    
    for _, row in df.iterrows():
        data = preprocess_trial(row)
        if len(data) < MI_SAMPLES:
            continue
        num_windows = (MI_SAMPLES - samples_per_window) // step + 1
        
        for i in range(num_windows):
            start = i * step
            end = start + samples_per_window
            window_data = data[start:end]
            
            new_row = row.copy()
            new_row['window_data'] = window_data
            new_row['window_start'] = start
            new_row['window_end'] = end
            new_rows.append(new_row)
    
    return pd.DataFrame(new_rows)

# Main Training Pipeline
def main():
    # Create windowed datasets
    print("Creating subject-wise windowed datasets...")
    windowed_train = create_subject_windows(mi_train_df)
    windowed_val = create_subject_windows(mi_validation_df)
    
    # Prepare data generators
    train_gen = EEGDataGenerator(windowed_train, batch_size=32, augment=True)
    val_gen = EEGDataGenerator(windowed_val, batch_size=32, augment=False)
    
    # Input shape: (time_steps, channels)
    input_shape = (int(2 * SAMPLING_RATE), len(TARGET_CHANNELS))
    
    # Create and compile models
    models_list = [
        ('EEGNet', create_eegnet(input_shape, NUM_CLASSES)),
        ('CCNN-RNN', create_ccnn_rnn(input_shape, NUM_CLASSES)),
        ('CSP-TSConvNet', create_csp_tsnet(input_shape, NUM_CLASSES))
    ]
    
    for name, model in models_list:
        print(f"\nTraining {name} model...")
        model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )
        
        # Callbacks
        callbacks_list = [
            callbacks.EarlyStopping(patience=10, restore_best_weights=True),
            callbacks.ReduceLROnPlateau(factor=0.5, patience=5),
            callbacks.ModelCheckpoint(f'best_{name}.h5', save_best_only=True)
        ]
        
        # Train with subject-grouped KFold validation
        group_kfold = GroupKFold(n_splits=5)
        groups = windowed_train['subject_id'].values
        
        for fold, (train_idx, val_idx) in enumerate(group_kfold.split(windowed_train, groups=groups)):
            print(f"\nTraining Fold {fold+1}/5")
            fold_train = windowed_train.iloc[train_idx]
            fold_val = windowed_train.iloc[val_idx]
            
            fold_train_gen = EEGDataGenerator(fold_train, batch_size=32, augment=True)
            fold_val_gen = EEGDataGenerator(fold_val, batch_size=32, augment=False)
            
            model.fit(
                fold_train_gen,
                validation_data=fold_val_gen,
                epochs=50,
                callbacks=callbacks_list,
                verbose=2
            )
        
        # Final evaluation on validation set
        val_loss, val_acc = model.evaluate(val_gen)
        print(f"{name} Validation Accuracy: {val_acc:.4f}")
    
    # Ensemble predictions for test set
    print("\nGenerating test predictions...")
    test_data = []
    for _, row in mi_test_df.iterrows():
        data = preprocess_trial(row)
        test_data.append(data)
    
    test_data = np.array(test_data)
    
    ensemble_preds = np.zeros((len(test_data), NUM_CLASSES))
    for name, model in models_list:
        model.load_weights(f'best_{name}.h5')
        preds = model.predict(test_data)
        ensemble_preds += preds
    
    final_preds = np.argmax(ensemble_preds, axis=1)
    pred_labels = ['Left' if p == 0 else 'Right' for p in final_preds]
    
    # Create submission file
    submission = mi_test_df[['id']].copy()
    submission['label'] = pred_labels
    submission.to_csv('submission.csv', index=False)
    print("Submission file created!")

if __name__ == "__main__":
    main()

Calculating subject-wise normalization parameters...
Creating subject-wise windowed datasets...


  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, ou