In [1]:
!pip install numpy pandas matplotlib scipy EMD-signal scikit-learn


Collecting numpy
  Downloading numpy-2.2.6-cp311-cp311-win_amd64.whl.metadata (60 kB)
Collecting pandas
  Downloading pandas-2.2.3-cp311-cp311-win_amd64.whl.metadata (19 kB)
Collecting matplotlib
  Downloading matplotlib-3.10.3-cp311-cp311-win_amd64.whl.metadata (11 kB)
Collecting scipy
  Downloading scipy-1.15.3-cp311-cp311-win_amd64.whl.metadata (60 kB)
Collecting EMD-signal
  Using cached EMD_signal-1.6.4-py3-none-any.whl.metadata (8.9 kB)
Collecting scikit-learn
  Downloading scikit_learn-1.6.1-cp311-cp311-win_amd64.whl.metadata (15 kB)
Collecting pytz>=2020.1 (from pandas)
  Using cached pytz-2025.2-py2.py3-none-any.whl.metadata (22 kB)
Collecting tzdata>=2022.7 (from pandas)
  Using cached tzdata-2025.2-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting contourpy>=1.0.1 (from matplotlib)
  Downloading contourpy-1.3.2-cp311-cp311-win_amd64.whl.metadata (5.5 kB)
Collecting cycler>=0.10 (from matplotlib)
  Using cached cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fontt

In [2]:
!pip install PyWavelets

Collecting PyWavelets
  Downloading pywavelets-1.8.0-cp311-cp311-win_amd64.whl.metadata (9.0 kB)
Downloading pywavelets-1.8.0-cp311-cp311-win_amd64.whl (4.2 MB)
   ---------------------------------------- 0.0/4.2 MB ? eta -:--:--
   -- ------------------------------------- 0.3/4.2 MB ? eta -:--:--
   ---- ----------------------------------- 0.5/4.2 MB 3.4 MB/s eta 0:00:02
   --------- ------------------------------ 1.0/4.2 MB 2.2 MB/s eta 0:00:02
   -------------- ------------------------- 1.6/4.2 MB 2.4 MB/s eta 0:00:02
   ------------------- -------------------- 2.1/4.2 MB 2.3 MB/s eta 0:00:01
   ------------------------ --------------- 2.6/4.2 MB 2.2 MB/s eta 0:00:01
   ----------------------------- ---------- 3.1/4.2 MB 2.4 MB/s eta 0:00:01
   ---------------------------------- ----- 3.7/4.2 MB 2.3 MB/s eta 0:00:01
   ------------------------------------- -- 3.9/4.2 MB 2.2 MB/s eta 0:00:01
   ---------------------------------------- 4.2/4.2 MB 2.2 MB/s eta 0:00:00
Installing collec

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')  # Set backend before importing pyplot
import matplotlib.pyplot as plt
from scipy import signal
from scipy.signal import butter, sosfiltfilt
import pywt
from concurrent.futures import ThreadPoolExecutor, as_completed
import multiprocessing as mp
from functools import partial
import warnings
import traceback
import gc
from sklearn.preprocessing import MinMaxScaler
warnings.filterwarnings('ignore')

