# Automated EEG Diagnosis Pipeline: Data Unification & Preprocessing

This notebook consolidates the entire data processing workflow (Phases 0 and 1) for the Automated Multiclass Diagnosis of Neuropsychiatric Disorders project. It is designed to be run in a Kaggle environment to leverage more powerful CPU resources for the computationally intensive steps.

Workflow:
1.  Setup: Install libraries and define Kaggle-specific paths.
2.  Phase 0.1: Metadata Unification: Scan raw dataset directories and create a single master_metadata.csv.
3.  Phase 0.2: Data Validation: Perform a quality control check on all raw files.
4.  Phase 1.1: Data Harmonization: Standardize all recordings to a common format (200Hz, 19 channels, Eyes-Closed).
5.  Phase 1.2: Preprocessing: Apply filtering, automated ICA artifact removal, and epoching to create clean, analysis-ready data.
6. Phase 2: Graph Construction: Transform preprocessed epochs into graph structures for the GNN.

## 1. Setup and Environment

In [None]:
# Install all required packages
!pip install mne==1.6.1 mne-icalabel==0.6.0 ewtpy==1.0 mne-connectivity==0.7.0 torch-geometric==2.5.3

# Import all necessary libraries for the entire notebook
import pandas as pd
import json
from pathlib import Path
from tqdm.notebook import tqdm # Use notebook-friendly tqdm
import uuid
import sys
import mne
from mne.io import Raw
from mne_icalabel import label_components
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch_geometric.data import Data
from torch_geometric.utils import dense_to_sparse
import ewtpy
from scipy.stats import entropy
import mne_connectivity
import argparse
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Tuple, Optional

In [None]:
# --- Kaggle Path Configuration ---
# IMPORTANT: Create a Kaggle Dataset and add it to this notebook.
# The input path should point to the root of your dataset.
KAGGLE_INPUT_DIR = Path('/kaggle/input/your-dataset-name')
KAGGLE_WORKING_DIR = Path('/kaggle/working/')

# Define all paths relative to the Kaggle environment
BASE_DATA_PATH = KAGGLE_INPUT_DIR
OUTPUT_DIR = KAGGLE_WORKING_DIR

# Phase 0 Outputs
MASTER_METADATA_PATH = OUTPUT_DIR / 'master_metadata.csv'
VALIDATION_ERRORS_PATH = OUTPUT_DIR / 'validation_errors.csv'

# Phase 1 Outputs
HARMONIZED_DIR = OUTPUT_DIR / 'processed/harmonized'
HARMONIZED_METADATA_PATH = OUTPUT_DIR / 'harmonized_metadata.csv'
PREPROCESSED_DIR = OUTPUT_DIR / 'processed/preprocessed_epochs'
PREPROCESSED_METADATA_PATH = OUTPUT_DIR / 'preprocessed_metadata.csv'
SANITY_CHECK_DIR = OUTPUT_DIR / 'results/figures/ica_sanity_checks'

# Phase 2 Outputs
GRAPH_DIR = OUTPUT_DIR / 'processed/graphs'
GRAPH_METADATA_PATH = OUTPUT_DIR / 'graph_metadata.csv'
LABEL_MAP_PATH = OUTPUT_DIR / 'label_mapping.json'

# Create output directories
HARMONIZED_DIR.mkdir(parents=True, exist_ok=True)
PREPROCESSED_DIR.mkdir(parents=True, exist_ok=True)
SANITY_CHECK_DIR.mkdir(parents=True, exist_ok=True)
GRAPH_DIR.mkdir(parents=True, exist_ok=True)

## Phase 0.1: Metadata Unification (create_metadata.py)

