In [None]:
import mne
import numpy as np
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, SeparableConv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization
from tensorflow.keras.regularizers import l2
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
import matplotlib.pyplot as plt

# ---------------- Data Loading and Preprocessing ----------------

base_path = r'C:\\Users\\karan\\Downloads\\EEG Data\\Data'
subjects = [f'A0{i}' for i in range(1, 10) if i != 4]
event_ids = [7, 8, 9, 10]  # Event IDs for motor imagery tasks

all_features, all_labels = [], []

for subject in subjects:
    file_path = f'{base_path}\\{subject}T.gdf'
    print(f"Processing {subject}...")

    raw = mne.io.read_raw_gdf(file_path, preload=True)
    raw.drop_channels(['EOG-left', 'EOG-central', 'EOG-right'])
    raw.set_eeg_reference()
    raw.filter(8., 30., fir_design='firwin', verbose=False)

    ica = mne.preprocessing.ICA(n_components=15, random_state=97, max_iter=800)
    ica.fit(raw)
    raw = ica.apply(raw)

    events, _ = mne.events_from_annotations(raw)
    epochs = mne.Epochs(raw, events, event_id=event_ids, tmin=0.5, tmax=4.0, baseline=(0.5, 1.0), preload=True)

    all_features.append(epochs.get_data())
    all_labels.append(epochs.events[:, -1])

features = np.concatenate(all_features, axis=0)
labels = np.concatenate(all_labels, axis=0)

# ---------------- Data Preparation ----------------

# Map labels to binary classification: Tongue = 1, Non-Tongue = 0
binary_labels = np.where(labels == 10, 1, 0)  # Event ID for Tongue is 10
tongue_features = features[binary_labels == 1]
non_tongue_features = features[binary_labels == 0]

# Oversample Tongue trials to balance the dataset
n_tongue = len(tongue_features)
n_non_tongue = len(non_tongue_features)
oversampled_tongue_features = np.tile(tongue_features, (n_non_tongue // n_tongue, 1, 1))
remainder = n_non_tongue % n_tongue
if remainder > 0:
    oversampled_tongue_features = np.concatenate([oversampled_tongue_features, tongue_features[:remainder]], axis=0)

# Combine the balanced dataset
balanced_features = np.concatenate([oversampled_tongue_features, non_tongue_features], axis=0)
balanced_labels = np.concatenate([np.ones(len(oversampled_tongue_features)), np.zeros(n_non_tongue)], axis=0)

# Normalize the features (Z-score normalization)
balanced_features = (balanced_features - np.mean(balanced_features, axis=0)) / np.std(balanced_features, axis=0)

# Add Gaussian noise for data augmentation
noise_factor = 0.05
augmented_features = balanced_features + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=balanced_features.shape)
augmented_features = np.clip(augmented_features, -1.0, 1.0)

# Combine original and augmented datasets
final_features = np.concatenate((balanced_features, augmented_features))
final_labels = np.concatenate((balanced_labels, balanced_labels))
final_features, final_labels = shuffle(final_features, final_labels, random_state=42)

X_train, X_test, y_train, y_test = train_test_split(final_features, final_labels, test_size=0.15, random_state=42)
X_train = X_train[..., np.newaxis]
X_test = X_test[..., np.newaxis]

# ---------------- Build the Model ----------------

model = Sequential()
model.add(Conv2D(8, kernel_size=(1, 64), padding='same', input_shape=(22, X_train.shape[2], 1),
                 activation='relu', kernel_regularizer=l2(0.01)))
model.add(BatchNormalization())
model.add(DepthwiseConv2D(kernel_size=(22, 1), depth_multiplier=2, use_bias=False, activation='relu',
                          depthwise_regularizer=l2(0.01)))
model.add(BatchNormalization())
model.add(Dropout(0.3))
model.add(SeparableConv2D(16, kernel_size=(1, 16), use_bias=False, padding='same', activation='relu',
                          depthwise_regularizer=l2(0.01)))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(1, 4)))
model.add(Dropout(0.4))
model.add(Flatten())
model.add(Dense(64, activation='relu', kernel_regularizer=l2(0.01)))
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='binary_crossentropy', metrics=['accuracy'])

