In [None]:
import mne
import numpy as np
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize

# Data Loading and Preprocessing
base_path = r'C:\\Users\\karan\\Downloads\\EEG Data\\Data'
# Exclude A04 from the list of subjects
subjects = [f'A0{i}' for i in range(1, 10) if i != 4]
event_ids = [7, 8, 9, 10]  # Define event IDs for motor imagery tasks

all_features = []
all_labels = []

for subject in subjects:
    file_path = f'{base_path}\\{subject}T.gdf'
    print(f"Processing {subject}...")
    raw = mne.io.read_raw_gdf(file_path, preload=True)
    raw.drop_channels(['EOG-left', 'EOG-central', 'EOG-right'])
    raw.set_eeg_reference()
    events, _ = mne.events_from_annotations(raw)
    epochs = mne.Epochs(raw, events, event_id=event_ids, on_missing='warn', preload=True)
    features = epochs.get_data()  # Shape: (n_epochs, n_channels, n_timepoints)
    labels = epochs.events[:, -1]  # Extract event IDs as labels
    all_features.append(features)
    all_labels.append(labels)

# Concatenate data across all subjects
features = np.concatenate(all_features, axis=0)
labels = np.concatenate(all_labels, axis=0)

# Normalize features (Z-score across all channels and timepoints)
scaler = StandardScaler()
features = scaler.fit_transform(features.reshape(features.shape[0], -1)).reshape(features.shape)

# Reshape features for CNN input (add channel dimension)
features = features[..., np.newaxis]  # Shape: (n_samples, n_channels, n_timepoints, 1)

# One-hot encode labels for multi-class classification
labels = to_categorical(labels - event_ids[0])  # Adjust labels to start from 0

# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.2, random_state=42)

# Print data shapes for verification
print(f"X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
print(f"X_test shape: {X_test.shape}, y_test shape: {y_test.shape}")

# Build ShallowNet Model
def build_shallownet(input_shape, num_classes):
    model = Sequential()
    
    # 1st Convolutional Block
    model.add(Conv2D(16, (1, 10), activation='relu', input_shape=input_shape))
    model.add(BatchNormalization())
    model.add(MaxPooling2D((1, 3)))
    
    # 2nd Convolutional Block
    model.add(Conv2D(32, (1, 10), activation='relu'))
    model.add(BatchNormalization())
    model.add(MaxPooling2D((1, 3)))
    
    # Fully Connected Layers
    model.add(Flatten())
    model.add(Dense(64, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(num_classes, activation='softmax'))
    
    # Compile the model
    model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    
    return model

# Define input shape and number of classes
input_shape = (features.shape[1], features.shape[2], 1)  # (n_channels, n_timepoints, 1)
num_classes = labels.shape[1]

# Build the model
shallownet = build_shallownet(input_shape, num_classes)

# Print the model summary
shallownet.summary()

# Define callbacks
callbacks = [
    EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True, verbose=1),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, verbose=1)
]

# Train the model
history = shallownet.fit(
    X_train, y_train,
    validation_data=(X_test, y_test),  # Use the test set as validation
    epochs=50,
    batch_size=32,
    callbacks=callbacks,
    verbose=1
)

# Evaluate the model on the test set
test_loss, test_accuracy = shallownet.evaluate(X_test, y_test, verbose=0)
print(f"Test Accuracy: {test_accuracy * 100:.2f}%")

# Predictions and Evaluation
y_pred_probs = shallownet.predict(X_test)  # Predicted probabilities
y_pred = np.argmax(y_pred_probs, axis=1)  # Predicted classes
y_test_classes = np.argmax(y_test, axis=1)  # True classes

# Classification Report
print("\nClassification Report:")
print(classification_report(y_test_classes, y_pred))

# Confusion Matrix
conf_matrix = confusion_matrix(y_test_classes, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=range(num_classes), yticklabels=range(num_classes))
plt.title('ShallowNet Confusion Matrix')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.show()

