In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Load preprocessed training metadata
train_df = pd.read_csv('/kaggle/input/preprocessed-with-25-artifact-threshold/quality_controlled_preprocessed/train.csv')

# Filter for MI tasks only
mi_train_df = train_df[train_df['task'] == 'MI'].copy()

# Remove rejected trials (where preprocessing failed)
valid_mi_train_df = mi_train_df[mi_train_df['processed_path'].notnull()]

In [2]:
# Count class distribution
class_counts = valid_mi_train_df['label'].value_counts()

# Calculate percentages
class_percentages = class_counts / len(valid_mi_train_df) * 100

# Display distribution
print("Motor Imagery Class Distribution:")
print(f"Total Trials: {len(valid_mi_train_df)}")
print(class_counts)
print("\nClass Balance:")
print(class_percentages)

Motor Imagery Class Distribution:
Total Trials: 2355
label
Right    1190
Left     1165
Name: count, dtype: int64

Class Balance:
label
Right    50.530786
Left     49.469214
Name: count, dtype: float64


In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (LSTM, Dense, Input, Dropout, 
                                     Bidirectional, LayerNormalization,
                                     Attention, Permute, Multiply)
from tensorflow.keras.callbacks import (ModelCheckpoint, CSVLogger, 
                                       EarlyStopping, ReduceLROnPlateau)
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import os
from scipy import signal

# Configuration
BATCH_SIZE = 32
EPOCHS = 100  # Increased to allow early stopping
CLASS_WEIGHTS = {0: 1.01, 1: 0.99}  # Slight adjustment for 50.5/49.5 split
BASE_PREPROCESSED_PATH = '/kaggle/input/preprocessed-with-25-artifact-threshold'

# Define channel selections
ALL_CHANNELS = ['FZ', 'C3', 'CZ', 'C4', 'PZ', 'PO7', 'OZ', 'PO8']
MI_CHANNELS = ['C3', 'CZ', 'C4', 'FZ', 'PZ']  # Optimal MI channels
CHANNEL_INDICES = [ALL_CHANNELS.index(ch) for ch in MI_CHANNELS]

# Load and filter data with proper NaN handling
train_df = pd.read_csv(os.path.join(BASE_PREPROCESSED_PATH, 'quality_controlled_preprocessed/train.csv'))

# Clean path strings and convert to absolute paths
train_df['processed_path'] = train_df['processed_path'].astype(str).str.replace('^\./', '', regex=True)
train_df['abs_path'] = train_df['processed_path'].apply(
    lambda x: os.path.join(BASE_PREPROCESSED_PATH, x) if x != 'nan' else np.nan
)

# Filter out invalid entries
train_df = train_df[train_df['abs_path'].notna()]
train_df['file_exists'] = train_df['abs_path'].apply(lambda x: os.path.exists(x) if isinstance(x, str) else False)
print(f"Training files missing: {len(train_df) - train_df['file_exists'].sum()}")
mi_train_df = train_df[(train_df['task'] == 'MI') & (train_df['file_exists'])]

# Data generator with fixed-length augmentation and channel selection
class EEGDataGenerator(tf.keras.utils.Sequence):
    def __init__(self, df, batch_size, seq_length=2250, augment=True):
        self.df = df
        self.batch_size = batch_size
        self.seq_length = seq_length
        self.augment = augment
        self.indices = np.arange(len(df))
        
    def __len__(self):
        return int(np.ceil(len(self.df) / self.batch_size))
    
    def __getitem__(self, idx):
        batch_indices = self.indices[idx*self.batch_size:(idx+1)*self.batch_size]
        X_batch = np.zeros((len(batch_indices), self.seq_length, len(MI_CHANNELS)))
        y_batch = []
        
        for j, i in enumerate(batch_indices):
            row = self.df.iloc[i]
            data = np.load(row['abs_path'])['data']
            # Select only MI-relevant channels
            data = data[:, CHANNEL_INDICES]
            label = 0 if row['label'] == 'Left' else 1
            
            # Augmentation techniques
            if self.augment:
                # Gaussian noise
                if np.random.rand() > 0.7:
                    data += np.random.normal(0, 0.5, data.shape)
                
                # Time warping with fixed length
                if np.random.rand() > 0.7:
                    warp_factor = np.random.uniform(0.9, 1.1)
                    new_length = int(data.shape[0] * warp_factor)
                    
                    # Resize to new length
                    warped = tf.image.resize(data[..., np.newaxis], 
                                           [new_length, data.shape[1]]).numpy()[:, :, 0]
                    
                    # Resize back to original length
                    data = tf.image.resize(warped[..., np.newaxis], 
                                         [self.seq_length, warped.shape[1]]).numpy()[:, :, 0]
            
            # Ensure correct length
            if len(data) != self.seq_length:
                data = tf.image.resize(data[..., np.newaxis], 
                                      [self.seq_length, data.shape[1]]).numpy()[:, :, 0]
            
            X_batch[j] = data
            y_batch.append(label)
        
        y_batch = tf.keras.utils.to_categorical(y_batch, num_classes=2)
        return X_batch, y_batch

