In [None]:
import os
import numpy as np
import mne
import tensorflow as tf
from scipy.signal import butter, lfilter
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_curve, auc, precision_recall_curve, average_precision_score
from imblearn.over_sampling import SMOTE
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

# [Previous functions: butter_bandpass, bandpass_filter unchanged]

FREQ_BANDS = {'delta': (0.5, 4), 'theta': (4, 8), 'alpha': (8, 13), 'beta': (13, 30), 'gamma': (30, 40)}

def extract_features(eeg_data, sfreq):
    features = []
    for band, (low, high) in FREQ_BANDS.items():
        filtered_signal = bandpass_filter(eeg_data, low, high, sfreq)
        features.append(np.mean(np.abs(filtered_signal)))
        features.append(np.var(filtered_signal))
    return np.array(features)

def load_seizure_files(seizure_list_path):
    seizure_files = set()
    with open(seizure_list_path, 'r') as f:
        for line in f:
            seizure_files.add(line.strip())
    return seizure_files

def load_eeg_data(edf_path, seizure_files):
    X, y = [], []
    for root, _, files in os.walk(edf_path):
        for file in files:
            if file.endswith('.edf'):
                file_path = os.path.join(root, file)
                relative_path = os.path.relpath(file_path, edf_path).replace('\\', '/')
                raw = mne.io.read_raw_edf(file_path, preload=True)
                eeg_data = raw.get_data()
                sfreq = raw.info['sfreq']
                features = extract_features(eeg_data, sfreq)
                X.append(features)
                label = 1 if relative_path in seizure_files else 0
                y.append(label)
    return np.array(X), np.array(y)

def build_eegnet(input_shape):
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(16, (1, 3), activation='relu', padding='same', input_shape=input_shape),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPooling2D((1, 2)),
        tf.keras.layers.Dropout(0.25),
        tf.keras.layers.Conv2D(32, (1, 3), activation='relu', padding='same'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPooling2D((1, 2)),
        tf.keras.layers.Dropout(0.25),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(2, activation='softmax')
    ])
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

# Paths
data_path = r"E:\MTech Projects\Suseela\Dataset\Extracted\chb-mit-scalp-eeg-database-1.0.0"
seizure_list_path = r"E:\MTech Projects\Suseela\Dataset\Extracted\chb-mit-scalp-eeg-database-1.0.0\RECORDS-WITH-SEIZURES"

# Load data
seizure_files = load_seizure_files(seizure_list_path)
X, y = load_eeg_data(data_path, seizure_files)
print("Unique labels in y:", np.unique(y))
print("Number of seizure samples (1):", np.sum(y == 1))
print("Number of non-seizure samples (0):", np.sum(y == 0))

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

# Balance training data with SMOTE
smote = SMOTE(random_state=42)
X_train_balanced, y_train_balanced = smote.fit_resample(X_train, y_train)

# Prepare for EEGNet
X_train_cnn = X_train_balanced.reshape(-1, 1, X.shape[1], 1)
X_test_cnn = X_test.reshape(-1, 1, X.shape[1], 1)

# Train EEGNet with class weights
class_weights = {0: 1.0, 1: float(len(y_train[y_train == 0]) / len(y_train[y_train == 1]))}
model = build_eegnet(input_shape=(1, X.shape[1], 1))
history = model.fit(X_train_cnn, y_train_balanced, epochs=10, batch_size=16, 
                    validation_data=(X_test_cnn, y_test), class_weight=class_weights)

# Extract deep features
deep_features_train = model.predict(X_train_cnn)
deep_features_test = model.predict(X_test_cnn)

# Train SVM with probability estimates
svm = SVC(kernel='rbf', class_weight='balanced', probability=True)
svm.fit(deep_features_train, y_train_balanced)
y_pred = svm.predict(deep_features_test)
y_prob = svm.predict_proba(deep_features_test)[:, 1]

# Evaluate
accuracy = accuracy_score(y_test, y_pred)
print("SVM Accuracy:", accuracy)
print(classification_report(y_test, y_pred))

# 1. Confusion Matrix Heatmap
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(6, 4))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Non-Seizure', 'Seizure'], yticklabels=['Non-Seizure', 'Seizure'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

# 2. ROC Curve
fpr, tpr, _ = roc_curve(y_test, y_prob)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(6, 4))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend(loc='lower right')
plt.show()

# 3. Precision-Recall Curve
precision, recall, _ = precision_recall_curve(y_test, y_prob)
avg_precision = average_precision_score(y_test, y_prob)
plt.figure(figsize=(6, 4))
plt.plot(recall, precision, color='b', lw=2, label=f'PR curve (AP = {avg_precision:.2f})')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc='lower left')
plt.show()

# 4. Training and Validation Loss/Accuracy
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.tight_layout()
plt.show()

# 5. Feature Distribution (example for delta_mean)
feature_names = [f'{band}_{stat}' for band in FREQ_BANDS for stat in ['mean', 'var']]
df = pd.DataFrame(X, columns=feature_names)
df['label'] = y
plt.figure(figsize=(6, 4))
sns.histplot(data=df, x='delta_mean', hue='label', element='step', stat='density', common_norm=False)
plt.xlabel('Delta Band Mean Amplitude')
plt.title('Feature Distribution: Delta Mean')
plt.show()