# EEG Preprocessing Pipeline: Bad Channel Detection & ICA


## Overview
This notebook implements the third phase of EEG preprocessing pipeline with two main components:
1. **Bad Channel Detection** - Automated identification of problematic channels
2. **ICA Processing** - Independent Component Analysis for artifact removal
 
### Pipeline Architecture
```
Filtered EEG ‚Üí Bad Channel Detection ‚Üí ICA ‚Üí Cleaned Data
              ‚Üò Channel Reports     ‚Üò ICA Reports
```

# PHASE 3: BAD CHANNEL DETECTION & ICA

**Purpose**
- Automatically detect and handle bad EEG channels
- Remove artifacts using Independent Component Analysis
- Generate comprehensive quality reports for cleaned data

## Step 1: Environment Setup & Import Dependencies

In [1]:
# %% Cell 1: Import all required libraries
import pandas as pd
import numpy as np
import mne
import os
import json
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import gc
import warnings
from mne.preprocessing import ICA, corrmap, find_bad_channels_maxwell
warnings.filterwarnings('ignore')

# Set professional plotting style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

print("‚úÖ Bad Channel & ICA environment setup complete")

‚úÖ Bad Channel & ICA environment setup complete


## Step 2: Load Previous Pipeline Output

In [2]:
# %% Cell 2: Load filtered data and previous pipeline results
def load_filtered_data_inventory(output_path):
    """
    Load inventory of filtered EEG files from previous pipeline phase.
    
    Args:
        output_path (Path): Main output directory path
        
    Returns:
        pd.DataFrame: Inventory of filtered files with metadata
    """
    print("üìÅ LOADING FILTERED DATA INVENTORY")
    print("=" * 50)
    
    filtered_path = output_path / 'preprocessed_data' / 'raw_cleaned'
    filtered_files = list(filtered_path.glob("*_filtered.fif"))
    
    if not filtered_files:
        raise FileNotFoundError("No filtered files found. Run filtering pipeline first.")
    
    # Create inventory from filtered files
    filtered_inventory = []
    for file_path in filtered_files:
        filename = file_path.stem  # Remove .fif extension
        # Extract metadata from filename pattern
        parts = filename.replace('_filtered', '').split('_')
        
        # Handle different filename patterns
        if 'sub-' in filename and 'ses-' in filename:
            subject_id = next((p for p in parts if p.startswith('sub-')), 'unknown')
            session_id = next((p for p in parts if p.startswith('ses-')), 'unknown')
            task_type = next((p for p in parts if p.startswith('task-')), 'eeg')
        else:
            # Fallback for custom naming
            subject_id = parts[0] if len(parts) > 0 else 'unknown'
            session_id = parts[1] if len(parts) > 1 else 'unknown'
            task_type = parts[2] if len(parts) > 2 else 'eeg'
        
        filtered_inventory.append({
            'subject_id': subject_id,
            'session_id': session_id,
            'task_type': task_type,
            'filename': file_path.name,
            'file_path': str(file_path),
            'original_filename': filename.replace('_filtered', '')
        })
    
    filtered_df = pd.DataFrame(filtered_inventory)
    
    print(f"üìä Filtered Data Overview:")
    print(f"   ‚Ä¢ Total filtered files: {len(filtered_df)}")
    print(f"   ‚Ä¢ Unique subjects: {filtered_df['subject_id'].nunique()}")
    print(f"   ‚Ä¢ Unique sessions: {filtered_df['session_id'].nunique()}")
    
    print("\nüìã Sample filtered files:")
    print(filtered_df[['subject_id', 'session_id', 'filename']].head(3))
    
    return filtered_df

# Load filtered data inventory
try:
    filtered_inventory_df = load_filtered_data_inventory(Path('EEG_Preprocessing_Output'))
except Exception as e:
    print(f"‚ùå Error loading filtered data: {e}")
    print("Please run the filtering pipeline first.")
    filtered_inventory_df = pd.DataFrame()

üìÅ LOADING FILTERED DATA INVENTORY
üìä Filtered Data Overview:
   ‚Ä¢ Total filtered files: 419
   ‚Ä¢ Unique subjects: 27
   ‚Ä¢ Unique sessions: 5

üìã Sample filtered files:
  subject_id session_id                              filename
0     sub-01     follow  sub-01_follow_run-1_eeg_filtered.fif
1     sub-01     follow  sub-01_follow_run-2_eeg_filtered.fif
2     sub-01     follow  sub-01_follow_run-3_eeg_filtered.fif


## Step 3: Bad Channel Detection Pipeline

In [3]:
# %% Cell 3: PROPERLY FIXED Comprehensive Bad Channel Detection

def calculate_channel_correlation(data, channel_idx):
    """
    PROPERLY FIXED: Calculate how well a channel correlates with NEARBY channels.
    This is the correct approach for EEG bad channel detection.
    """
    channel_data = data[channel_idx]
    
    # For EEG, we care about correlation with NEARBY channels, not all channels
    # This is more robust and physiologically meaningful
    nearby_indices = get_nearby_channel_indices(channel_idx, data.shape[0])
    
    correlations = []
    for i in nearby_indices:
        if i == channel_idx:
            continue
            
        try:
            corr = np.corrcoef(channel_data, data[i])[0, 1]
            if not np.isnan(corr):
                correlations.append(corr)
        except:
            continue
    
    return np.mean(correlations) if correlations else 0.0

