In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import signal
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, f1_score
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (Input, Conv2D, BatchNormalization, Activation, 
                                    DepthwiseConv2D, AveragePooling2D, Dropout, 
                                    SeparableConv2D, Flatten, Dense, SpatialDropout2D)
from tensorflow.keras.constraints import max_norm
from tensorflow.keras.callbacks import CSVLogger, ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.utils import to_categorical
from tqdm import tqdm
import gc

# =============================================
# 1. Configuration and Data Loading
# =============================================
# Set paths and parameters
BASE_PATH = '/kaggle/input/mtcaic3'
TRAIN_CSV = os.path.join(BASE_PATH, 'train.csv')
EEG_CHANNELS = ['FZ', 'C3', 'CZ', 'C4', 'PZ', 'PO7', 'OZ', 'PO8']
SAMPLING_RATE = 250  # Hz
MI_DURATION = 9  # seconds
MI_SAMPLES = SAMPLING_RATE * MI_DURATION
CROPPED_SAMPLES = 2240  # Divisible by 32 for EEGNet

# Load training data and filter for MI
train_df = pd.read_csv(TRAIN_CSV)
mi_df = train_df[train_df['task'] == 'MI'].copy()
mi_df['label'] = mi_df['label'].map({'Left': 0, 'Right': 1})

# =============================================
# 2. Improved Preprocessing Functions
# =============================================
def butter_bandpass(lowcut, highcut, fs, order=4):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = signal.butter(order, [low, high], btype='band')
    return b, a

def butter_bandpass_filter(data, lowcut, highcut, fs, order=4):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = signal.filtfilt(b, a, data, axis=0)
    return y

def notch_filter(data, f0, fs, Q=30):
    w0 = f0 / (fs/2)
    b, a = signal.iirnotch(w0, Q)
    y = signal.filtfilt(b, a, data, axis=0)
    return y

def load_mi_trial(subject_id, session, trial_num):
    eeg_path = f"{BASE_PATH}/MI/train/{subject_id}/{session}/EEGdata.csv"
    full_data = pd.read_csv(eeg_path)
    start_idx = (trial_num - 1) * MI_SAMPLES
    end_idx = start_idx + MI_SAMPLES
    return full_data.iloc[start_idx:end_idx].copy()

# =============================================
# 3. Enhanced EEGNet Model Architecture
# =============================================
def EEGNet_v2(nb_classes, Chans=8, Samples=2240, 
              dropoutRate=0.5, kernLength=128, F1=16, 
              D=2, F2=32):
    
    input_shape = (Chans, Samples, 1)
    input1 = Input(shape=input_shape)

    # Block 1 - Temporal convolution with larger kernel
    block1 = Conv2D(F1, (1, kernLength), padding='same',
                   use_bias=False)(input1)
    block1 = BatchNormalization()(block1)
    block1 = Activation('elu')(block1)
    
    # Spatial filtering with spatial dropout
    block1 = DepthwiseConv2D((Chans, 1), depth_multiplier=D,
                            depthwise_constraint=max_norm(1.),
                            use_bias=False)(block1)
    block1 = BatchNormalization()(block1)
    block1 = Activation('elu')(block1)
    block1 = AveragePooling2D((1, 4))(block1)
    block1 = SpatialDropout2D(dropoutRate)(block1)  # Better for spatial data

    # Block 2 - Temporal feature extraction
    block2 = SeparableConv2D(F2, (1, 32),  # Increased kernel size
                            padding='same', use_bias=False)(block1)
    block2 = BatchNormalization()(block2)
    block2 = Activation('elu')(block2)
    block2 = AveragePooling2D((1, 8))(block2)
    block2 = Dropout(dropoutRate)(block2)
    
    # Additional convolutional block
    block3 = SeparableConv2D(F2*2, (1, 16), 
                            padding='same', use_bias=False)(block2)
    block3 = BatchNormalization()(block3)
    block3 = Activation('elu')(block3)
    block3 = AveragePooling2D((1, 4))(block3)
    block3 = Dropout(dropoutRate)(block3)

    # Output
    flatten = Flatten(name='flatten')(block3)
    dense = Dense(32, activation='elu', name='dense1')(flatten)  # Additional dense layer
    dense = Dropout(0.3)(dense)
    output = Dense(nb_classes, name='output', 
                 kernel_constraint=max_norm(0.25))(dense)
    softmax = Activation('softmax', name='softmax')(output)
    
    return Model(inputs=input1, outputs=softmax)

# =============================================
# 4. Data Preparation Pipeline with improved filtering
# =============================================
# Preload and preprocess trials
X = []
y = []

