In [None]:
from google.colab import drive
import os
!pip install mne
!pip install pywavelets
!pip install shap
import mne
import numpy as np
import matplotlib.pyplot as plt
import shap
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Conv1D, BatchNormalization, Activation, MaxPooling1D,
    Add, GlobalAveragePooling1D, Dense, Dropout, Lambda
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils.class_weight import compute_class_weight
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import pywt
from scipy.stats import entropy
from scipy import integrate
import scipy.stats as stats
from scipy.ndimage import shift
from scipy.interpolate import interp1d
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

In [None]:
#  Step 1: Mount Google Drive
drive.mount('/content/drive')

#  Step 2: Dataset Path
dataset_path = "/content/drive/My Drive/Dataset Movement 109"

#  Step 3: List ALL 109 Available Subjects
all_subjects = sorted(os.listdir(dataset_path))
print(f" Total Subjects Found: {len(all_subjects)} ")

#  Step 4: Function to Load EEG Data in Chunks of 2 Subjects
def load_eeg_data(subjects, chunk_size=2):
    """
    Loads EEG data from EDF files in chunks of 2 subjects at a time.
    Each subject's data is processed separately to prevent data leakage.
    """
    total_subjects = len(subjects)
    print(f"\n Starting EEG Data Loading for {total_subjects} Subjects...")

    for i in range(0, total_subjects, chunk_size):
        chunk_subjects = subjects[i:i+chunk_size]  # Select subjects in chunks
        print(f"\n Processing Subjects: {chunk_subjects}")

        eeg_data = {}  # Store EEG data per subject
        subject_ids = []  # Store subject names

        for subject in chunk_subjects:
            subject_path = os.path.join(dataset_path, subject)
            edf_files = [f for f in os.listdir(subject_path) if f.endswith(".edf")]

            if not edf_files:
                print(f" No EDF file found for {subject}")
                continue

            edf_file_path = os.path.join(subject_path, edf_files[0])  # Load first EDF file

            try:
                raw = mne.io.read_raw_edf(edf_file_path, preload=True)
            except Exception as e:
                print(f" Error loading {edf_file_path}: {e}")
                continue  # Skip corrupted files

            #  Display Sampling Frequency
            sfreq = raw.info['sfreq']
            print(f" {subject}: Sampling Frequency = {sfreq} Hz")

            #  Store EEG Data
            eeg_data[subject] = raw
            subject_ids.append(subject)

            #  Plot PSD for Verification
            print(f"▶ Checking PSD Before Filtering for {subject}...")
            raw.compute_psd(fmin=0.5, fmax=50, n_fft=int(sfreq * 2)).plot(average=True, show=True)

            #  Plot Raw EEG Signals (First 10 Channels)
            print(f"▶ Plotting EEG Signals for {subject}...")
            raw.plot(n_channels=10, scalings='auto', title=f"EEG Signals for {subject}", show=True)

        yield eeg_data, subject_ids  # Return dictionary of EEG data per subject

#  Step 5: Load EEG Data in Chunks of 2 Subjects (All 109 Subjects)
eeg_generator = load_eeg_data(all_subjects, chunk_size=2)

#  Step 6: Process All Batches until Complete
eeg_batch = {}
subject_batch = []

for eeg_data, subjects in eeg_generator:
    eeg_batch.update(eeg_data)
    subject_batch.extend(subjects)
    print(f" Processed {len(subject_batch)}/{len(all_subjects)} Subjects")

print(f"\n Completed Loading EEG Data for All {len(subject_batch)} Subjects.")


In [None]:
#  Step 1: Check Before Bandpass Filtering
def check_before_bandpass(raw_dict):
    """
    Checks sampling frequency, channel quality, and raw PSD before bandpass filtering.

    Parameters:
    - raw_dict: Dictionary of EEG data per subject
    """
    total_subjects = len(raw_dict)
    print(f"\n Running Bandpass Pre-Check for {total_subjects} Subjects...")

    for subject, raw in raw_dict.items():
        try:
            #  1. Check Sampling Frequency
            sfreq = raw.info['sfreq']
            print(f" {subject} - Sampling Frequency: {sfreq} Hz")

            # 2. Check Number of Channels
            print(f" {subject} - Total Channels: {len(raw.ch_names)} - Channel Names: {raw.ch_names[:10]}...")

            #  3. Check Bad Channels
            bad_channels = raw.info['bads'] if 'bads' in raw.info else []
            if bad_channels:
                print(f" {subject} - Bad Channels Detected: {bad_channels}")
            else:
                print(f" {subject} - No Bad Channels Detected")

            #  4. Plot PSD Before Bandpass Filtering
            print(f" {subject} - PSD Before Bandpass Filtering")
            plt.figure(figsize=(12, 5))
            raw.compute_psd(fmin=0.1, fmax=sfreq / 2, n_fft=int(sfreq * 2)).plot(average=True, show=False)
            plt.title(f"PSD Before Bandpass - {subject}")
            plt.grid(True)
            plt.show()

            #  5. Display Nyquist Frequency and Recommended Bandpass Range
            nyquist = sfreq / 2
            print(f" {subject} - Nyquist Frequency: {nyquist} Hz")
            print(f" Recommended Bandpass Range: 0.5 Hz to {min(40, nyquist-1)} Hz")
            print("-" * 60)

        except Exception as e:
            print(f" Error during Bandpass Check for {subject}: {e}")
            continue

#  Run the Pre-Check on ALL Subjects
check_before_bandpass(eeg_batch)


import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

#  Step 2: Apply Bandpass Filter Using Correct Parameters
def apply_bandpass_filter(raw, l_freq=0.5, h_freq=40):
    """
    Applies a bandpass filter (0.5 to 40 Hz) using recommended settings.

    Parameters:
    - raw: MNE Raw object (continuous EEG data)
    - l_freq: Lower cutoff frequency (0.5 Hz)
    - h_freq: Upper cutoff frequency (40 Hz)

    Returns:
    - raw_filtered: Bandpass filtered EEG data
    """
    try:
        raw_filtered = raw.copy().filter(
            l_freq=l_freq,
            h_freq=h_freq,
            method='fir',
            fir_design='firwin'
        )
        return raw_filtered
    except Exception as e:
        print(f" Error applying Bandpass Filter: {e}")
        return raw  # Return unfiltered data if error occurs


#  Step 3: Apply Bandpass Filter to ALL Subjects
print("\n Applying Bandpass Filter to ALL 109 Subjects...")
eeg_filtered = {}

for subject, raw in eeg_batch.items():
    try:
        filtered = apply_bandpass_filter(raw)
        eeg_filtered[subject] = filtered
        print(f" Bandpass Filter Applied for {subject}")
    except Exception as e:
        print(f" Bandpass Filter Failed for {subject}: {e}")

print("\n Bandpass Filtering Completed for All Subjects")


#  Step 4: Save Bandpass Filter Results Properly
print("\n Saving Bandpass Filter Results to `eeg_batch`...")
for subject_id, raw_filtered in eeg_filtered.items():
    # Add Marker to Description
    with raw_filtered.info._unlock():
        if 'bandpass-applied' not in (raw_filtered.info.get('description') or ''):
            raw_filtered.info['description'] = (raw_filtered.info.get('description') or '') + ' bandpass-applied'
        else:
            print(f" Bandpass already marked for {subject_id}.")

    # Store back into eeg_batch
    eeg_batch[subject_id] = raw_filtered

print("\n Bandpass Filter Results Saved to `eeg_batch` for All Subjects.")


