In [None]:
import mne
import numpy as np
from sklearn.preprocessing import StandardScaler, label_binarize
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, Add, Flatten, Dense, Dropout, MaxPooling2D
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
import seaborn as sns
import matplotlib.pyplot as plt

# Data Loading and Preprocessing
base_path = r'C:\\Users\\karan\\Downloads\\EEG Data\\Data'
subjects = [f'A0{i}' for i in range(1, 10) if i != 4]  # Exclude subject A04
event_ids = [7, 8, 9, 10]  # Define event IDs for motor imagery tasks

all_features = []
all_labels = []

for subject in subjects:
    file_path = f'{base_path}\\{subject}T.gdf'
    print(f"Processing {subject}...")
    raw = mne.io.read_raw_gdf(file_path, preload=True)
    raw.drop_channels(['EOG-left', 'EOG-central', 'EOG-right'])
    raw.set_eeg_reference()
    events, _ = mne.events_from_annotations(raw)
    epochs = mne.Epochs(raw, events, event_id=event_ids, on_missing='warn', preload=True)
    features = epochs.get_data()  # Shape: (n_epochs, n_channels, n_timepoints)
    labels = epochs.events[:, -1]  # Extract event IDs as labels
    all_features.append(features)
    all_labels.append(labels)

# Concatenate data across all subjects
features = np.concatenate(all_features, axis=0)
labels = np.concatenate(all_labels, axis=0)

# Normalize features (Z-score across all channels and timepoints)
scaler = StandardScaler()
features = scaler.fit_transform(features.reshape(features.shape[0], -1)).reshape(features.shape)

# Reshape features for CNN input (add channel dimension)
features = features[..., np.newaxis]  # Shape: (n_samples, n_channels, n_timepoints, 1)

# One-hot encode labels for multi-class classification
labels = to_categorical(labels - event_ids[0])  # Adjust labels to start from 0

# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.2, random_state=42)

# Print data shapes for verification
print(f"X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
print(f"X_test shape: {X_test.shape}, y_test shape: {y_test.shape}")

# Build ResNet-inspired Model
def build_resnet(input_shape, num_classes):
    def residual_block(x, filters):
        shortcut = x
        x = Conv2D(filters, (1, 3), padding='same', activation='relu')(x)
        x = BatchNormalization()(x)
        x = Conv2D(filters, (1, 3), padding='same', activation=None)(x)
        x = BatchNormalization()(x)
        x = Add()([x, shortcut])  # Add shortcut
        x = Activation('relu')(x)
        return x

    inputs = Input(shape=input_shape)
    
    # Initial Conv Layer
    x = Conv2D(16, (1, 7), padding='same', activation='relu')(inputs)
    x = BatchNormalization()(x)
    x = MaxPooling2D((1, 3))(x)

    # Residual Blocks
    x = residual_block(x, 16)
    x = residual_block(x, 16)
    
    # Fully Connected Layers
    x = Flatten()(x)
    x = Dense(64, activation='relu')(x)
    x = Dropout(0.5)(x)
    outputs = Dense(num_classes, activation='softmax')(x)

    model = Model(inputs, outputs)
    model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    return model

# Define input shape and number of classes
input_shape = (features.shape[1], features.shape[2], 1)  # (n_channels, n_timepoints, 1)
num_classes = labels.shape[1]

# Build the model
resnet = build_resnet(input_shape, num_classes)

# Print the model summary
resnet.summary()

# Define callbacks
callbacks = [
    EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True, verbose=1),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, verbose=1)
]

# Train the model
history = resnet.fit(
    X_train, y_train,
    validation_data=(X_test, y_test),
    epochs=50,
    batch_size=32,
    callbacks=callbacks,
    verbose=1
)

# Evaluate the model on the test set
test_loss, test_accuracy = resnet.evaluate(X_test, y_test, verbose=0)
print(f"Test Accuracy: {test_accuracy * 100:.2f}%")

# Predictions and Evaluation
y_pred_probs = resnet.predict(X_test)  # Predicted probabilities
y_pred = np.argmax(y_pred_probs, axis=1)  # Predicted classes
y_test_classes = np.argmax(y_test, axis=1)  # True classes

# Classification Report
print("\nClassification Report:")
print(classification_report(y_test_classes, y_pred))

# Confusion Matrix
conf_matrix = confusion_matrix(y_test_classes, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=range(num_classes), yticklabels=range(num_classes))
plt.title('ResNet Confusion Matrix')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.show()