# Create datasets
train_gen = EEGDataGenerator(mi_train_df, BATCH_SIZE, seq_length=2250, augment=True)

# Prepare validation data
val_df = pd.read_csv(os.path.join(BASE_PREPROCESSED_PATH, 'quality_controlled_preprocessed/validation.csv'))
val_df['processed_path'] = val_df['processed_path'].astype(str).str.replace('^\./', '', regex=True)
val_df['abs_path'] = val_df['processed_path'].apply(
    lambda x: os.path.join(BASE_PREPROCESSED_PATH, x) if x != 'nan' else np.nan
)
val_df = val_df[val_df['abs_path'].notna()]
val_df['file_exists'] = val_df['abs_path'].apply(lambda x: os.path.exists(x))
print(f"Validation files missing: {len(val_df) - val_df['file_exists'].sum()}")
mi_val_df = val_df[(val_df['task'] == 'MI') & (val_df['file_exists'])]
val_gen = EEGDataGenerator(mi_val_df, BATCH_SIZE, seq_length=2250, augment=False)

# Attention LSTM Model with reduced input channels
def create_attention_lstm(input_shape):
    inputs = Input(shape=input_shape)
    
    # Encoder
    x = Bidirectional(LSTM(128, return_sequences=True))(inputs)
    x = LayerNormalization()(x)
    x = Dropout(0.3)(x)
    
    # Attention mechanism
    attention = Dense(1, activation='tanh')(x)
    attention = Permute((2, 1))(attention)
    attention = tf.keras.layers.Softmax()(attention)
    attention = Permute((2, 1))(attention)
    context = Multiply()([x, attention])
    
    # Decoder
    context = LSTM(64)(context)
    context = Dropout(0.3)(context)
    outputs = Dense(2, activation='softmax')(context)
    
    model = Model(inputs, outputs)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-3),
        loss='categorical_crossentropy',
        metrics=['accuracy', 
                 tf.keras.metrics.Precision(name='precision'),
                 tf.keras.metrics.Recall(name='recall'),
                 tf.keras.metrics.AUC(name='auc')]
    )
    return model

# Create model
model_lstm = create_attention_lstm((2250, len(MI_CHANNELS)))

# Callbacks
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=15,
    restore_best_weights=True,
    verbose=1
)

reduce_lr = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=5,
    min_lr=1e-6,
    verbose=1
)

checkpoint = ModelCheckpoint('lstm_best_model.h5', save_best_only=True, verbose=1)
csv_logger = CSVLogger('lstm_training_log.csv')

history = model_lstm.fit(
    train_gen,
    epochs=EPOCHS,
    validation_data=val_gen,
    class_weight=CLASS_WEIGHTS,
    callbacks=[checkpoint, csv_logger, early_stopping, reduce_lr],
    verbose=1
)

# Evaluation
y_true, y_pred = [], []
for i in range(len(val_gen)):
    X, y = val_gen[i]
    preds = model_lstm.predict(X, verbose=0)
    y_true.extend(np.argmax(y, axis=1))
    y_pred.extend(np.argmax(preds, axis=1))

# Classification report
print("LSTM Model Classification Report:")
print(classification_report(y_true, y_pred, target_names=['Left', 'Right']))

# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Left', 'Right'], 
            yticklabels=['Left', 'Right'])
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('LSTM Confusion Matrix')
plt.savefig('lstm_confusion_matrix.png', dpi=300)
plt.close()

# Calculate F1
precision = cm[1][1] / (cm[1][1] + cm[0][1]) if (cm[1][1] + cm[0][1]) > 0 else 0
recall = cm[1][1] / (cm[1][1] + cm[1][0]) if (cm[1][1] + cm[1][0]) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
print(f"\nValidation F1 Score: {f1:.4f}")

Training files missing: 0
Validation files missing: 0


  self._warn_if_super_not_called()