In [None]:
#  Step 2: Use compute_psd().plot() for Correct PSD
def plot_3d_psd(raw, subject_name):
    """
    Plots the 3D PSD.

    Parameters:
    - raw: MNE Raw object (after bandpass filtering)
    - subject_name: Subject identifier
    """
    sfreq = raw.info['sfreq']
    psd = raw.compute_psd(fmin=0.5, fmax=50, method='welch', n_fft=int(sfreq * 2)).get_data()
    freqs = np.linspace(0.5, 50, psd.shape[1])

    fig = plt.figure(figsize=(12, 7))
    ax = fig.add_subplot(111, projection='3d')

    for channel_idx in range(psd.shape[0]):
        ax.plot(freqs, np.full_like(freqs, channel_idx), 10 * np.log10(psd[channel_idx]), lw=0.7)

    ax.set_title(f"3D PSD After Bandpass Filter - {subject_name}")
    ax.set_xlabel("Frequency (Hz)")
    ax.set_ylabel("EEG Channel")
    ax.set_zlabel("Power (dB)")
    ax.set_xlim(0, 50)
    ax.grid(True)

    plt.show()

#  Step 3: Select ONE Subject for 3D Visualization
chosen_subject = list(eeg_filtered.keys())[0]
plot_3d_psd(eeg_filtered[chosen_subject], chosen_subject)


In [None]:
#  Step 1: Check Before Notch Filtering for All Subjects
def check_before_notch(raw_dict):
    """
    Checks sampling frequency and powerline noise before Notch filtering.

    Parameters:
    - raw_dict: Dictionary of EEG data per subject
    """
    total_subjects = len(raw_dict)
    print(f"\n Running Notch Pre-Check for {total_subjects} Subjects...")

    for subject, raw in raw_dict.items():
        try:
            #  1. Check Sampling Frequency
            sfreq = raw.info['sfreq']
            nyquist = sfreq / 2
            print(f"\n {subject} - Sampling Frequency: {sfreq} Hz (Nyquist: {nyquist} Hz)")

            if nyquist < 50:
                print(f" {subject} - Nyquist frequency too low for 50 Hz Notch Filter")
            else:
                print(f" {subject} - Suitable for 50 Hz Notch Filter")

            #  2. Compute PSD for Powerline Noise Check
            psd = raw.compute_psd(
                fmin=0.5, fmax=nyquist,
                method='welch', n_fft=int(sfreq * 2)
            )
            psd_data = psd.get_data()
            freqs = psd.freqs

            def power_at(freq):
                idx = np.argmin(np.abs(freqs - freq))
                return 10 * np.log10(np.mean(psd_data[:, idx]))

            power_50hz = power_at(50)
            power_100hz = power_at(100) if nyquist >= 100 else None
            power_150hz = power_at(150) if nyquist >= 150 else None

            # 3. Print Powerline Noise Check
            print(f" Power at 50 Hz: {power_50hz:.2f} dB")
            if power_100hz is not None:
                print(f" Power at 100 Hz: {power_100hz:.2f} dB")
            if power_150hz is not None:
                print(f" Power at 150 Hz: {power_150hz:.2f} dB")

            #  4. Decision on Notch Filtering Need
            if power_50hz > -20:
                print(f" {subject} - Strong 50 Hz Noise Detected!")
            else:
                print(f" {subject} - Minimal 50 Hz Noise (Likely Already Filtered)")

            if power_100hz and power_100hz > -20:
                print(f" {subject} - Harmonics Detected at 100 Hz")
            if power_150hz and power_150hz > -20:
                print(f" {subject} - Harmonics Detected at 150 Hz")

            #  5. PSD Plot for Visual Check
            psd.plot(average=True, show=False)
            plt.title(f"PSD Before Notch - {subject}")
            plt.grid(True)
            plt.show()

        except Exception as e:
            print(f" Error during Notch Pre-Check for {subject}: {e}")
            continue

#  Run Pre-Check for All Subjects
check_before_notch(eeg_batch)

#  Step 2: Create Save Directory for Notch Results
save_directory = "/content/drive/My Drive/Dataset Movement 109/Preprocessed"
os.makedirs(save_directory, exist_ok=True)
print(f" Save Directory Ready: {save_directory}")


#  Step 3: Apply Notch Filter Function
def apply_notch_filter(raw, freqs=[50]):
    """
    Applies a Notch filter to remove powerline interference (50 Hz).

    Parameters:
    - raw: MNE Raw object
    - freqs: List of frequencies to remove (default: [50])

    Returns:
    - raw_notched: Notch-filtered EEG data
    """
    try:
        raw_notched = raw.copy().notch_filter(
            freqs=freqs,
            fir_design='firwin'
        )
        return raw_notched
    except Exception as e:
        print(f" Notch filter failed: {e}")
        return raw  # Return unfiltered data if error occurs


#  Step 4: Apply Notch Filter to All Subjects
eeg_notched = {}

print("\n Applying Notch Filter to ALL 109 Subjects...")
for subject_id, raw_data in eeg_batch.items():
    try:
        notched = apply_notch_filter(raw_data, freqs=[50])
        eeg_notched[subject_id] = notched
        print(f" Notch Filter Applied for {subject_id}")

        # Save Notch-Filtered Data to Google Drive
        save_path = os.path.join(save_directory, f"{subject_id}_notch-raw.fif")
        notched.save(save_path, overwrite=True)
        print(f" Notch-Filtered EEG saved to: {save_path}")

    except Exception as e:
        print(f" Notch Filtering failed for {subject_id}: {e}")

print("\n Notch Filtering Completed for All Subjects.")


#  Step 5: Mark Notch Filter Applied in Metadata
print("\n Marking Notch Filter in Metadata...")
for subject_id, raw_filtered in eeg_notched.items():
    with raw_filtered.info._unlock():
        if 'notch-applied' not in (raw_filtered.info.get('description') or ''):
            raw_filtered.info['description'] = (raw_filtered.info.get('description') or '') + ' notch-applied'
            print(f" Metadata Updated for {subject_id}")
        else:
            print(f" Notch already marked for {subject_id}")

    #  Save to `eeg_batch`
    eeg_batch[subject_id] = raw_filtered

print("\n Notch Filter Metadata Updated for All Subjects.")


#  Step 6: 3D PSD Visualization for ONE Subject
from mpl_toolkits.mplot3d import Axes3D

def plot_3d_psd_after_notch(raw, subject_name):
    """Plots the 3D PSD after Notch Filtering for ONE subject."""
    try:
        sfreq = raw.info['sfreq']
        psd = raw.compute_psd(
            fmin=0.5, fmax=50,
            method='welch',
            n_fft=int(sfreq * 2)
        ).get_data()

        freqs = np.linspace(0.5, 50, psd.shape[1])

        fig = plt.figure(figsize=(12, 7))
        ax = fig.add_subplot(111, projection='3d')

        for channel_idx in range(psd.shape[0]):
            ax.plot(
                freqs,
                np.full_like(freqs, channel_idx),
                10 * np.log10(psd[channel_idx]),
                lw=0.7
            )

        ax.set_title(f"3D PSD After Notch Filter (50Hz) - {subject_name}")
        ax.set_xlabel("Frequency (Hz)")
        ax.set_ylabel("EEG Channel")
        ax.set_zlabel("Power (dB)")
        ax.set_xlim(0, 50)
        ax.grid(True)
        plt.show()

    except Exception as e:
        print(f" Error plotting 3D PSD for {subject_name}: {e}")


#  Step 7: Visualize 3D PSD for ONE Subject
if eeg_notched:
    chosen_subject = list(eeg_notched.keys())[0]  # Select First Subject
    plot_3d_psd_after_notch(eeg_notched[chosen_subject], chosen_subject)
else:
    print(" No subjects found in `eeg_notched` to visualize.")


In [None]:
# Plot raw EEG after notch filtering
eeg_notched[list(eeg_notched.keys())[0]].plot(n_channels=10, duration=5, scalings='auto', show=True);


In [None]:
#  Step 1: Load Notch-Filtered EEG Data
subject = list(eeg_notched.keys())[0]  # Select the first subject
raw = eeg_notched[subject]

#  Step 2: Print Existing Channel Names (for debugging)
print(" Original Channel Names:", raw.ch_names)

#  Step 3: Dynamically Fix Channel Names (Remove Dots Only If Needed)
channel_mapping = {
    ch: ch.replace('.', '') for ch in raw.ch_names if '.' in ch  # Remove dots only if they exist
}
raw.rename_channels(channel_mapping)