def get_nearby_channel_indices(channel_idx, total_channels, max_neighbors=8):
    """
    Get indices of nearby channels for correlation calculation.
    For simplicity, we use adjacent channels.
    """
    # Simple approach: take channels within a window around the target
    start = max(0, channel_idx - max_neighbors // 2)
    end = min(total_channels, channel_idx + max_neighbors // 2 + 1)
    
    nearby_indices = list(range(start, end))
    
    # Remove the target channel itself
    if channel_idx in nearby_indices:
        nearby_indices.remove(channel_idx)
    
    # Ensure we have some neighbors
    if len(nearby_indices) < 2:
        # Fallback: use all other channels
        nearby_indices = [i for i in range(total_channels) if i != channel_idx]
        nearby_indices = nearby_indices[:max_neighbors]  # Limit to reasonable number
    
    return nearby_indices

def calculate_hurst_exponent(time_series):
    """Calculate Hurst exponent for channel data."""
    try:
        lags = range(2, min(20, len(time_series)//10))
        tau = [np.std(np.subtract(time_series[lag:], time_series[:-lag])) for lag in lags]
        poly = np.polyfit(np.log(lags), np.log(tau), 1)
        return poly[0]
    except:
        return 0.5

def calculate_line_noise_ratio(data, sfreq, line_freq=50):
    """Calculate line noise ratio at specified frequency."""
    try:
        from scipy import signal
        freqs, psd = signal.welch(data, sfreq, nperseg=min(1024, len(data)))
        line_idx = np.argmin(np.abs(freqs - line_freq))
        noise_band = [line_freq-2, line_freq+2]
        noise_idx = (freqs >= noise_band[0]) & (freqs <= noise_band[1])
        baseline_idx = (freqs >= line_freq-10) & (freqs <= line_freq-5) | (freqs >= line_freq+5) & (freqs <= line_freq+10)
        
        line_power = np.mean(psd[noise_idx])
        baseline_power = np.mean(psd[baseline_idx])
        
        return line_power / baseline_power if baseline_power > 0 else 1
    except:
        return 1

def calculate_adaptive_variance_threshold(channel_metrics):
    """
    Calculate adaptive variance threshold based on data distribution.
    """
    variances = [metrics['variance'] for metrics in channel_metrics.values()]
    median_var = np.median(variances)
    
    # Conservative thresholds for clean EEG
    high_threshold = median_var * 20  # Very conservative
    low_threshold = median_var * 0.05  # Very conservative
    
    return {'high': high_threshold, 'low': low_threshold}

def detect_bad_channels_statistical(raw, channel_metrics):
    """Detect bad channels using statistical thresholds."""
    bad_channels = []
    
    # Calculate adaptive thresholds
    variances = [metrics['variance'] for metrics in channel_metrics.values()]
    median_var = np.median(variances)
    var_thresholds = calculate_adaptive_variance_threshold(channel_metrics)
    
    # Calculate correlation statistics
    correlations = [metrics['correlation_with_others'] for metrics in channel_metrics.values()]
    median_corr = np.median(correlations)
    
    print(f"   üìä Variance - Median: {median_var:.1f} ¬µV¬≤, Thresholds: {var_thresholds['low']:.1f}-{var_thresholds['high']:.1f}")
    print(f"   üìä Correlation - Median: {median_corr:.3f}, Range: {np.min(correlations):.3f}-{np.max(correlations):.3f}")
    
    for ch_name, metrics in channel_metrics.items():
        reasons = []
        
        # Check variance (too high or too low) - VERY CONSERVATIVE
        if metrics['variance'] > var_thresholds['high']:
            reasons.append(f"high_var({metrics['variance']:.1f})")
        elif metrics['variance'] < var_thresholds['low']:
            reasons.append(f"low_var({metrics['variance']:.1f})")
            
        # Check correlation - ADAPTIVE threshold based on data
        corr_threshold = max(0.3, median_corr * 0.5)  # Adaptive threshold
        if metrics['correlation_with_others'] < corr_threshold:
            reasons.append(f"low_corr({metrics['correlation_with_others']:.3f})")
            
        # Check amplitude - CONSERVATIVE
        if metrics['max_amplitude'] > 150:  # Conservative for clean EEG
            reasons.append(f"high_amp({metrics['max_amplitude']:.1f})")
        
        # Only mark as bad if we have clear reasons
        if reasons:
            bad_channels.append(ch_name)
            print(f"      üö® {ch_name}: {', '.join(reasons)}")
    
    return bad_channels

def detect_bad_channels_comprehensive(raw, method='auto'):
    """
    Detect bad channels using multiple criteria and methods.
    """
    print(f"üîç Detecting bad channels using {method} method...")
    
    # Get basic channel information
    ch_names = raw.ch_names
    data = raw.get_data() * 1e6  # Convert to ¬µV
    
    bad_channels_results = {
        'method': method,
        'total_channels': len(ch_names),
        'channels_checked': ch_names,
        'bad_channels_identified': [],
        'detection_metrics': {},
        'channel_quality_scores': {}
    }
    
    # Calculate channel quality metrics
    print("   üìà Calculating channel metrics...")
    channel_metrics = {}
    for i, ch_name in enumerate(ch_names):
        ch_data = data[i]
        
        metrics = {
            'variance': np.var(ch_data),
            'mean_amplitude': np.mean(np.abs(ch_data)),
            'max_amplitude': np.max(np.abs(ch_data)),
            'hurst_exponent': calculate_hurst_exponent(ch_data),
            'correlation_with_others': calculate_channel_correlation(data, i),
            'line_noise_ratio': calculate_line_noise_ratio(ch_data, raw.info['sfreq'])
        }
        channel_metrics[ch_name] = metrics
    
    bad_channels_results['channel_quality_scores'] = channel_metrics
    
    # Use statistical method (most reliable)
    bad_channels = detect_bad_channels_statistical(raw, channel_metrics)
    bad_channels_results['bad_channels_identified'] = bad_channels
    bad_channels_results['detection_metrics']['method'] = 'statistical'
    
    print(f"   ‚úÖ Identified {len(bad_channels_results['bad_channels_identified'])} bad channels")
    
    return bad_channels_results

# Test the PROPERLY fixed function
if __name__ == "__main__":
    print("üß™ TESTING PROPERLY FIXED CORRELATION FUNCTION")
    print("=" * 50)
    
    # Create realistic test data - channels that should correlate
    np.random.seed(42)
    n_samples = 1000
    base_signal = np.random.randn(n_samples)
    
    test_data = np.zeros((5, n_samples))
    test_data[0] = base_signal  # Channel 0
    test_data[1] = base_signal * 0.9 + np.random.randn(n_samples) * 0.1  # Highly correlated
    test_data[2] = base_signal * 0.8 + np.random.randn(n_samples) * 0.2  # Correlated  
    test_data[3] = base_signal * 0.3 + np.random.randn(n_samples) * 0.7  # Weakly correlated
    test_data[4] = np.random.randn(n_samples)  # Uncorrelated noise
    
    print("Testing Channel 0 (should have high correlation):")
    test_corr = calculate_channel_correlation(test_data, 0)
    print(f"   Calculated correlation: {test_corr:.3f}")
    
    print("Testing Channel 4 (should have low correlation):")
    test_corr_noise = calculate_channel_correlation(test_data, 4)
    print(f"   Calculated correlation: {test_corr_noise:.3f}")
    
    # Verify with manual calculation
    manual_corr = np.corrcoef(test_data[0], test_data[1])[0, 1]
    print(f"   Manual correlation check (0 vs 1): {manual_corr:.3f}")
    
    if test_corr > 0.5 and test_corr_noise < 0.5:
        print("üéâ ‚úÖ CORRELATION FUNCTION PROPERLY FIXED!")
    else:
        print("‚ùå Function still needs adjustment")

üß™ TESTING PROPERLY FIXED CORRELATION FUNCTION
Testing Channel 0 (should have high correlation):
   Calculated correlation: 0.575
Testing Channel 4 (should have low correlation):
   Calculated correlation: -0.020
   Manual correlation check (0 vs 1): 0.994
üéâ ‚úÖ CORRELATION FUNCTION PROPERLY FIXED!


## Step 4: Apply Bad Channel Detection to Sample File

In [4]:
# %% Cell 4: Test Bad Channel Detection with FIXED Function

def verify_correlation_function(raw):
    """Verify our correlation function is working correctly."""
    print("üîç VERIFYING CORRELATION FUNCTION")
    data = raw.get_data() * 1e6
    
    # Test our function vs manual calculation
    test_channel = 0  # AF3
    our_correlation = calculate_channel_correlation(data, test_channel)
    
    # Manual calculation for comparison
    manual_correlations = []
    for i in range(1, min(5, data.shape[0])):  # Compare with first 4 other channels
        manual_corr = np.corrcoef(data[test_channel], data[i])[0, 1]
        manual_correlations.append(manual_corr)
    manual_avg = np.mean(manual_correlations)
    
    print(f"   ‚Ä¢ Our function result: {our_correlation:.3f}")
    print(f"   ‚Ä¢ Manual average: {manual_avg:.3f}")
    print(f"   ‚Ä¢ Match: {'‚úÖ' if abs(our_correlation - manual_avg) < 0.2 else '‚ùå'}")
    
    return our_correlation

def diagnose_correlation_issue(raw):
    """
    Diagnose channel correlations.
    """
    print("\nüî¨ DATA DIAGNOSIS")
    print("=" * 40)
    
    data = raw.get_data() * 1e6
    
    # Check basic statistics
    print("üìä Data Statistics:")
    print(f"   ‚Ä¢ Shape: {data.shape}")
    print(f"   ‚Ä¢ Global mean: {np.mean(data):.2f} ¬µV")
    print(f"   ‚Ä¢ Global std: {np.std(data):.2f} ¬µV")
    print(f"   ‚Ä¢ Data range: {np.min(data):.1f} to {np.max(data):.1f} ¬µV")
    
    # Check if data is already average referenced
    channel_means = np.mean(data, axis=1)
    print(f"   ‚Ä¢ Channel means range: {np.min(channel_means):.2f} to {np.max(channel_means):.2f} ¬µV")
    
    # Quick correlation check
    print(f"\nüîç Quick Correlation Check:")
    for i in range(min(2, data.shape[0])):
        our_corr = calculate_channel_correlation(data, i)
        manual_corr = np.corrcoef(data[i], data[(i+1)%data.shape[0]])[0, 1]
        print(f"   ‚Ä¢ {raw.ch_names[i]}: our={our_corr:.3f}, manual={manual_corr:.3f}")

def test_bad_channel_detection(filtered_inventory_df):
    """
    Test bad channel detection on a sample filtered file.
    """
    print("üß™ TESTING BAD CHANNEL DETECTION WITH FIXED FUNCTION")
    print("=" * 60)
    
    if len(filtered_inventory_df) == 0:
        print("‚ùå No filtered files available for testing")
        return None
    
    # Use first available file
    sample_file = filtered_inventory_df.iloc[0]
    print(f"üìÅ Testing with: {sample_file['filename']}")
    
    try:
        # Load filtered data
        print("üîÑ Loading filtered data...")
        raw_filtered = mne.io.read_raw_fif(sample_file['file_path'], preload=True, verbose=False)
        print(f"‚úÖ Loaded: {len(raw_filtered.ch_names)} channels, {raw_filtered.times[-1]:.1f}s")
        
        # VERIFY OUR FUNCTION FIRST
        verify_correlation_function(raw_filtered)
        
        # RUN DIAGNOSIS
        diagnose_correlation_issue(raw_filtered)
        
        # Detect bad channels
        print("\nüîç Running bad channel detection...")
        bad_channels_results = detect_bad_channels_comprehensive(raw_filtered, method='auto')
        
        # Display results
        print(f"\nüìä BAD CHANNEL DETECTION RESULTS:")
        print(f"   ‚Ä¢ Method: {bad_channels_results['method']}")
        print(f"   ‚Ä¢ Total channels: {bad_channels_results['total_channels']}")
        print(f"   ‚Ä¢ Bad channels identified: {len(bad_channels_results['bad_channels_identified'])}")
        
        if bad_channels_results['bad_channels_identified']:
            print(f"   ‚Ä¢ Bad channels: {bad_channels_results['bad_channels_identified']}")
        else:
            print("   ‚Ä¢ No bad channels detected - data looks clean! ‚úÖ")
        
        # Show correlation statistics
        correlations = [metrics['correlation_with_others'] for metrics in bad_channels_results['channel_quality_scores'].values()]
        print(f"\nüìà CORRELATION STATISTICS:")
        print(f"   ‚Ä¢ Range: {np.min(correlations):.3f} - {np.max(correlations):.3f}")
        print(f"   ‚Ä¢ Median: {np.median(correlations):.3f}")
        print(f"   ‚Ä¢ Mean: {np.mean(correlations):.3f}")
        
        # Show sample channel metrics
        print(f"\nüìä SAMPLE CHANNEL METRICS (first 3):")
        for i, (ch_name, metrics) in enumerate(list(bad_channels_results['channel_quality_scores'].items())[:3]):
            print(f"   {ch_name}:")
            print(f"      ‚Ä¢ Correlation: {metrics['correlation_with_others']:.3f}")
            print(f"      ‚Ä¢ Variance: {metrics['variance']:.1f} ¬µV¬≤")
            print(f"      ‚Ä¢ Max amplitude: {metrics['max_amplitude']:.1f} ¬µV")
        
        return raw_filtered, bad_channels_results
        
    except Exception as e:
        print(f"‚ùå Error in bad channel detection: {e}")
        import traceback
        traceback.print_exc()
        return None, None

# RUN THE TEST
print("üöÄ EXECUTING BAD CHANNEL DETECTION TEST...")
sample_raw, bad_ch_results = test_bad_channel_detection(filtered_inventory_df)

if sample_raw is not None and bad_ch_results is not None:
    print(f"\nüéâ TEST SUCCESSFUL! Ready for visualization.")
    print(f"   ‚Ä¢ Identified {len(bad_ch_results['bad_channels_identified'])} bad channels")
else:
    print(f"\nüí• TEST FAILED! Check error above.")

üöÄ EXECUTING BAD CHANNEL DETECTION TEST...
üß™ TESTING BAD CHANNEL DETECTION WITH FIXED FUNCTION
üìÅ Testing with: sub-01_follow_run-1_eeg_filtered.fif
üîÑ Loading filtered data...
‚úÖ Loaded: 40 channels, 200.1s
üîç VERIFYING CORRELATION FUNCTION
   ‚Ä¢ Our function result: 0.775
   ‚Ä¢ Manual average: 0.775
   ‚Ä¢ Match: ‚úÖ

üî¨ DATA DIAGNOSIS
üìä Data Statistics:
   ‚Ä¢ Shape: (40, 50019)
   ‚Ä¢ Global mean: -0.00 ¬µV
   ‚Ä¢ Global std: 2.99 ¬µV
   ‚Ä¢ Data range: -30.0 to 31.6 ¬µV
   ‚Ä¢ Channel means range: -0.01 to 0.01 ¬µV

üîç Quick Correlation Check:
   ‚Ä¢ AF3: our=0.775, manual=0.694
   ‚Ä¢ AF4: our=0.786, manual=0.824

üîç Running bad channel detection...
üîç Detecting bad channels using auto method...
   üìà Calculating channel metrics...
   üìä Variance - Median: 7.2 ¬µV¬≤, Thresholds: 0.4-144.7
   üìä Correlation - Median: 0.461, Range: -0.112-0.843
      üö® FC5: low_corr(0.270)
      üö® C3: low_corr(0.243)
      üö® C4: low_corr(0.268)
      üö® C5:

In [5]:
# Cell 4.5: to investigate the 13 bad channels:

def diagnose_bad_channel_reasons(bad_ch_results):
    """Detailed analysis of why channels were marked bad"""
    print("üîç DETAILED BAD CHANNEL ANALYSIS")
    print("=" * 50)
    
    bad_channels = bad_ch_results['bad_channels_identified']
    metrics = bad_ch_results['channel_quality_scores']
    
    print(f"Identified {len(bad_channels)} bad channels:")
    
    for ch_name in bad_channels:
        ch_metrics = metrics[ch_name]
        reasons = []
        
        # Check each criterion
        variances = [m['variance'] for m in metrics.values()]
        median_var = np.median(variances)
        
        if ch_metrics['variance'] > median_var * 20:
            reasons.append(f"high_var({ch_metrics['variance']:.1f} vs median {median_var:.1f})")
        elif ch_metrics['variance'] < median_var * 0.05:
            reasons.append(f"low_var({ch_metrics['variance']:.1f} vs median {median_var:.1f})")
            
        correlations = [m['correlation_with_others'] for m in metrics.values()]
        median_corr = np.median(correlations)
        
        if ch_metrics['correlation_with_others'] < max(0.3, median_corr * 0.5):
            reasons.append(f"low_corr({ch_metrics['correlation_with_others']:.3f} vs median {median_corr:.3f})")
            
        if ch_metrics['max_amplitude'] > 150:
            reasons.append(f"high_amp({ch_metrics['max_amplitude']:.1f})")
            
        print(f"   {ch_name}: {', '.join(reasons)}")

# Run this diagnosis
if bad_ch_results is not None:
    diagnose_bad_channel_reasons(bad_ch_results)

üîç DETAILED BAD CHANNEL ANALYSIS
Identified 13 bad channels:
   FC5: low_corr(0.270 vs median 0.461)
   C3: low_corr(0.243 vs median 0.461)
   C4: low_corr(0.268 vs median 0.461)
   C5: low_corr(-0.089 vs median 0.461)
   C6: low_corr(0.037 vs median 0.461)
   CP1: low_corr(0.161 vs median 0.461)
   CP2: low_corr(0.104 vs median 0.461)
   CP3: low_corr(0.069 vs median 0.461)
   CP4: low_corr(0.151 vs median 0.461)
   CP5: low_corr(0.151 vs median 0.461)
   CP6: low_corr(0.126 vs median 0.461)
   Pz: low_corr(0.298 vs median 0.461)
   CPz: low_corr(-0.112 vs median 0.461)


In [None]:
# Add this diagnostic function to Cell 4 before the test
def diagnose_correlation_issue(raw):
    """
    Diagnose why channel correlations are so low.
    """
    print("\nüî¨ CORRELATION DIAGNOSIS")
    print("=" * 40)
    
    data = raw.get_data() * 1e6
    
    # Check basic statistics
    print("üìä Data Statistics:")
    print(f"   ‚Ä¢ Shape: {data.shape}")
    print(f"   ‚Ä¢ Global mean: {np.mean(data):.2f} ¬µV")
    print(f"   ‚Ä¢ Global std: {np.std(data):.2f} ¬µV")
    print(f"   ‚Ä¢ Data range: {np.min(data):.1f} to {np.max(data):.1f} ¬µV")
    
    # Check if data is already average referenced
    channel_means = np.mean(data, axis=1)
    print(f"   ‚Ä¢ Channel means range: {np.min(channel_means):.2f} to {np.max(channel_means):.2f} ¬µV")
    
    # Manual correlation check
    print(f"\nüîç Manual Correlation Check (first 3 channels):")
    for i in range(min(3, data.shape[0])):
        for j in range(i+1, min(4, data.shape[0])):
            corr = np.corrcoef(data[i], data[j])[0,1]
            print(f"   ‚Ä¢ {raw.ch_names[i]} vs {raw.ch_names[j]}: {corr:.3f}")
    
    # Check reference
    print(f"\nüìã Reference Info:")
    print(f"   ‚Ä¢ Reference: {getattr(raw.info, 'custom_ref_applied', 'Unknown')}")
    print(f"   ‚Ä¢ Description: {raw.info.get('description', 'Not specified')}")

# Then modify the test function to include diagnosis:
def test_bad_channel_detection(filtered_inventory_df):
    """
    Test bad channel detection on a sample filtered file.
    """
    print("üß™ TESTING BAD CHANNEL DETECTION")
    print("=" * 50)
    
    if len(filtered_inventory_df) == 0:
        print("‚ùå No filtered files available for testing")
        return None
    
    # Use first available file
    sample_file = filtered_inventory_df.iloc[0]
    print(f"üìÅ Testing with: {sample_file['filename']}")
    print(f"üìÅ File path: {sample_file['file_path']}")
    
    try:
        # Load filtered data
        print("üîÑ Loading filtered data...")
        raw_filtered = mne.io.read_raw_fif(sample_file['file_path'], preload=True, verbose=False)
        print(f"‚úÖ Loaded: {len(raw_filtered.ch_names)} channels, {raw_filtered.times[-1]:.1f}s")
        
        # RUN DIAGNOSIS FIRST
        diagnose_correlation_issue(raw_filtered)
        
        # Detect bad channels (with relaxed thresholds for now)
        print("\nüîç Running bad channel detection...")
        bad_channels_results = detect_bad_channels_comprehensive(raw_filtered, method='auto')
        
        # Display results
        print(f"\nüìä BAD CHANNEL DETECTION RESULTS:")
        print(f"   ‚Ä¢ Method: {bad_channels_results['method']}")
        print(f"   ‚Ä¢ Total channels: {bad_channels_results['total_channels']}")
        print(f"   ‚Ä¢ Bad channels identified: {len(bad_channels_results['bad_channels_identified'])}")
        
        if bad_channels_results['bad_channels_identified']:
            print(f"   ‚Ä¢ Bad channels: {bad_channels_results['bad_channels_identified']}")
        else:
            print("   ‚Ä¢ No bad channels detected")
        
        return raw_filtered, bad_channels_results
        
    except Exception as e:
        print(f"‚ùå Error in bad channel detection: {e}")
        import traceback
        traceback.print_exc()
        return None, None

# ACTUALLY RUN THE TEST
print("üöÄ EXECUTING BAD CHANNEL DETECTION TEST...")
sample_raw, bad_ch_results = test_bad_channel_detection(filtered_inventory_df)

## Step 5: Visualize Bad Channel Detection

In [None]:
# %% Cell 5: Create Bad Channel Visualization
def create_bad_channel_visualization(raw, bad_channels_results, file_info, output_path):
    """
    Create comprehensive visualization for bad channel detection results.
    
    Args:
        raw: MNE Raw object
        bad_channels_results: Bad channel detection results
        file_info: File metadata
        output_path: Output directory path
    """
    print("üìà CREATING BAD CHANNEL VISUALIZATION")
    
    # Use non-interactive backend
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle(f'Bad Channel Detection: {file_info["subject_id"]} - {file_info["session_id"]}', 
                 fontsize=16, fontweight='bold')
    
    data = raw.get_data() * 1e6
    ch_names = raw.ch_names
    bad_channels = bad_channels_results['bad_channels_identified']
    channel_metrics = bad_channels_results['channel_quality_scores']
    
    # 1. Channel variances with bad channels highlighted
    ax1 = axes[0, 0]
    variances = [channel_metrics[ch]['variance'] for ch in ch_names]
    colors = ['red' if ch in bad_channels else 'skyblue' for ch in ch_names]
    
    bars = ax1.bar(range(len(variances)), variances, color=colors, alpha=0.7)
    ax1.set_title('Channel Variances (Red = Bad Channels)')
    ax1.set_xlabel('Channel Index')
    ax1.set_ylabel('Variance (¬µV¬≤)')
    ax1.grid(True, alpha=0.3)
    
    # 2. Channel correlation matrix
    ax2 = axes[0, 1]
    try:
        # Calculate correlation matrix
        corr_matrix = np.corrcoef(data)
        im = ax2.imshow(corr_matrix, cmap='RdBu_r', aspect='auto', vmin=-1, vmax=1)
        
        # Mark bad channels
        bad_indices = [i for i, ch in enumerate(ch_names) if ch in bad_channels]
        for idx in bad_indices:
            ax2.axhline(idx - 0.5, color='red', linewidth=2)
            ax2.axvline(idx - 0.5, color='red', linewidth=2)
        
        ax2.set_title('Channel Correlation Matrix\n(Red Lines = Bad Channels)')
        ax2.set_xlabel('Channel Index')
        ax2.set_ylabel('Channel Index')
        plt.colorbar(im, ax=ax2, shrink=0.6)
    except Exception as e:
        ax2.text(0.5, 0.5, f'Correlation matrix failed: {str(e)}', 
                transform=ax2.transAxes, ha='center')
        ax2.set_title('Channel Correlation Matrix')
    
    # 3. Channel quality scatter plot
    ax3 = axes[1, 0]
    variances = [channel_metrics[ch]['variance'] for ch in ch_names]
    correlations = [channel_metrics[ch]['correlation_with_others'] for ch in ch_names]
    
    colors = ['red' if ch in bad_channels else 'blue' for ch in ch_names]
    sizes = [100 if ch in bad_channels else 50 for ch in ch_names]
    
    scatter = ax3.scatter(variances, correlations, c=colors, s=sizes, alpha=0.7)
    ax3.set_title('Channel Quality: Variance vs Correlation')
    ax3.set_xlabel('Variance (¬µV¬≤)')
    ax3.set_ylabel('Mean Correlation')
    ax3.grid(True, alpha=0.3)
    
    # Add thresholds
    median_var = np.median(variances)
    ax3.axvline(median_var * 5, color='red', linestyle='--', alpha=0.5, label='High var threshold')
    ax3.axvline(median_var * 0.01, color='orange', linestyle='--', alpha=0.5, label='Low var threshold')
    ax3.axhline(0.4, color='green', linestyle='--', alpha=0.5, label='Low corr threshold')
    ax3.legend()
    
    # 4. Bad channel locations
    ax4 = axes[1, 1]
    try:
        from mne.viz import plot_sensors
        plot_sensors(raw.info, show_names=True, axes=ax4, show=False)
        
        # Highlight bad channels
        if bad_channels:
            bad_ch_idx = [raw.ch_names.index(ch) for ch in bad_channels]
            ax4.scatter([], [], color='red', s=100, label='Bad channels')  # For legend
            # Note: Sensor positions would need to be accessed for precise highlighting
        ax4.set_title('Channel Locations\n(Red = Bad Channels)')
        ax4.legend()
    except Exception as e:
        ax4.text(0.5, 0.5, f'Sensor plot error: {str(e)}', 
                transform=ax4.transAxes, ha='center')
        ax4.set_title('Channel Locations')
    
    plt.tight_layout()
    
    # Save figure
    original_name = file_info['original_filename']
    fig_path = output_path / 'preprocessed_data' / 'visualizations' / f'{original_name}_bad_channels.png'
    fig_path.parent.mkdir(parents=True, exist_ok=True)
    
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    plt.close(fig)
    
    print(f"‚úÖ Bad channel visualization saved: {fig_path.name}")
    
    # Switch back to interactive backend
    matplotlib.use('module://matplotlib_inline.backend_inline')
    
    return fig_path

# Create visualization if test was successful
if sample_raw is not None and bad_ch_results is not None:
    sample_file_info = filtered_inventory_df.iloc[0]
    bad_ch_viz_path = create_bad_channel_visualization(
        sample_raw, bad_ch_results, sample_file_info, Path('EEG_Preprocessing_Output')
    )

## Step 6: ICA Processing Pipeline

In [None]:
# %% Cell 6: ICA Processing Implementation
def perform_ica_processing(raw, n_components=0.95, method='fastica', random_state=42):
    """
    Perform Independent Component Analysis on EEG data.
    
    Args:
        raw: MNE Raw object
        n_components: Number of components (float for variance, int for exact)
        method: ICA method ('fastica', 'infomax', 'picard')
        random_state: Random seed for reproducibility
        
    Returns:
        tuple: (ICA object, ICA results dictionary)
    """
    print("üß† PERFORMING INDEPENDENT COMPONENT ANALYSIS")
    print(f"   ‚Ä¢ Method: {method}")
    print(f"   ‚Ä¢ Components: {n_components}")
    print(f"   ‚Ä¢ Random state: {random_state}")
    
    # Create and fit ICA
    ica = ICA(
        n_components=n_components,
        method=method,
        random_state=random_state,
        max_iter=800,
        fit_params=dict(extended=True) if method == 'infomax' else None
    )
    
    # Fit ICA
    ica.fit(raw, verbose=False)
    
    # Analyze components
    ica_results = {
        'n_components': ica.n_components_,
        'method': method,
        'explained_variance': ica.pca_explained_variance_ratio_.sum(),
        'component_characteristics': analyze_ica_components(ica, raw),
        'fitting_time': 'N/A'  # Could be enhanced with timing
    }
    
    print(f"‚úÖ ICA completed: {ica.n_components_} components")
    print(f"   ‚Ä¢ Explained variance: {ica_results['explained_variance']:.3f}")
    
    return ica, ica_results

def analyze_ica_components(ica, raw):
    """
    Analyze ICA components for artifact characteristics.
    
    Args:
        ica: Fitted ICA object
        raw: Original raw data
        
    Returns:
        dict: Component analysis results
    """
    component_analysis = {}
    
    for idx in range(ica.n_components_):
        # Get component properties
        component_data = ica.get_components()[:, idx]
        
        analysis = {
            'variance_explained': ica.pca_explained_variance_ratio_[idx] if idx < len(ica.pca_explained_variance_ratio_) else 0,
            'max_amplitude': np.max(np.abs(component_data)),
            'topographic_std': np.std(component_data),
            'is_eyeblink_likely': is_eyeblink_component(component_data, ica, idx),
            'is_cardiac_likely': is_cardiac_component(component_data, raw.info['sfreq']),
            'is_noise_likely': is_noise_component(component_data)
        }
        
        component_analysis[f'component_{idx:02d}'] = analysis
    
    return component_analysis

def is_eyeblink_component(component_data, ica, component_idx):
    """Heuristic for eyeblink component detection."""
    # Frontal channels typically have high weights for eyeblinks
    frontal_channels = ['Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8']
    frontal_indices = [i for i, ch in enumerate(ica.ch_names) if ch in frontal_channels]
    
    if not frontal_indices:
        return False
    
    frontal_weights = np.abs(component_data[frontal_indices])
    max_frontal = np.max(frontal_weights) if len(frontal_weights) > 0 else 0
    max_overall = np.max(np.abs(component_data))
    
    return max_frontal / max_overall > 0.5 if max_overall > 0 else False

def is_cardiac_component(component_data, sfreq):
    """Heuristic for cardiac component detection."""
    # Simple variance-based heuristic
    return np.var(component_data) > np.median(np.var(component_data)) * 2

def is_noise_component(component_data):
    """Heuristic for noise component detection."""
    # High kurtosis often indicates noise
    from scipy.stats import kurtosis
    return kurtosis(component_data) > 3

## Step 7: Test ICA Processing

In [None]:
# %% Cell 7: Test ICA Processing on Sample File
def test_ica_processing(raw, file_info):
    """
    Test ICA processing on a sample file.
    
    Args:
        raw: MNE Raw object
        file_info: File metadata
        
    Returns:
        tuple: (ICA object, ICA results)
    """
    print("üß™ TESTING ICA PROCESSING")
    print("=" * 50)
    print(f"üìÅ Processing: {file_info['filename']}")
    
    try:
        # Perform ICA
        ica, ica_results = perform_ica_processing(raw, n_components=0.95, method='fastica')
        
        # Display results
        print(f"\nüìä ICA PROCESSING RESULTS:")
        print(f"   ‚Ä¢ Components extracted: {ica_results['n_components']}")
        print(f"   ‚Ä¢ Total variance explained: {ica_results['explained_variance']:.3f}")
        print(f"   ‚Ä¢ Method: {ica_results['method']}")
        
        # Component analysis summary
        component_chars = ica_results['component_characteristics']
        eyeblink_components = [comp for comp, chars in component_chars.items() 
                              if chars['is_eyeblink_likely']]
        cardiac_components = [comp for comp, chars in component_chars.items() 
                             if chars['is_cardiac_likely']]
        noise_components = [comp for comp, chars in component_chars.items() 
                           if chars['is_noise_likely']]
        
        print(f"\nüìà COMPONENT ANALYSIS:")
        print(f"   ‚Ä¢ Potential eyeblink components: {len(eyeblink_components)}")
        print(f"   ‚Ä¢ Potential cardiac components: {len(cardiac_components)}")
        print(f"   ‚Ä¢ Potential noise components: {len(noise_components)}")
        
        if eyeblink_components:
            print(f"   ‚Ä¢ Eyeblink components: {eyeblink_components[:3]}...")
        
        return ica, ica_results
        
    except Exception as e:
        print(f"‚ùå Error in ICA processing: {e}")
        return None, None

# Test ICA processing if sample data is available
if sample_raw is not None:
    sample_file_info = filtered_inventory_df.iloc[0]
    ica_obj, ica_results = test_ica_processing(sample_raw, sample_file_info)
else:
    ica_obj, ica_results = None, None

## Step 8: Visualize ICA Results

In [None]:
# %% Cell 8: Create ICA Visualization
def create_ica_visualization(raw, ica, ica_results, file_info, output_path, n_components_show=8):
    """
    Create comprehensive visualization for ICA results.
    
    Args:
        raw: MNE Raw object
        ica: Fitted ICA object
        ica_results: ICA analysis results
        file_info: File metadata
        output_path: Output directory path
        n_components_show: Number of components to display
    """
    print("üìà CREATING ICA VISUALIZATION")
    
    # Use non-interactive backend
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    
    n_components = min(ica.n_components_, n_components_show)
    n_rows = (n_components + 1) // 2  # Adjust layout
    
    fig, axes = plt.subplots(n_rows, 2, figsize=(15, 4 * n_rows))
    fig.suptitle(f'ICA Components: {file_info["subject_id"]} - {file_info["session_id"]}\n'
                 f'{ica.n_components_} components, {ica_results["explained_variance"]:.3f} variance explained',
                 fontsize=14, fontweight='bold')
    
    # Flatten axes for easier indexing
    if n_rows > 1:
        axes_flat = axes.flatten()
    else:
        axes_flat = [axes] if n_components == 1 else axes
    
    # Plot each component
    for idx in range(n_components):
        ax = axes_flat[idx]
        
        try:
            # Plot component topography
            ica.plot_components(picks=[idx], axes=ax, show=False)
            ax.set_title(f'Component {idx}', fontweight='bold')
            
            # Add component type annotation
            comp_key = f'component_{idx:02d}'
            comp_chars = ica_results['component_characteristics'].get(comp_key, {})
            
            component_type = []
            if comp_chars.get('is_eyeblink_likely'):
                component_type.append('üëÅÔ∏è')
            if comp_chars.get('is_cardiac_likely'):
                component_type.append('‚ù§Ô∏è')
            if comp_chars.get('is_noise_likely'):
                component_type.append('üì¢')
            
            if component_type:
                ax.text(0.02, 0.98, ' '.join(component_type), 
                       transform=ax.transAxes, fontsize=12,
                       verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
                       
        except Exception as e:
            ax.text(0.5, 0.5, f'Component {idx}\nPlot failed', 
                   transform=ax.transAxes, ha='center', va='center')
            ax.set_title(f'Component {idx}')
    
    # Hide unused subplots
    for idx in range(n_components, len(axes_flat)):
        axes_flat[idx].set_visible(False)
    
    plt.tight_layout()
    
    # Save figure
    original_name = file_info['original_filename']
    fig_path = output_path / 'preprocessed_data' / 'visualizations' / f'{original_name}_ica_components.png'
    fig_path.parent.mkdir(parents=True, exist_ok=True)
    
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    plt.close(fig)
    
    print(f"‚úÖ ICA visualization saved: {fig_path.name}")
    
    # Switch back to interactive backend
    matplotlib.use('module://matplotlib_inline.backend_inline')
    
    return fig_path

# Create ICA visualization if available
if ica_obj is not None and ica_results is not None:
    ica_viz_path = create_ica_visualization(
        sample_raw, ica_obj, ica_results, sample_file_info, Path('EEG_Preprocessing_Output')
    )

## Step 9: Batch Processing Pipeline

In [None]:
# %% Cell 9: Batch Processing for Bad Channels & ICA
def process_bad_channels_ica_batch(filtered_inventory_df, output_path, batch_size=20, max_files=None):
    """
    Process all filtered files through bad channel detection and ICA.
    
    Args:
        filtered_inventory_df: DataFrame of filtered files
        output_path: Output directory path
        batch_size: Number of files to process in each batch
        max_files: Maximum number of files to process (None for all)
    """
    print("üöÄ STARTING BATCH PROCESSING: BAD CHANNELS & ICA")
    print("=" * 60)
    
    # Limit files if specified
    files_to_process = filtered_inventory_df
    if max_files and max_files < len(filtered_inventory_df):
        files_to_process = filtered_inventory_df.head(max_files)
    
    total_files = len(files_to_process)
    total_batches = (total_files + batch_size - 1) // batch_size
    
    print(f"üìä Processing {total_files} files in {total_batches} batches")
    print(f"üéØ Batch size: {batch_size}")
    
    processed_count = 0
    error_count = 0
    
    # Create output directories
    ica_cleaned_path = output_path / 'preprocessed_data' / 'ica_cleaned'
    ica_cleaned_path.mkdir(parents=True, exist_ok=True)
    
    bad_channel_reports_path = output_path / 'preprocessed_data' / 'quality_reports'
    bad_channel_reports_path.mkdir(parents=True, exist_ok=True)
    
    for batch_num in range(total_batches):
        batch_start = batch_num * batch_size
        batch_end = min((batch_num + 1) * batch_size, total_files)
        
        print(f"\n{'='*50}")
        print(f"üîÑ BATCH {batch_num + 1}/{total_batches} (Files {batch_start + 1}-{batch_end})")
        print(f"{'='*50}")
        
        for idx in range(batch_start, batch_end):
            row = files_to_process.iloc[idx]
            file_num = idx + 1
            
            print(f"   [{file_num}/{total_files}] {row['subject_id']} {row['session_id']}")
            
            try:
                # STEP 1: Load filtered data
                raw_filtered = mne.io.read_raw_fif(row['file_path'], preload=True, verbose=False)
                print(f"      ‚úÖ Loaded: {len(raw_filtered.ch_names)} channels")
                
                # STEP 2: Bad channel detection
                bad_channels_results = detect_bad_channels_comprehensive(raw_filtered, method='auto')
                
                # Mark bad channels in the data
                if bad_channels_results['bad_channels_identified']:
                    raw_filtered.info['bads'] = bad_channels_results['bad_channels_identified']
                    print(f"      üî¥ Marked {len(bad_channels_results['bad_channels_identified'])} bad channels")
                
                # STEP 3: ICA processing
                ica, ica_results = perform_ica_processing(raw_filtered, n_components=0.95)
                
                # STEP 4: Save ICA-cleaned data
                original_name = row['original_filename']
                ica_filename = f"{original_name}_ica_cleaned.fif"
                ica_filepath = ica_cleaned_path / ica_filename
                
                # Apply ICA (remove components automatically classified as artifacts)
                artifact_components = identify_artifact_components(ica_results)
                if artifact_components:
                    print(f"      üßπ Removing {len(artifact_components)} artifact components")
                    ica.apply(raw_filtered, exclude=artifact_components)
                
                # Save ICA-cleaned data
                raw_filtered.save(ica_filepath, overwrite=True, verbose=False)
                
                # STEP 5: Save reports
                # Bad channel report
                bad_ch_report = {
                    'subject_id': row['subject_id'],
                    'session_id': row['session_id'],
                    'task_type': row['task_type'],
                    'original_filename': original_name,
                    'bad_channels_detection': bad_channels_results,
                    'ica_components_removed': artifact_components,
                    'processing_timestamp': pd.Timestamp.now().isoformat()
                }
                
                bad_ch_report_path = bad_channel_reports_path / f"{original_name}_badch_ica_report.json"
                with open(bad_ch_report_path, 'w') as f:
                    json.dump(bad_ch_report, f, indent=2, default=str)
                
                # STEP 6: Create visualizations for first few files or periodically
                if file_num <= 10 or file_num % 50 == 0:
                    # Bad channel visualization
                    create_bad_channel_visualization(raw_filtered, bad_channels_results, row, output_path)
                    
                    # ICA visualization
                    create_ica_visualization(raw_filtered, ica, ica_results, row, output_path)
                
                processed_count += 1
                print(f"      ‚úÖ SUCCESS: {ica_filename}")
                
                # Cleanup
                del raw_filtered, ica
                
            except Exception as e:
                error_count += 1
                print(f"      ‚ùå ERROR: {str(e)[:80]}...")
                continue
        
        # Memory cleanup after each batch
        gc.collect()
        print(f"   üßπ Memory cleared after batch {batch_num + 1}")
    
    # Final summary
    print(f"\n{'='*60}")
    print("üéØ BATCH PROCESSING COMPLETED")
    print(f"{'='*60}")
    print(f"üìä Results:")
    print(f"   ‚Ä¢ Successfully processed: {processed_count}/{total_files}")
    print(f"   ‚Ä¢ Errors: {error_count}")
    print(f"   ‚Ä¢ ICA-cleaned files: {len(list(ica_cleaned_path.glob('*_ica_cleaned.fif')))}")
    
    return processed_count, error_count

def identify_artifact_components(ica_results):
    """Identify components to exclude based on automatic classification."""
    exclude_components = []
    
    for comp_key, comp_chars in ica_results['component_characteristics'].items():
        # Exclude components classified as artifacts
        if (comp_chars.get('is_eyeblink_likely') or 
            comp_chars.get('is_cardiac_likely') or 
            comp_chars.get('is_noise_likely')):
            
            comp_idx = int(comp_key.split('_')[1])
            exclude_components.append(comp_idx)
    
    return exclude_components

## Step 10: Run Batch Processing

In [None]:
# %% Cell 10: Execute Batch Processing
def execute_badchannel_ica_pipeline():
    """
    Execute the complete bad channel and ICA pipeline.
    """
    print("üéØ EXECUTING BAD CHANNEL & ICA PIPELINE")
    print("=" * 60)
    
    output_path = Path('EEG_Preprocessing_Output')
    
    # Check if filtered data exists
    filtered_path = output_path / 'preprocessed_data' / 'raw_cleaned'
    filtered_files = list(filtered_path.glob("*_filtered.fif"))
    
    if not filtered_files:
        print("‚ùå No filtered files found. Please run filtering pipeline first.")
        return
    
    print(f"üìÅ Found {len(filtered_files)} filtered files")
    
    # Load filtered inventory
    filtered_inventory_df = load_filtered_data_inventory(output_path)
    
    # Ask for processing parameters
    print("\n‚öôÔ∏è  PROCESSING PARAMETERS")
    max_files = input("Max files to process (Enter for all): ").strip()
    max_files = int(max_files) if max_files else None
    
    batch_size = input("Batch size (Enter for 20): ").strip()
    batch_size = int(batch_size) if batch_size else 20
    
    # Confirm processing
    total_to_process = min(max_files, len(filtered_inventory_df)) if max_files else len(filtered_inventory_df)
    response = input(f"\nProcess {total_to_process} files? (y/n): ")
    
    if response.lower() == 'y':
        processed, errors = process_bad_channels_ica_batch(
            filtered_inventory_df, output_path, 
            batch_size=batch_size, max_files=max_files
        )
        
        print(f"\nüéâ PIPELINE COMPLETED!")
        print(f"   ‚Ä¢ Processed: {processed} files")
        print(f"   ‚Ä¢ Errors: {errors} files")
        print(f"   ‚Ä¢ Output: preprocessed_data/ica_cleaned/")
        
    else:
        print("Processing cancelled.")

# Execute the pipeline
execute_badchannel_ica_pipeline()

## Step 11: Pipeline Progress & Next Steps

In [None]:
# %% Cell 11: Comprehensive Progress Check
def check_badchannel_ica_progress(output_path):
    """
    Check progress of bad channel detection and ICA pipeline.
    
    Args:
        output_path: Main output directory path
    """
    print("\nüìä BAD CHANNEL & ICA PIPELINE PROGRESS")
    print("=" * 50)
    
    # Count files in each stage
    raw_cleaned_path = output_path / 'preprocessed_data' / 'raw_cleaned'
    ica_cleaned_path = output_path / 'preprocessed_data' / 'ica_cleaned'
    quality_reports_path = output_path / 'preprocessed_data' / 'quality_reports'
    visualizations_path = output_path / 'preprocessed_data' / 'visualizations'
    
    # File counts
    filtered_files = list(raw_cleaned_path.glob("*_filtered.fif"))
    ica_cleaned_files = list(ica_cleaned_path.glob("*_ica_cleaned.fif"))
    badchannel_reports = list(quality_reports_path.glob("*_badch_ica_report.json"))
    badchannel_viz = list(visualizations_path.glob("*_bad_channels.png"))
    ica_viz = list(visualizations_path.glob("*_ica_components.png"))
    
    total_filtered = len(filtered_files)
    
    print(f"üìà PIPELINE PROGRESS STATISTICS:")
    print(f"   ‚Ä¢ Filtered files available: {total_filtered}")
    print(f"   ‚Ä¢ ICA-cleaned files: {len(ica_cleaned_files)}")
    print(f"   ‚Ä¢ Bad channel/ICA reports: {len(badchannel_reports)}")
    print(f"   ‚Ä¢ Bad channel visualizations: {len(badchannel_viz)}")
    print(f"   ‚Ä¢ ICA component visualizations: {len(ica_viz)}")
    
    # Completion percentages
    if total_filtered > 0:
        ica_pct = (len(ica_cleaned_files) / total_filtered) * 100
        report_pct = (len(badchannel_reports) / total_filtered) * 100
        
        print(f"\nüéØ COMPLETION STATUS:")
        print(f"   ‚Ä¢ ICA Processing: {ica_pct:.1f}%")
        print(f"   ‚Ä¢ Quality Reports: {report_pct:.1f}%")
    
    # Show sample files
    if ica_cleaned_files:
        print(f"\nüìù Sample ICA-cleaned files:")
        for f in ica_cleaned_files[:3]:
            file_size = f.stat().st_size / (1024 * 1024)
            print(f"   ‚Ä¢ {f.name} ({file_size:.1f} MB)")
    
    # Recommendations
    if len(ica_cleaned_files) >= total_filtered * 0.9:
        print(f"\n   ‚úÖ READY: Bad channel & ICA pipeline complete!")
        print(f"   ‚Üí Proceed to Epoching & Feature Extraction")
    elif len(ica_cleaned_files) == 0:
        print(f"\n   üîÑ NEED: Run bad channel & ICA pipeline")
        print(f"   ‚Üí Execute batch processing above")
    else:
        print(f"\n   ‚ö†Ô∏è  PARTIAL: Pipeline incomplete")
        print(f"   ‚Üí Continue batch processing or check for errors")
    
    return {
        'filtered_files': total_filtered,
        'ica_cleaned_files': len(ica_cleaned_files),
        'badchannel_reports': len(badchannel_reports),
        'badchannel_viz': len(badchannel_viz),
        'ica_viz': len(ica_viz)
    }

# Check progress
pipeline_progress = check_badchannel_ica_progress(Path('EEG_Preprocessing_Output'))

# %% [markdown]
# ## Summary
# 
# This notebook completes the third phase of EEG preprocessing with:
# 
# ‚úÖ **Bad Channel Detection**: Automated identification of problematic channels using statistical methods  
# ‚úÖ **ICA Processing**: Artifact removal using Independent Component Analysis  
# ‚úÖ **Quality Reports**: Comprehensive reporting for each processing step  
# ‚úÖ **Visualizations**: Professional plots for quality assessment  
# ‚úÖ **Batch Processing**: Efficient processing of all files  
# 
# **Next Steps**: Proceed to epoching and feature extraction for analysis-ready data.

print("\nüéâ Bad Channel Detection & ICA Pipeline Ready!")