# CleanEEG: Automated Resting-State EEG Preprocessing Tutorial
This tutorial demonstrates the complete CleanEEG preprocessing pipeline using MNE-Python and complementary libraries. Based on the DISCOVER-EEG framework, it covers all preprocessing steps with quality assessment metrics including Signal-to-Noise Ratio (SNR) and Power Spectral Density (PSD) visualization after each step.

# Table of Contents

1. [Installation and Setup](#installation)
2. [Quality Assessment Functions](#quality-functions)
3. [Loading EEG Data](#loading-data)
4. [Channel Montage Setup](#montage-setup)
5. [Preprocessing Pipeline](#preprocessing-pipeline)
   - [Line Noise Removal (DSS)](#line-noise)
   - [Bandpass Filter](#bandpass-filter)
   - [Downsample Data](#downsample)
   - [Bad Channel Rejection (PREP)](#bad-channels)
   - [EOG Artifact Removal](#eog-removal)
   - [Independent Component Analysis (ICA)](#ica)
   - [Bad Channel Interpolation](#interpolation)
   - [Bad Time Segments Removal (ASR)](#asr)
6. [Final Quality Assessment](#final-assessment)
7. [Saving Results](#saving)

## 1. Installation and Setup {#installation}

### Option 1: Using Conda Environment (Recommended)

Create a conda environment with all dependencies:

```bash
# Create environment from the provided environment.yml
conda env create -f environment.yml
conda activate cleaneeg
```

### Option 2: Using Pip Environment

Create a virtual environment and install dependencies:

```bash
# Create virtual environment
python -m venv cleaneeg_env
source cleaneeg_env/bin/activate  # On Windows: cleaneeg_env\Scripts\activate

# Install dependencies
pip install -r requirements.txt
```

### Option 3: Install in Current Environment (Jupyter/Colab)

If you're running this in Jupyter or Google Colab, you can install packages directly:

In [None]:
# Install required packages for EEG processing and visualization
!pip install mne==1.5.0           # Core package for EEG/MEG data analysis
!pip install pyprep>=0.4.0        # For automatic bad channel detection
!pip install meegkit>=0.1.5       # For advanced denoising methods (DSS, ASR)
!pip install mne-icalabel==0.5.0  # For automatic classification of ICA components
!pip install matplotlib>=3.6.0    # For visualization
!pip install numpy==1.26.4        # For numerical operations
!pip install scipy>=1.10.0        # For scientific computing
!pip install pandas>=1.5.0        # For data handling
!pip install pybv>=0.7.0          # For BrainVision file support
!pip install eeglabio>=0.0.2      # For EEGLAB file support
!pip install edfio>=0.1.0         # For EDF file support
!pip install EDFlib-Python>=1.0.8 # For EDF+ file support
!pip install h5py>=3.7.0          # For HDF5 file support
!pip install tqdm                 # For progress bars (sample data download)

### Verify Installation
Let's check that all required packages are installed correctly:

In [None]:
# Verify installation of key packages
required_packages = {
    'mne': '1.5.0',
    'numpy': '1.26.4',
    'scipy': '1.10.0',
    'matplotlib': '3.6.0',
    'pandas': '1.5.0',
    'pyprep': '0.4.0',
    'meegkit': '0.1.5',
    'mne_icalabel': '0.5.0'
}

print("🔍 Verifying package installations...\n")

installation_status = {}
for package, min_version in required_packages.items():
    try:
        if package == 'mne_icalabel':
            import mne_icalabel
            version = mne_icalabel.__version__
        else:
            module = __import__(package)
            version = getattr(module, '__version__', 'Unknown')
        
        installation_status[package] = {'installed': True, 'version': version}
        print(f"✅ {package:12} v{version}")
        
    except ImportError:
        installation_status[package] = {'installed': False, 'version': None}
        print(f"❌ {package:12} - NOT INSTALLED")
    except Exception as e:
        installation_status[package] = {'installed': False, 'version': None}
        print(f"⚠️  {package:12} - ERROR: {e}")

# Check installation completeness
installed_packages = sum(1 for status in installation_status.values() if status['installed'])
total_packages = len(required_packages)

print(f"\n📊 Installation Summary: {installed_packages}/{total_packages} packages installed")

if installed_packages == total_packages:
    print("🎉 All packages installed successfully! Ready to proceed.")
elif installed_packages >= total_packages * 0.8:  # 80% threshold
    print("⚠️  Most packages installed. Some optional features may not work.")
else:
    print("❌ Many packages missing. Please check your installation.")
    print("   Consider reinstalling with: pip install -r requirements.txt")

In [None]:
# Import all necessary libraries
import mne
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
from scipy import signal
import ftplib
import random
from tqdm.notebook import tqdm

# Set MNE logging level
mne.set_log_level('WARNING')

# Configure matplotlib for better plots
plt.style.use('default')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12

### Sample Data Download Function
If you don't have your own EEG data, we'll implement a function to download sample resting-state EEG data:

In [None]:
# Sample data download function implementation
import ftplib
import random
from pathlib import Path
from tqdm.notebook import tqdm

def is_dir(ftp: ftplib.FTP, path: str) -> bool:
    """Check if a path is a directory on the FTP server."""
    cwd = ftp.pwd()
    try:
        ftp.cwd(path)
        ftp.cwd(cwd)
        return True
    except ftplib.error_perm:
        return False

def download_remote(ftp: ftplib.FTP, remote_dir: str, local_dir: Path):
    """Recursively download files from FTP server."""
    local_dir.mkdir(parents=True, exist_ok=True)
    try:
        ftp.cwd(remote_dir)
    except ftplib.error_perm:
        return
    
    try:
        entries = list(ftp.mlsd())
    except (ftplib.error_perm, AttributeError):
        names = ftp.nlst()
        entries = [(name, {'type': 'dir' if is_dir(ftp, f"{remote_dir}/{name}") else 'file'})
                   for name in names]
    
    for name, info in tqdm(entries, desc=f"Scanning {Path(remote_dir).name}", leave=False):
        rpath = f"{remote_dir}/{name}"
        lpath = local_dir / name
        
        if info.get('type') == 'dir':
            download_remote(ftp, rpath, lpath)
        else:
            if not lpath.exists():  # Skip if file already exists
                try:
                    with open(lpath, 'wb') as f:
                        ftp.retrbinary(f"RETR {rpath}", f.write)
                except Exception as e:
                    print(f"⚠️ Failed to download {rpath}: {e}")
            else:
                print(f"📁 File already exists: {lpath.name}")

def download_sample_data(ftp_host: str,
                         ftp_base: str,
                         local_base: Path,
                         num_subjects: int = 1):
    """
    Download sample EEG data from FTP server.
    
    Parameters:
    -----------
    ftp_host : str
        FTP server hostname
    ftp_base : str
        Base directory on FTP server
    local_base : Path
        Local directory to save data
    num_subjects : int
        Number of subjects to download
    """
    print(f"🌐 Connecting to {ftp_host}...")
    
    try:
        ftp = ftplib.FTP(ftp_host)
        ftp.login()
        ftp.cwd(ftp_base)
        
        subjects = ftp.nlst()
        if len(subjects) < num_subjects:
            raise ValueError(f"Found only {len(subjects)} subjects, asked for {num_subjects}")
        
        chosen = random.sample(subjects, num_subjects)
        print(f"📥 Downloading {num_subjects} random subject(s): {chosen}\n")
        
        for subj in tqdm(chosen, desc="Subjects"):
            download_remote(ftp, f"{ftp_base}/{subj}", local_base / subj)
        
        ftp.quit()
        print(f"\n✅ Download complete. Data is in: {local_base.resolve()}")
        
    except Exception as e:
        print(f"❌ Download failed: {e}")
        print("   Will use MNE sample data instead...")
        return False
    
    return True

print("✅ Sample data download functions loaded!")

### Download Sample Data (Optional)
If you don't have your own resting-state EEG data, you can download sample data from public repositories:

In [None]:
# Download sample EEG data
sample_data_downloaded = False
sample_data_path = Path('sample_data')

print("🔍 Checking for existing sample data...")
if sample_data_path.exists() and any(sample_data_path.iterdir()):
    print(f"✅ Found existing data in {sample_data_path}")
    sample_data_downloaded = True
else:
    print("📥 No existing data found. Attempting to download sample data...")
    
    # Try to download from MPI-Leipzig LEMON dataset
    # This dataset contains high-quality resting-state EEG recordings
    try:
        sample_data_downloaded = download_sample_data(
            ftp_host='ftp.gwdg.de',
            ftp_base='/pub/misc/MPI-Leipzig_Mind-Brain-Body-LEMON/EEG_MPILMBB_LEMON/EEG_Raw_BIDS_ID',
            local_base=sample_data_path,
            num_subjects=1  # Download just one subject for this tutorial
        )
    except Exception as e:
        print(f"⚠️ FTP download failed: {e}")
        sample_data_downloaded = False

# Fallback to MNE sample data if download fails
if not sample_data_downloaded:
    print("\n🔄 Falling back to MNE sample data...")
    try:
        # Use MNE's built-in sample dataset
        import mne
        sample_data_folder = mne.datasets.sample.data_path()
        sample_data_path = sample_data_folder / 'MEG' / 'sample'
        print(f"✅ Will use MNE sample data from: {sample_data_path}")
        sample_data_downloaded = True
    except Exception as e:
        print(f"❌ MNE sample data also failed: {e}")
        print("   Please provide your own EEG data file path in the next section.")

print(f"\n📊 Sample data status: {'Available' if sample_data_downloaded else 'Not available'}")

## 2. Quality Assessment Functions {#quality-functions}

**Purpose**: Monitor and quantify data quality improvements throughout the preprocessing pipeline.

**Why needed**: Preprocessing should improve signal quality, but it's important to verify this objectively. Signal-to-Noise Ratio (SNR) and Power Spectral Density (PSD) provide quantitative metrics to ensure each step is helping rather than hurting data quality.

**Methods**: 
- **SNR calculation**: Compares signal power in neural frequency bands to noise estimates
- **PSD visualization**: Shows how preprocessing affects the frequency content of signals
- **Progress tracking**: Documents quality changes after each preprocessing step

These functions will help us track data quality throughout the preprocessing pipeline:

In [None]:
def compute_snr(raw, freq_bands=None, method='rms'):
    """
    Compute Signal-to-Noise Ratio for EEG data.
    
    Parameters:
    -----------
    raw : mne.io.Raw
        The EEG data
    freq_bands : dict
        Dictionary of frequency bands to analyze
    method : str
        Method for SNR calculation ('rms' or 'spectral')
    
    Returns:
    --------
    snr_results : dict
        SNR values for different frequency bands
    """
    if freq_bands is None:
        freq_bands = {
            'delta': (1, 4),
            'theta': (4, 8), 
            'alpha': (8, 13),
            'beta': (13, 30),
            'gamma': (30, 100)
        }
    
    # Get data and sampling frequency
    data = raw.get_data()
    sfreq = raw.info['sfreq']
    
    snr_results = {}
    
    if method == 'spectral':
        # Compute PSD
        freqs, psd = signal.welch(data, sfreq, nperseg=int(2*sfreq))
        
        for band_name, (low_freq, high_freq) in freq_bands.items():
            # Find frequency indices
            freq_mask = (freqs >= low_freq) & (freqs <= high_freq)
            
            # Signal power in the band
            signal_power = np.mean(psd[:, freq_mask], axis=1)
            
            # Noise estimation (neighboring frequencies)
            noise_low = max(0, low_freq - 2)
            noise_high = min(freqs[-1], high_freq + 2)
            noise_mask = ((freqs >= noise_low) & (freqs < low_freq)) | \
                        ((freqs > high_freq) & (freqs <= noise_high))
            
            if np.any(noise_mask):
                noise_power = np.mean(psd[:, noise_mask], axis=1)
                snr = 10 * np.log10(signal_power / (noise_power + 1e-10))
            else:
                snr = np.full(len(raw.ch_names), np.nan)
            
            snr_results[band_name] = {
                'mean_snr': np.nanmean(snr),
                'std_snr': np.nanstd(snr),
                'channel_snr': snr
            }
    
    else:  # RMS method
        for band_name, (low_freq, high_freq) in freq_bands.items():
            # Filter data to frequency band
            raw_filtered = raw.copy().filter(low_freq, high_freq, verbose=False)
            filtered_data = raw_filtered.get_data()
            
            # RMS of signal
            signal_rms = np.sqrt(np.mean(filtered_data**2, axis=1))
            
            # Estimate noise from high frequencies (above 80 Hz)
            if raw.info['sfreq'] > 160:  # Ensure we can filter above 80 Hz
                raw_noise = raw.copy().filter(80, None, verbose=False)
                noise_data = raw_noise.get_data()
                noise_rms = np.sqrt(np.mean(noise_data**2, axis=1))
                snr = 20 * np.log10(signal_rms / (noise_rms + 1e-10))
            else:
                # Use standard deviation as noise estimate
                noise_std = np.std(filtered_data, axis=1)
                snr = 20 * np.log10(signal_rms / (noise_std + 1e-10))
            
            snr_results[band_name] = {
                'mean_snr': np.mean(snr),
                'std_snr': np.std(snr),
                'channel_snr': snr
            }
    
    return snr_results

def plot_psd_comparison(raw_list, labels, title="Power Spectral Density Comparison", fmax=80):
    """
    Plot PSD comparison for multiple raw objects.
    
    Parameters:
    -----------
    raw_list : list
        List of mne.io.Raw objects
    labels : list
        Labels for each raw object
    title : str
        Plot title
    fmax : float
        Maximum frequency to plot
    """
    fig, ax = plt.subplots(1, 1, figsize=(12, 6))
    
    # Red for before, Blue for after
    colors = ['red', 'blue'][:len(raw_list)]
    
    for i, (raw, label, color) in enumerate(zip(raw_list, labels, colors)):
        # Compute PSD
        psd = raw.compute_psd(fmax=fmax, verbose=False)
        
        # Plot average across channels
        freqs = psd.freqs
        psd_data = psd.get_data()
        mean_psd = np.mean(psd_data, axis=0)
        
        ax.semilogy(freqs, mean_psd, label=label, color=color, linewidth=2)
    
    ax.set_xlabel('Frequency (Hz)')
    ax.set_ylabel('Power Spectral Density (V²/Hz)')
    ax.set_title(title)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

def print_snr_summary(snr_results, step_name):
    """
    Print a formatted summary of SNR results.
    """
    print(f"\n📊 SNR Summary - {step_name}:")
    print("=" * 50)
    for band, results in snr_results.items():
        print(f"{band.capitalize():>8}: {results['mean_snr']:6.2f} ± {results['std_snr']:5.2f} dB")
    print("=" * 50)

def plot_processing_summary(processing_log):
    """
    Plot a summary of SNR changes throughout processing.
    """
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    bands = ['delta', 'theta', 'alpha', 'beta', 'gamma']
    steps = list(processing_log.keys())
    
    for i, band in enumerate(bands):
        if i < len(axes):
            snr_values = [processing_log[step][band]['mean_snr'] for step in steps]
            axes[i].plot(range(len(steps)), snr_values, 'o-', linewidth=2, markersize=8)
            axes[i].set_title(f'{band.capitalize()} Band SNR')
            axes[i].set_ylabel('SNR (dB)')
            axes[i].set_xticks(range(len(steps)))
            axes[i].set_xticklabels(steps, rotation=45, ha='right')
            axes[i].grid(True, alpha=0.3)
    
    # Remove the last subplot if we have 6 subplots but only 5 bands
    if len(bands) < len(axes):
        fig.delaxes(axes[-1])
    
    plt.tight_layout()
    plt.show()

print("✅ Quality assessment functions loaded successfully!")

## 3. Loading EEG Data {#loading-data}
Load your EEG data from various formats supported by MNE-Python.

In [None]:
# Define the path to your EEG data
# You can either:
# 1. Use the downloaded sample data (if available)
# 2. Update this path to point to your own EEG files
# 3. Use MNE sample data as fallback

data_path = None
user_data_path = Path('your_eeg_data')  # Change this to your data directory

# Check for user data first
if user_data_path.exists() and user_data_path != Path('your_eeg_data'):
    data_path = user_data_path
    print(f"📁 Using user-specified data path: {data_path}")
elif sample_data_downloaded:
    data_path = sample_data_path
    print(f"📁 Using downloaded sample data: {data_path}")

# List of valid EEG file extensions that MNE can read
valid_eeg_formats = [".vhdr", ".edf", ".bdf", ".gdf", ".cnt", ".egi", 
                    ".mff", ".set", ".fif", ".data", ".nxe", ".lay"]

# Find EEG files automatically
eeg_files = []
if data_path and data_path.exists():
    for ext in valid_eeg_formats:
        eeg_files.extend(list(data_path.rglob(f"*{ext}")))

if not eeg_files:
    print("⚠️ No EEG files found in specified directories.")
    print("🔄 Using MNE sample data as fallback...")
    
    # Use MNE sample data as fallback
    try:
        sample_data_folder = mne.datasets.sample.data_path()
        sample_data_raw_file = sample_data_folder / 'MEG' / 'sample' / 'sample_audvis_filt-0-40_raw.fif'
        raw = mne.io.read_raw_fif(sample_data_raw_file, preload=True)
        
        # Pick only EEG channels for this tutorial
        raw.pick('eeg')
        
        print(f"✅ Using MNE sample EEG data: {len(raw.ch_names)} EEG channels")
        data_source = "MNE Sample Data"
        
    except Exception as e:
        print(f"❌ Failed to load MNE sample data: {e}")
        print("   Please specify a valid EEG data path in the 'user_data_path' variable above.")
        raise FileNotFoundError("No EEG data available for processing")
else:
    # Use the first EEG file found
    eeg_path = eeg_files[0]
    print(f"📁 Loading {eeg_path}")
    
    try:
        # Read the EEG data
        raw = mne.io.read_raw(eeg_path, preload=True)
        
        # Pick EEG channels only (if other channel types exist)
        if len(mne.pick_types(raw.info, eeg=True)) > 0:
            raw.pick('eeg')
        
        data_source = f"External File: {eeg_path.name}"
        
    except Exception as e:
        print(f"❌ Failed to load {eeg_path}: {e}")
        print("   Trying MNE sample data as fallback...")
        
        # Fallback to MNE sample data
        sample_data_folder = mne.datasets.sample.data_path()
        sample_data_raw_file = sample_data_folder / 'MEG' / 'sample' / 'sample_audvis_filt-0-40_raw.fif'
        raw = mne.io.read_raw_fif(sample_data_raw_file, preload=True)
        raw.pick('eeg')
        data_source = "MNE Sample Data (fallback)"

# Print basic information about the loaded data
print("\n📋 EEG Data Information:")
print(f"├── Data source: {data_source}")
print(f"├── Sampling rate: {raw.info['sfreq']} Hz")
print(f"├── Duration: {raw.times[-1]:.1f} seconds")
print(f"├── Number of EEG channels: {len(raw.ch_names)}")
print(f"└── Channel names: {raw.ch_names[:10]}{'...' if len(raw.ch_names) > 10 else ''}")

# Create processing log to track quality metrics
processing_log = {}
raw_versions = {'Original': raw.copy()}

# Initial quality assessment
print("\n🔍 Computing initial data quality...")
initial_snr = compute_snr(raw, method='spectral')
processing_log['Original'] = initial_snr
print_snr_summary(initial_snr, "Original Data")

## 4. Channel Montage Setup {#montage-setup}

**Purpose**: Assign 3D spatial coordinates to each EEG electrode for spatial analyses and visualizations.

**Why needed**: Many preprocessing steps (bad channel detection, interpolation) and analyses (source localization, connectivity) require knowing where each electrode is positioned on the scalp. Without spatial information, we can't determine which channels are neighbors or create topographic maps.

**Method**: Match electrode names to standard montage templates (10-20, 10-05, etc.) that define precise 3D coordinates for each electrode position.

In [None]:
# Enhanced Channel Type Setup and EEG Processing
print("🔧 Setting up channel types and preparing data...")

# Create a copy for processing
print("📋 Creating a working copy of the data for processing...")
print("   This preserves the original data for comparison and allows us to restart if needed.")
raw_processed = raw.copy()

# Step 1: Identify and set EOG channel types
print("\n👁️ Identifying and setting EOG channel types...")

# Common EOG channel patterns to look for
eog_patterns = [
    'VEOG', 'HEOG',           # Vertical/Horizontal EOG (most common)
    'EOG1', 'EOG2', 'EOG01', 'EOG02',  # Numbered EOG channels
    'LEOG', 'REOG',           # Left/Right EOG
    'LO1', 'LO2', 'IO1', 'IO2',  # Superior/Inferior Orbital
    'SO1', 'SO2',             # Superior Orbital
    'VREF', 'HREF',           # Vertical/Horizontal Reference
    'EYE', 'BLINK'            # Other eye-related patterns
]

# Find EOG channels in the data
eog_channels_found = []
eog_channel_mapping = {}

print("   🔍 Scanning channel names for EOG patterns...")
for ch_name in raw_processed.ch_names:
    ch_upper = ch_name.upper()
    for pattern in eog_patterns:
        if pattern in ch_upper:
            eog_channels_found.append(ch_name)
            eog_channel_mapping[ch_name] = 'eog'
            print(f"      ✅ Found EOG channel: {ch_name}")
            break

# Apply EOG channel type changes
if eog_channels_found:
    print(f"\n   📝 Setting {len(eog_channels_found)} channels to EOG type...")
    raw_processed.set_channel_types(eog_channel_mapping)
    
    # Verify the changes
    for ch_name in eog_channels_found:
        ch_idx = raw_processed.ch_names.index(ch_name)
        ch_type = raw_processed.info['chs'][ch_idx]['kind']
        if ch_type == mne.io.constants.FIFF.FIFFV_EOG_CH:
            print(f"      ✅ {ch_name}: Successfully set to EOG")
        else:
            print(f"      ⚠️ {ch_name}: Failed to set EOG type")
else:
    print("   ℹ️ No EOG channels found in data")

# Step 2: Get channel type summary (robust method)
print(f"\n📊 Channel type summary:")

# Robust method to get channel types that works across MNE versions
def get_channel_types_robust(raw):
    """Get channel types in a robust way across MNE versions"""
    ch_types_dict = {}
    
    try:
        # Method 1: Try get_channel_types() if available and returns dict
        ch_types = raw.get_channel_types()
        if isinstance(ch_types, dict):
            return ch_types
    except:
        pass
    
    # Method 2: Use info structure directly
    try:
        for i, ch_name in enumerate(raw.ch_names):
            ch_kind = raw.info['chs'][i]['kind']
            if ch_kind == mne.io.constants.FIFF.FIFFV_EEG_CH:
                ch_types_dict[ch_name] = 'eeg'
            elif ch_kind == mne.io.constants.FIFF.FIFFV_EOG_CH:
                ch_types_dict[ch_name] = 'eog'
            elif ch_kind == mne.io.constants.FIFF.FIFFV_ECG_CH:
                ch_types_dict[ch_name] = 'ecg'
            elif ch_kind == mne.io.constants.FIFF.FIFFV_EMG_CH:
                ch_types_dict[ch_name] = 'emg'
            elif ch_kind == mne.io.constants.FIFF.FIFFV_MISC_CH:
                ch_types_dict[ch_name] = 'misc'
            elif ch_kind == mne.io.constants.FIFF.FIFFV_STIM_CH:
                ch_types_dict[ch_name] = 'stim'
            else:
                ch_types_dict[ch_name] = 'unknown'
        return ch_types_dict
    except:
        pass
    
    # Method 3: Fallback - assume all are EEG except known EOG
    for ch_name in raw.ch_names:
        if ch_name in eog_channels_found:
            ch_types_dict[ch_name] = 'eog'
        else:
            ch_types_dict[ch_name] = 'eeg'
    
    return ch_types_dict

# Get channel types using robust method
ch_types_dict = get_channel_types_robust(raw_processed)

# Create lists for each channel type
eeg_channels = [ch for ch, ch_type in ch_types_dict.items() if ch_type == 'eeg']
eog_channels = [ch for ch, ch_type in ch_types_dict.items() if ch_type == 'eog']
other_channels = [ch for ch, ch_type in ch_types_dict.items() if ch_type not in ['eeg', 'eog']]

# Update eog_channels list if we found more
if eog_channels and eog_channels != eog_channels_found:
    print(f"   🔄 Updated EOG channels list: {eog_channels}")
    eog_channels_found = eog_channels

print(f"   📍 EEG channels: {len(eeg_channels)}")
print(f"   👁️ EOG channels: {len(eog_channels)}")
if other_channels:
    # Get unique channel types for other channels
    other_types = list(set([ch_types_dict[ch] for ch in other_channels]))
    print(f"   🔧 Other channels: {len(other_channels)} (types: {other_types})")

# Show channel details
if eog_channels:
    print(f"\n   EOG channels found: {eog_channels}")
if len(eeg_channels) > 10:
    print(f"   EEG channels (first 10): {eeg_channels[:10]}...")
else:
    print(f"   EEG channels: {eeg_channels}")

# Step 3: Create EEG-only version for montage and main processing
print(f"\n🧠 Creating EEG-only dataset for main processing...")
raw_eeg_only = raw_processed.copy().pick('eeg')
print(f"   ✅ EEG-only dataset created with {len(raw_eeg_only.ch_names)} channels")

# Step 4: Set up channel montage on EEG channels
print("\n🗺️ Setting up channel montage for EEG channels...")

# Try BrainProducts montage first, then fallbacks
montage_candidates = [
    'brainproducts-RNP-BA-128',
    'standard_1020',
    'standard_1005', 
    'easycap-M1',
    'biosemi64'
]

montage_set = False
montage_name = None

for candidate_montage in montage_candidates:
    try:
        print(f"   🎯 Trying {candidate_montage} montage...")
        montage = mne.channels.make_standard_montage(candidate_montage)
        
        # Apply montage to EEG-only data
        raw_eeg_only.set_montage(montage, match_case=False, on_missing='ignore', verbose=False)
        
        # Check if any channels were matched
        n_matched = sum(1 for ch in raw_eeg_only.info['chs'] if ch['loc'][0] != 0)
        match_percentage = (n_matched / len(raw_eeg_only.ch_names)) * 100
        
        print(f"      📍 Matched: {n_matched}/{len(raw_eeg_only.ch_names)} channels ({match_percentage:.1f}%)")
        
        if n_matched > len(raw_eeg_only.ch_names) * 0.5:  # At least 50% matched
            print(f"   ✅ Successfully applied {candidate_montage} montage")
            montage_name = candidate_montage
            montage_set = True
            
            # Show matched/unmatched channels
            matched_channels = [ch_name for i, ch_name in enumerate(raw_eeg_only.ch_names) 
                               if raw_eeg_only.info['chs'][i]['loc'][0] != 0]
            unmatched_channels = [ch_name for i, ch_name in enumerate(raw_eeg_only.ch_names) 
                                 if raw_eeg_only.info['chs'][i]['loc'][0] == 0]
            
            if len(matched_channels) <= 10:
                print(f"      ✅ Matched channels: {matched_channels}")
            else:
                print(f"      ✅ Matched channels (first 10): {matched_channels[:10]}...")
                
            if unmatched_channels:
                if len(unmatched_channels) <= 5:
                    print(f"      ⚠️ Unmatched channels: {unmatched_channels}")
                else:
                    print(f"      ⚠️ Unmatched channels (first 5): {unmatched_channels[:5]}...")
            
            break
        else:
            print(f"      ❌ Insufficient matches ({match_percentage:.1f}%), trying next montage...")
            
    except Exception as e:
        print(f"      ❌ Failed to apply {candidate_montage}: {e}")
        continue

if not montage_set:
    print("⚠️ No standard montage provided sufficient channel matches.")
    print("   Proceeding without spatial information (some analyses may be limited).")

# Step 5: Apply montage back to full dataset (including EOG)
if montage_set:
    print(f"\n🔄 Applying {montage_name} montage to full dataset...")
    try:
        montage = mne.channels.make_standard_montage(montage_name)
        raw_processed.set_montage(montage, match_case=False, on_missing='ignore', verbose=False)
        print("   ✅ Montage applied to full dataset (EEG + EOG)")
    except Exception as e:
        print(f"   ⚠️ Could not apply montage to full dataset: {e}")

# Step 6: Set processing dataset (use EEG-only for main processing)
print(f"\n🎯 Setting up processing dataset...")
print("   📋 For preprocessing pipeline, we'll use EEG-only data")
print("   📋 EOG channels will be available for artifact removal when needed")

# Use EEG-only data as the main processing dataset
raw_for_processing = raw_eeg_only.copy()

print(f"\n📊 Final setup summary:")
print(f"   Original dataset: {len(raw.ch_names)} channels")
print(f"   EEG channels: {len(eeg_channels)} channels")
print(f"   EOG channels: {len(eog_channels)} channels")
print(f"   Processing dataset: {len(raw_for_processing.ch_names)} EEG channels")
print(f"   Montage applied: {montage_name if montage_set else 'None'}")

# Channel locations (if montage was set)
if montage_set:
    print("   🗺️ Plotting channel locations...")
    try:
        raw_processed.plot_sensors(kind='topomap', show_names=True, show=False)
        plt.title('All Channel Locations')
        plt.gcf().set_constrained_layout(True)
        plt.show()
        
    except Exception as e:
        print(f"   ⚠️ Could not plot channel locations: {e}")

print(f"\n✅ Channel setup and montage configuration completed!")
print(f"   📊 Ready to proceed with preprocessing pipeline using {len(raw_for_processing.ch_names)} EEG channels")
print(f"   👁️ EOG channels ({eog_channels}) remain available in full dataset for artifact removal")

# Update variables for the rest of the preprocessing pipeline
raw_processed = raw_for_processing  # This will be used for the rest of preprocessing

# Store references for later use
eog_channels = eog_channels_found  # Make sure eog_channels variable is available globally

## 5. Preprocessing Pipeline {#preprocessing-pipeline}

**Overview**: This section applies the complete CleanEEG preprocessing workflow following the DISCOVER-EEG framework. Each step targets specific types of artifacts while preserving neural signals.

**Pipeline Logic**: Steps are ordered to handle the largest artifacts first (line noise, drifts) before more sophisticated analyses (ICA, ASR) that work better on cleaner data. Quality metrics after each step confirm improvements.

**Key Principle**: Every preprocessing step should improve signal quality. We'll monitor this with quantitative metrics throughout.

Now we'll apply the complete CleanEEG preprocessing pipeline, monitoring quality at each step:

### 5.1 Line Noise Removal (Denoising Source Separation) {#line-noise}

**Purpose**: Remove electrical interference from power lines (50/60 Hz) that contaminates EEG signals.

**Why needed**: Power line noise creates strong, narrow-band artifacts that can overwhelm neural signals and distort frequency analysis. This interference comes from electrical equipment and building wiring.

**Method**: Denoising Source Separation (DSS) is superior to simple notch filtering because it removes line noise while preserving neural activity at the same frequencies.

In [None]:
from meegkit import dss

print("⚡ Removing line noise using Denoising Source Separation (DSS)...")

# Line frequency depends on geographical location:
# - 50 Hz: Europe, Asia, Africa, Australia (most of the world)
# - 60 Hz: North America, parts of South America, some Pacific islands
# Since this dataset was collected in Germany (LEMON dataset), we use 50 Hz
line_freq = 50
print(f"   🌍 Using {line_freq} Hz line frequency (European power grid standard)")

# Get the EEG data
data = raw_processed.get_data()
sfreq = raw_processed.info['sfreq']

# Optional: Verify line noise presence at target frequency
print(f"   🔍 Checking line noise power at {line_freq} Hz...")
freqs, psd = signal.welch(data, sfreq, nperseg=int(2*sfreq))
freq_idx = np.argmin(np.abs(freqs - line_freq))

power_line_freq = np.mean(psd[:, freq_idx])
print(f"      - Power at {line_freq} Hz: {power_line_freq:.2e}")

# Compare to neighboring frequencies for context
neighbor_freqs = [line_freq - 2, line_freq + 2]  # ±2 Hz from line frequency
neighbor_powers = []
for neighbor_freq in neighbor_freqs:
    neighbor_idx = np.argmin(np.abs(freqs - neighbor_freq))
    neighbor_power = np.mean(psd[:, neighbor_idx])
    neighbor_powers.append(neighbor_power)
    
avg_neighbor_power = np.mean(neighbor_powers)
line_noise_ratio = power_line_freq / avg_neighbor_power
print(f"      - Average power at {neighbor_freqs[0]}-{neighbor_freqs[1]} Hz: {avg_neighbor_power:.2e}")
print(f"      - Line noise prominence: {line_noise_ratio:.1f}x above neighbors")

# Apply DSS line noise removal
try:
    print(f"   🔧 Applying DSS to remove {line_freq} Hz line noise...")
    processed_data, artifacts = dss.dss_line(
        data.T,                          # Data must be (time x channels)
        fline=line_freq,                 # Line frequency to target
        sfreq=sfreq,                     # Sampling frequency
        show=False                       # Don't show plots
    )
    
    # Update the data in our raw object
    raw_processed._data = processed_data.T  # Convert back to (channels x time)
    
    print(f"   ✅ DSS line noise removal completed")
    print(f"      - Targeted frequency: {line_freq} Hz")
    print(f"      - Method: Denoising Source Separation")
    
    # Store version
    raw_versions['After Line Noise Removal'] = raw_processed.copy()
    
    # Quality assessment
    snr_after_line = compute_snr(raw_processed, method='spectral')
    processing_log['After Line Noise'] = snr_after_line
    print_snr_summary(snr_after_line, "After Line Noise Removal")
    
except Exception as e:
    print(f"   ⚠️ DSS line noise removal failed: {e}")
    print("   🔄 Falling back to notch filter...")
    
    # Fallback to notch filter
    raw_processed.notch_filter(freqs=[line_freq], verbose=False)
    
    print(f"   ✅ Notch filter applied at {line_freq} Hz")
    
    raw_versions['After Line Noise Removal'] = raw_processed.copy()
    snr_after_line = compute_snr(raw_processed, method='spectral')
    processing_log['After Line Noise'] = snr_after_line
    print_snr_summary(snr_after_line, "After Notch Filter")

# Verify line noise reduction
print("   📊 Verifying line noise reduction...")
data_after = raw_processed.get_data()
freqs_after, psd_after = signal.welch(data_after, sfreq, nperseg=int(2*sfreq))
freq_idx_after = np.argmin(np.abs(freqs_after - line_freq))
power_line_freq_after = np.mean(psd_after[:, freq_idx_after])

reduction_db = 10 * np.log10(power_line_freq / (power_line_freq_after + 1e-12))
print(f"      - Line noise reduction: {reduction_db:.1f} dB at {line_freq} Hz")

# Plot comparison
plot_psd_comparison(
    [raw_versions['Original'], raw_versions['After Line Noise Removal']], 
    ['Original', 'After Line Noise Removal'],
    "PSD: Before vs After Line Noise Removal"
)

### 5.2 Bandpass Filter {#bandpass-filter}

**Purpose**: Remove slow drifts, baseline shifts, and high-frequency noise from EEG recordings.

**Why needed**: EEG amplifiers can introduce very low-frequency drifts (<1 Hz) due to electrode movement, skin conductance changes, and amplifier instabilities. Additionally, high-frequency noise (>100 Hz) from electrical interference, muscle artifacts, and amplifier noise can contaminate the signal. These artifacts can distort analyses and make data appear non-stationary.

**Method**: A 1-100 Hz bandpass filter combines:
- **Highpass component (1 Hz)**: Removes slow drifts while preserving all neural frequencies of interest (delta waves start at ~1-4 Hz)
- **Lowpass component (100 Hz)**: Removes high-frequency noise while retaining all relevant neural oscillations including gamma activity (30-100 Hz)

In [None]:
print("🔽 Applying bandpass filter...")

# Apply bandpass filter at 1 Hz to remove slow drifts
hp_freq = 1.0
lp_freq = 100.0
raw_processed.filter(
    l_freq=hp_freq,   # High-pass cutoff
    h_freq=lp_freq,   # Low-pass cutoff
    method='fir',     # Finite Impulse Response filter
    verbose=False
)

print(f"   ✅ Highpass filter applied at {hp_freq} Hz and Lowpass filter applied at {lp_freq} Hz")

# Store version
raw_versions['After Bandpass Filter'] = raw_processed.copy()

# Quality assessment
snr_after_hp = compute_snr(raw_processed, method='spectral')
processing_log['After Bandpass'] = snr_after_hp
print_snr_summary(snr_after_hp, "After Bandpass Filter")

# Plot comparison
plot_psd_comparison(
    [raw_versions['After Line Noise Removal'], raw_versions['After Bandpass Filter']], 
    ['After Line Noise Removal', 'After Bandpass Filter'],
    "PSD: Before vs After Bandpass Filter"
)

### 5.3 Downsample Data {#downsample}

**Purpose**: Reduce data size and computational load while preserving all relevant neural information.

**Why needed**: Many EEG systems record at very high sampling rates (>1000 Hz) to avoid aliasing, but most EEG analysis only requires frequencies up to 100-200 Hz. High sampling rates create unnecessarily large files and slow processing.

**Method**: Downsample to 500 Hz (adequate for frequencies up to 250 Hz) after applying anti-aliasing filters to prevent frequency distortion.

In [None]:
print("📉 Downsampling data...")

# Downsample to 500 Hz (adequate for most EEG analyses)
target_sfreq = 500
original_sfreq = raw_processed.info['sfreq']

if original_sfreq > target_sfreq:
    raw_processed.resample(target_sfreq, verbose=False)
    print(f"   ✅ Downsampled from {original_sfreq} Hz to {target_sfreq} Hz")
    print(f"   📊 Data reduction: {(1 - target_sfreq/original_sfreq)*100:.1f}%")
else:
    print(f"   ℹ️ No downsampling needed (current: {original_sfreq} Hz)")

# Store version
raw_versions['After Downsampling'] = raw_processed.copy()

# Quality assessment
snr_after_downsample = compute_snr(raw_processed, method='spectral')
processing_log['After Downsample'] = snr_after_downsample
print_snr_summary(snr_after_downsample, "After Downsampling")

# Plot comparison (focus on frequencies up to new Nyquist)
nyquist_freq = min(target_sfreq/2, 100)
plot_psd_comparison(
    [raw_versions['After Bandpass Filter'], raw_versions['After Downsampling']], 
    ['Before Downsampling', 'After Downsampling'],
    "PSD: Before vs After Downsampling",
    fmax=nyquist_freq
)

### 5.4 Bad Channel Rejection (PREP Pipeline) {#bad-channels}

**Purpose**: Automatically identify and mark electrodes that are not recording valid neural signals.

**Why needed**: EEG electrodes can malfunction due to poor skin contact, broken wires, high impedance, or movement artifacts. Bad channels introduce noise and can distort spatial analyses, source localization, and connectivity measures.

**Method**: The PREP pipeline uses multiple statistical criteria: flat channels, channels with extreme amplitudes, poor correlation with neighbors, and channels that deviate from robust signal statistics.

In [None]:
print("🔍 Detecting bad channels using PREP pipeline...")

try:
    from pyprep.find_noisy_channels import NoisyChannels
    
    # Create a NoisyChannels object
    nd = NoisyChannels(raw_processed, random_state=42)
    
    # Find all types of bad channels
    nd.find_all_bads(ransac=True, channel_wise=True, max_chunk_size=None)
    
    # Get the detected bad channels
    bad_channels = nd.get_bads()
    
    if bad_channels:
        print(f"   🚫 Detected bad channels: {bad_channels}")
        
        # Mark channels as bad
        raw_processed.info['bads'] = bad_channels
        
        # Show channel types found
        for bad_type in ['bad_by_nan', 'bad_by_flat', 'bad_by_deviation', 
                        'bad_by_hf_noise', 'bad_by_correlation', 'bad_by_ransac']:
            bad_chans = getattr(nd, bad_type, [])
            if bad_chans:
                print(f"      - {bad_type.replace('bad_by_', '').replace('_', ' ').title()}: {bad_chans}")
    else:
        print("   ✅ No bad channels detected")
    
    prep_success = True
    
except Exception as e:
    print(f"   ⚠️ PREP pipeline failed: {e}")
    print("   Using simple statistical bad channel detection...")
    
    # Fallback: simple statistical method
    data = raw_processed.get_data()
    
    # Find channels with extreme variance
    channel_vars = np.var(data, axis=1)
    var_threshold_high = np.percentile(channel_vars, 95)
    var_threshold_low = np.percentile(channel_vars, 5)
    
    bad_channels = []
    for i, ch_name in enumerate(raw_processed.ch_names):
        if channel_vars[i] > var_threshold_high or channel_vars[i] < var_threshold_low:
            bad_channels.append(ch_name)
    
    if bad_channels:
        print(f"   🚫 Detected bad channels (statistical): {bad_channels}")
        raw_processed.info['bads'] = bad_channels
    else:
        print("   ✅ No bad channels detected (statistical method)")
    
    prep_success = False

# Store version
raw_versions['After Bad Channel Detection'] = raw_processed.copy()

# Quality assessment (excluding bad channels)
snr_after_bad_chans = compute_snr(raw_processed, method='spectral')
processing_log['After Bad Channels'] = snr_after_bad_chans
print_snr_summary(snr_after_bad_chans, "After Bad Channel Detection")

print(f"\n📋 Bad Channel Summary:")
print(f"   Total channels: {len(raw_processed.ch_names)}")
print(f"   Bad channels: {len(raw_processed.info['bads'])}")
print(f"   Good channels: {len(raw_processed.ch_names) - len(raw_processed.info['bads'])}")

### 5.5 EOG Artifact Removal {#eog-removal}

**Purpose**: Remove eye movement and blink artifacts that contaminate frontal EEG electrodes.

**Why needed**: Eye movements and blinks generate large electrical potentials (much larger than neural signals) that spread to frontal and temporal EEG channels. These artifacts can completely obscure brain activity and create false patterns in analysis.

**Method**: EOG regression uses dedicated eye movement channels to model and subtract eye-related artifacts from EEG channels, preserving underlying neural activity.

In [None]:
print("👁️ Simple EOG artifact removal...")

# Check if we have EOG channels
if 'eog_channels' in locals() and eog_channels and len(eog_channels) > 0:
    print(f"   📍 Found {len(eog_channels)} EOG channels: {eog_channels}")
    
    try:
        from mne.preprocessing import EOGRegression
        
        print("   🔧 Preparing data for EOG regression...")
        
        # Method 1: Use the original raw data that contains both EEG and EOG
        raw_full = raw.copy()  # Original data with all channels
        
        # Ensure we have the right channel types
        print("   📝 Setting up channel types...")
        
        # Get all channel names
        all_ch_names = raw_full.ch_names.copy()
        
        # Separate EEG and EOG channels
        eeg_ch_names = [ch for ch in all_ch_names if ch not in eog_channels]
        
        print(f"      EEG channels: {len(eeg_ch_names)}")
        print(f"      EOG channels: {len(eog_channels)}")
        
        # Set channel types explicitly
        channel_types = {}
        for ch in eeg_ch_names:
            channel_types[ch] = 'eeg'
        for ch in eog_channels:
            channel_types[ch] = 'eog'
            
        raw_full.set_channel_types(channel_types)
        print(f"   ✅ Channel types set: {len(eeg_ch_names)} EEG, {len(eog_channels)} EOG")
        
        # Set average reference for EEG channels (required for EOG regression)
        print("   📊 Setting average reference for EEG channels...")
        raw_full.set_eeg_reference('average', projection=True, verbose=False)
        print("   ✅ Average reference set")
        
        # Verify we have both EEG and EOG channels
        eeg_picks = mne.pick_types(raw_full.info, eeg=True)
        eog_picks = mne.pick_types(raw_full.info, eog=True)
        
        print(f"   🔍 Verification - EEG channels found: {len(eeg_picks)}, EOG channels found: {len(eog_picks)}")
        
        if len(eeg_picks) == 0:
            raise ValueError("No EEG channels found for EOG regression")
        if len(eog_picks) == 0:
            raise ValueError("No EOG channels found for EOG regression")
        
        # Apply EOG regression
        print("   🔄 Applying EOG regression...")
        eog_regression = EOGRegression(
            picks='eeg',           # Apply to EEG channels
            picks_artifact='eog'   # Use EOG channels as reference
        )
        
        # Fit the regression model
        print("   📈 Fitting EOG regression model...")
        eog_regression.fit(raw_full)
        print("   ✅ EOG regression model fitted")
        
        # Apply regression to get cleaned data
        raw_cleaned = eog_regression.apply(raw_full, copy=True)
        print("   ✅ EOG regression applied")
        
        # Extract only EEG channels for continued processing
        eeg_channel_names = [ch for ch in raw_cleaned.ch_names if ch not in eog_channels]
        raw_processed_new = raw_cleaned.copy().pick(eeg_channel_names)
        
        print(f"   📊 Extracted {len(raw_processed_new.ch_names)} EEG channels after EOG removal")
        
        # Calculate improvement
        data_before = raw_processed.get_data()
        data_after   = raw_processed_new.get_data()
        
        min_channels = min(data_before.shape[0], data_after.shape[0])
        variance_before = np.var(data_before[:min_channels])
        variance_after  = np.var(data_after[:min_channels])
        
        if variance_before > 0:
            variance_reduction = ((variance_before - variance_after) / variance_before) * 100
            print(f"   📈 Variance reduction: {variance_reduction:.1f}%")
        
        # Update processing dataset
        raw_processed = raw_processed_new
        
        # Store for comparison
        raw_versions['After EOG Removal'] = raw_processed.copy()
        
        # --- 🔧 QUALITY LOG UPDATE ---
        snr_after_eog = compute_snr(raw_processed)        # calculate SNR metrics
        processing_log['After EOG Removal'] = snr_after_eog
        print_snr_summary(snr_after_eog, 'After EOG Removal')
        # -------------------------------------------------------------------------
        
        # Show comparison if we have previous version
        if len(raw_versions) >= 2:
            version_keys = list(raw_versions.keys())
            plot_psd_comparison(
                [raw_versions[version_keys[-2]],
                 raw_versions[version_keys[-1]]],
                labels=[version_keys[-2], version_keys[-1]],
                title="PSD Before vs After EOG Removal"
            )
        
        print("   ✅ EOG regression completed successfully")
        
    except ImportError:
        print("   ⚠️ EOGRegression not available, skipping EOG removal")
        raw_versions['After EOG Removal'] = raw_processed.copy()
        
    except Exception as e:
        print(f"   ⚠️ EOG regression failed: {e}")
        print("   🔄 Trying alternative EOG removal approach...")
        
        try:
            # Alternative approach: Simple linear regression
            print("   📊 Using simple linear regression approach...")
            
            # Get EEG and EOG data
            raw_eeg = raw.copy().pick([ch for ch in raw.ch_names if ch not in eog_channels])
            raw_eog = raw.copy().pick(eog_channels)
            
            # Set average reference for EEG
            raw_eeg.set_eeg_reference('average', projection=False, verbose=False)
            
            # Get the data arrays
            eeg_data = raw_eeg.get_data()  # Shape: (n_eeg_channels, n_times)
            eog_data = raw_eog.get_data()  # Shape: (n_eog_channels, n_times)
            
            print(f"      EEG data shape: {eeg_data.shape}")
            print(f"      EOG data shape: {eog_data.shape}")
            
            # Simple regression: for each EEG channel, regress out EOG
            eeg_cleaned = eeg_data.copy()
            
            for i, eeg_ch in enumerate(raw_eeg.ch_names):
                for j, eog_ch in enumerate(eog_channels):
                    eog_signal = eog_data[j, :]
                    eeg_signal = eeg_cleaned[i, :]
                    
                    covariance   = np.cov(eeg_signal, eog_signal)[0, 1]
                    eog_variance = np.var(eog_signal)
                    
                    if eog_variance > 0:
                        beta = covariance / eog_variance
                        eeg_cleaned[i, :] = eeg_signal - beta * eog_signal
            
            # Create new raw object with cleaned EEG data
            raw_processed_new = mne.io.RawArray(eeg_cleaned, raw_eeg.info.copy(), verbose=False)
            
            print("   ✅ Alternative EOG removal applied")
            
            # Calculate improvement
            variance_before = np.var(eeg_data)
            variance_after  = np.var(eeg_cleaned)
            
            if variance_before > 0:
                variance_reduction = ((variance_before - variance_after) / variance_before) * 100
                print(f"   📈 Variance reduction: {variance_reduction:.1f}%")
            
            # Update processing dataset
            raw_processed = raw_processed_new
            
            # Store for comparison
            raw_versions['After EOG Removal'] = raw_processed.copy()
            
            # --- 🔧 QUALITY LOG UPDATE -----------------------
            snr_after_eog = compute_snr(raw_processed)
            processing_log['After EOG Removal'] = snr_after_eog
            print_snr_summary(snr_after_eog, 'After EOG Removal')
            # ---------------------------------------------------------------------
            
            print("   ✅ Alternative EOG removal completed successfully")
            
        except Exception as e2:
            print(f"   ⚠️ Alternative EOG removal also failed: {e2}")
            print("   📋 Continuing without EOG correction...")
            
            raw_versions['After EOG Removal'] = raw_processed.copy()
            
            if processing_log:
                last_key = list(processing_log.keys())[-1]
                processing_log['After EOG Removal'] = processing_log[last_key].copy()

else:
    print("   ℹ️ No EOG channels found - skipping EOG artifact removal")
    
    raw_versions['After EOG Removal'] = raw_processed.copy()
    
    if processing_log:
        last_key = list(processing_log.keys())[-1]
        processing_log['After EOG Removal'] = processing_log[last_key].copy()
        print_snr_summary(processing_log['After EOG Removal'], 
                                 "No EOG Removal (Skipped)")

print(f"\n✅ EOG processing step completed")
print(f"   📊 Processing dataset: {len(raw_processed.ch_names)} EEG channels")
print(f"   👁️ EOG regression: {'Applied' if eog_channels and len(eog_channels) > 0 else 'Skipped'}")

print(f"\n🚀 Ready to continue with next preprocessing step!")


### 5.6 Independent Component Analysis (ICA) {#ica}

**Purpose**: Separate mixed EEG signals into independent components and remove non-neural artifacts.

**Why needed**: EEG signals are mixtures of neural activity, muscle artifacts, heart beats, eye movements, and other noise sources. These artifacts can't always be removed by simple filtering and often overlap with neural frequencies.

**Method**: ICA decomposes the signal into statistically independent components. ICLabel automatically classifies components as brain activity, muscle, eye blinks, heart beats, or noise, allowing selective removal of artifacts while preserving neural signals.

In [None]:
print("🧠 Performing Independent Component Analysis (ICA)...")

try:
    from mne.preprocessing import ICA
    from mne_icalabel import label_components
    
    # Create a copy for ICA (excluding bad channels)
    raw_for_ica = raw_processed.copy().pick('eeg', exclude='bads')
    
    # Re-reference to average for better ICA decomposition
    raw_for_ica.set_eeg_reference('average', projection=False, verbose=False)
    
    # Create ICA object
    ica = ICA(
        n_components=None,
        method='fastica',
        random_state=42,
        max_iter='auto'
        )
    
    print(f"   🔧 Fitting ICA ...")
    ica.fit(raw_for_ica, verbose=False)
    
    # Use ICLabel for automatic component classification
    print("   🏷️ Classifying components with ICLabel...")
    ic_labels = label_components(raw_for_ica, ica, method='iclabel')
    
    # Get component labels and probabilities
    labels = ic_labels['labels']
    probabilities = ic_labels['y_pred_proba']
    
    # Define components to exclude (artifacts)
    artifact_types = ['muscle artifact', 'eye blink', 'heart beat', 'line noise', 'channel noise']
    
    exclude_idx = []
    component_summary = {}
    
    for i, (label, probs) in enumerate(zip(labels, probabilities)):
        component_summary[f'IC{i:02d}'] = {
            'label': label,
            'confidence': np.max(probs)
        }
        
        # Exclude components that are artifacts with high confidence
        if label in artifact_types and np.max(probs) > 0.7:
            exclude_idx.append(i)
    
    print(f"\n   📊 Component Classification Summary:")
    for comp, info in component_summary.items():
        status = "🚫 EXCLUDE" if int(comp[2:]) in exclude_idx else "✅ KEEP"
        print(f"      {comp}: {info['label']} (conf: {info['confidence']:.2f}) {status}")
    
    # Set components to exclude
    ica.exclude = exclude_idx
    
    print(f"\n   🗑️ Excluding {len(exclude_idx)} artifactual components")
    
    # Apply ICA to the original data (with bad channels)
    raw_processed = ica.apply(raw_processed, verbose=False)
    
    print("   ✅ ICA applied successfully")
    
    # Plot ICA components (first 12)
    if len(ica.exclude) > 0:
        fig = ica.plot_components(picks=range(min(12, ica.n_components_)), 
                                 title='ICA Components (Red = Excluded)', show=False)
        plt.show()
    
except Exception as e:
    print(f"   ⚠️ ICA failed: {e}")
    print("   Proceeding without ICA component removal")

# Store version
raw_versions['After ICA'] = raw_processed.copy()

# Quality assessment
snr_after_ica = compute_snr(raw_processed, method='spectral')
processing_log['After ICA'] = snr_after_ica
print_snr_summary(snr_after_ica, "After ICA")

# Plot comparison
plot_psd_comparison(
    [raw_versions['After EOG Removal'], raw_versions['After ICA']], 
    ['Before ICA', 'After ICA'],
    "PSD: Before vs After ICA"
)

### 5.7 Bad Channel Interpolation {#interpolation}

**Purpose**: Restore the full electrode array by estimating signals at previously identified bad channel locations.

**Why needed**: Many analyses (especially connectivity and source localization) require a complete, uniform electrode montage. Missing channels create gaps in spatial coverage and can bias results toward areas with higher electrode density.

**Method**: Spherical spline interpolation uses signals from neighboring good electrodes to estimate what the signal would have been at bad electrode locations, based on the spatial smoothness of scalp potentials.

In [None]:
print("🔧 Interpolating bad channels...")

n_bad_channels = len(raw_processed.info['bads'])

if n_bad_channels > 0:
    print(f"   📍 Interpolating {n_bad_channels} bad channels: {raw_processed.info['bads']}")
    
    try:
        # Interpolate bad channels
        raw_processed.interpolate_bads(reset_bads=True, verbose=False)
        print("   ✅ Bad channels interpolated successfully")
        
    except Exception as e:
        print(f"   ⚠️ Channel interpolation failed: {e}")
        print("   This might be due to missing channel locations")
        
        # Reset bads list if interpolation failed
        raw_processed.info['bads'] = []
else:
    print("   ℹ️ No bad channels to interpolate")

# Store version
raw_versions['After Interpolation'] = raw_processed.copy()

# Quality assessment
snr_after_interp = compute_snr(raw_processed, method='spectral')
processing_log['After Interpolation'] = snr_after_interp
print_snr_summary(snr_after_interp, "After Channel Interpolation")

print(f"\n📋 Channel Status:")
print(f"   Active channels: {len(raw_processed.ch_names)}")
print(f"   Bad channels remaining: {len(raw_processed.info['bads'])}")

### 5.8 Bad Time Segments Removal (Artifact Subspace Reconstruction) {#asr}

**Purpose**: Automatically detect and correct brief periods of extreme artifacts that affect multiple channels simultaneously.

**Why needed**: Even after other cleaning steps, occasional periods of extreme artifacts can remain (sudden movements, cable bumps, amplifier saturation). These brief but severe artifacts can distort statistical analyses and connectivity measures.

**Method**: ASR learns the 'normal' signal patterns from clean calibration data, then identifies and reconstructs time periods where the signal deviates beyond a statistical threshold, effectively removing transient artifacts while preserving normal neural activity.

In [None]:
print("⚡ Removing bad time segments using Artifact Subspace Reconstruction (ASR)...")

try:
    from meegkit.asr import ASR
    
    # Get data for ASR
    data = raw_processed.get_data() 
    sfreq = raw_processed.info['sfreq']
    
    # ASR parameters
    asr_cutoff = 5      # Standard deviation cutoff (lower = more aggressive)
    calibration_time = min(60, data.shape[0] / sfreq)  # Use first 60s or all data
    
    print(f"   🔧 ASR parameters: cutoff={asr_cutoff}, calibration={calibration_time:.1f}s")
    
    # Initialize ASR
    asr = ASR(sfreq=sfreq, cutoff=asr_cutoff)
    
    # Fit ASR on calibration data (clean segment)
    calibration_samples = int(calibration_time * sfreq)
    calibration_data = data[:calibration_samples, :]
    asr.fit(calibration_data)
    
    # Apply ASR to the entire dataset
    data_asr = asr.transform(data)
    
    # Calculate percentage of data reconstructed
    reconstruction_ratio = np.mean(np.var(data - data_asr, axis=0) / np.var(data, axis=0))
    
    print(f"   📊 Reconstruction ratio: {reconstruction_ratio:.1%}")
    
    # Update raw data
    raw_processed._data = data_asr.T  # Convert back to (channels x time)
    
    print("   ✅ ASR applied successfully")
    
except Exception as e:
    print(f"   ⚠️ ASR failed: {e}")
    print("   Trying alternative bad segment removal...")
    
    try:
        # Alternative: simple artifact rejection based on amplitude
        data = raw_processed.get_data()

        # Flag samples whose absolute value exceeds threshold
        amplitude_threshold = 5 * np.std(data)          # 5× RMS
        bad_samples = np.any(np.abs(data) > amplitude_threshold, axis=0)

        if np.any(bad_samples):
            sfreq = raw_processed.info['sfreq']
            bad_idx = np.where(bad_samples)[0]

            # ─── group contiguous bad samples into segments ───
            breaks = np.where(np.diff(bad_idx) > 1)[0] + 1
            segments = np.split(bad_idx, breaks)

            onsets    = [seg[0] / sfreq for seg in segments]
            durations = [len(seg) / sfreq for seg in segments]

            # mark the segments as BAD_amplitude
            bad_annot = mne.Annotations(onsets,
                                        durations,
                                        ['BAD_amplitude'] * len(onsets),
                                        orig_time=raw_processed.info['meas_date'])
            raw_processed.set_annotations(raw_processed.annotations + bad_annot)

            print(f"   📍 Marked {len(onsets)} bad segments "
                  f"({bad_samples.sum()/len(bad_samples):.2%} of samples)")
        else:
            print("   ✅ No samples exceeded the amplitude threshold")

        print("   ✅ Alternative artifact rejection applied")
        
    except Exception as e2:
        print(f"   ⚠️ Alternative method also failed: {e2}")
        print("   Proceeding without bad segment removal")

# Store version
raw_versions['After ASR'] = raw_processed.copy()

# Quality assessment
snr_after_asr = compute_snr(raw_processed, method='spectral')
processing_log['After ASR'] = snr_after_asr
print_snr_summary(snr_after_asr, "After ASR")

# Plot comparison
plot_psd_comparison(
    [raw_versions['After Interpolation'], raw_versions['After ASR']], 
    ['Before ASR', 'After ASR'],
    "PSD: Before vs After ASR"
)

## 6. Final Quality Assessment {#final-assessment}

**Purpose**: Evaluate the overall effectiveness of the preprocessing pipeline and document improvements.

**Why needed**: It's crucial to verify that preprocessing actually improved data quality rather than inadvertently removing important neural signals. Quantitative metrics provide objective evidence of improvement and help optimize preprocessing parameters.

**Methods**: Compare SNR across frequency bands before and after processing, visualize PSD changes, and generate comprehensive quality reports.

Let's examine the overall improvement in data quality:

In [None]:
print("📈 Final Quality Assessment")
print("=" * 60)

# Plot processing summary
plot_processing_summary(processing_log)

# Final comparison: Original vs Cleaned
plot_psd_comparison(
    [raw_versions['Original'], raw_versions['After ASR']], 
    ['Original Data', 'Fully Processed'],
    "Final Comparison: Original vs Cleaned EEG Data"
)

# SNR improvement summary
print("\n🎯 SNR Improvement Summary:")
print("=" * 50)

original_snr = processing_log['Original']
final_snr = processing_log['After ASR']

for band in ['delta', 'theta', 'alpha', 'beta', 'gamma']:
    original_val = original_snr[band]['mean_snr']
    final_val = final_snr[band]['mean_snr']
    improvement = final_val - original_val
    
    print(f"{band.capitalize():>8}: {original_val:6.2f} → {final_val:6.2f} dB (Δ{improvement:+5.2f} dB)")

print("=" * 50)

# Data quality metrics
print("\n📊 Final Data Summary:")
print(f"├── Duration: {raw_processed.times[-1]:.1f} seconds")
print(f"├── Sampling rate: {raw_processed.info['sfreq']:.0f} Hz")
print(f"├── Channels: {len(raw_processed.ch_names)}")
print(f"├── Bad channels interpolated: {n_bad_channels}")
print(f"└── Processing completed successfully! ✅")

## 7. Saving Results {#saving}

**Purpose**: Export cleaned data in multiple formats and generate comprehensive documentation of the preprocessing workflow.

**Why needed**: Different analysis software requires different file formats. Documentation ensures reproducibility and helps track what preprocessing steps were applied. Quality metrics provide evidence of data improvement for publications.

**Methods**: Save in common EEG formats (BrainVision, EEGLAB, EDF), generate HTML reports with MNE, and export quantitative quality metrics as CSV files.

Save the cleaned data and generate a processing report:

In [None]:
print("💾 Saving processed data and reports...")

# Create output directory
output_dir = Path('cleaned_eeg_output')
output_dir.mkdir(exist_ok=True)

# Generate timestamp for filenames
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# Save cleaned data in multiple formats
formats_to_save = {
    'BrainVision': '.vhdr',
    'EEGLAB': '.set', 
    'EDF': '.edf'
}

saved_files = []

for format_name, extension in formats_to_save.items():
    try:
        output_file = output_dir / f'cleaned_eeg_{timestamp}{extension}'
        
        if extension == '.set':
            # For EEGLAB format
            raw_processed.save(output_file, overwrite=True, verbose=False)
        else:
            # For other formats using export
            mne.export.export_raw(output_file, raw_processed, fmt='auto', overwrite=True, verbose=False)
        
        saved_files.append((format_name, output_file))
        print(f"   ✅ Saved {format_name} format: {output_file.name}")
        
    except Exception as e:
        print(f"   ⚠️ Failed to save {format_name} format: {e}")

# Save processing report
report_file = output_dir / f'processing_report_{timestamp}.html'

try:
    # Create an MNE Report
    report = mne.Report(title='CleanEEG Processing Report')
    
    # Add original vs cleaned comparison
    fig_comparison = plt.figure(figsize=(12, 8))
    
    # Plot PSD comparison
    raw_versions['Original'].compute_psd(fmax=50).plot(show=False)
    plt.title('Original Data PSD')
    report.add_figure(fig_comparison, title='Original Data PSD')
    
    fig_final = plt.figure(figsize=(12, 8))
    raw_versions['After ASR'].compute_psd(fmax=50).plot(show=False)
    plt.title('Cleaned Data PSD')
    report.add_figure(fig_final, title='Cleaned Data PSD')
    
    # Add processing summary as text
    processing_summary = f"""
    # CleanEEG Processing Summary
    
    **Processing Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
    
    **Input Data:**
    - Duration: {raw.times[-1]:.1f} seconds
    - Original sampling rate: {raw.info['sfreq']:.0f} Hz
    - Channels: {len(raw.ch_names)}
    
    **Processing Steps Applied:**
    1. Line noise removal (DSS at {line_freq} Hz)
    2. Bandpass filter (1 Hz - 100 Hz)
    3. Downsampling (to {raw_processed.info['sfreq']:.0f} Hz)
    4. Bad channel detection ({n_bad_channels} channels found)
    5. EOG artifact removal
    6. Independent Component Analysis (ICA)
    7. Bad channel interpolation
    8. Artifact Subspace Reconstruction (ASR)
    
    **Output Data:**
    - Final duration: {raw_processed.times[-1]:.1f} seconds
    - Final sampling rate: {raw_processed.info['sfreq']:.0f} Hz
    - Final channels: {len(raw_processed.ch_names)}
    """
    
    report.add_html(processing_summary, title='Processing Summary')
    
    # Save report
    report.save(report_file, overwrite=True, open_browser=False)
    print(f"   📋 Processing report saved: {report_file.name}")
    
except Exception as e:
    print(f"   ⚠️ Failed to create HTML report: {e}")
    
    # Create a simple text report instead
    text_report_file = output_dir / f'processing_summary_{timestamp}.txt'
    with open(text_report_file, 'w') as f:
        f.write(f"CleanEEG Processing Summary\n")
        f.write(f"Generated: {datetime.now()}\n\n")
        
        f.write(f"Input Data:\n")
        f.write(f"- Duration: {raw.times[-1]:.1f} seconds\n")
        f.write(f"- Sampling rate: {raw.info['sfreq']:.0f} Hz\n")
        f.write(f"- Channels: {len(raw.ch_names)}\n\n")
        
        f.write(f"Output Data:\n")
        f.write(f"- Duration: {raw_processed.times[-1]:.1f} seconds\n")
        f.write(f"- Sampling rate: {raw_processed.info['sfreq']:.0f} Hz\n")
        f.write(f"- Channels: {len(raw_processed.ch_names)}\n\n")
        
        f.write(f"SNR Improvements:\n")
        for band in ['delta', 'theta', 'alpha', 'beta', 'gamma']:
            original_val = processing_log['Original'][band]['mean_snr']
            final_val = processing_log['After ASR'][band]['mean_snr']
            improvement = final_val - original_val
            f.write(f"- {band.capitalize()}: {original_val:.2f} → {final_val:.2f} dB (Δ{improvement:+.2f} dB)\n")
    
    print(f"   📄 Text summary saved: {text_report_file.name}")

# Save SNR data as CSV
snr_csv_file = output_dir / f'snr_data_{timestamp}.csv'
snr_df_list = []

for step_name, snr_data in processing_log.items():
    for band, band_data in snr_data.items():
        snr_df_list.append({
            'processing_step': step_name,
            'frequency_band': band,
            'mean_snr_db': band_data['mean_snr'],
            'std_snr_db': band_data['std_snr']
        })

snr_df = pd.DataFrame(snr_df_list)
snr_df.to_csv(snr_csv_file, index=False)
print(f"   📊 SNR data saved: {snr_csv_file.name}")

print(f"\n🎉 Processing completed successfully!")
print(f"📁 All files saved to: {output_dir.absolute()}")
print(f"\n📋 Saved files:")
for format_name, file_path in saved_files:
    print(f"   - {format_name}: {file_path.name}")
print(f"   - SNR Data: {snr_csv_file.name}")
print(f"   - Report: {report_file.name if 'report_file' in locals() else text_report_file.name}")

## Conclusion

🎉 **Congratulations!** You have successfully completed the CleanEEG preprocessing pipeline.

### What we accomplished:

1. **Loaded and inspected** your EEG data
2. **Applied comprehensive preprocessing** following the DISCOVER-EEG framework:
   - Line noise removal using Denoising Source Separation
   - Highpass filtering to remove slow drifts
   - Downsampling for computational efficiency
   - Automatic bad channel detection using PREP pipeline
   - EOG artifact removal (if channels available)
   - Independent Component Analysis with automatic classification
   - Bad channel interpolation
   - Bad time segment removal using Artifact Subspace Reconstruction

3. **Monitored data quality** throughout the pipeline using SNR metrics
4. **Saved cleaned data** in multiple formats for further analysis
5. **Generated comprehensive reports** documenting the preprocessing steps

### Next steps:

Your cleaned EEG data is now ready for:
- **Spectral analysis** (power spectral density, frequency band analysis)
- **Connectivity analysis** (coherence, phase-amplitude coupling)
- **Event-related potential analysis** (if you have event markers)
- **Machine learning applications** (classification, regression)
- **Source localization** (if you have a forward model)

### Tips for further analysis:

- The cleaned data maintains the original channel structure and timing
- All preprocessing steps are documented in the generated reports
- SNR improvements indicate the effectiveness of each preprocessing step
- Consider the specific requirements of your analysis when choosing output formats

**Happy analyzing!** 🧠✨