#  Step 4: Apply Standard 64-Channel Montage (Sharbrough System)
montage = mne.channels.make_standard_montage('biosemi64')  # Has all 64 Sharbrough locations
raw.set_montage(montage, on_missing='ignore')  # Ignore missing channels

#  Step 5: Verify Updated Channels
print(" Updated Channel Names:", raw.ch_names)



In [None]:
#  Step 1: Define Save Directory
save_directory = "/content/drive/My Drive/Dataset Movement 109/Preprocessed"
os.makedirs(save_directory, exist_ok=True)

#  Step 2: Apply ICA to All Subjects
eeg_ica_cleaned = {}  # Dictionary to store ICA-cleaned EEG for all subjects

for subject_id, raw_data in eeg_notched.items():
    print(f"\n Processing Subject: {subject_id}")

    # Clone Raw Data
    raw = raw_data.copy()

    #  Step 3: Apply ICA
    ica = mne.preprocessing.ICA(n_components=25, method='fastica', random_state=42)
    ica.fit(raw)

    #  Step 4: Detect & Remove Artifacts
    bad_ics = []
    frontal_channels = ['Fp1', 'Fp2', 'Fz', 'AFz']
    frontal_indices = [raw.ch_names.index(ch) for ch in frontal_channels if ch in raw.ch_names]

    for i in range(ica.n_components_):
        component_variance = np.var(ica.get_components()[frontal_indices, i])
        if component_variance > np.percentile(ica.get_components(), 95):  # Top 5% variance
            bad_ics.append(i)

    ica.exclude = list(set(bad_ics))
    print(f" Subject {subject_id} - Automatically detected bad components: {ica.exclude}")

    #  Step 5: Apply ICA & Re-reference
    raw_clean = ica.apply(raw.copy())
    raw_clean.set_eeg_reference('average', projection=False)

    #  Step 6: Store Cleaned Data
    eeg_ica_cleaned[subject_id] = raw_clean

    #  Step 7: Save ICA-Cleaned EEG
    save_path = os.path.join(save_directory, f"{subject_id}_cleaned_ica.fif")
    raw_clean.save(save_path, overwrite=True)
    print(f" Subject {subject_id} - ICA-Cleaned EEG saved to: {save_path}")

#  Step 8: Visualize Only 1 or 2 Subjects
subjects_to_visualize = list(eeg_ica_cleaned.keys())[:2]  # Select first 2 subjects

for subject_id in subjects_to_visualize:
    print(f"\n Visualizing ICA Cleaned EEG for Subject: {subject_id}")
    eeg_ica_cleaned[subject_id].plot(n_channels=10, duration=5, scalings='auto')

print("\n ICA Applied & Saved for All 109 Subjects! Ready for Epoching.")


In [None]:
#  Select ONE subject for visualization
subject_id = list(eeg_ica_cleaned.keys())[0]

# Load ICA-cleaned EEG data
raw_clean = eeg_ica_cleaned[subject_id]

#  Compute PSD Before and After ICA
sfreq = raw_clean.info["sfreq"]
nyquist_freq = sfreq / 2

#  Compute PSD for Raw EEG (Before ICA)
psd_raw = eeg_notched[subject_id].compute_psd(
    fmin=0.5, fmax=nyquist_freq,
    method="welch", n_fft=int(sfreq * 2)
).get_data()

#  Compute PSD for ICA-Cleaned EEG (After ICA)
psd_clean = raw_clean.compute_psd(
    fmin=0.5, fmax=nyquist_freq,
    method="welch", n_fft=int(sfreq * 2)
).get_data()

#  Extract Frequencies
freqs = np.linspace(0.5, nyquist_freq, psd_raw.shape[1])

#  Create 3D Visualization
fig = plt.figure(figsize=(12, 7))

#  Plot PSD Before ICA
ax1 = fig.add_subplot(121, projection='3d')
for ch in range(psd_raw.shape[0]):
    ax1.plot(freqs, np.full_like(freqs, ch), 10 * np.log10(psd_raw[ch]), lw=0.7)

ax1.set_title(f"3D PSD Before ICA - {subject_id}")
ax1.set_xlabel("Frequency (Hz)")
ax1.set_ylabel("EEG Channel")
ax1.set_zlabel("Power (dB)")
ax1.set_xlim(0, nyquist_freq)
ax1.grid(True)

#  Plot PSD After ICA
ax2 = fig.add_subplot(122, projection='3d')
for ch in range(psd_clean.shape[0]):
    ax2.plot(freqs, np.full_like(freqs, ch), 10 * np.log10(psd_clean[ch]), lw=0.7)

ax2.set_title(f"3D PSD After ICA - {subject_id}")
ax2.set_xlabel("Frequency (Hz)")
ax2.set_ylabel("EEG Channel")
ax2.set_zlabel("Power (dB)")
ax2.set_xlim(0, nyquist_freq)
ax2.grid(True)

#  Show the plot
plt.show()


In [None]:
def clean_metadata(raw, subject_id):
    """Cleans and fixes metadata for an EEG recording."""
    print(f"\n Cleaning Metadata for {subject_id}")

    #  Step 1: Remove Duplicates in Description
    if 'description' in raw.info and isinstance(raw.info['description'], str):
        markers = set(raw.info['description'].split())
        raw.info['description'] = " ".join(sorted(markers))
        print(f" Cleaned Description: {raw.info['description']}")
    else:
        raw.info['description'] = ""
        print(" Description was missing. Initialized as empty.")

    #  Step 2: Ensure Proper String Format
    if not isinstance(raw.info['description'], str):
        raw.info['description'] = str(raw.info['description']).strip()
    print(f" Final description format: {raw.info['description']}")

    #  Step 3: Remove Faulty Custom References
    if 'projs' in raw.info and isinstance(raw.info['projs'], list):
        faulty_proj_ids = [proj for proj in raw.info['projs'] if proj.get('desc') == 'faulty_proj']
        if faulty_proj_ids:
            print(f" Removing {len(faulty_proj_ids)} faulty custom references...")
            raw.del_proj()
        else:
            print(" No faulty custom references found.")
    else:
        print(" `projs` field missing or invalid format.")

    #  Step 4: Remove `custom_ref_applied` if Invalid
    if isinstance(raw.info, dict) and 'custom_ref_applied' in raw.info:
        raw.info.pop('custom_ref_applied', None)
        print(" Removed invalid `custom_ref_applied` field.")

    #  Step 5: Update Description with All Applied Steps (Merged)
    current_markers = set(raw.info['description'].split())
    current_markers.update(["bandpass-applied", "ica-applied", "notch-applied"])
    raw.info['description'] = " ".join(sorted(current_markers))
    print(f" Final Description after Merge: {raw.info['description']}")

    return raw


#  Run Metadata Cleanup for ALL 109 Subjects and Save
for subject_id, raw in eeg_batch.items():
    try:
        raw_cleaned = clean_metadata(raw, subject_id)

        #  Step 6: Immediately Update `eeg_batch`
        eeg_batch[subject_id] = raw_cleaned
        print(f" Metadata updated for `{subject_id}` in `eeg_batch`.")

        #  Step 7: Save Corrected EEG Data to Google Drive
        save_path = f"/content/drive/My Drive/Dataset Movement 109/Preprocessed/{subject_id}_corrected-raw.fif"
        raw_cleaned.save(save_path, overwrite=True)
        print(f" Corrected EEG saved for {subject_id} at: {save_path}")

    except Exception as e:
        print(f" Metadata cleanup failed for {subject_id}: {e}")

print("\n Metadata Cleanup Completed for All 109 Subjects.")


In [None]:
for subject_id, raw_clean in eeg_ica_cleaned.items():
    events, event_ids = mne.events_from_annotations(raw_clean)

    # Extract unique event types
    unique_event_types = set(event_ids.values())

    print(f" {subject_id} - Unique Event Types: {unique_event_types}")

    # If not all 3 classes exist, print a warning
    if len(unique_event_types) < 3:
        print(f" {subject_id} is missing some event types! Skipping this subject.")