Epoch 1/100
[1m74/74[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 286ms/step - accuracy: 0.4825 - auc: 0.4825 - loss: 0.6941 - precision: 0.4825 - recall: 0.4825
Epoch 1: val_loss improved from inf to 0.69355, saving model to lstm_best_model.h5
[1m74/74[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 308ms/step - accuracy: 0.4825 - auc: 0.4826 - loss: 0.6941 - precision: 0.4825 - recall: 0.4825 - val_accuracy: 0.4490 - val_auc: 0.4798 - val_loss: 0.6935 - val_precision: 0.4490 - val_recall: 0.4490 - learning_rate: 0.0010
Epoch 2/100
[1m74/74[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 286ms/step - accuracy: 0.5173 - auc: 0.5047 - loss: 0.6928 - precision: 0.5173 - recall: 0.5173
Epoch 2: val_loss improved from 0.69355 to 0.69251, saving model to lstm_best_model.h5
[1m74/74[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m22s[0m 292ms/step - accuracy: 0.5171 - auc: 0.5046 - loss: 0.6928 - precision: 0.5171 - recall: 0.5171 - val_accuracy: 0.5510 - val_auc: 0

In [None]:
# STFT Transformation with channel selection
def compute_stft(eeg_data, fs=250):
    stft_features = []
    # Compute STFT only for MI-relevant channels
    for channel_idx in CHANNEL_INDICES:
        f, t, Zxx = signal.stft(eeg_data[:, channel_idx], 
                                fs=fs, 
                                nperseg=128, 
                                noverlap=96)
        # Select frequency range (0-40 Hz)
        freq_mask = (f >= 0) & (f <= 40)
        mag = np.abs(Zxx[freq_mask])
        stft_features.append(mag.T)  # Time x Freq
    
    return np.stack(stft_features, axis=-1)  # Time x Freq x Channels

# Create STFT dataset with path fixes
def create_stft_dataset(df):
    X, y = [], []
    for _, row in df.iterrows():
        # Use absolute path
        abs_path = os.path.join(BASE_PREPROCESSED_PATH, row['processed_path'].lstrip('./'))
        if not os.path.exists(abs_path):
            print(f"File not found: {abs_path}")
            continue
            
        data = np.load(abs_path)['data']
        stft_data = compute_stft(data)
        label = 0 if row['label'] == 'Left' else 1
        X.append(stft_data)
        y.append(label)
    return np.array(X), tf.keras.utils.to_categorical(y, 2)

# Create STFT datasets
X_train, y_train = create_stft_dataset(mi_train_df)
X_val, y_val = create_stft_dataset(mi_val_df)

# CNN Model Architecture with reduced input channels
def create_cnn_model(input_shape):
    model = tf.keras.Sequential([
        # Conv Block 1
        tf.keras.layers.Conv2D(32, (3, 3), activation='relu', 
                               input_shape=input_shape),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPooling2D((2, 2)),
        
        # Conv Block 2
        tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPooling2D((2, 2)),
        
        # Conv Block 3
        tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPooling2D((2, 2)),
        
        # Conv Block 4
        tf.keras.layers.Conv2D(256, (3, 3), activation='relu'),
        tf.keras.layers.BatchNormalization(),
        
        # Conv Block 5
        tf.keras.layers.Conv2D(512, (3, 3), activation='relu'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.GlobalAveragePooling2D(),
        
        # FC Layers
        tf.keras.layers.Dense(256, activation='relu'),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(2, activation='softmax')
    ])
    
    model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-4),
        loss='categorical_crossentropy',
        metrics=['accuracy', 
                 tf.keras.metrics.Precision(name='precision'),
                 tf.keras.metrics.Recall(name='recall'),
                 tf.keras.metrics.AUC(name='auc')]
    )
    return model

# Data augmentation
datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rotation_range=5,
    width_shift_range=0.1,
    height_shift_range=0.1,
    zoom_range=0.1,
    fill_mode='nearest'
)

# Create model
model_cnn = create_cnn_model(X_train.shape[1:])

# Callbacks
early_stopping_cnn = EarlyStopping(
    monitor='val_loss',
    patience=15,
    restore_best_weights=True,
    verbose=1
)

reduce_lr_cnn = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=5,
    min_lr=1e-6,
    verbose=1
)

checkpoint_cnn = ModelCheckpoint('cnn_best_model.h5', save_best_only=True, verbose=1)
csv_logger_cnn = CSVLogger('cnn_training_log.csv')

history_cnn = model_cnn.fit(
    datagen.flow(X_train, y_train, batch_size=BATCH_SIZE),
    epochs=EPOCHS,
    validation_data=(X_val, y_val),
    class_weight=CLASS_WEIGHTS,
    callbacks=[checkpoint_cnn, csv_logger_cnn, early_stopping_cnn, reduce_lr_cnn],
    verbose=1
)

# Evaluation
y_pred_cnn = model_cnn.predict(X_val, verbose=0)
y_true_cnn = np.argmax(y_val, axis=1)
y_pred_cnn = np.argmax(y_pred_cnn, axis=1)

# Classification report
print("\nCNN Model Classification Report:")
print(classification_report(y_true_cnn, y_pred_cnn, target_names=['Left', 'Right']))

# Confusion matrix
cm_cnn = confusion_matrix(y_true_cnn, y_pred_cnn)
plt.figure(figsize=(8, 6))
sns.heatmap(cm_cnn, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Left', 'Right'], 
            yticklabels=['Left', 'Right'])
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('CNN Confusion Matrix')
plt.savefig('cnn_confusion_matrix.png', dpi=300)
plt.close()

# Calculate F1
precision_cnn = cm_cnn[1][1] / (cm_cnn[1][1] + cm_cnn[0][1]) if (cm_cnn[1][1] + cm_cnn[0][1]) > 0 else 0
recall_cnn = cm_cnn[1][1] / (cm_cnn[1][1] + cm_cnn[1][0]) if (cm_cnn[1][1] + cm_cnn[1][0]) > 0 else 0
f1_cnn = 2 * (precision_cnn * recall_cnn) / (precision_cnn + recall_cnn) if (precision_cnn + recall_cnn) > 0 else 0
print(f"\nValidation F1 Score: {f1_cnn:.4f}")