# ---------------- Train the Model ----------------

reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-5, verbose=1)
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True, verbose=1)

history = model.fit(X_train, y_train, epochs=50, batch_size=32, validation_split=0.15,
                    callbacks=[reduce_lr, early_stopping], verbose=1)

# ---------------- Evaluate the Model ----------------

test_loss, test_accuracy = model.evaluate(X_test, y_test, verbose=0)
print(f"Test Accuracy: {test_accuracy * 100:.2f}%")

y_pred_probs = model.predict(X_test)
y_pred = (y_pred_probs > 0.5).astype(int)

conf_matrix = confusion_matrix(y_test, y_pred)
print("Confusion Matrix:\n", conf_matrix)
print("\nClassification Report:\n", classification_report(y_test, y_pred))

# ---------------- Plot Confusion Matrix ----------------

def plot_confusion_matrix(cm, class_names, title="Confusion Matrix"):
    plt.figure(figsize=(6, 6))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)

    fmt = 'd'
    thresh = cm.max() / 2
    for i, j in np.ndindex(cm.shape):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.show()

plot_confusion_matrix(
    conf_matrix,
    class_names=['Non-Tongue', 'Tongue'],
    title="Enhanced EEGNet: Tongue vs Non-Tongue Model - Confusion Matrix"
)

# ---------------- Plot Training and Validation Trends ----------------

plt.figure()
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Enhanced EEGNet: Tongue vs Non-Tongue Model - Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.grid()
plt.show()

plt.figure()
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Enhanced EEGNet: Tongue vs Non-Tongue Model - Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid()
plt.show()

# ---------------- Plot ROC Curve ----------------

fpr, tpr, _ = roc_curve(y_test, y_pred_probs)
roc_auc = auc(fpr, tpr)

plt.figure()
plt.plot(fpr, tpr, color='darkorange', label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', linestyle='--')
plt.title('Enhanced EEGNet: Tongue vs Non-Tongue Model - ROC Curve')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc='lower right')
plt.grid()
plt.show()


Processing A01...
Extracting EDF parameters from C:\Users\karan\Downloads\EEG Data\Data\A01T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right
Creating raw.info structure...


  next(self.gen)


Reading 0 ... 672527  =      0.000 ...  2690.108 secs...
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Fitting ICA to data using 22 channels (please be patient, this may take a while)
Selecting by number: 15 components
Fitting ICA took 14.1s.
Applying ICA to Raw instance
    Transforming to ICA space (15 components)
    Zeroing out 0 ICA components
    Projecting back using 22 PCA components
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
Not setting metadata
288 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 288 events and 876 original time points ...
0 bad epochs dropped
Processing A02...
Extracting EDF parameters from C:\Users\karan\Downloads\EEG Data\Data\A02T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they 

  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Fitting ICA to data using 22 channels (please be patient, this may take a while)
Selecting by number: 15 components
Fitting ICA took 16.6s.
Applying ICA to Raw instance
    Transforming to ICA space (15 components)
    Zeroing out 0 ICA components
    Projecting back using 22 PCA components
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
Not setting metadata
288 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 288 events and 876 original time points ...
0 bad epochs dropped
Processing A03...
Extracting EDF parameters from C:\Users\karan\Downloads\EEG Data\Data\A03T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG,

  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Fitting ICA to data using 22 channels (please be patient, this may take a while)
Selecting by number: 15 components
Fitting ICA took 21.6s.
Applying ICA to Raw instance
    Transforming to ICA space (15 components)
    Zeroing out 0 ICA components
    Projecting back using 22 PCA components
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
Not setting metadata
288 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 288 events and 876 original time points ...
0 bad epochs dropped
Processing A05...
Extracting EDF parameters from C:\Users\karan\Downloads\EEG Data\Data\A05T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG,

  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Fitting ICA to data using 22 channels (please be patient, this may take a while)
Selecting by number: 15 components
Fitting ICA took 16.3s.
Applying ICA to Raw instance
    Transforming to ICA space (15 components)
    Zeroing out 0 ICA components
    Projecting back using 22 PCA components
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
Not setting metadata
288 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 288 events and 876 original time points ...
0 bad epochs dropped
Processing A06...
Extracting EDF parameters from C:\Users\karan\Downloads\EEG Data\Data\A06T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG,

  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Fitting ICA to data using 22 channels (please be patient, this may take a while)