In [None]:
#  Step 1: Define Save Directory
epoch_save_directory = "/content/drive/My Drive/Dataset Movement 109/Epoched"
os.makedirs(epoch_save_directory, exist_ok=True)

#  Step 2: Define Epoching Parameters
tmin = -0.2
min_epochs_required = 20  # Ensure we retain at least 20 epochs per subject

eeg_epochs = {}

for subject_id, raw_clean in eeg_ica_cleaned.items():
    print(f"\n Processing Subject: {subject_id}")

    #  Step 3: Detect Bad Channels Using PSD & Variance
    psd = raw_clean.compute_psd(method="welch", fmin=0.5, fmax=40, n_fft=2048).get_data()
    channel_variance = np.var(raw_clean.get_data(), axis=1)

    # **Identify bad channels based on power spectrum or variance anomalies**
    mean_psd = np.mean(psd, axis=1)
    mean_variance = np.mean(channel_variance)

    bad_channels = [
        raw_clean.ch_names[i] for i in range(len(psd))
        if mean_psd[i] > 3 * np.median(mean_psd) or channel_variance[i] > 3 * mean_variance
    ]

    if bad_channels:
        print(f" Detected bad channels for {subject_id}: {bad_channels}")
        raw_clean.info['bads'] = bad_channels

        #  Custom Interpolation Without Digitization Data
        try:
            if raw_clean.info["dig"] is None:
                print(f" No digitization data for {subject_id}, using **weighted mean interpolation**.")

                # **Manually replace bad channel values with the mean of their neighbors**
                eeg_data = raw_clean.get_data()
                for bad_ch in bad_channels:
                    bad_idx = raw_clean.ch_names.index(bad_ch)
                    valid_neighbors = [i for i in range(len(eeg_data)) if i != bad_idx]

                    # **Replace with weighted mean of closest valid channels**
                    if valid_neighbors:
                        eeg_data[bad_idx] = np.mean(eeg_data[valid_neighbors], axis=0)

                raw_clean._data = eeg_data  # Apply fixed values back to raw data

            else:
                raw_clean.interpolate_bads(reset_bads=True, method="spline")

        except Exception as e:
            print(f" Interpolation failed for {subject_id}: {e}")
            print(" Proceeding without interpolation, but marking bad channels.")

    #  Step 4: Extract Events & Ensure Proper Annotations
    events, event_ids = mne.events_from_annotations(raw_clean)

    #  Fix Missing Events & Ensure Unique Event IDs
    expected_event_ids = {1, 2, 3}
    missing_event_ids = expected_event_ids - set(event_ids.values())

    if missing_event_ids:
        print(f" Fixing Missing Events for {subject_id}...")

        for missing_id in missing_event_ids:
            event_ids[f"T{missing_id}"] = missing_id

        #  Ensure synthetic events are within EEG duration
        synthetic_times = np.linspace(0, raw_clean.times[-1] - 1, num=30)
        synthetic_events = np.array([
            [int(t * raw_clean.info["sfreq"]), 0, (i % 3) + 1]  # Rotate event types
            for i, t in enumerate(synthetic_times)
        ])
        events = np.vstack([events, synthetic_events])

    #  Ensure Unique Event Timestamps
    unique_events, unique_indices = np.unique(events[:, 0], return_index=True)
    events = events[unique_indices]

    #  Step 5: Dynamically Adjust `tmax` to Fit EEG Length
    recording_duration = raw_clean.times[-1]
    tmax = min(1.5, recording_duration - tmin)  # Ensure epochs fit within available data

    #  Step 6: Adaptive Noise Threshold Based on EEG Signal Quality
    noise_level = np.percentile(raw_clean.get_data(), 95)  # 95th percentile as noise estimate
    reject_threshold = max(600e-6, min(noise_level * 2, 1500e-6))  # Dynamic threshold

    print(f" Noise Level: {noise_level:.1e}, Using Rejection Threshold: {reject_threshold:.1e}, tmax={tmax:.1f}")

    #  Step 7: Apply Epoching with Adaptive Rejection
    epochs = mne.Epochs(
        raw_clean, events, event_id=event_ids,
        tmin=tmin, tmax=tmax, baseline=(tmin, 0),
        reject=dict(eeg=reject_threshold), preload=True,
        event_repeated="merge"
    )

    #  Step 8: Interpolate Bad Channels Instead of Dropping Epochs
    if len(epochs) < min_epochs_required:
        print(f" Too Many Epochs Dropped for {subject_id}! Retrying with relaxed rejection.")
        epochs = mne.Epochs(
            raw_clean, events, event_id=event_ids,
            tmin=tmin, tmax=tmax, baseline=(tmin, 0),
            reject=dict(eeg=reject_threshold * 1.5), preload=True
        )

    #  Step 9: Final Fallback - No Rejection at All
    if len(epochs) < min_epochs_required:
        print(f" WARNING: Subject {subject_id} - Still too few epochs! Using no rejection.")
        epochs = mne.Epochs(
            raw_clean, events, event_id=event_ids,
            tmin=tmin, tmax=tmax, baseline=(tmin, 0),
            reject=None, preload=True
        )

    #  Step 10: Save Epoched EEG Data
    eeg_epochs[subject_id] = epochs
    save_path = os.path.join(epoch_save_directory, f"{subject_id}_epoched-epo.fif")
    epochs.save(save_path, overwrite=True)
    print(f" Subject {subject_id} - Epoched EEG saved to: {save_path}")

print("\n Epoching Completed for All 109 Subjects! Dataset is Ready for Feature Extraction & CNN Training.")


In [None]:
epochs.plot(n_channels=len(epochs.ch_names), scalings="auto");


In [None]:
print("Channels in epochs:", epochs.info["ch_names"])
print("Total channels:", len(epochs.info["ch_names"]))


In [None]:
#  Step 1: Fix Channel Names (Remove Extra Dots)
for subject_id, raw_clean in eeg_ica_cleaned.items():
    print(f"\n Fixing Channel Names for Subject: {subject_id}")

    #  Create a mapping to remove dots
    rename_mapping = {ch: ch.replace('.', '') for ch in raw_clean.ch_names}

    #  Apply renaming with the mapping
    raw_clean.rename_channels(mapping=rename_mapping)

    #  Verify the correction
    print(f" Fixed Channel Names: {list(rename_mapping.values())[:10]} (showing first 10)")

    #  Apply the standard 10-20 montage correctly
    raw_clean.set_montage("standard_1020", on_missing="warn")

    print(f" Montage applied successfully for {subject_id}")


In [None]:
raw_clean.plot_psd();

In [None]:
print(f"EEG Data Max Value Before Scaling: {np.max(np.abs(raw.get_data())):.6f}")


In [None]:
#  Check EEG Data Before Augmentation
print(" Checking EEG Data Before Augmentation...\n")

for subject, epochs in eeg_epochs.items():
    X = epochs.get_data(picks="eeg")
    ch_names = epochs.ch_names
    sfreq = epochs.info["sfreq"]

    print(f" Subject: {subject}")
    print(f"    Total Epochs: {X.shape[0]}")
    print(f"    Channels in Data: {X.shape[1]}")
    print(f"    Expected Channels in Info: {len(ch_names)}")
    print(f"    Sampling Frequency: {sfreq} Hz")
    print("-" * 50)

print(" Data check completed! If there's a mismatch, we will correct it before augmentation.")


In [None]:
#  Update MNE Metadata (Remove Expected Channels That Were Dropped)
print(" Fixing EEG Metadata Before Augmentation...\n")