print("Loading and preprocessing trials...")
for _, row in tqdm(mi_df.iterrows(), total=len(mi_df)):
    try:
        data = load_mi_trial(row['subject_id'], row['trial_session'], row['trial'])
        data = data[EEG_CHANNELS][:CROPPED_SAMPLES]  # Crop to 2240 samples
        
        # Apply preprocessing per channel
        processed_data = []
        for channel in EEG_CHANNELS:
            # Denoising pipeline with improved filtering
            ch_data = data[channel].values
            ch_data = notch_filter(ch_data, f0=50.0, fs=SAMPLING_RATE)
            ch_data = notch_filter(ch_data, f0=60.0, fs=SAMPLING_RATE)  # Additional notch
            ch_data = butter_bandpass_filter(ch_data, lowcut=8.0, highcut=30.0,  # Focus on mu/beta bands
                                            fs=SAMPLING_RATE, order=6)  # Higher order filter
            
            # Robust scaling with outlier clipping
            median = np.median(ch_data)
            iqr = np.percentile(ch_data, 75) - np.percentile(ch_data, 25)
            ch_data = (ch_data - median) / iqr
            # Clip outliers to ±5 IQRs
            ch_data = np.clip(ch_data, -5, 5)
            processed_data.append(ch_data)
        
        X.append(np.array(processed_data))  # Shape: (8, 2240)
        y.append(row['label'])
    except Exception as e:
        print(f"Error processing trial {row['trial']}: {str(e)}")

X = np.array(X, dtype=np.float32)  # Convert to float32
y = np.array(y)

# Reshape for EEGNet: (n_trials, channels, time, 1)
X = X[..., np.newaxis]

# Convert labels to one-hot encoding
y = to_categorical(y, num_classes=2)

# Train-validation split
X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

print(f"Data shapes - X_train: {X_train.shape}, X_val: {X_val.shape}")
print(f"Data types - X_train: {X_train.dtype}, X_val: {X_val.dtype}")

# Clean memory
del X, y
gc.collect()

# =============================================
# 5. Fixed Data Augmentation (without time warping)
# =============================================
def augment_trial(trial, label):
    """Apply random augmentations to EEG trial using TF ops"""
    # Gaussian noise - 40% probability
    if tf.random.uniform(()) > 0.6:
        noise = tf.random.normal(tf.shape(trial), mean=0.0, stddev=0.03, dtype=tf.float32)
        trial = trial + noise
    
    # Channel dropout - 30% probability to apply
    if tf.random.uniform(()) > 0.7:
        # Create a random mask for channels (corrected shape)
        # Shape should be (num_channels, 1, 1) for broadcasting
        channel_mask = tf.random.uniform((tf.shape(trial)[0], 1, 1), dtype=tf.float32) > 0.15
        trial = trial * tf.cast(channel_mask, tf.float32)
    
    # Random scaling - 30% probability
    if tf.random.uniform(()) > 0.7:
        scale = tf.random.normal([], mean=1.0, stddev=0.15, dtype=tf.float32)
        trial = trial * scale
        
    # Ensure consistent shape
    trial = tf.ensure_shape(trial, (8, 2240, 1))
        
    return trial, label

# Create TensorFlow Dataset with augmentation
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=2048)  # Larger buffer
train_dataset = train_dataset.map(
    lambda x, y: augment_trial(x, y),
    num_parallel_calls=tf.data.AUTOTUNE
)
train_dataset = train_dataset.batch(32).prefetch(tf.data.AUTOTUNE)  # Original batch size

# Validation dataset
val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val))
val_dataset = val_dataset.batch(32).prefetch(tf.data.AUTOTUNE)

# =============================================
# 6. Model Training with Enhanced Setup
# =============================================
# Build enhanced EEGNet model
model = EEGNet_v2(nb_classes=2, Chans=8, Samples=CROPPED_SAMPLES,
                  dropoutRate=0.5, kernLength=128, F1=16, D=2, F2=32)

# Use class weights to address imbalance in predictions
class_weights = {0: 2.5, 1: 0.8}  # Penalize misclassification of Left more

model.compile(loss='categorical_crossentropy',
              optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005),  # Lower learning rate
              metrics=['accuracy'])

model.summary()

# Enhanced callbacks
callbacks = [
    CSVLogger('training_log.csv', append=False),
    ModelCheckpoint('best_model.keras', save_best_only=True,  # Fixed saving format
                   monitor='val_loss', mode='min'),
    EarlyStopping(monitor='val_loss', patience=20, 
                  restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, 
                      patience=8, min_lr=1e-6, verbose=1)
]

# Custom callback for F1 score
class F1Callback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        val_pred = self.model.predict(X_val, verbose=0, batch_size=32)
        val_pred = np.argmax(val_pred, axis=1)
        val_true = np.argmax(y_val, axis=1)
        f1 = f1_score(val_true, val_pred, average='macro')
        logs['val_f1'] = f1
        print(f" - val_f1: {f1:.4f}")

callbacks.append(F1Callback())

# Train model
print("Starting training...")
history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=100,
    callbacks=callbacks,
    verbose=1,
    class_weight=class_weights  # Apply class weights
)

# =============================================
# 7. Enhanced Evaluation and Visualization
# =============================================
# Load best model
model = tf.keras.models.load_model('best_model.keras')

# Generate predictions
y_pred = model.predict(X_val, verbose=0, batch_size=32)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true = np.argmax(y_val, axis=1)

# Classification report
class_names = ['Left', 'Right']
print("\nClassification Report:")
print(classification_report(y_true, y_pred_classes, target_names=class_names, digits=4))