Selecting by number: 15 components
Fitting ICA took 19.1s.
Applying ICA to Raw instance
    Transforming to ICA space (15 components)
    Zeroing out 0 ICA components
    Projecting back using 22 PCA components
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
Not setting metadata
288 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 288 events and 876 original time points ...
0 bad epochs dropped
Processing A07...
Extracting EDF parameters from C:\Users\karan\Downloads\EEG Data\Data\A07T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG,

  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Fitting ICA to data using 22 channels (please be patient, this may take a while)
Selecting by number: 15 components
Fitting ICA took 12.4s.
Applying ICA to Raw instance
    Transforming to ICA space (15 components)
    Zeroing out 0 ICA components
    Projecting back using 22 PCA components
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
Not setting metadata
288 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 288 events and 876 original time points ...
0 bad epochs dropped
Processing A08...
Extracting EDF parameters from C:\Users\karan\Downloads\EEG Data\Data\A08T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG,

  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Fitting ICA to data using 22 channels (please be patient, this may take a while)
Selecting by number: 15 components
Fitting ICA took 18.1s.
Applying ICA to Raw instance
    Transforming to ICA space (15 components)
    Zeroing out 0 ICA components
    Projecting back using 22 PCA components
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
Not setting metadata
288 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 288 events and 876 original time points ...
0 bad epochs dropped
Processing A09...
Extracting EDF parameters from C:\Users\karan\Downloads\EEG Data\Data\A09T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG,

  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Fitting ICA to data using 22 channels (please be patient, this may take a while)
Selecting by number: 15 components
Fitting ICA took 20.9s.
Applying ICA to Raw instance
    Transforming to ICA space (15 components)
    Zeroing out 0 ICA components
    Projecting back using 22 PCA components
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
Not setting metadata
288 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 288 events and 876 original time points ...
0 bad epochs dropped


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/50
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 93ms/step - accuracy: 0.5185 - loss: 2.1573 - val_accuracy: 0.5420 - val_loss: 1.4979 - learning_rate: 0.0010
Epoch 2/50
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 93ms/step - accuracy: 0.5788 - loss: 1.3374 - val_accuracy: 0.5884 - val_loss: 1.1723 - learning_rate: 0.0010
Epoch 3/50
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 86ms/step - accuracy: 0.6645 - loss: 1.0604 - val_accuracy: 0.6837 - val_loss: 1.0119 - learning_rate: 0.0010
Epoch 4/50
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 55ms/step - accuracy: 0.7361 - loss: 0.9659 - val_accuracy: 0.8107 - val_loss: 0.8851 - learning_rate: 0.0010
Epoch 5/50
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 51ms/step - accuracy: 0.7909 - loss: 0.9230 - val_accuracy: 0.8254 - val_loss: 0.8537 - learning_rate: 0.0010
Epoch 6/50
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m

In [3]:
import tensorflow as tf
from sklearn.utils import shuffle
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, SeparableConv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization
from tensorflow.keras.regularizers import l2
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping

# ---------------- Data Preparation for tongue vs. Non-tongue Classification ----------------

# Map labels to binary classification: tongue = 1, Non-tongue = 0
binary_labels = np.where(labels == 10, 1, 0)  # Event ID for tongue is 10

# Separate tongue and Non-tongue data
tongue_features = features[binary_labels == 1]
non_tongue_features = features[binary_labels == 0]