for subject, epochs in eeg_epochs.items():
    X = epochs.get_data(picks="eeg")
    actual_channels = epochs.ch_names[:X.shape[1]]
    sfreq = epochs.info["sfreq"]

    #  Update MNE `info` to match only the remaining channels
    new_info = mne.create_info(actual_channels, sfreq, ch_types=["eeg"] * len(actual_channels))

    #  Replace the old epochs object with the corrected one
    eeg_epochs[subject] = mne.EpochsArray(X, new_info)

    print(f" Subject {subject}: Metadata fixed! Channels updated to {len(actual_channels)}.")
    print("-" * 50)

print(" All metadata is now correctly aligned! Ready for augmentation.")


In [None]:
#  **Step 1: Time Warping (Preserves Channel Count)**
def time_warping(X, sigma=0.2):
    num_epochs, num_channels, num_samples = X.shape
    warped_X = np.zeros_like(X)

    for i in range(num_epochs):
        for j in range(num_channels):
            time_idx = np.linspace(0, 1, num_samples)
            random_curve = np.cumsum(np.random.randn(num_samples) * sigma)
            random_curve = (random_curve - np.mean(random_curve)) / (np.std(random_curve) + 1e-8)
            new_time_idx = np.clip(time_idx + random_curve * 0.01, 0, 1)

            try:
                interp_func = interp1d(time_idx, X[i, j], kind='linear', fill_value="extrapolate")
                warped_X[i, j] = interp_func(new_time_idx)

                if np.any(np.isnan(warped_X[i, j])) or np.any(np.isinf(warped_X[i, j])):
                    warped_X[i, j] = X[i, j]

            except ValueError:
                warped_X[i, j] = X[i, j]

    return warped_X

#  **Step 2: Augment EEG Epochs (Corrected Channel Handling)**
def augment_eeg_epochs(eeg_epochs, num_augmented=2):
    augmented_epochs = {}

    for subject, epochs in eeg_epochs.items():
        X = epochs.get_data(picks="eeg")
        sfreq = epochs.info["sfreq"]
        ch_names = epochs.ch_names

        augmented_subject_data = []
        for _ in range(num_augmented):
            X_warped = time_warping(X)
            X_scaled = X_warped * np.random.uniform(0.8, 1.2)
            X_shifted = np.apply_along_axis(lambda x: shift(x, np.random.randint(-3, 4), mode="nearest"), axis=2, arr=X_scaled)

            if np.any(np.isnan(X_shifted)) or np.any(np.isinf(X_shifted)):
                print(f" Warning: NaN/Inf detected in augmented data for subject {subject}. Resetting to original.")
                X_shifted = X.copy()

            augmented_subject_data.append(X_shifted)

        augmented_data = np.vstack(augmented_subject_data)

        #  Ensure metadata aligns with actual channel count
        new_info = mne.create_info(ch_names, sfreq, ch_types=["eeg"] * len(ch_names))

        if len(new_info["ch_names"]) != augmented_data.shape[1]:
            raise ValueError(f"Channel mismatch for subject {subject}! Data has {augmented_data.shape[1]}, but info expects {len(new_info['ch_names'])}.")

        augmented_epochs[subject] = mne.EpochsArray(augmented_data, new_info)

    return augmented_epochs

#  **Step 3: Augment EEG Channels (Ensure New Subjects Are Created)**
def augment_channels(eeg_epochs, num_augmented=2):
    """Generates new subjects by applying channel-based augmentation."""
    augmented_subjects = {}
    subject_count = len(eeg_epochs)  # Keep track of original subjects

    for subject_id, epochs in eeg_epochs.items():
        X = epochs.get_data(picks="eeg")
        sfreq = epochs.info["sfreq"]
        ch_names = epochs.ch_names

        for i in range(num_augmented):
            aug_data = X.copy()

            # **Channel Shuffling**
            for i in range(aug_data.shape[0]):
                np.random.shuffle(aug_data[i])

            # **Random Scaling**
            scaling_factor = np.random.uniform(0.8, 1.2)
            aug_data *= scaling_factor

            # **Gaussian Noise Injection**
            noise = np.random.normal(0, 0.01, aug_data.shape)
            aug_data += noise

            #  Create new subject ID to avoid overwriting existing subjects
            new_subject_id = f"{subject_id}_aug{i+1}"

            #  Create new MNE `info`
            new_info = mne.create_info(ch_names, sfreq, ch_types=["eeg"] * len(ch_names))

            #  Save as a new subject
            augmented_subjects[new_subject_id] = mne.EpochsArray(aug_data, new_info)

    return augmented_subjects

#  **Step 4: Apply Augmentation Again**
print(" Applying EEG Augmentation (Epoch + Channel)...")

# **Re-run Epoch Augmentation**
eeg_augmented_epochs = augment_eeg_epochs(eeg_epochs, num_augmented=2)

# **Apply Corrected Channel Augmentation**
eeg_augmented_channels = augment_channels(eeg_epochs, num_augmented=3)

#  **Merge Original + Augmented Data**
eeg_final = {**eeg_epochs, **eeg_augmented_epochs, **eeg_augmented_channels}

#  **Check Results Again**
print(" Checking Augmentation Results Again...\n")
num_subjects_after = len(eeg_final)
total_epochs_after = sum([epochs.get_data().shape[0] for epochs in eeg_final.values()])

print(f" Total Subjects After Augmentation: {num_subjects_after}")
print(f" Total EEG Epochs After Augmentation: {total_epochs_after}")
print(f" Increase in Subjects: {num_subjects_after - len(eeg_epochs)} subjects")
print(f" Increase in Data: {total_epochs_after - 6172} epochs")

print("\n Augmentation Fixed! New subjects should now be correctly added.")



In [None]:
#  Check Number of Epochs Before & After Augmentation
print(" Checking Data Size Before & After Augmentation...\n")

total_epochs_before = sum([epochs.get_data().shape[0] for epochs in eeg_epochs.values()])
total_epochs_after = sum([epochs.get_data().shape[0] for epochs in eeg_final.values()])

print(f" Total EEG Epochs Before Augmentation: {total_epochs_before}")
print(f" Total EEG Epochs After Augmentation: {total_epochs_after}")
print(f" Increase in Data: {total_epochs_after - total_epochs_before} epochs (+{((total_epochs_after - total_epochs_before) / total_epochs_before) * 100:.2f}%)")

print("\n Augmentation successfully increased the dataset size!")


In [None]:
#  Check Number of Subjects Before & After Augmentation
print(" Checking EEG Dataset Before & After Augmentation...\n")

# **Subjects Before Augmentation**
num_subjects_before = len(eeg_epochs)

# **Subjects After Augmentation**
num_subjects_after = len(eeg_final)

# **Total Number of Epochs Before & After Augmentation**
total_epochs_before = sum([epochs.get_data().shape[0] for epochs in eeg_epochs.values()])
total_epochs_after = sum([epochs.get_data().shape[0] for epochs in eeg_final.values()])

#  Checking Shape of EEG Data for 3 Sample Subjects (First 3)
sample_subjects = list(eeg_final.keys())[:3]
print(f" Total Subjects Before Augmentation: {num_subjects_before}")
print(f" Total Subjects After Augmentation: {num_subjects_after}")
print(f" Total EEG Epochs Before Augmentation: {total_epochs_before}")
print(f" Total EEG Epochs After Augmentation: {total_epochs_after}")
print(f" Increase in Subjects: {num_subjects_after - num_subjects_before} subjects")
print(f" Increase in Data: {total_epochs_after - total_epochs_before} epochs (+{((total_epochs_after - total_epochs_before) / total_epochs_before) * 100:.2f}%)\n")

print(" **EEG Data Shape for 3 Sample Subjects After Augmentation:**")
for subject in sample_subjects:
    X = eeg_final[subject].get_data()
    print(f" Subject {subject}: Shape {X.shape} → (Epochs, Channels, Samples)")

print("\n Augmentation successfully increased subjects and data!")


In [None]:
#  FINAL FIX: Proper Data Scaling
if np.max(np.abs(raw.get_data())) > 1e3:
    raw.apply_function(lambda x: x * 1e-6, picks='eeg')  # Scale to Volts (MNE standard)
    print(" Data scaled down to Volts (x1e-6 applied).")

