In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')  # Set backend before importing pyplot
import matplotlib.pyplot as plt
from scipy import signal
from scipy.signal import butter, sosfiltfilt, hilbert
from PyEMD import EMD, EEMD
from concurrent.futures import ThreadPoolExecutor, as_completed
import multiprocessing as mp
from functools import partial
import warnings
import traceback
import gc
from sklearn.preprocessing import MinMaxScaler
warnings.filterwarnings('ignore')

class OptimizedEEGHilbertSpectrumGenerator:
    def __init__(self, sz_input_dir, hc_input_dir, output_dir, sampling_rate=256, n_jobs=None):
        """
        Initialize the Optimized EEG Hilbert Spectrum Generator using Hilbert-Huang Transform (HHT)
        Now generates HHT plots for 5 brain regions instead of individual channels
        
        Parameters:
        sz_input_dir (str): Directory containing CSV files with Schizophrenia EEG data
        hc_input_dir (str): Directory containing CSV files with Healthy Control EEG data
        output_dir (str): Root directory to save Hilbert spectrum images
        sampling_rate (int): EEG sampling rate in Hz (default: 256 Hz)
        n_jobs (int): Number of parallel threads (default: CPU count)
        """
        self.sz_input_dir = sz_input_dir
        self.hc_input_dir = hc_input_dir
        self.output_dir = output_dir
        self.sampling_rate = sampling_rate
        self.segment_duration = 3  # 3 seconds
        self.segment_samples = self.sampling_rate * self.segment_duration
        
        # Define brain regions based on 10-20 EEG electrode system
        # Assuming standard 19-channel setup: Fp1, Fp2, F7, F3, Fz, F4, F8, T3, C3, Cz, C4, T4, T5, P3, Pz, P4, T6, O1, O2
        self.brain_regions = {
            'Frontal': [0, 1, 2, 3, 4, 5, 6],      # Fp1, Fp2, F7, F3, Fz, F4, F8
            'Central': [7, 8, 9, 10, 11],          # T3, C3, Cz, C4, T4
            'Temporal': [2, 6, 7, 11, 12, 16],     # F7, F8, T3, T4, T5, T6
            'Parietal': [12, 13, 14, 15, 16],      # T5, P3, Pz, P4, T6
            'Occipital': [17, 18]                  # O1, O2
        }
        
        self.region_names = list(self.brain_regions.keys())
        
        # Set number of parallel jobs (using threads instead of processes)
        if n_jobs is None:
            self.n_jobs = min(mp.cpu_count(), 6)  # Cap at 6 to avoid memory issues
        else:
            self.n_jobs = min(n_jobs, 6)
        
        # HHT parameters
        self.freq_min = 0.5   # Minimum frequency (Hz)
        self.freq_max = 32    # Maximum frequency (Hz)
        self.freq_bins = 100  # Number of frequency bins for Hilbert spectrum
        self.max_imfs = 8     # Maximum number of IMFs to extract
        
        # EMD parameters
        self.emd_trials = 50  # Number of trials for EEMD (Ensemble EMD)
        self.noise_std = 0.2  # Standard deviation of added noise for EEMD
        
        # Pre-compute frequency bins for Hilbert spectrum
        self.freq_bins_array = np.linspace(self.freq_min, self.freq_max, self.freq_bins, dtype=np.float32)
        
        # Pre-compute filter coefficients (SOS format for stability and speed)
        self.filter_sos = self._precompute_filter()
        
        # Create output directories for each brain region
        self.region_dirs = {}
        for region in self.region_names:
            sz_region_dir = os.path.join(output_dir, f'Schizophrenia_HHT_{region}')
            hc_region_dir = os.path.join(output_dir, f'HealthyControl_HHT_{region}')
            os.makedirs(sz_region_dir, exist_ok=True)
            os.makedirs(hc_region_dir, exist_ok=True)
            self.region_dirs[region] = {
                'SZ': sz_region_dir,
                'HC': hc_region_dir
            }
        
        print(f"Initialized HHT Generator with {self.n_jobs} parallel threads")
        print(f"Brain regions: {', '.join(self.region_names)}")
        print(f"Frequency range: {self.freq_min}-{self.freq_max} Hz with {self.freq_bins} bins")
        
    def _precompute_filter(self):
        """Pre-compute bandpass filter coefficients for reuse"""
        nyquist = 0.5 * self.sampling_rate
        low = self.freq_min / nyquist
        high = self.freq_max / nyquist
        sos = signal.butter(4, [low, high], btype='band', output='sos')
        return sos
    
    def bandpass_filter_optimized(self, data):
        """
        Apply pre-computed bandpass filter to EEG data (optimized version)
        """
        # Use SOS format for better numerical stability and speed
        filtered_data = sosfiltfilt(self.filter_sos, data, axis=0)
        return filtered_data.astype(np.float32)
    
    def segment_signal_vectorized(self, data):
        """
        Vectorized segmentation of EEG signal into 3-second windows
        """
        n_samples, n_channels = data.shape
        n_segments = n_samples // self.segment_samples
        
        if n_segments == 0:
            return []
        
        # Reshape for vectorized processing
        valid_samples = n_segments * self.segment_samples
        reshaped_data = data[:valid_samples, :].reshape(n_segments, self.segment_samples, n_channels)
        
        # Convert to list of segments
        segments = [segment.astype(np.float32) for segment in reshaped_data]
        return segments
    
    def average_channels_by_region(self, segment, region_channels):
        """
        Average EEG channels within a brain region
        
        Parameters:
        segment: EEG segment (samples x channels)
        region_channels: List of channel indices for the region
        
        Returns:
        averaged_signal: 1D array of averaged signal for the region
        """
        try:
            # Select channels for this region
            valid_channels = [ch for ch in region_channels if ch < segment.shape[1]]
            if not valid_channels:
                # If no valid channels, return zeros
                return np.zeros(segment.shape[0], dtype=np.float32)
            
            region_data = segment[:, valid_channels]
            
            # Average across channels in the region
            averaged_signal = np.mean(region_data, axis=1, dtype=np.float32)
            
            return averaged_signal
            
        except Exception as e:
            print(f"Error averaging channels for region: {e}")
            return np.zeros(segment.shape[0], dtype=np.float32)
    
    def emd_decomposition(self, signal_data, use_eemd=True):
        """
        Perform Empirical Mode Decomposition (EMD) or Ensemble EMD (EEMD)
        
        Parameters:
        signal_data: 1D array of EEG signal
        use_eemd: Whether to use Ensemble EMD (more robust but slower)
        
        Returns:
        imfs: Array of Intrinsic Mode Functions
        """
        try:
            if use_eemd:
                # Use Ensemble EMD for better decomposition
                eemd = EEMD(trials=self.emd_trials, noise_std=self.noise_std)
                imfs = eemd.eemd(signal_data, max_imf=self.max_imfs)
            else:
                # Use standard EMD
                emd = EMD()
                imfs = emd.emd(signal_data, max_imf=self.max_imfs)
            
            return imfs.astype(np.float32)
            
        except Exception as e:
            print(f"EMD decomposition error: {e}")
            # Return original signal as single IMF if decomposition fails
            return np.array([signal_data], dtype=np.float32)
    
    def compute_instantaneous_frequency(self, imf):
        """
        Compute instantaneous frequency and amplitude using Hilbert Transform
        
        Parameters:
        imf: Intrinsic Mode Function (1D array)
        
        Returns:
        inst_freq: Instantaneous frequency array
        inst_amp: Instantaneous amplitude array
        """
        try:
            # Apply Hilbert transform
            analytic_signal = hilbert(imf)
            
            # Compute instantaneous amplitude
            inst_amp = np.abs(analytic_signal)
            
            # Compute instantaneous phase
            inst_phase = np.angle(analytic_signal)
            
            # Compute instantaneous frequency (derivative of phase)
            inst_freq = np.diff(np.unwrap(inst_phase)) * self.sampling_rate / (2 * np.pi)
            
            # Pad to maintain same length
            inst_freq = np.concatenate([[inst_freq[0]], inst_freq])
            
            # Remove negative frequencies and outliers
            inst_freq = np.abs(inst_freq)
            inst_freq = np.clip(inst_freq, self.freq_min, self.freq_max)
            
            return inst_freq.astype(np.float32), inst_amp.astype(np.float32)
            
        except Exception as e:
            print(f"Instantaneous frequency computation error: {e}")
            # Return zero arrays if computation fails
            return (np.zeros(len(imf), dtype=np.float32), 
                   np.zeros(len(imf), dtype=np.float32))
    
    def create_hilbert_spectrum(self, imfs):
        """
        Create Hilbert Spectrum from IMFs
        
        Parameters:
        imfs: Array of Intrinsic Mode Functions
        
        Returns:
        hilbert_spectrum: 2D array (frequency_bins x time_samples)
        time_axis: Time axis for the spectrum
        """
        try:
            n_samples = imfs.shape[1]
            time_axis = np.linspace(0, self.segment_duration, n_samples, dtype=np.float32)
            
            # Initialize Hilbert spectrum
            hilbert_spectrum = np.zeros((len(self.freq_bins_array), n_samples), dtype=np.float32)
            
            # Process each IMF
            for imf in imfs:
                if len(imf) == 0:
                    continue
                    
                # Compute instantaneous frequency and amplitude
                inst_freq, inst_amp = self.compute_instantaneous_frequency(imf)
                
                # Map instantaneous frequency to frequency bins
                for t in range(len(inst_freq)):
                    freq_val = inst_freq[t]
                    amp_val = inst_amp[t]
                    
                    # Find closest frequency bin
                    freq_idx = np.argmin(np.abs(self.freq_bins_array - freq_val))
                    
                    # Add amplitude to Hilbert spectrum
                    hilbert_spectrum[freq_idx, t] += amp_val
            
            # Convert to power spectrum (square of amplitude)
            power_spectrum = hilbert_spectrum ** 2
            
            # Convert to dB scale with error handling
            power_spectrum_db = 10 * np.log10(np.maximum(power_spectrum, 1e-12))
            
            return power_spectrum_db, time_axis
            
        except Exception as e:
            print(f"Hilbert spectrum creation error: {e}")
            # Return zero spectrum if computation fails
            n_samples = imfs.shape[1] if len(imfs) > 0 else self.segment_samples
            time_axis = np.linspace(0, self.segment_duration, n_samples, dtype=np.float32)
            return (np.zeros((len(self.freq_bins_array), n_samples), dtype=np.float32), 
                   time_axis)
    
    def compute_hht_spectrum_optimized(self, signal_data):
        """
        Optimized HHT computation for a single region signal
        
        Parameters:
        signal_data: 1D EEG signal array (averaged across region channels)
        
        Returns:
        hilbert_spectrum: 2D Hilbert spectrum (frequency x time)
        time_axis: Time axis
        """
        try:
            # Normalize signal to improve EMD stability
            signal_normalized = (signal_data - np.mean(signal_data)) / (np.std(signal_data) + 1e-8)
            
            # Perform EMD decomposition
            imfs = self.emd_decomposition(signal_normalized, use_eemd=True)
            
            # Create Hilbert spectrum
            hilbert_spectrum, time_axis = self.create_hilbert_spectrum(imfs)
            
            return hilbert_spectrum, time_axis
            
        except Exception as e:
            print(f"HHT spectrum computation error: {e}")
            # Return zero spectrum if computation fails
            time_axis = np.linspace(0, self.segment_duration, len(signal_data), dtype=np.float32)
            return (np.zeros((len(self.freq_bins_array), len(signal_data)), dtype=np.float32), 
                   time_axis)
    
    def create_region_hilbert_spectra_safe(self, segment, patient_id, segment_idx, label):
        """
        Create Hilbert spectra for all 5 brain regions
        """
        saved_files = []
        
        try:
            for region_name in self.region_names:
                region_channels = self.brain_regions[region_name]
                
                # Average channels within the region
                region_signal = self.average_channels_by_region(segment, region_channels)
                
                # Skip if region signal is empty or all zeros
                if len(region_signal) == 0 or np.all(region_signal == 0):
                    print(f"    Skipping {region_name} region - no valid signal")
                    continue
                
                # Compute HHT spectrum for this region
                spectrum, time_axis = self.compute_hht_spectrum_optimized(region_signal)
                
                # Create figure with explicit cleanup
                plt.ioff()  # Turn off interactive mode
                fig, ax = plt.subplots(figsize=(12, 8), dpi=150)
                
                try:
                    # Create the Hilbert spectrum plot
                    im = ax.imshow(
                        spectrum, 
                        aspect='auto', 
                        origin='lower',
                        extent=[time_axis[0], time_axis[-1], 
                               self.freq_bins_array[0], self.freq_bins_array[-1]],
                        cmap='jet',  # 'jet' colormap often works well for Hilbert spectra
                        interpolation='bilinear'
                    )
                    
                    # Remove all captions - just the plot
                    ax.set_xticks([])
                    ax.set_yticks([])
                    ax.set_ylim(self.freq_min, self.freq_max)
                    ax.axis('off')
                    
                    # Determine save directory and filename
                    save_dir = self.region_dirs[region_name][label]
                    filename = f'{label}_patient_{patient_id}_segment_{segment_idx:03d}_{region_name}_hilbert_spectrum.png'
                    filepath = os.path.join(save_dir, filename)
                    
                    # Save with error handling
                    plt.savefig(filepath, dpi=150, bbox_inches='tight', facecolor='white', 
                               format='png')
                    
                    saved_files.append(filepath)
                    
                finally:
                    plt.close(fig)  # Always close the figure
                    plt.clf()       # Clear any remaining plots
                    gc.collect()    # Force garbage collection
                    
        except Exception as e:
            print(f"Error creating region Hilbert spectra for {patient_id} segment {segment_idx}: {e}")
        
        return saved_files
    
    def process_patient_file_safe(self, csv_file, patient_id, label):
        """
        Safe version of patient file processing with comprehensive error handling
        Now processes by brain regions instead of individual channels
        """
        print(f"Processing {label} Patient {patient_id}...")
        
        try:
            # Load EEG data with error handling
            try:
                data = pd.read_csv(csv_file, dtype=np.float32)
                eeg_data = data.values.astype(np.float32)
            except Exception as e:
                print(f"Error loading CSV file {csv_file}: {e}")
                return []
            
            # Handle channel count
            if eeg_data.shape[1] != 19:
                print(f"Warning: Expected 19 channels, got {eeg_data.shape[1]} for patient {patient_id}")
                if eeg_data.shape[1] > 19:
                    eeg_data = eeg_data[:, :19]
                else:
                    padding = np.zeros((eeg_data.shape[0], 19 - eeg_data.shape[1]), dtype=np.float32)
                    eeg_data = np.hstack([eeg_data, padding])
            
            print(f"  EEG data shape: {eeg_data.shape}")
            
            # Apply bandpass filter with error handling
            try:
                print(f"  Applying bandpass filter ({self.freq_min}-{self.freq_max} Hz)...")
                filtered_data = self.bandpass_filter_optimized(eeg_data)
            except Exception as e:
                print(f"Error in filtering for patient {patient_id}: {e}")
                return []
            
            # Segment the data
            print(f"  Segmenting data into {self.segment_duration}-second windows...")
            segments = self.segment_signal_vectorized(filtered_data)
            
            if not segments:
                print(f"  No segments generated for patient {patient_id}")
                return []
            
            print(f"  Generated {len(segments)} segments from patient {patient_id}")
            
            # Create Hilbert spectrum images for each segment and each brain region
            all_saved_files = []
            for idx, segment in enumerate(segments):
                print(f"  Processing segment {idx + 1}/{len(segments)} for all brain regions...")
                
                segment_files = self.create_region_hilbert_spectra_safe(segment, patient_id, idx, label)
                all_saved_files.extend(segment_files)
                print(f"    Generated {len(segment_files)} region spectra for segment {idx}")
            
            return all_saved_files
            
        except Exception as e:
            print(f"Error processing patient {patient_id}: {str(e)}")
            traceback.print_exc()
            return []
    
    def get_patient_files(self, directory):
        """
        Get all CSV files from a directory
        """
        csv_files = []
        if os.path.exists(directory):
            for file in os.listdir(directory):
                if file.lower().endswith('.csv'):
                    csv_files.append(os.path.join(directory, file))
            csv_files.sort()
        else:
            print(f"Warning: Directory {directory} does not exist!")
        
        return csv_files
    
    def process_all_patients_threaded(self):
        """
        Process all patients using thread-based parallelism (safer than multiprocessing)
        """
        print("=== Optimized EEG Hilbert Spectrum Generation (HHT - 5 Brain Regions) ===")
        print("Generating Hilbert spectra for 5 brain regions with thread-based parallelism\n")
        
        # Get patient files from both directories
        sz_patient_files = self.get_patient_files(self.sz_input_dir)
        hc_patient_files = self.get_patient_files(self.hc_input_dir)
        
        print(f"Found {len(sz_patient_files)} Schizophrenia patient files")
        print(f"Found {len(hc_patient_files)} Healthy Control patient files")
        print(f"Using {self.n_jobs} parallel threads")
        print(f"Brain regions: {', '.join(self.region_names)}")
        
        if len(sz_patient_files) == 0 and len(hc_patient_files) == 0:
            print("No CSV files found in the specified directories!")
            return
        
        # Prepare all patient processing tasks
        tasks = []
        
        # Add SZ patients
        for i, csv_file in enumerate(sz_patient_files, 1):
            patient_id = f"SZ_{i:02d}"
            tasks.append((csv_file, patient_id, 'SZ'))
        
        # Add HC patients
        for i, csv_file in enumerate(hc_patient_files, 1):
            patient_id = f"HC_{i:02d}"
            tasks.append((csv_file, patient_id, 'HC'))
        
        total_sz_images = 0
        total_hc_images = 0
        region_counts = {region: {'SZ': 0, 'HC': 0} for region in self.region_names}
        
        # Process patients using ThreadPoolExecutor
        with ThreadPoolExecutor(max_workers=self.n_jobs) as executor:
            # Submit all tasks
            future_to_task = {
                executor.submit(self.process_patient_file_safe, csv_file, patient_id, label): (patient_id, label)
                for csv_file, patient_id, label in tasks
            }
            
            # Process completed tasks
            for future in as_completed(future_to_task):
                patient_id, label = future_to_task[future]
                try:
                    saved_files = future.result(timeout=600)  # 10 minute timeout per patient
                    
                    # Count images by region
                    for filepath in saved_files:
                        filename = os.path.basename(filepath)
                        for region in self.region_names:
                            if f'_{region}_' in filename:
                                region_counts[region][label] += 1
                                break
                    
                    if label == 'SZ':
                        total_sz_images += len(saved_files)
                    else:
                        total_hc_images += len(saved_files)
                    print(f"  Completed {patient_id}: {len(saved_files)} region Hilbert spectra generated\n")
                    
                except Exception as e:
                    print(f"Error processing {patient_id}: {e}")
                    traceback.print_exc()
        
        print("=== Summary ===")
        print(f"Total Schizophrenia Hilbert spectrum images: {total_sz_images}")
        print(f"Total Healthy Control Hilbert spectrum images: {total_hc_images}")
        print(f"Total Hilbert spectrum images generated: {total_sz_images + total_hc_images}")
        print("\nBreakdown by brain region:")
        for region in self.region_names:
            sz_count = region_counts[region]['SZ']
            hc_count = region_counts[region]['HC']
            total_region = sz_count + hc_count
            print(f"  {region}: {total_region} images (SZ: {sz_count}, HC: {hc_count})")
        
        print(f"\nImages saved in: {self.output_dir}")
        for region in self.region_names:
            print(f"  - {region} SZ: {self.region_dirs[region]['SZ']}")
            print(f"  - {region} HC: {self.region_dirs[region]['HC']}")
        
        print(f"\nHHT Method: Ensemble EMD + Hilbert Transform")
        print(f"Frequency range: {self.freq_min}-{self.freq_max} Hz")
        print(f"Processing completed successfully with {self.n_jobs} threads")

    def process_all_patients_sequential_optimized(self):
        """
        Optimized sequential processing as fallback option
        """
        print("=== Optimized EEG Hilbert Spectrum Generation (HHT - 5 Brain Regions - Sequential) ===")
        print("Using optimized sequential processing for maximum stability\n")
        
        # Get patient files from both directories
        sz_patient_files = self.get_patient_files(self.sz_input_dir)
        hc_patient_files = self.get_patient_files(self.hc_input_dir)
        
        print(f"Found {len(sz_patient_files)} Schizophrenia patient files")
        print(f"Found {len(hc_patient_files)} Healthy Control patient files")
        print(f"Brain regions: {', '.join(self.region_names)}")
        
        if len(sz_patient_files) == 0 and len(hc_patient_files) == 0:
            print("No CSV files found in the specified directories!")
            return
        
        total_sz_images = 0
        total_hc_images = 0
        region_counts = {region: {'SZ': 0, 'HC': 0} for region in self.region_names}
        
        print("\n=== Processing Schizophrenia Patients ===")
        for i, csv_file in enumerate(sz_patient_files, 1):
            patient_id = f"SZ_{i:02d}"
            saved_files = self.process_patient_file_safe(csv_file, patient_id, 'SZ')
            total_sz_images += len(saved_files)
            
            # Count by region
            for filepath in saved_files:
                filename = os.path.basename(filepath)
                for region in self.region_names:
                    if f'_{region}_' in filename:
                        region_counts[region]['SZ'] += 1
                        break
            
            print(f"  Generated {len(saved_files)} region Hilbert spectra for {patient_id}\n")
        
        print("=== Processing Healthy Control Patients ===")
        for i, csv_file in enumerate(hc_patient_files, 1):
            patient_id = f"HC_{i:02d}"
            saved_files = self.process_patient_file_safe(csv_file, patient_id, 'HC')
            total_hc_images += len(saved_files)
            
            # Count by region
            for filepath in saved_files:
                filename = os.path.basename(filepath)
                for region in self.region_names:
                    if f'_{region}_' in filename:
                        region_counts[region]['HC'] += 1
                        break
            
            print(f"  Generated {len(saved_files)} region Hilbert spectra for {patient_id}\n")
        
        print("=== Summary ===")
        print(f"Total Schizophrenia Hilbert spectrum images: {total_sz_images}")
        print(f"Total Healthy Control Hilbert spectrum images: {total_hc_images}")
        print(f"Total Hilbert spectrum images generated: {total_sz_images + total_hc_images}")
        print("\nBreakdown by brain region:")
        for region in self.region_names:
            sz_count = region_counts[region]['SZ']
            hc_count = region_counts[region]['HC']
            total_region = sz_count + hc_count
            print(f"  {region}: {total_region} images (SZ: {sz_count}, HC: {hc_count})")
        
        print(f"Sequential HHT processing for brain regions completed successfully")