# Enhanced Confusion Matrix
cm = confusion_matrix(y_true, y_pred_classes)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names,
            cbar=True, annot_kws={"size": 16})

# Add percentages
total = np.sum(cm)
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        plt.text(j+0.5, i+0.5, f"{cm[i, j]}\n({cm[i, j]/total:.1%})",
                 ha='center', va='center', color='red', fontsize=12)

plt.xlabel('Predicted', fontsize=14)
plt.ylabel('True', fontsize=14)
plt.title('Confusion Matrix', fontsize=16)
plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()

# Plot training history
history_df = pd.read_csv('training_log.csv')
plt.figure(figsize=(15, 12))

# Loss plot
plt.subplot(2, 2, 1)
plt.plot(history_df['loss'], label='Training Loss')
plt.plot(history_df['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss', fontsize=14)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.legend()
plt.grid(alpha=0.3)

# Accuracy plot
plt.subplot(2, 2, 2)
plt.plot(history_df['accuracy'], label='Training Accuracy')
plt.plot(history_df['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy', fontsize=14)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Accuracy', fontsize=12)
plt.legend()
plt.grid(alpha=0.3)

# F1 Score plot
plt.subplot(2, 2, 3)
if 'val_f1' in history_df.columns:
    plt.plot(history_df['val_f1'], label='Validation F1', color='green')
    plt.title('Validation F1 Score', fontsize=14)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('F1 Score', fontsize=12)
    plt.legend()
    plt.grid(alpha=0.3)

# Combined metrics
plt.subplot(2, 2, 4)
if 'val_f1' in history_df.columns:
    plt.plot(history_df['val_accuracy'], label='Accuracy')
    plt.plot(history_df['val_f1'], label='F1 Score')
    plt.title('Validation Metrics Comparison', fontsize=14)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Score', fontsize=12)
    plt.legend()
    plt.grid(alpha=0.3)

plt.tight_layout()
plt.savefig('training_history.png', dpi=300, bbox_inches='tight')
plt.show()

# Save final metrics
final_val_f1 = f1_score(y_true, y_pred_classes, average='macro')
final_val_acc = history_df['val_accuracy'].iloc[-1]

print("\n================ Final Metrics ================")
print(f"Validation Accuracy: {final_val_acc:.4f}")
print(f"Validation F1 Score: {final_val_f1:.4f}")
print("==============================================")

# Additional Performance Visualization
# Plot sample predictions
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.bar(class_names, [np.sum(y_true == 0), np.sum(y_true == 1)], color=['blue', 'orange'])
plt.title('True Class Distribution')
plt.ylabel('Count')

plt.subplot(1, 2, 2)
plt.bar(class_names, [np.sum(y_pred_classes == 0), np.sum(y_pred_classes == 1)], color=['blue', 'orange'])
plt.title('Predicted Class Distribution')
plt.savefig('class_distributions.png', dpi=300)
plt.show()

# Plot learning curves
fig, ax1 = plt.subplots(figsize=(10, 6))

color = 'tab:red'
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss', color=color)
ax1.plot(history_df['loss'], color=color, label='Train Loss')
ax1.plot(history_df['val_loss'], color='tab:orange', label='Val Loss')
ax1.tick_params(axis='y', labelcolor=color)
ax1.legend(loc='upper left')

ax2 = ax1.twinx()
color = 'tab:blue'
ax2.set_ylabel('Accuracy', color=color)
ax2.plot(history_df['accuracy'], color=color, label='Train Acc')
ax2.plot(history_df['val_accuracy'], color='tab:green', label='Val Acc')
ax2.tick_params(axis='y', labelcolor=color)
ax2.legend(loc='upper right')

plt.title('Training History')
plt.savefig('learning_curves.png', dpi=300)
plt.show()

Loading and preprocessing trials...


100%|██████████| 2400/2400 [03:55<00:00, 10.17it/s]


Data shapes - X_train: (1920, 8, 2240, 1), X_val: (480, 8, 2240, 1)
Data types - X_train: float32, X_val: float32


Starting training...
Epoch 1/100
[1m60/60[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 540ms/step - accuracy: 0.4918 - loss: 1.0359 - val_f1: 0.3305
[1m60/60[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 607ms/step - accuracy: 0.4918 - loss: 1.0350 - val_accuracy: 0.4938 - val_loss: 0.6970 - learning_rate: 5.0000e-04 - val_f1: 0.3305
Epoch 2/100
[1m60/60[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 539ms/step - accuracy: 0.4841 - loss: 0.9390 - val_f1: 0.3305
[1m60/60[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m38s[0m 628ms/step - accuracy: 0.4843 - loss: 0.9389 - val_accuracy: 0.4938 - val_loss: 0.7024 - learning_rate: 5.0000e-04 - val_f1: 0.3305
Epoch 3/100
[1m60/60[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 544ms/step - accuracy: 0.4843 - loss: 0.9344 - val_f1: 0.3305
[1m60/60[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m36s[0m 596ms/step - accuracy: 0.4844 - loss: 0.9345 - val_accuracy: 0.4938 - val_loss: 0.7096 - learning_rate: 5.