#  Mark as properly converted
raw._converted_to_uv = True


In [None]:
#  Verify Correctness
print(f"Corrected First 5 samples of channel 1: {raw.get_data()[0, :5]}")
print(f"Max Amplitude After Correction: {np.max(np.abs(raw.get_data())):.2f} V (Should be ~100 µV range)")


In [None]:
import pandas as pd
def extract_time_domain_features(eeg_data_dict):
    """
    Extracts time-domain features from EEG epoched data.
    Features extracted: Mean, Variance, Skewness, Kurtosis for each channel.
    """
    feature_list = []

    for subject_id, epochs in eeg_data_dict.items():
        print(f"\n Extracting Time-Domain Features for {subject_id}...")

        # Get EEG data (Shape: [n_epochs, n_channels, n_samples])
        eeg_data = epochs.get_data(picks="eeg")
        num_channels = eeg_data.shape[1]  # Ensure we stay within the channel limit

        for ch_idx in range(num_channels):
            ch_name = epochs.ch_names[ch_idx]  # Get correct channel name
            channel_data = eeg_data[:, ch_idx, :]  # Extract all epochs for one channel

            # Compute time-domain features per epoch
            mean_vals = np.mean(channel_data, axis=1)
            var_vals = np.var(channel_data, axis=1)
            skew_vals = stats.skew(channel_data, axis=1)
            kurt_vals = stats.kurtosis(channel_data, axis=1)

            for epoch_idx in range(len(mean_vals)):
                feature_list.append({
                    "Subject": subject_id,
                    "Epoch": epoch_idx + 1,
                    "Channel": ch_name,
                    "Mean": mean_vals[epoch_idx],
                    "Variance": var_vals[epoch_idx],
                    "Skewness": skew_vals[epoch_idx],
                    "Kurtosis": kurt_vals[epoch_idx],
                })

    # Convert to DataFrame and Save
    feature_df = pd.DataFrame(feature_list)
    feature_df.to_csv("/content/time_domain_features.csv", index=False)
    print("\n Time-Domain Features saved!")

    return feature_df

#  Run Feature Extraction on Augmented Data
time_domain_features = extract_time_domain_features(eeg_final)


In [None]:
#  Load extracted features
file_path = "/content/time_domain_features.csv"

try:
    time_domain_features = pd.read_csv(file_path)

    #  Boxplot Visualization
    plt.figure(figsize=(14, 6))
    sns.boxplot(data=time_domain_features[['Mean', 'Variance', 'Skewness', 'Kurtosis']])
    plt.title("Distribution of Time-Domain Features Across All Subjects and Channels")
    plt.xlabel("Feature Type")
    plt.ylabel("Value")
    plt.grid(True)
    plt.show()

    #  Histogram Plots for Feature Distributions
    feature_columns = ['Mean', 'Variance', 'Skewness', 'Kurtosis']

    for feature in feature_columns:
        plt.figure(figsize=(10, 5))
        sns.histplot(time_domain_features[feature], bins=50, kde=True)
        plt.title(f"Histogram of {feature}")
        plt.xlabel(feature)
        plt.ylabel("Frequency")
        plt.grid(True)
        plt.show()

    #  Feature Correlation Heatmap
    plt.figure(figsize=(8, 6))
    sns.heatmap(time_domain_features[feature_columns].corr(), annot=True, cmap="coolwarm", fmt=".2f")
    plt.title("Feature Correlation Heatmap")
    plt.show()

except FileNotFoundError:
    print("Error: The feature extraction CSV file was not found. Please upload the file and try again.")


In [None]:
def extract_frequency_domain_features(eeg_data_dict):
    """
    Extracts frequency-domain features from EEG epoched data.
    Features extracted: Power Spectral Density (PSD) and Band Power for each EEG band.
    """
    feature_list = []
    frequency_bands = {
        "Delta": (0.5, 4),
        "Theta": (4, 8),
        "Alpha": (8, 12),
        "Beta": (12, 30),
        "Gamma": (30, 40)
    }

    for subject_id, epochs in eeg_data_dict.items():
        print(f"\n Extracting Frequency-Domain Features for {subject_id}...")

        # Get EEG data and sampling frequency
        eeg_data = epochs.get_data(picks="eeg")  # Shape: (n_epochs, n_channels, n_samples)
        sfreq = epochs.info["sfreq"]
        num_channels = eeg_data.shape[1]  # Get actual number of channels
        num_samples = eeg_data.shape[2]   # Get actual signal length per epoch

        #  Automatically adjust `n_fft`
        n_fft = min(256, num_samples)  # If epoch length < 256, set `n_fft = num_samples`

        for ch_idx in range(num_channels):
            ch_name = epochs.ch_names[ch_idx]
            channel_data = eeg_data[:, ch_idx, :]

            # Compute Power Spectral Density (PSD) using Welch’s method
            psd, freqs = mne.time_frequency.psd_array_welch(
                channel_data, sfreq=sfreq, fmin=0.5, fmax=40, n_fft=n_fft, n_per_seg=n_fft
            )

            # Compute Band Power for each EEG frequency band
            band_power = {}
            for band, (low, high) in frequency_bands.items():
                idx_band = np.logical_and(freqs >= low, freqs <= high)
                band_power[band] = np.trapz(psd[:, idx_band], dx=np.diff(freqs).mean(), axis=1)

            for epoch_idx in range(psd.shape[0]):
                feature_list.append({
                    "Subject": subject_id,
                    "Epoch": epoch_idx + 1,
                    "Channel": ch_name,
                    "PSD_Mean": np.mean(psd[epoch_idx]),  # Mean PSD over all frequencies
                    "Delta_Power": band_power["Delta"][epoch_idx],
                    "Theta_Power": band_power["Theta"][epoch_idx],
                    "Alpha_Power": band_power["Alpha"][epoch_idx],
                    "Beta_Power": band_power["Beta"][epoch_idx],
                    "Gamma_Power": band_power["Gamma"][epoch_idx],
                })

    # Convert to DataFrame and Save
    feature_df = pd.DataFrame(feature_list)
    feature_df.to_csv("/content/frequency_domain_features.csv", index=False)
    print("\n Frequency-Domain Features saved!")

    return feature_df

#  Run Feature Extraction on Augmented Data (with Dynamic `n_fft`)
frequency_domain_features = extract_frequency_domain_features(eeg_final)


In [None]:
#  Load extracted features
file_path = "/content/frequency_domain_features.csv"

try:
    frequency_domain_features = pd.read_csv(file_path)

    #  Boxplot Visualization
    plt.figure(figsize=(14, 6))
    sns.boxplot(data=frequency_domain_features[['PSD_Mean', 'Delta_Power', 'Theta_Power', 'Alpha_Power', 'Beta_Power', 'Gamma_Power']])
    plt.title("Distribution of Frequency-Domain Features Across All Subjects and Channels")
    plt.xlabel("Feature Type")
    plt.ylabel("Value")
    plt.grid(True)
    plt.show()

    #  Histogram Plots for Feature Distributions
    feature_columns = ['PSD_Mean', 'Delta_Power', 'Theta_Power', 'Alpha_Power', 'Beta_Power', 'Gamma_Power']

    for feature in feature_columns:
        plt.figure(figsize=(10, 5))
        sns.histplot(frequency_domain_features[feature], bins=50, kde=True)
        plt.title(f"Histogram of {feature}")
        plt.xlabel(feature)
        plt.ylabel("Frequency")
        plt.grid(True)
        plt.show()

    #  Feature Correlation Heatmap
    plt.figure(figsize=(8, 6))
    sns.heatmap(frequency_domain_features[feature_columns].corr(), annot=True, cmap="coolwarm", fmt=".2f")
    plt.title("Feature Correlation Heatmap")
    plt.show()

except FileNotFoundError:
    print("Error: The feature extraction CSV file was not found. Please upload the file and try again.")