# Oversample tongue trials to balance the dataset
n_tongue = len(tongue_features)
n_non_tongue = len(non_tongue_features)
oversampled_tongue_features = np.tile(tongue_features, (n_non_tongue // n_tongue, 1, 1))
remainder = n_non_tongue % n_tongue
if remainder > 0:
    oversampled_tongue_features = np.concatenate([oversampled_tongue_features, tongue_features[:remainder]], axis=0)

# Combine and shuffle the balanced dataset
balanced_features = np.concatenate([oversampled_tongue_features, non_tongue_features], axis=0)
balanced_labels = np.concatenate([np.ones(len(oversampled_tongue_features)), np.zeros(n_non_tongue)], axis=0)

# Normalize the features (Z-score normalization)
balanced_features = (balanced_features - np.mean(balanced_features, axis=0)) / np.std(balanced_features, axis=0)

# Apply Gaussian noise for data augmentation
noise_factor = 0.05
augmented_features = balanced_features + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=balanced_features.shape)
augmented_features = np.clip(augmented_features, -1.0, 1.0)  # Ensure valid range for EEG signals

# Combine original and augmented data
final_features = np.concatenate((balanced_features, augmented_features))
final_labels = np.concatenate((balanced_labels, balanced_labels))

# Shuffle the final dataset
final_features, final_labels = shuffle(final_features, final_labels, random_state=42)

# Split into training and testing sets (80% training, 20% testing)
X_train, X_test, y_train, y_test = train_test_split(final_features, final_labels, test_size=0.15, random_state=42)

# Reshape data for CNN input
X_train = X_train[..., np.newaxis]
X_test = X_test[..., np.newaxis]

# ---------------- Build the EEGNet Model ----------------

model = Sequential()

# Temporal Convolution Block
model.add(Conv2D(8, kernel_size=(1, 64), padding='same', input_shape=(22, X_train.shape[2], 1),
                 activation='relu', kernel_regularizer=l2(0.01)))
model.add(BatchNormalization())
model.add(DepthwiseConv2D(kernel_size=(22, 1), depth_multiplier=2, use_bias=False, activation='relu',
                          depthwise_regularizer=l2(0.01)))
model.add(BatchNormalization())
model.add(Dropout(0.3))  # Increased dropout for regularization

# Spatial Convolution Block
model.add(SeparableConv2D(16, kernel_size=(1, 16), use_bias=False, padding='same', activation='relu',
                          depthwise_regularizer=l2(0.01)))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(1, 4)))
model.add(Dropout(0.4))  # Increased dropout

# Fully Connected Layers
model.add(Flatten())
model.add(Dense(64, activation='relu', kernel_regularizer=l2(0.01)))
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))  # Binary classification

# Compile the Model
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])

# ---------------- Training the Model ----------------

# Callbacks for learning rate adjustment and early stopping
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-5, verbose=1)
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True, verbose=1)

# Train the model
history = model.fit(
    X_train, y_train,
    epochs=50,
    batch_size=32,
    validation_split=0.15,
    callbacks=[reduce_lr, early_stopping],
    verbose=1
)

# ---------------- Evaluate the Model ----------------

# Test accuracy
test_loss, test_accuracy = model.evaluate(X_test, y_test, verbose=0)
print(f"Test Accuracy: {test_accuracy * 100:.2f}%")

# Predict labels for the test set
y_pred_probs = model.predict(X_test)
y_pred = (y_pred_probs > 0.5).astype(int)  # Default threshold 0.5

# Confusion Matrix
conf_matrix = confusion_matrix(y_test, y_pred)
print("Confusion Matrix:\n", conf_matrix)

# Classification Report
print("\nClassification Report:\n", classification_report(y_test, y_pred))




  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/50
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 80ms/step - accuracy: 0.5392 - loss: 2.1494 - val_accuracy: 0.5510 - val_loss: 1.4949 - learning_rate: 0.0010
Epoch 2/50
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 79ms/step - accuracy: 0.6537 - loss: 1.3225 - val_accuracy: 0.5964 - val_loss: 1.3588 - learning_rate: 0.0010
Epoch 3/50
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 79ms/step - accuracy: 0.6983 - loss: 1.0833 - val_accuracy: 0.7279 - val_loss: 1.0017 - learning_rate: 0.0010
Epoch 4/50
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 80ms/step - accuracy: 0.7548 - loss: 0.9510 - val_accuracy: 0.7846 - val_loss: 0.9184 - learning_rate: 0.0010
Epoch 5/50
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 78ms/step - accuracy: 0.7909 - loss: 0.9287 - val_accuracy: 0.8209 - val_loss: 0.8788 - learning_rate: 0.0010
Epoch 6/50
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37

In [None]:
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score

# ---------------- Stratified K-Fold Cross-Validation ----------------

n_splits = 5  # Number of folds
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