# Plot Training History
# Accuracy Plot
plt.figure()
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('ShallowNet Model Validaiton Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

# Loss Plot
plt.figure()
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('ShallowNet Model Validaiton Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()



# Compute ROC curve and AUC for each class
fpr = {}
tpr = {}
roc_auc = {}

# Binarize the true labels for multi-class ROC-AUC computation
y_test_binarized = label_binarize(y_test_classes, classes=range(num_classes))

# Compute ROC and AUC for each class
for i in range(num_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test_binarized[:, i], y_pred_probs[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Print ROC-AUC for each class
print("\nROC-AUC for each class:")
for i in range(num_classes):
    print(f"Class {i}: AUC = {roc_auc[i]:.4f}")

# Compute the average ROC-AUC across all classes
average_roc_auc = np.mean(list(roc_auc.values()))
print(f"\nAverage ROC-AUC: {average_roc_auc:.4f}")

# Plot ROC curves for each class
plt.figure(figsize=(10, 8))
for i in range(num_classes):
    plt.plot(fpr[i], tpr[i], label=f"Class {i} (AUC = {roc_auc[i]:.2f})")
plt.plot([0, 1], [0, 1], 'k--', lw=2)
plt.title('ShalloNet ROC Curve for Multi-Class Classification')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc='lower right')
plt.grid()
plt.show()



In [3]:
from sklearn.model_selection import StratifiedKFold
from tensorflow.keras.models import clone_model
from sklearn.metrics import roc_curve, auc
import numpy as np

# Convert one-hot labels to single-label integers for StratifiedKFold
labels_integers = np.argmax(labels, axis=1)

# Parameters for Stratified Cross-Validation
num_folds = 5
skf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=42)
fold_no = 1

# Store results
accuracy_per_fold = []
loss_per_fold = []
roc_auc_per_fold = []

# StratifiedKFold Cross-Validation
for train_idx, val_idx in skf.split(features, labels_integers):
    print(f"\nTraining Fold {fold_no}/{num_folds}")
    
    # Split data into training and validation sets
    X_train, X_val = features[train_idx], features[val_idx]
    y_train, y_val = labels[train_idx], labels[val_idx]
    
    # Build a new model instance for each fold
    model = build_shallownet(input_shape, num_classes)
    
    # Train the model
    history = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=50,
        batch_size=32,
        callbacks=[
            EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True, verbose=1),
            ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, verbose=1)
        ],
        verbose=1
    )
    
    # Evaluate the model on validation data
    val_loss, val_accuracy = model.evaluate(X_val, y_val, verbose=0)
    print(f"Fold {fold_no} - Validation Accuracy: {val_accuracy * 100:.2f}%")
    
    # Store validation accuracy and loss
    accuracy_per_fold.append(val_accuracy)
    loss_per_fold.append(val_loss)
    
    # Compute ROC-AUC for each class
    y_pred_probs = model.predict(X_val)
    fpr, tpr, roc_auc = {}, {}, {}
    for i in range(num_classes):
        fpr[i], tpr[i], _ = roc_curve(y_val[:, i], y_pred_probs[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
    
    # Calculate average ROC-AUC for the fold
    avg_roc_auc = np.mean(list(roc_auc.values()))
    roc_auc_per_fold.append(avg_roc_auc)
    print(f"Fold {fold_no} - Average ROC-AUC: {avg_roc_auc:.4f}")
    
    # Increment fold number
    fold_no += 1

# Final Cross-Validation Results
print("\nCross-Validation Results:")
print(f"Average Accuracy: {np.mean(accuracy_per_fold) * 100:.2f}%")
print(f"Standard Deviation of Accuracy: {np.std(accuracy_per_fold) * 100:.2f}%")
print(f"Average Loss: {np.mean(loss_per_fold):.4f}")
print(f"Standard Deviation of Loss: {np.std(loss_per_fold):.4f}")
print(f"Average ROC-AUC: {np.mean(roc_auc_per_fold):.4f}")
print(f"Standard Deviation of ROC-AUC: {np.std(roc_auc_per_fold):.4f}")



Training Fold 1/5


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 32ms/step - accuracy: 0.2645 - loss: 2.1739 - val_accuracy: 0.3145 - val_loss: 1.6377 - learning_rate: 0.0010
Epoch 2/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 23ms/step - accuracy: 0.2607 - loss: 1.3828 - val_accuracy: 0.3080 - val_loss: 2.6596 - learning_rate: 0.0010
Epoch 3/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 24ms/step - accuracy: 0.2807 - loss: 1.3755 - val_accuracy: 0.2972 - val_loss: 3.2734 - learning_rate: 0.0010
Epoch 4/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 24ms/step - accuracy: 0.2828 - loss: 1.3668 - val_accuracy: 0.3232 - val_loss: 2.9713 - learning_rate: 0.0010
Epoch 5/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 23ms/step - accuracy: 0.2963 - loss: 1.3696 - val_accuracy: 0.3015 - val_loss: 2.2976 - learning_rate: 0.0010
Epoch 6/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 107ms/step - accuracy: 0.2615 - loss: 2.3941 - val_accuracy: 0.3102 - val_loss: 2.0347 - learning_rate: 0.0010
Epoch 2/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 87ms/step - accuracy: 0.2816 - loss: 1.3621 - val_accuracy: 0.3189 - val_loss: 3.2145 - learning_rate: 0.0010
Epoch 3/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 88ms/step - accuracy: 0.2980 - loss: 1.3375 - val_accuracy: 0.3275 - val_loss: 3.2992 - learning_rate: 0.0010
Epoch 4/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 84ms/step - accuracy: 0.3393 - loss: 1.3100 - val_accuracy: 0.3688 - val_loss: 2.6723 - learning_rate: 0.0010
Epoch 5/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 78ms/step - accuracy: 0.3618 - loss: 1.2873 - val_accuracy: 0.3774 - val_loss: 1.8810 - learning_rate: 0.0010
Epoch 6/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 79ms/step 

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 107ms/step - accuracy: 0.2618 - loss: 2.2764 - val_accuracy: 0.2777 - val_loss: 2.0129 - learning_rate: 0.0010
Epoch 2/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 78ms/step - accuracy: 0.2508 - loss: 1.3781 - val_accuracy: 0.2603 - val_loss: 3.9976 - learning_rate: 0.0010
Epoch 3/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 80ms/step - accuracy: 0.2806 - loss: 1.3779 - val_accuracy: 0.2473 - val_loss: 4.3760 - learning_rate: 0.0010
Epoch 4/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 76ms/step - accuracy: 0.2529 - loss: 1.3956 - val_accuracy: 0.2625 - val_loss: 3.6106 - learning_rate: 0.0010
Epoch 5/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 78ms/step - accuracy: 0.2821 - loss: 1.3576 - val_accuracy: 0.2625 - val_loss: 2.4709 - learning_rate: 0.0010
Epoch 6/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 69ms/step - accuracy: 0.2694 - loss: 2.5501 - val_accuracy: 0.2863 - val_loss: 1.8540 - learning_rate: 0.0010
Epoch 2/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 47ms/step - accuracy: 0.2774 - loss: 1.3755 - val_accuracy: 0.2820 - val_loss: 3.8914 - learning_rate: 0.0010
Epoch 3/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 44ms/step - accuracy: 0.2855 - loss: 1.3647 - val_accuracy: 0.3124 - val_loss: 3.5129 - learning_rate: 0.0010
Epoch 4/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 41ms/step - accuracy: 0.3067 - loss: 1.3551 - val_accuracy: 0.2907 - val_loss: 3.1164 - learning_rate: 0.0010
Epoch 5/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 44ms/step - accuracy: 0.3226 - loss: 1.3115 - val_accuracy: 0.3449 - val_loss: 1.9592 - learning_rate: 0.0010
Epoch 6/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m38s[0m 139ms/step - accuracy: 0.2767 - loss: 2.1037 - val_accuracy: 0.2870 - val_loss: 1.5334 - learning_rate: 0.0010
Epoch 2/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 87ms/step - accuracy: 0.2593 - loss: 1.3812 - val_accuracy: 0.2848 - val_loss: 2.3211 - learning_rate: 0.0010
Epoch 3/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 86ms/step - accuracy: 0.2731 - loss: 1.3701 - val_accuracy: 0.2587 - val_loss: 2.5087 - learning_rate: 0.0010
Epoch 4/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 82ms/step - accuracy: 0.2836 - loss: 1.3633 - val_accuracy: 0.2826 - val_loss: 2.2343 - learning_rate: 0.0010
Epoch 5/50
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 79ms/step - accuracy: 0.2772 - loss: 1.3452 - val_accuracy: 0.2870 - val_loss: 1.8194 - learning_rate: 0.0010
Epoch 6/50
[1m57/58[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0