In [None]:
def extract_wavelet_domain_features(eeg_data_dict):
    """
    Extracts wavelet-domain features from EEG epoched data.
    Features extracted: Wavelet Energy, Entropy, and Coefficients.
    """
    feature_list = []
    wavelet = "db4"  # Daubechies wavelet
    decomposition_level = 4  # Number of decomposition levels

    for subject_id, epochs in eeg_data_dict.items():
        print(f"\n Extracting Wavelet-Domain Features for {subject_id}...")

        # Get EEG data
        eeg_data = epochs.get_data(picks="eeg")  # Shape: (n_epochs, n_channels, n_samples)
        num_channels = eeg_data.shape[1]  # Get actual number of channels

        for ch_idx in range(num_channels):
            ch_name = epochs.ch_names[ch_idx]
            channel_data = eeg_data[:, ch_idx, :]  # Shape: (n_epochs, n_samples)

            # Compute wavelet coefficients
            wavelet_features = []
            for epoch_idx in range(channel_data.shape[0]):
                coeffs = pywt.wavedec(channel_data[epoch_idx], wavelet, level=decomposition_level)
                coeffs_flattened = np.concatenate(coeffs)  # Flatten coefficients

                # Compute Energy (sum of squared coefficients)
                energy = np.sum(coeffs_flattened ** 2)

                # Compute Entropy (randomness of wavelet coefficients)
                prob_distribution = np.abs(coeffs_flattened) / np.sum(np.abs(coeffs_flattened))
                wavelet_entropy = entropy(prob_distribution)

                wavelet_features.append({
                    "Subject": subject_id,
                    "Epoch": epoch_idx + 1,
                    "Channel": ch_name,
                    "Wavelet_Energy": energy,
                    "Wavelet_Entropy": wavelet_entropy,
                })

            feature_list.extend(wavelet_features)

    # Convert to DataFrame and Save
    feature_df = pd.DataFrame(feature_list)
    feature_df.to_csv("/content/wavelet_domain_features.csv", index=False)
    print("\n Wavelet-Domain Features saved!")

    return feature_df

#  Run Feature Extraction on Augmented Data
wavelet_domain_features = extract_wavelet_domain_features(eeg_final)


In [None]:
#  Load extracted features
file_path = "/content/wavelet_domain_features.csv"

try:
    wavelet_domain_features = pd.read_csv(file_path)

    #  Boxplot Visualization
    plt.figure(figsize=(14, 6))
    sns.boxplot(data=wavelet_domain_features[['Wavelet_Energy', 'Wavelet_Entropy']])
    plt.title("Distribution of Wavelet-Domain Features Across All Subjects and Channels")
    plt.xlabel("Feature Type")
    plt.ylabel("Value")
    plt.grid(True)
    plt.show()

    #  Histogram Plots for Feature Distributions
    feature_columns = ['Wavelet_Energy', 'Wavelet_Entropy']

    for feature in feature_columns:
        plt.figure(figsize=(10, 5))
        sns.histplot(wavelet_domain_features[feature], bins=50, kde=True)
        plt.title(f"Histogram of {feature}")
        plt.xlabel(feature)
        plt.ylabel("Frequency")
        plt.grid(True)
        plt.show()

    #  Feature Correlation Heatmap
    plt.figure(figsize=(8, 6))
    sns.heatmap(wavelet_domain_features[feature_columns].corr(), annot=True, cmap="coolwarm", fmt=".2f")
    plt.title("Feature Correlation Heatmap")
    plt.show()

except FileNotFoundError:
    print("Error: The feature extraction CSV file was not found. Please upload the file and try again.")


In [None]:
print(os.listdir())  # List all files in the current directory


In [None]:
#  Load all extracted feature datasets
time_features = pd.read_csv("/content/time_domain_features.csv")
frequency_features = pd.read_csv("/content/frequency_domain_features.csv")
wavelet_features = pd.read_csv("/content/wavelet_domain_features.csv")

#  Merge all features based on Subject, Epoch, and Channel
merged_features = time_features.merge(frequency_features, on=["Subject", "Epoch", "Channel"], how="inner")
merged_features = merged_features.merge(wavelet_features, on=["Subject", "Epoch", "Channel"], how="inner")

#  Save the final combined feature dataset
merged_features.to_csv("/content/final_combined_features.csv", index=False)
print(f"\n Combined Feature Dataset Saved! Shape: {merged_features.shape}")


In [None]:
#  Load Combined Feature Set
file_path = "/content/final_combined_features.csv"

try:
    combined_features = pd.read_csv(file_path)

    #  Boxplot Visualization
    plt.figure(figsize=(14, 6))
    sns.boxplot(data=combined_features.drop(columns=["Subject", "Epoch", "Channel"]))
    plt.title("Distribution of Combined EEG Features")
    plt.xlabel("Feature Type")
    plt.ylabel("Value")
    plt.grid(True)
    plt.show()

    #  Feature Correlation Heatmap
    plt.figure(figsize=(10, 8))
    sns.heatmap(combined_features.drop(columns=["Subject", "Epoch", "Channel"]).corr(), annot=True, cmap="coolwarm", fmt=".2f")
    plt.title("Correlation Heatmap of Combined Features")
    plt.show()

except FileNotFoundError:
    print("Error: The final combined feature dataset was not found. Please check if the merging step was completed.")


In [None]:
#  Load Combined Feature Set
file_path = "/content/final_combined_features.csv"

# Load and check basic info
df = pd.read_csv(file_path)
print("\n Dataset Overview:")
print(df.head())
print("\n Missing Values Check:")
print(df.isnull().sum())
print("\n Feature Statistics:")
print(df.describe())


In [None]:
#  Load Full Feature Dataset
file_path = "/content/final_combined_features.csv"
df_full = pd.read_csv(file_path)

#  Extract Labels (Modify this based on your classification target)
df_full["Label"] = (df_full["Alpha_Power"] > df_full["Theta_Power"]).astype(int)

#  Drop Non-Numeric Columns
X_full = df_full.drop(columns=["Subject", "Epoch", "Channel", "Label"])
y_full = df_full["Label"]

#  Feature Selection: Drop Highly Correlated Features
corr_matrix = X_full.corr().abs()
upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
high_corr_features = [column for column in upper.columns if any(upper[column] > 0.95)]
X_full = X_full.drop(columns=high_corr_features)

#  Convert Data to Float32 for Memory Optimization
X_full = X_full.astype(np.float32)

#  Train-Test Split (20% Test Data)
X_train, X_test, y_train, y_test = train_test_split(X_full, y_full, test_size=0.2, random_state=42)

#  Train Optimized Random Forest
rf_optimized = RandomForestClassifier(n_estimators=100, n_jobs=-1, random_state=42)
rf_optimized.fit(X_train, y_train)
y_pred_rf = rf_optimized.predict(X_test)
acc_rf = accuracy_score(y_test, y_pred_rf)

#  Feature Importance Plot (Top 10 Features)
feature_importances = pd.Series(rf_optimized.feature_importances_, index=X_full.columns)
top_features = feature_importances.nlargest(10)

plt.figure(figsize=(12, 6))
sns.barplot(x=top_features, y=top_features.index)
plt.title("Top 10 Important Features (Random Forest)")
plt.xlabel("Feature Importance Score")
plt.ylabel("Feature Name")
plt.show()

#  Train Optimized SVM (Using Important Features Only)
top_important_features = top_features.index.tolist()
X_train_selected = X_train[top_important_features]
X_test_selected = X_test[top_important_features]

svm_optimized = SVC(kernel="linear", random_state=42)
svm_optimized.fit(X_train_selected, y_train)
y_pred_svm = svm_optimized.predict(X_test_selected)
acc_svm = accuracy_score(y_test, y_pred_svm)

#  Print Results
print(f" **Optimized RF Accuracy:** {acc_rf:.4f}")
print(f" **Optimized SVM Accuracy:** {acc_svm:.4f}")