In [None]:
def create_metadata_main():
    # --- Logic from create_metadata.py ---
    def process_openneuro_ds004504():
        dataset_path = BASE_DATA_PATH / 'ds004504'
        participants_path = dataset_path / 'participants.tsv'
        records = []
        print('Processing OpenNeuro dataset...')
        if not participants_path.exists():
            print(f'  [ERROR] participants.tsv not found at {participants_path}')
            return []
        participants_df = pd.read_csv(participants_path, sep='	')
        for _, row in tqdm(participants_df.iterrows(), total=len(participants_df), desc='  -> OpenNeuro'):
            subject_id = row['participant_id']
            group_code = row['Group']
            label_map = {'C': 'CN', 'A': 'AD', 'F': 'FTD'}
            label = label_map.get(group_code)
            if label:
                eeg_file_path = dataset_path / 'raw_data-edf_converted' / subject_id / 'eeg' / f'{subject_id}task-eyesclosed_eeg.edf'
                if eeg_file_path.exists():
                    records.append({'original_subject_id': subject_id, 'file_path': str(eeg_file_path.resolve()), 'diagnosis': label, 'original_dataset_source': 'ds004504', 'age': row.get('Age'), 'sex': row.get('Gender'), 'sampling_rate': 500})
        return records

    def process_caueeg():
        dataset_path = BASE_DATA_PATH / 'caueeg'
        annotation_path = dataset_path / 'annotation.json'
        signal_path = dataset_path / 'signal' / 'edf'
        records = []
        print('Processing CAUEEG dataset...')
        if not annotation_path.exists() or not signal_path.exists(): return []
        with open(annotation_path, 'r') as f: master_meta = json.load(f)
        subject_lookup = {item['serial']: item for item in master_meta['data']}
        for edf_file in tqdm(list(signal_path.glob('*.edf')), desc='  -> CAUEEG'):
            original_subject_id = edf_file.stem
            subject_data = subject_lookup.get(original_subject_id)
            if not subject_data: continue
            symptoms = subject_data.get('symptom', [])
            label = None
            if any(tag in symptoms for tag in ['dementia', 'ad', 'load']): label = 'AD'
            elif any(tag in symptoms for tag in ['mci', 'mci_amnestic', 'mci_amnestic_rf']): label = 'MCI'
            elif any(tag in symptoms for tag in ['normal', 'cb_normal']): label = 'CN'
            if label:
                records.append({'original_subject_id': original_subject_id, 'file_path': str(edf_file.resolve()), 'diagnosis': label, 'original_dataset_source': 'caueeg', 'age': subject_data.get('age'), 'sex': subject_data.get('gender'), 'sampling_rate': 200})
        return records

    def process_figshare_mdd():
        dataset_path = BASE_DATA_PATH / 'figshare_mdd'
        records = []
        print('Processing Figshare MDD dataset...')
        if not dataset_path.exists(): return []
        for edf_file in tqdm(list(dataset_path.glob('* Subjects/*EC.edf')), desc='  -> Figshare'):
            filename = edf_file.stem
            label = 'MDD' if 'MDD' in filename else 'CN' if 'H S' in filename else None
            if label:
                records.append({'original_subject_id': filename, 'file_path': str(edf_file.resolve()), 'diagnosis': label, 'original_dataset_source': 'figshare_mdd', 'age': None, 'sex': None, 'sampling_rate': 256})
        return records

    print('--- Phase 0.1: Master Metadata Generation ---')
    all_records = []
    all_records.extend(process_openneuro_ds004504())
    all_records.extend(process_caueeg())
    all_records.extend(process_figshare_mdd())
    if not all_records: return
    df = pd.DataFrame(all_records)
    df['subject_id'] = [str(uuid.uuid4()) for _ in range(len(df))]
    column_order = ['subject_id', 'original_subject_id', 'diagnosis', 'file_path', 'original_dataset_source', 'sampling_rate', 'age', 'sex']
    df = df[column_order]
    df.to_csv(MASTER_METADATA_PATH, index=False)
    print(f'
--- Metadata Generation Complete ---')
    print(f'Successfully processed {len(df)} records.')
    print('
Final Class Distribution:
', df['diagnosis'].value_counts())
    print(f'
✅ Master metadata file saved to: {MASTER_METADATA_PATH}')

create_metadata_main()

## Phase 0.2: Data Validation (validate_data.py)

In [None]:
def validate_data_main():
    # --- Logic from validate_data.py ---
    REQUIRED_CHANNELS = {'Fp1', 'Fp2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2', 'F7', 'F8', 'T3', 'T4', 'T5', 'T6', 'Fz', 'Cz', 'Pz'}
    print('
--- Phase 0.2: Data Validation ---')
    if not MASTER_METADATA_PATH.exists():
        print('[FATAL] Metadata file not found.')
        return
    metadata_df = pd.read_csv(MASTER_METADATA_PATH)
    success_count = 0
    error_records = []
    for _, row in tqdm(metadata_df.iterrows(), total=len(metadata_df), desc='Validating files'):
        file_path = Path(row['file_path'])
        error_messages = []
        if not file_path.exists():
            error_messages.append('File not found.')
        else:
            try:
                raw = mne.io.read_raw_edf(file_path, preload=False, verbose='CRITICAL')
                if int(raw.info['sfreq']) != row['sampling_rate']:
                    error_messages.append(f'SR mismatch: Expected {row['sampling_rate']}, found {int(raw.info['sfreq'])}')
                available_channels_upper = {ch.upper() for ch in raw.ch_names}
                missing = {ch for ch in REQUIRED_CHANNELS if not any(ch in actual for actual in available_channels_upper)}
                if missing: error_messages.append(f'Missing channels: {', '.join(sorted(list(missing)))}')
            except Exception as e: error_messages.append(f'MNE read error: {e}')
        if not error_messages: success_count += 1
        else: error_records.append({'subject_id': row['subject_id'], 'file_path': str(file_path), 'errors': '; '.join(error_messages)})
    print('
--- Validation Complete ---')
    print(f'Successfully validated: {success_count} / {len(metadata_df)} files.')
    if error_records:
        pd.DataFrame(error_records).to_csv(VALIDATION_ERRORS_PATH, index=False)
        print(f'Found {len(error_records)} issues. See: {VALIDATION_ERRORS_PATH}')
    else: print('
✅ All files passed validation!')

validate_data_main()

## Phase 1.1: Data Harmonization (harmonize_data.py)

In [None]:
def harmonize_data_main():
    # --- Logic from harmonize_data.py ---
    TARGET_SAMPLING_RATE = 200
    STANDARD_19_CHANNELS = ['Fp1', 'Fp2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2', 'F7', 'F8', 'T3', 'T4', 'T5', 'T6', 'Fz', 'Cz', 'Pz']

    def rename_channels(raw: Raw) -> Raw:
        rename_mapping = {}
        ch_names_upper = [ch.upper() for ch in raw.ch_names]
        for standard_ch in STANDARD_19_CHANNELS:
            for i, actual_ch_upper in enumerate(ch_names_upper):
                if standard_ch.upper() in actual_ch_upper: rename_mapping[raw.ch_names[i]] = standard_ch; break
        raw.rename_channels(rename_mapping)
        return raw

    def extract_eyes_closed_caueeg(raw: Raw, subject_id: str) -> Raw:
        event_json_path = BASE_DATA_PATH / 'caueeg' / 'event' / f'{subject_id}.json'
        if not event_json_path.exists(): return raw
        with open(event_json_path, 'r') as f: events = json.load(f)
        ec_segments = []
        max_time = raw.times[-1]
        for i, (timestamp, desc) in enumerate(events):
            if 'eyes closed' in desc.lower():
                start_time_sec = timestamp / raw.info['sfreq']
                end_time_sec = max_time
                if i + 1 < len(events): end_time_sec = events[i+1][0] / raw.info['sfreq']
                end_time_sec = min(end_time_sec, max_time)
                if end_time_sec - start_time_sec > 1e-5: ec_segments.append(raw.copy().crop(tmin=start_time_sec, tmax=end_time_sec))
        if not ec_segments: return raw.crop(tmax=0)
        return mne.concatenate_raws(ec_segments)

    print('
--- Phase 1.1: Data Harmonization ---')
    if not MASTER_METADATA_PATH.exists():
        print('[FATAL] Master metadata not found.')
        return
    metadata_df = pd.read_csv(MASTER_METADATA_PATH)
    harmonized_records = []
    for _, row in tqdm(metadata_df.iterrows(), total=len(metadata_df), desc='Harmonizing files'):
        try:
            raw: Raw = mne.io.read_raw_edf(row['file_path'], preload=True, verbose='WARNING')
            raw.set_meas_date(None)
            raw = rename_channels(raw)
            raw.pick(STANDARD_19_CHANNELS, ordered=True)
            if row['original_dataset_source'] == 'caueeg':
                raw = extract_eyes_closed_caueeg(raw, row['original_subject_id'])
            if raw.times.size == 0: continue
            if raw.info['sfreq'] != TARGET_SAMPLING_RATE:
                raw.resample(TARGET_SAMPLING_RATE, verbose='WARNING')
            sanitized_id = str(row['original_subject_id']).replace(' ', '').replace('-', '')
            output_filename = f'{row['original_dataset_source']}harmonized{row['diagnosis']}{sanitized_id}eeg.fif'
            output_filepath = HARMONIZED_DIR / output_filename
            raw.save(output_filepath, overwrite=True, verbose='WARNING')
            new_record = row.to_dict()
            new_record['harmonized_file_path'] = str(output_filepath.resolve())
            harmonized_records.append(new_record)
        except Exception as e: print(f'
[ERROR] Failed on {row['file_path']}. Reason: {e}')
    if harmonized_records:
        harmonized_df = pd.DataFrame(harmonized_records)
        harmonized_df.to_csv(HARMONIZED_METADATA_PATH, index=False)
        print('
--- Harmonization Complete ---')
        print(f'Successfully processed {len(harmonized_records)} files.')
        print(f'✅ New metadata saved to: {HARMONIZED_METADATA_PATH}')

harmonize_data_main()

## Phase 1.2: Preprocessing (preprocess_data.py)

In [None]:
def preprocess_data_main():
    # --- Logic from preprocess_data.py ---
    BANDPASS_FREQ = (1.0, 45.0)
    NOTCH_FREQ = 50.0
    EPOCH_DURATION = 2.0
    REJECT_CRITERIA = dict(eeg=100e-6)
    ICA_N_COMPONENTS = 15
    ICA_RANDOM_STATE = 42
    ICA_METHOD = 'infomax'
    ICA_FIT_PARAMS = dict(extended=True)
    ICA_FILTER_FREQ = (1.0, 100.0)

    def preprocess_subject(file_path: Path) -> Tuple[Optional[mne.Epochs], Optional[mne.preprocessing.ICA]]:
        try:
            raw: Raw = mne.io.read_raw_fif(file_path, preload=True, verbose='WARNING')
            raw.set_montage('standard_1020', on_missing='raise', verbose='WARNING')
            raw_for_ica = raw.copy().filter(l_freq=ICA_FILTER_FREQ[0], h_freq=ICA_FILTER_FREQ[1], verbose='WARNING')
            raw.filter(l_freq=BANDPASS_FREQ[0], h_freq=BANDPASS_FREQ[1], verbose='WARNING')
            raw.notch_filter(freqs=NOTCH_FREQ, verbose='WARNING')
            raw.set_eeg_reference('average', projection=False, verbose='WARNING')
            ica = mne.preprocessing.ICA(n_components=ICA_N_COMPONENTS, method=ICA_METHOD, fit_params=ICA_FIT_PARAMS, random_state=ICA_RANDOM_STATE)
            ica.fit(raw_for_ica)
            component_labels = label_components(raw, ica, method='iclabel')
            ica.exclude = [idx for idx, label in enumerate(component_labels['labels']) if label not in ['brain', 'other']]
            epochs = mne.make_fixed_length_epochs(raw, duration=EPOCH_DURATION, preload=True, verbose='WARNING')
            ica.apply(epochs, verbose='WARNING')
            epochs.drop_bad(reject=REJECT_CRITERIA, verbose='WARNING')
            return epochs, ica
        except Exception as e: print(f'
[ERROR] Could not process {file_path}. Reason: {e}'); return None, None

    print('
--- Phase 1.2: Preprocessing ---')
    if not HARMONIZED_METADATA_PATH.exists():
        print('[FATAL] Harmonized metadata not found.')
        return
    metadata_df = pd.read_csv(HARMONIZED_METADATA_PATH)
    preprocessed_records = []
    for idx, row in tqdm(metadata_df.iterrows(), total=len(metadata_df), desc='Preprocessing files'):
        epochs, ica = preprocess_subject(Path(row['harmonized_file_path']))
        if epochs is not None and len(epochs) > 0:
            sanitized_id = str(row['original_subject_id']).replace(' ', '').replace('-', '')
            output_filename = f'{row['original_dataset_source']}preprocessed{row['diagnosis']}_{sanitized_id}_epo.fif'
            output_filepath = PREPROCESSED_DIR / output_filename
            epochs.save(output_filepath, overwrite=True, verbose='WARNING')
            new_record = row.to_dict()
            new_record['preprocessed_epo_path'] = str(output_filepath.resolve())
            new_record['n_clean_epochs'] = len(epochs)
            preprocessed_records.append(new_record)
    if preprocessed_records:
        preprocessed_df = pd.DataFrame(preprocessed_records)
        preprocessed_df.to_csv(PREPROCESSED_METADATA_PATH, index=False)
        print('
--- Preprocessing Complete ---')
        print(f'Successfully processed {len(preprocessed_records)} files.')
        print(f'✅ New metadata saved to: {PREPROCESSED_METADATA_PATH}')

preprocess_data_main()

## Phase 2: Graph Construction (create_graphs.py)

In [None]:
def create_graphs_main():
    # --- Logic from create_graphs.py ---
    SAMPLING_RATE = 200
    BANDS = {"Delta": (1, 4), "Theta": (4, 8), "Alpha": (8, 13), "Beta": (13, 30), "Gamma": (30, 45)}

    def shannon_entropy(data: np.ndarray) -> float:
        """Calculates the Shannon Entropy of a signal."""
        counts, _ = np.histogram(data, bins='auto', density=True)
        # Filter out zero probabilities to avoid log(0)
        counts = counts[counts > 0]
        return entropy(counts) # type: ignore

    def calculate_node_features(epoch_data: np.ndarray) -> np.ndarray:
        """Calculates node features for a single epoch using EWT."""
        n_channels, _ = epoch_data.shape
        node_features = np.zeros((n_channels, len(BANDS) * 2))
        for i in range(n_channels):
            channel_signal = epoch_data[i, :]
            ewt, _, _ = ewtpy.EWT1D(channel_signal, N=len(BANDS))
            for band_idx in range(ewt.shape[1]):
                sub_band = ewt[:, band_idx]
                node_features[i, band_idx * 2] = np.log1p(np.var(sub_band))
                node_features[i, (band_idx * 2) + 1] = shannon_entropy(sub_band)
        return node_features

    def calculate_wpli_connectivity(epochs: mne.Epochs) -> np.ndarray:
        """Calculates the broadband wPLI connectivity matrix for all epochs."""
        conn = mne_connectivity.spectral_connectivity_epochs(
            epochs, method='wpli', mode='multitaper',
            fmin=BANDS["Delta"][0], fmax=BANDS["Gamma"][1],
            faverage=True, verbose=False
        )
        return conn.get_data(output='dense').squeeze()

    print('\n--- Phase 2: Graph Construction ---')
    if not PREPROCESSED_METADATA_PATH.exists():
        print('[FATAL] Preprocessed metadata not found.')
        return
    metadata_df = pd.read_csv(PREPROCESSED_METADATA_PATH)
    unique_labels = sorted(metadata_df['diagnosis'].unique())
    label_map = {label: i for i, label in enumerate(unique_labels)}
    with open(LABEL_MAP_PATH, 'w') as f: json.dump(label_map, f, indent=4)
    print(f'✅ Saved label mapping to {LABEL_MAP_PATH}')

    graph_records = []
    for _, row in tqdm(metadata_df.iterrows(), total=len(metadata_df), desc='Creating graphs'):
        try:
            epochs = mne.read_epochs(row['preprocessed_epo_path'], preload=True, verbose=False)
            wpli_matrices = calculate_wpli_connectivity(epochs)
            if len(epochs) == 1: wpli_matrices = np.expand_dims(wpli_matrices, axis=0)

            for i in range(len(epochs)):
                node_features = calculate_node_features(epochs[i].get_data(copy=False).squeeze())
                adj_matrix = wpli_matrices[i, :, :]
                edge_index, edge_attr = dense_to_sparse(torch.tensor(adj_matrix, dtype=torch.float))
                x = torch.tensor(node_features, dtype=torch.float)
                y = torch.tensor([label_map[row['diagnosis']]], dtype=torch.long)
                graph_data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
                
                graph_filename = f"{Path(row['preprocessed_epo_path']).stem}_graph_{i}.pt"
                graph_filepath = GRAPH_DIR / graph_filename
                torch.save(graph_data, graph_filepath)
                
                graph_records.append({
                    'subject_id': row['subject_id'], 
                    'diagnosis': row['diagnosis'], 
                    'graph_file_path': str(graph_filepath.resolve())
                })
        except Exception as e: print(f'\n[ERROR] Failed on {row['subject_id']}. Reason: {e}')
    
    # --- Completion of the cell logic ---
    if graph_records:
        graph_df = pd.DataFrame(graph_records)
        graph_df.to_csv(GRAPH_METADATA_PATH, index=False)
        print("\n--- Graph Construction Complete ---")
        print(f"Successfully created {len(graph_df)} graph files.")
        print(f"✅ New graph metadata saved to: {GRAPH_METADATA_PATH}")
    else:
        print("\n--- Graph Construction Failed ---")
        print("No graph files were generated.")

# Execute the main function for this phase
create_graphs_main()