# Metrics storage
fold_accuracies = []
fold_losses = []
fold_roc_aucs = []
precision_list = []
recall_list = []
f1_list = []

fold = 1

for train_index, test_index in skf.split(final_features, final_labels):
    print(f"\nTraining Fold {fold}...")

    # Split data into training and testing sets for the fold
    X_train_fold = final_features[train_index][..., np.newaxis]
    X_test_fold = final_features[test_index][..., np.newaxis]
    y_train_fold = final_labels[train_index]
    y_test_fold = final_labels[test_index]

    # Build the model for each fold
    model = Sequential()
    model.add(Conv2D(8, kernel_size=(1, 64), padding='same', input_shape=(22, X_train_fold.shape[2], 1),
                     activation='relu', kernel_regularizer=l2(0.01)))
    model.add(BatchNormalization())
    model.add(DepthwiseConv2D(kernel_size=(22, 1), depth_multiplier=2, use_bias=False, activation='relu',
                              depthwise_regularizer=l2(0.01)))
    model.add(BatchNormalization())
    model.add(Dropout(0.3))
    model.add(SeparableConv2D(16, kernel_size=(1, 16), use_bias=False, padding='same', activation='relu',
                              depthwise_regularizer=l2(0.01)))
    model.add(BatchNormalization())
    model.add(MaxPooling2D(pool_size=(1, 4)))
    model.add(Dropout(0.4))
    model.add(Flatten())
    model.add(Dense(64, activation='relu', kernel_regularizer=l2(0.01)))
    model.add(Dropout(0.5))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
                  loss='binary_crossentropy', metrics=['accuracy'])

    # Callbacks
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-5, verbose=1)
    early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True, verbose=1)

    # Train the model
    history = model.fit(
        X_train_fold, y_train_fold,
        epochs=50,
        batch_size=32,
        validation_split=0.15,
        callbacks=[reduce_lr, early_stopping],
        verbose=1
    )

    # Evaluate the model
    test_loss, test_accuracy = model.evaluate(X_test_fold, y_test_fold, verbose=0)
    fold_accuracies.append(test_accuracy)
    fold_losses.append(test_loss)
    print(f"Fold {fold} Test Accuracy: {test_accuracy * 100:.2f}%")
    print(f"Fold {fold} Test Loss: {test_loss:.4f}")

    # Predictions and probabilities
    y_pred_probs = model.predict(X_test_fold)
    y_pred = (y_pred_probs > 0.5).astype(int)

    # Compute metrics for the fold
    roc_auc = roc_auc_score(y_test_fold, y_pred_probs)
    fold_roc_aucs.append(roc_auc)
    print(f"Fold {fold} ROC-AUC: {roc_auc:.4f}")

    precision, recall, f1, _ = precision_recall_fscore_support(y_test_fold, y_pred, average='binary')
    precision_list.append(precision)
    recall_list.append(recall)
    f1_list.append(f1)
    print(f"Fold {fold} Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}")

    fold += 1

# ---------------- Cross-Validation Results ----------------

# Compute overall averages and standard deviations
average_accuracy = np.mean(fold_accuracies)
std_accuracy = np.std(fold_accuracies)
average_loss = np.mean(fold_losses)
std_loss = np.std(fold_losses)
average_roc_auc = np.mean(fold_roc_aucs)
std_roc_auc = np.std(fold_roc_aucs)
average_precision = np.mean(precision_list)
average_recall = np.mean(recall_list)
average_f1 = np.mean(f1_list)

# Print overall cross-validation results
print("\nCross-Validation Results:")
print(f"Average Accuracy: {average_accuracy * 100:.2f}%")
print(f"Standard Deviation of Accuracy: {std_accuracy * 100:.2f}%")
print(f"Average Loss: {average_loss:.4f}")
print(f"Standard Deviation of Loss: {std_loss:.4f}")
print(f"Average ROC-AUC: {average_roc_auc:.4f}")
print(f"Standard Deviation of ROC-AUC: {std_roc_auc:.4f}")
print(f"Average Precision: {average_precision:.4f}")
print(f"Average Recall: {average_recall:.4f}")
print(f"Average F1-Score: {average_f1:.4f}")