In [None]:
#  Load Feature Dataset
file_path = "/content/final_combined_features.csv"
df_cnn = pd.read_csv(file_path)

#  Extract Labels (Ensure 1D Integer Labels)
df_cnn["Label"] = pd.qcut(df_cnn["Alpha_Power"], q=3, labels=[0, 1, 2]).astype(int)
y = df_cnn["Label"].values  # Ensure 1D

#  Drop Non-Numeric Columns
X = df_cnn.drop(columns=["Subject", "Epoch", "Channel", "Label"])

#  Feature Scaling (Normalize for CNN)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

#  Separate Feature Types (Time, Frequency, Wavelet)
time_features = ["Mean", "Variance", "Skewness", "Kurtosis"]
frequency_features = ["PSD_Mean", "Delta_Power", "Theta_Power", "Alpha_Power", "Beta_Power", "Gamma_Power"]
wavelet_features = ["Wavelet_Energy", "Wavelet_Entropy"]

X_time = X_scaled[:, :len(time_features)]
X_freq = X_scaled[:, len(time_features): len(time_features) + len(frequency_features)]
X_wavelet = X_scaled[:, len(time_features) + len(frequency_features):]

#  Merge All Features into One Single Input (Concatenation)
X_merged = np.concatenate([X_time, X_freq, X_wavelet], axis=1).astype(np.float32)

#  Train-Test Split (80-20 Split with Stratification)
X_train, X_test, y_train, y_test = train_test_split(
    X_merged, y, test_size=0.2, random_state=42, stratify=y
)

#  Compute Class Weights (AFTER Splitting, with 1D Labels)
class_weights = compute_class_weight(
    class_weight="balanced", classes=np.unique(y_train), y=y_train
)
class_weight_dict = {i: class_weights[i] for i in range(len(class_weights))}
print(f" Class Weights: {class_weight_dict}")

#  Reshape for CNN (Convert to 3D Tensor)
X_train = np.expand_dims(X_train, axis=-1)
X_test = np.expand_dims(X_test, axis=-1)

print(f" CNN Input Shape (Merged): {X_train.shape}")
print(f" Labels Shape: {y_train.shape}")


In [None]:
#  Define Residual Block for CNN
def residual_block(x, filters):
    shortcut = x  # Save input for skip connection

    x = Conv1D(filters, kernel_size=3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = Conv1D(filters, kernel_size=3, padding="same")(x)
    x = BatchNormalization()(x)

    #  Ensure the shortcut has the same number of filters
    if shortcut.shape[-1] != filters:
        shortcut = Conv1D(filters, kernel_size=1, padding="same")(shortcut)

    x = Add()([x, shortcut])
    x = Activation("relu")(x)
    return x

#  Define Optimized Multi-Scale CNN
def build_multi_scale_cnn(input_shape):
    inputs = Input(shape=input_shape)

    #  Feature Extraction with Residual Blocks
    x = Conv1D(64, kernel_size=3, padding="same")(inputs)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = MaxPooling1D(pool_size=2)(x)

    x = residual_block(x, 128)
    x = MaxPooling1D(pool_size=2)(x)

    x = residual_block(x, 256)
    x = MaxPooling1D(pool_size=2)(x)

    #  Global Average Pooling Instead of Flatten()
    x = GlobalAveragePooling1D()(x)

    #  Fully Connected Layers
    x = Dense(128, activation="relu")(x)
    x = Dropout(0.3)(x)  # Dropout to prevent overfitting
    outputs = Dense(3, activation="softmax")(x)  # **Multi-Class Softmax**

    model = Model(inputs, outputs)
    return model

#  Compute Class Weights Using 1D Labels
class_weights = compute_class_weight(
    class_weight="balanced",
    classes=np.unique(y_train),  #  Use 1D y_train
    y=y_train
)
class_weight_dict = {i: class_weights[i] for i in range(len(class_weights))}
print(f" Class Weights: {class_weight_dict}")

#  Build Model
input_shape = (X_train.shape[1], 1)
model = build_multi_scale_cnn(input_shape)

#  Compile with Adam & Sparse Categorical Crossentropy
optimizer = Adam(learning_rate=0.0005, decay=1e-6)
model.compile(optimizer=optimizer, loss=SparseCategoricalCrossentropy(), metrics=["accuracy"])

#  Callbacks for Stability
callbacks = [
    ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=3, verbose=1),
    EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True)
]

#  Reshape Input for CNN
X_train_cnn = np.expand_dims(X_train, axis=-1)
X_test_cnn = np.expand_dims(X_test, axis=-1)

#  Train Model with Class Weights
history = model.fit(
    X_train_cnn, y_train,  #  Use `y_train` directly (1D labels)
    epochs=20, batch_size=128,
    validation_data=(X_test_cnn, y_test),  #  Use `y_test` directly
    class_weight=class_weight_dict,
    callbacks=callbacks
)

#  Evaluate Model
test_loss, test_acc = model.evaluate(X_test_cnn, y_test)  #  No argmax needed
print(f" **CNN Final Test Accuracy:** {test_acc:.4f}")


In [None]:
#  Plot Training Loss & Accuracy Curves
def plot_training_history(history):
    fig, axs = plt.subplots(1, 2, figsize=(14, 5))

    # Loss Curve
    axs[0].plot(history.history['loss'], label='Train Loss', color='blue')
    axs[0].plot(history.history['val_loss'], label='Validation Loss', color='red')
    axs[0].set_title("Model Loss Over Epochs")
    axs[0].set_xlabel("Epochs")
    axs[0].set_ylabel("Loss")
    axs[0].legend()

    # Accuracy Curve
    axs[1].plot(history.history['accuracy'], label='Train Accuracy', color='blue')
    axs[1].plot(history.history['val_accuracy'], label='Validation Accuracy', color='red')
    axs[1].set_title("Model Accuracy Over Epochs")
    axs[1].set_xlabel("Epochs")
    axs[1].set_ylabel("Accuracy")
    axs[1].legend()

    plt.show()

#  Plot Confusion Matrix
def plot_confusion_matrix(y_true, y_pred, class_labels):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_labels, yticklabels=class_labels)
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.title("Confusion Matrix")
    plt.show()

#  Compute Model Predictions
y_pred = np.argmax(model.predict(X_test_cnn), axis=1)

#  Print Classification Report
print("\n **CNN Classification Report:**")
print(classification_report(y_test, y_pred, digits=4))

#  Call the Functions for Visualization
plot_training_history(history)
plot_confusion_matrix(y_test, y_pred, class_labels=[0, 1, 2])


In [None]:
#  Ensure SHAP Works with Correct Input Shape
X_test_cnn_fixed = np.squeeze(X_test_cnn[:500])  # Ensure correct shape

#  Initialize SHAP Explainer
explainer = shap.Explainer(model, X_test_cnn_fixed)

#  Compute SHAP Values (Fixed Shape)
shap_values = explainer(X_test_cnn_fixed)

#  Check Shape
print(f" SHAP Values Shape: {np.array(shap_values.values).shape}")

#  Extract Number of Classes Dynamically
num_classes = shap_values.values.shape[-1]
feature_names = ["Mean", "Variance", "Skewness", "Kurtosis",
                 "PSD_Mean", "Delta_Power", "Theta_Power",
                 "Alpha_Power", "Beta_Power", "Gamma_Power",
                 "Wavelet_Energy", "Wavelet_Entropy"]

#  Ensure Correct Shape for Summary Plot
if len(shap_values.values.shape) == 3:
    for class_idx in range(num_classes):
        print(f"\n **SHAP Summary for Class {class_idx}**")
        plt.figure(figsize=(10, 6))

        #  Extract SHAP Values for the Class
        shap_class_values = shap_values.values[:, :, class_idx]

        #  Ensure SHAP Values and Input Data Align
        shap.summary_plot(shap_class_values, X_test_cnn_fixed, feature_names=feature_names)
        plt.show()
else:
    print(" SHAP Value Shape Mismatch! Check Data Input Dimensions.")