# Plot Training History
# Accuracy Plot
plt.figure()
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('ResNet Model Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

# Loss Plot
plt.figure()
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('ResNet Model Validaitonn Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

# Compute ROC curve and AUC for each class
fpr = {}
tpr = {}
roc_auc = {}

# Binarize the true labels for multi-class ROC-AUC computation
y_test_binarized = label_binarize(y_test_classes, classes=range(num_classes))

# Compute ROC and AUC for each class
for i in range(num_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test_binarized[:, i], y_pred_probs[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Print ROC-AUC for each class
print("\nROC-AUC for each class:")
for i in range(num_classes):
    print(f"Class {i}: AUC = {roc_auc[i]:.4f}")

# Compute the average ROC-AUC across all classes
average_roc_auc = np.mean(list(roc_auc.values()))
print(f"\nAverage ROC-AUC: {average_roc_auc:.4f}")

# Plot ROC curves for each class
plt.figure(figsize=(10, 8))
for i in range(num_classes):
    plt.plot(fpr[i], tpr[i], label=f"Class {i} (AUC = {roc_auc[i]:.2f})")
plt.plot([0, 1], [0, 1], 'k--', lw=2)
plt.title('ResNet ROC Curve for Multi-Class Classification')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc='lower right')
plt.grid()
plt.show()


In [8]:
from sklearn.model_selection import StratifiedKFold
from tensorflow.keras.models import clone_model
from sklearn.metrics import roc_curve, auc
import numpy as np

# Prepare data for StratifiedKFold
labels_integers = np.argmax(labels, axis=1)  # Convert one-hot labels to class indices
num_folds = 5  # Number of splits for cross-validation
skf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=42)

# Storage for metrics
accuracy_per_fold = []
loss_per_fold = []
roc_auc_per_fold = []

# Cross-validation loop
fold_no = 1
for train_idx, test_idx in skf.split(features, labels_integers):
    print(f"\nTraining Fold {fold_no}/{num_folds}...")

    # Split the data
    X_train, X_test = features[train_idx], features[test_idx]
    y_train, y_test = labels[train_idx], labels[test_idx]

    # Build and compile the model for each fold
    model = build_resnet(input_shape, num_classes)

    # Train the model
    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=50,
        batch_size=32,
        callbacks=[
            EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True, verbose=1),
            ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, verbose=1)
        ],
        verbose=0
    )

    # Evaluate on the test fold
    val_loss, val_accuracy = model.evaluate(X_test, y_test, verbose=0)
    print(f"Fold {fold_no} - Validation Accuracy: {val_accuracy * 100:.2f}%")
    print(f"Fold {fold_no} - Validation Loss: {val_loss:.4f}")

    # Store accuracy and loss
    accuracy_per_fold.append(val_accuracy)
    loss_per_fold.append(val_loss)

    # Compute ROC-AUC for each class
    y_pred_probs = model.predict(X_test)
    fpr, tpr, roc_auc = {}, {}, {}
    for i in range(num_classes):
        fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_pred_probs[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # Calculate average ROC-AUC for the fold
    avg_roc_auc = np.mean(list(roc_auc.values()))
    roc_auc_per_fold.append(avg_roc_auc)
    print(f"Fold {fold_no} - Average ROC-AUC: {avg_roc_auc:.4f}")

    # Increment fold number
    fold_no += 1

# Print final cross-validation results
print("\nCross-Validation Results:")
print(f"Average Accuracy: {np.mean(accuracy_per_fold) * 100:.2f}%")
print(f"Standard Deviation of Accuracy: {np.std(accuracy_per_fold) * 100:.2f}%")
print(f"Average Loss: {np.mean(loss_per_fold):.4f}")
print(f"Standard Deviation of Loss: {np.std(loss_per_fold):.4f}")
print(f"Average ROC-AUC: {np.mean(roc_auc_per_fold):.4f}")
print(f"Standard Deviation of ROC-AUC: {np.std(roc_auc_per_fold):.4f}")



Training Fold 1/5...

Epoch 6: ReduceLROnPlateau reducing learning rate to 0.0005000000237487257.

Epoch 11: ReduceLROnPlateau reducing learning rate to 0.0002500000118743628.
Epoch 13: early stopping
Restoring model weights from the end of the best epoch: 3.
Fold 1 - Validation Accuracy: 25.16%
Fold 1 - Validation Loss: 1.3863
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 59ms/step
Fold 1 - Average ROC-AUC: 0.5000

Training Fold 2/5...

Epoch 8: ReduceLROnPlateau reducing learning rate to 0.0005000000237487257.

Epoch 13: ReduceLROnPlateau reducing learning rate to 0.0002500000118743628.
Epoch 13: early stopping
Restoring model weights from the end of the best epoch: 3.
Fold 2 - Validation Accuracy: 26.90%
Fold 2 - Validation Loss: 1.3856
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 214ms/step
Fold 2 - Average ROC-AUC: 0.5160

Training Fold 3/5...

Epoch 10: ReduceLROnPlateau reducing learning rate to 0.0005000000237487257.

Epoch 15: ReduceLROnPl