class OptimizedEEGCWTSpectrogramGenerator:
    def __init__(self, sz_input_dir, hc_input_dir, output_dir, sampling_rate=256, n_jobs=None):
        """
        Initialize the Optimized EEG CWT Spectrogram Generator using Continuous Wavelet Transform
        
        Parameters:
        sz_input_dir (str): Directory containing CSV files with Schizophrenia EEG data
        hc_input_dir (str): Directory containing CSV files with Healthy Control EEG data
        output_dir (str): Root directory to save CWT spectrogram images
        sampling_rate (int): EEG sampling rate in Hz (default: 256 Hz)
        n_jobs (int): Number of parallel threads (default: CPU count)
        """
        self.sz_input_dir = sz_input_dir
        self.hc_input_dir = hc_input_dir
        self.output_dir = output_dir
        self.sampling_rate = sampling_rate
        self.segment_duration = 3  # 3 seconds
        self.segment_samples = self.sampling_rate * self.segment_duration
        
        # Set number of parallel jobs (using threads instead of processes)
        if n_jobs is None:
            self.n_jobs = min(mp.cpu_count(), 6)  # Cap at 6 to avoid memory issues
        else:
            self.n_jobs = min(n_jobs, 6)
        
        # CWT parameters
        self.freq_min = 0.5   # Minimum frequency (Hz)
        self.freq_max = 32    # Maximum frequency (Hz)
        self.wavelet = 'cmor1.5-1.0'  # Complex Morlet wavelet (bandwidth-frequency = 1.5, center frequency = 1.0)
        self.n_frequencies = 100  # Number of frequency scales
        
        # Alternative wavelets you can try:
        # 'morl' - Morlet wavelet
        # 'cgau8' - Complex Gaussian wavelet (8th order)
        # 'cmor' - Complex Morlet wavelet
        
        # Generate frequency scales for CWT
        self.frequencies = np.logspace(np.log10(self.freq_min), np.log10(self.freq_max), self.n_frequencies)
        self.scales = pywt.frequency2scale(self.wavelet, self.frequencies/self.sampling_rate) * self.sampling_rate
        
        # Pre-compute filter coefficients (SOS format for stability and speed)
        self.filter_sos = self._precompute_filter()
        
        # Create output directories
        self.sz_dir = os.path.join(output_dir, 'Schizophrenia_CWT')
        self.hc_dir = os.path.join(output_dir, 'HealthyControl_CWT')
        os.makedirs(self.sz_dir, exist_ok=True)
        os.makedirs(self.hc_dir, exist_ok=True)
        
        print(f"Initialized CWT Spectrogram Generator with {self.n_jobs} parallel threads")
        print(f"Frequency range: {self.freq_min}-{self.freq_max} Hz with {self.n_frequencies} frequency bins")
        print(f"Using wavelet: {self.wavelet}")
        
    def _precompute_filter(self):
        """Pre-compute bandpass filter coefficients for reuse"""
        nyquist = 0.5 * self.sampling_rate
        low = self.freq_min / nyquist
        high = self.freq_max / nyquist
        sos = signal.butter(4, [low, high], btype='band', output='sos')
        return sos
    
    def bandpass_filter_optimized(self, data):
        """
        Apply pre-computed bandpass filter to EEG data (optimized version)
        """
        # Use SOS format for better numerical stability and speed
        filtered_data = sosfiltfilt(self.filter_sos, data, axis=0)
        return filtered_data.astype(np.float32)
    
    def segment_signal_vectorized(self, data):
        """
        Vectorized segmentation of EEG signal into 3-second windows
        """
        n_samples, n_channels = data.shape
        n_segments = n_samples // self.segment_samples
        
        if n_segments == 0:
            return []
        
        # Reshape for vectorized processing
        valid_samples = n_segments * self.segment_samples
        reshaped_data = data[:valid_samples, :].reshape(n_segments, self.segment_samples, n_channels)
        
        # Convert to list of segments
        segments = [segment.astype(np.float32) for segment in reshaped_data]
        return segments
    
    def compute_cwt_spectrogram(self, signal_data):
        """
        Compute CWT spectrogram for a single channel using Continuous Wavelet Transform
        
        Parameters:
        signal_data: 1D array of EEG signal
        
        Returns:
        cwt_matrix: 2D CWT coefficients (frequencies x time)
        time_axis: Time axis for the spectrogram
        """
        try:
            # Normalize signal to improve CWT stability
            signal_normalized = (signal_data - np.mean(signal_data)) / (np.std(signal_data) + 1e-8)
            
            # Compute Continuous Wavelet Transform
            cwt_matrix, _ = pywt.cwt(signal_normalized, self.scales, self.wavelet, 
                                   sampling_period=1/self.sampling_rate)
            
            # Convert to power (magnitude squared)
            power_matrix = np.abs(cwt_matrix) ** 2
            
            # Convert to dB scale for better visualization
            power_db = 10 * np.log10(np.maximum(power_matrix, 1e-12))
            
            # Create time axis
            time_axis = np.linspace(0, self.segment_duration, signal_data.shape[0])
            
            return power_db.astype(np.float32), time_axis
            
        except Exception as e:
            print(f"CWT computation error: {e}")
            # Return zero spectrogram if computation fails
            time_axis = np.linspace(0, self.segment_duration, len(signal_data))
            return (np.zeros((len(self.scales), len(signal_data)), dtype=np.float32), 
                   time_axis)
    
    def compute_stft_spectrogram(self, signal_data):
        """
        Alternative method: Compute STFT spectrogram as backup
        
        Parameters:
        signal_data: 1D array of EEG signal
        
        Returns:
        stft_matrix: 2D STFT coefficients (frequencies x time)
        time_axis: Time axis for the spectrogram
        """
        try:
            # Compute Short-Time Fourier Transform
            nperseg = min(256, len(signal_data) // 4)  # Window size
            noverlap = nperseg // 2  # 50% overlap
            
            frequencies, times, stft_matrix = signal.spectrogram(
                signal_data, 
                fs=self.sampling_rate,
                nperseg=nperseg,
                noverlap=noverlap,
                scaling='density'
            )
            
            # Filter frequencies within our range
            freq_mask = (frequencies >= self.freq_min) & (frequencies <= self.freq_max)
            filtered_freqs = frequencies[freq_mask]
            filtered_stft = stft_matrix[freq_mask, :]
            
            # Convert to dB scale
            power_db = 10 * np.log10(np.maximum(filtered_stft, 1e-12))
            
            return power_db.astype(np.float32), times
            
        except Exception as e:
            print(f"STFT computation error: {e}")
            # Return zero spectrogram if computation fails
            time_axis = np.linspace(0, self.segment_duration, len(signal_data))
            return (np.zeros((self.n_frequencies, len(signal_data)), dtype=np.float32), 
                   time_axis)
    
    def create_combined_cwt_spectrogram_safe(self, segment, patient_id, segment_idx, label):
        """
        Safe version of CWT spectrogram creation with proper error handling
        """
        try:
            # Compute CWT for all channels
            spectrograms = []
            time_axes = []
            
            for ch in range(segment.shape[1]):
                spectrogram, time_axis = self.compute_cwt_spectrogram(segment[:, ch])
                spectrograms.append(spectrogram)
                time_axes.append(time_axis)
            
            # Average across all channels
            avg_spectrogram = np.mean(spectrograms, axis=0, dtype=np.float32)
            avg_time_axis = time_axes[0]  # All time axes should be the same
            
            # Create figure with explicit cleanup
            plt.ioff()  # Turn off interactive mode
            fig, ax = plt.subplots(figsize=(12, 8), dpi=150)
            
            try:
                # Create the CWT spectrogram plot
                im = ax.imshow(
                    avg_spectrogram, 
                    aspect='auto', 
                    origin='lower',
                    extent=[avg_time_axis[0], avg_time_axis[-1], 
                           self.frequencies[0], self.frequencies[-1]],
                    cmap='jet',  # 'jet' colormap for spectrograms
                    interpolation='bilinear'
                )
                
                # Remove all captions - just the plot
                ax.set_xticks([])
                ax.set_yticks([])
                ax.set_ylim(self.freq_min, self.freq_max)
                ax.axis('off')
                
                # Determine save directory and filename
                save_dir = self.sz_dir if label == 'SZ' else self.hc_dir
                filename = f'{label}_patient_{patient_id}_segment_{segment_idx:03d}_cwt_spectrogram.png'
                filepath = os.path.join(save_dir, filename)
                
                # Save with error handling
                plt.savefig(filepath, dpi=150, bbox_inches='tight', facecolor='white', 
                           format='png')
                
                return filepath
                
            finally:
                plt.close(fig)  # Always close the figure
                plt.clf()       # Clear any remaining plots
                gc.collect()    # Force garbage collection
                
        except Exception as e:
            print(f"Error creating CWT spectrogram for {patient_id} segment {segment_idx}: {e}")
            return None
    
    def process_patient_file_safe(self, csv_file, patient_id, label):
        """
        Safe version of patient file processing with comprehensive error handling
        """
        print(f"Processing {label} Patient {patient_id}...")
        
        try:
            # Load EEG data with error handling
            try:
                data = pd.read_csv(csv_file, dtype=np.float32)
                eeg_data = data.values.astype(np.float32)
            except Exception as e:
                print(f"Error loading CSV file {csv_file}: {e}")
                return []
            
            # Handle channel count
            if eeg_data.shape[1] != 19:
                print(f"Warning: Expected 19 channels, got {eeg_data.shape[1]} for patient {patient_id}")
                if eeg_data.shape[1] > 19:
                    eeg_data = eeg_data[:, :19]
                else:
                    padding = np.zeros((eeg_data.shape[0], 19 - eeg_data.shape[1]), dtype=np.float32)
                    eeg_data = np.hstack([eeg_data, padding])
            
            print(f"  EEG data shape: {eeg_data.shape}")
            
            # Apply bandpass filter with error handling
            try:
                print(f"  Applying bandpass filter ({self.freq_min}-{self.freq_max} Hz)...")
                filtered_data = self.bandpass_filter_optimized(eeg_data)
            except Exception as e:
                print(f"Error in filtering for patient {patient_id}: {e}")
                return []
            
            # Segment the data
            print(f"  Segmenting data into {self.segment_duration}-second windows...")
            segments = self.segment_signal_vectorized(filtered_data)
            
            if not segments:
                print(f"  No segments generated for patient {patient_id}")
                return []
            
            print(f"  Generated {len(segments)} segments from patient {patient_id}")
            
            # Create CWT spectrogram images for each segment
            saved_files = []
            for idx, segment in enumerate(segments):
                print(f"  Processing segment {idx + 1}/{len(segments)} with CWT...")
                
                filepath = self.create_combined_cwt_spectrogram_safe(segment, patient_id, idx, label)
                if filepath:
                    saved_files.append(filepath)
                    print(f"  Saved: {os.path.basename(filepath)}")
                else:
                    print(f"  Failed to save segment {idx}")
            
            return saved_files
            
        except Exception as e:
            print(f"Error processing patient {patient_id}: {str(e)}")
            traceback.print_exc()
            return []
    
    def get_patient_files(self, directory):
        """
        Get all CSV files from a directory
        """
        csv_files = []
        if os.path.exists(directory):
            for file in os.listdir(directory):
                if file.lower().endswith('.csv'):
                    csv_files.append(os.path.join(directory, file))
            csv_files.sort()
        else:
            print(f"Warning: Directory {directory} does not exist!")
        
        return csv_files
    
    def process_all_patients_threaded(self):
        """
        Process all patients using thread-based parallelism (safer than multiprocessing)
        """
        print("=== Optimized EEG CWT Spectrogram Generation (Thread-based Parallel Processing) ===")
        print("Generating CWT spectrograms with thread-based parallelism for stability\n")
        
        # Get patient files from both directories
        sz_patient_files = self.get_patient_files(self.sz_input_dir)
        hc_patient_files = self.get_patient_files(self.hc_input_dir)
        
        print(f"Found {len(sz_patient_files)} Schizophrenia patient files")
        print(f"Found {len(hc_patient_files)} Healthy Control patient files")
        print(f"Using {self.n_jobs} parallel threads")
        
        if len(sz_patient_files) == 0 and len(hc_patient_files) == 0:
            print("No CSV files found in the specified directories!")
            return
        
        # Prepare all patient processing tasks
        tasks = []
        
        # Add SZ patients
        for i, csv_file in enumerate(sz_patient_files, 1):
            patient_id = f"SZ_{i:02d}"
            tasks.append((csv_file, patient_id, 'SZ'))
        
        # Add HC patients
        for i, csv_file in enumerate(hc_patient_files, 1):
            patient_id = f"HC_{i:02d}"
            tasks.append((csv_file, patient_id, 'HC'))
        
        total_sz_images = 0
        total_hc_images = 0
        
        # Process patients using ThreadPoolExecutor
        with ThreadPoolExecutor(max_workers=self.n_jobs) as executor:
            # Submit all tasks
            future_to_task = {
                executor.submit(self.process_patient_file_safe, csv_file, patient_id, label): (patient_id, label)
                for csv_file, patient_id, label in tasks
            }
            
            # Process completed tasks
            for future in as_completed(future_to_task):
                patient_id, label = future_to_task[future]
                try:
                    saved_files = future.result(timeout=300)  # 5 minute timeout per patient
                    if label == 'SZ':
                        total_sz_images += len(saved_files)
                    else:
                        total_hc_images += len(saved_files)
                    print(f"  Completed {patient_id}: {len(saved_files)} CWT spectrograms generated\n")
                    
                except Exception as e:
                    print(f"Error processing {patient_id}: {e}")
                    traceback.print_exc()
        
        print("=== Summary ===")
        print(f"Total Schizophrenia CWT spectrogram images: {total_sz_images}")
        print(f"Total Healthy Control CWT spectrogram images: {total_hc_images}")
        print(f"Total CWT spectrogram images generated: {total_sz_images + total_hc_images}")
        print(f"Images saved in: {self.output_dir}")
        print(f"  - Schizophrenia CWT spectrograms: {self.sz_dir}")
        print(f"  - Healthy Control CWT spectrograms: {self.hc_dir}")
        print(f"CWT Method: Continuous Wavelet Transform using {self.wavelet}")
        print(f"Frequency range: {self.freq_min}-{self.freq_max} Hz")
        print(f"Processing completed successfully with {self.n_jobs} threads")

    def process_all_patients_sequential_optimized(self):
        """
        Optimized sequential processing as fallback option
        """
        print("=== Optimized EEG CWT Spectrogram Generation (Sequential Processing) ===")
        print("Using optimized sequential processing for maximum stability\n")
        
        # Get patient files from both directories
        sz_patient_files = self.get_patient_files(self.sz_input_dir)
        hc_patient_files = self.get_patient_files(self.hc_input_dir)
        
        print(f"Found {len(sz_patient_files)} Schizophrenia patient files")
        print(f"Found {len(hc_patient_files)} Healthy Control patient files")
        
        if len(sz_patient_files) == 0 and len(hc_patient_files) == 0:
            print("No CSV files found in the specified directories!")
            return
        
        total_sz_images = 0
        total_hc_images = 0
        
        print("\n=== Processing Schizophrenia Patients ===")
        for i, csv_file in enumerate(sz_patient_files, 1):
            patient_id = f"SZ_{i:02d}"
            saved_files = self.process_patient_file_safe(csv_file, patient_id, 'SZ')
            total_sz_images += len(saved_files)
            print(f"  Generated {len(saved_files)} CWT spectrograms for {patient_id}\n")
        
        print("=== Processing Healthy Control Patients ===")
        for i, csv_file in enumerate(hc_patient_files, 1):
            patient_id = f"HC_{i:02d}"
            saved_files = self.process_patient_file_safe(csv_file, patient_id, 'HC')
            total_hc_images += len(saved_files)
            print(f"  Generated {len(saved_files)} CWT spectrograms for {patient_id}\n")
        
        print("=== Summary ===")
        print(f"Total Schizophrenia CWT spectrogram images: {total_sz_images}")
        print(f"Total Healthy Control CWT spectrogram images: {total_hc_images}")
        print(f"Total CWT spectrogram images generated: {total_sz_images + total_hc_images}")
        print(f"Images saved in: {self.output_dir}")
        print(f"Sequential CWT processing completed successfully")

# Example usage
if __name__ == "__main__":
    # Define paths
    sz_input_directory = "D:/result/dataset/S"
    hc_input_directory = "D:/result/dataset/H" 
    output_directory = "D:/result/sobi"
    
    # Create optimized CWT spectrogram generator
    generator = OptimizedEEGCWTSpectrogramGenerator(
        sz_input_dir=sz_input_directory,
        hc_input_dir=hc_input_directory,
        output_dir=output_directory,
        sampling_rate=256,
        n_jobs=4  # Use 4 threads for safer parallel processing
    )
    
    print("Choose processing method:")
    print("1. Thread-based parallel processing (recommended)")
    print("2. Sequential processing (most stable)")
    
    try:
        # Try thread-based parallel processing first
        print("\nStarting thread-based parallel processing with CWT...")
        generator.process_all_patients_threaded()
    except Exception as e:
        print(f"\nThread-based processing failed: {e}")
        print("Falling back to sequential processing...")
        generator.process_all_patients_sequential_optimized()
    
    print("\nOptimized Continuous Wavelet Transform processing completed!")
    print("Key CWT Features:")
    print("- Continuous Wavelet Transform for time-frequency analysis")
    print("- Complex Morlet wavelet for optimal time-frequency resolution")
    print("- Logarithmic frequency scaling for better low-frequency resolution")
    print("- Thread-based parallelism with comprehensive error handling")
    print("- Memory management and optimization")
    print("- Robust signal processing pipeline")
    print("- High-quality spectrogram visualization")

Initialized CWT Spectrogram Generator with 4 parallel threads
Frequency range: 0.5-32 Hz with 100 frequency bins
Using wavelet: cmor1.5-1.0
Choose processing method:
1. Thread-based parallel processing (recommended)
2. Sequential processing (most stable)

Starting thread-based parallel processing with CWT...
=== Optimized EEG CWT Spectrogram Generation (Thread-based Parallel Processing) ===
Generating CWT spectrograms with thread-based parallelism for stability

Found 0 Schizophrenia patient files
Found 1 Healthy Control patient files
Using 4 parallel threads
Processing HC Patient HC_01...
  EEG data shape: (227500, 19)
  Applying bandpass filter (0.5-32 Hz)...
  Segmenting data into 3-second windows...
  Generated 296 segments from patient HC_01
  Processing segment 1/296 with CWT...
  Saved: HC_patient_HC_01_segment_000_cwt_spectrogram.png
  Processing segment 2/296 with CWT...
  Saved: HC_patient_HC_01_segment_001_cwt_spectrogram.png
  Processing segment 3/296 with CWT...
  Saved: H