# Example usage
if __name__ == "__main__":
    # Define paths
    sz_input_directory = "D:/result/dataset2/S"
    hc_input_directory = "D:/result/dataset2/H" 
    output_directory = "D:/result/sobi2_brain_regions"
    
    # Create optimized Hilbert spectrum generator for brain regions
    generator = OptimizedEEGHilbertSpectrumGenerator(
        sz_input_dir=sz_input_directory,
        hc_input_dir=hc_input_directory,
        output_dir=output_directory,
        sampling_rate=256,
        n_jobs=4  # Use 4 threads for safer parallel processing
    )
    
    print("Choose processing method:")
    print("1. Thread-based parallel processing (recommended)")
    print("2. Sequential processing (most stable)")
    
    try:
        # Try thread-based parallel processing first
        print("\nStarting thread-based parallel processing with HHT for brain regions...")
        generator.process_all_patients_threaded()
    except Exception as e:
        print(f"\nThread-based processing failed: {e}")
        print("Falling back to sequential processing...")
        generator.process_all_patients_sequential_optimized()
    
    print("\nOptimized Hilbert-Huang Transform processing for brain regions completed!")
    print("Key Features:")
    print("- 5 Brain Regions: Frontal, Central, Temporal, Parietal, Occipital")
    print("- Empirical Mode Decomposition (EMD/EEMD)")
    print("- Hilbert Transform for instantaneous frequency/amplitude")
    print("- Channel averaging within each brain region")
    print("- Adaptive time-frequency representation")
    print("- Thread-based parallelism with comprehensive error handling")
    print("- Memory management and optimization")
    print("- Robust signal processing pipeline")