In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')  # Set backend before importing pyplot
import matplotlib.pyplot as plt
from scipy import signal
from scipy.signal import butter, sosfiltfilt
import pywt
from concurrent.futures import ThreadPoolExecutor, as_completed
import multiprocessing as mp
from functools import partial
import warnings
import traceback
import gc
warnings.filterwarnings('ignore')

class OptimizedEEGScalogramGenerator:
    def __init__(self, sz_input_dir, hc_input_dir, output_dir, sampling_rate=250, n_jobs=None):
        """
        Initialize the Optimized EEG Scalogram Generator using Continuous Wavelet Transform (CWT)
        
        Parameters:
        sz_input_dir (str): Directory containing CSV files with Schizophrenia EEG data
        hc_input_dir (str): Directory containing CSV files with Healthy Control EEG data
        output_dir (str): Root directory to save scalogram images
        sampling_rate (int): EEG sampling rate in Hz (default: 250 Hz)
        n_jobs (int): Number of parallel threads (default: CPU count)
        """
        self.sz_input_dir = sz_input_dir
        self.hc_input_dir = hc_input_dir
        self.output_dir = output_dir
        self.sampling_rate = sampling_rate
        self.segment_duration = 3  # 3 seconds
        self.segment_samples = self.sampling_rate * self.segment_duration
        
        # Set number of parallel jobs (using threads instead of processes)
        if n_jobs is None:
            self.n_jobs = min(mp.cpu_count(), 8)  # Cap at 8 to avoid memory issues
        else:
            self.n_jobs = min(n_jobs, 8)
        
        # Wavelet parameters - optimized
        self.wavelet = 'cmor1.5-1.0'  # Complex Morlet wavelet
        self.freq_min = 0.5  # Minimum frequency (Hz)
        self.freq_max = 45   # Maximum frequency (Hz)
        self.num_scales = 100  # Number of scales/frequencies
        
        # Pre-compute frequency array and scales (Float32 for efficiency)
        self.frequencies = np.logspace(np.log10(self.freq_min), np.log10(self.freq_max), 
                                     self.num_scales, dtype=np.float32)
        self.scales = pywt.frequency2scale(self.wavelet, self.frequencies / self.sampling_rate) * self.sampling_rate
        self.scales = self.scales.astype(np.float32)
        
        # Pre-compute filter coefficients (SOS format for stability and speed)
        self.filter_sos = self._precompute_filter()
        
        # Create output directories
        self.sz_dir = os.path.join(output_dir, 'Schizophrenia')
        self.hc_dir = os.path.join(output_dir, 'HealthyControl')
        os.makedirs(self.sz_dir, exist_ok=True)
        os.makedirs(self.hc_dir, exist_ok=True)
        
        print(f"Initialized with {self.n_jobs} parallel threads")
        
    def _precompute_filter(self):
        """Pre-compute bandpass filter coefficients for reuse"""
        nyquist = 0.5 * self.sampling_rate
        low = self.freq_min / nyquist
        high = self.freq_max / nyquist
        sos = signal.butter(4, [low, high], btype='band', output='sos')
        return sos
    
    def bandpass_filter_optimized(self, data):
        """
        Apply pre-computed bandpass filter to EEG data (optimized version)
        """
        # Use SOS format for better numerical stability and speed
        filtered_data = sosfiltfilt(self.filter_sos, data, axis=0)
        return filtered_data.astype(np.float32)
    
    def segment_signal_vectorized(self, data):
        """
        Vectorized segmentation of EEG signal into 3-second windows
        """
        n_samples, n_channels = data.shape
        n_segments = n_samples // self.segment_samples
        
        if n_segments == 0:
            return []
        
        # Reshape for vectorized processing
        valid_samples = n_segments * self.segment_samples
        reshaped_data = data[:valid_samples, :].reshape(n_segments, self.segment_samples, n_channels)
        
        # Convert to list of segments
        segments = [segment.astype(np.float32) for segment in reshaped_data]
        return segments
    
    def compute_cwt_optimized(self, signal_data):
        """
        Optimized CWT computation for a single channel
        """
        signal_data = signal_data.astype(np.float32)
        
        try:
            coefficients, _ = pywt.cwt(signal_data, self.scales, self.wavelet, 
                                     sampling_period=1/self.sampling_rate)
            
            # Convert to power and dB scale with error handling
            power = np.abs(coefficients) ** 2
            # Avoid log(0) by adding small epsilon
            power_db = 10 * np.log10(np.maximum(power, 1e-12))
            
            return power_db.astype(np.float32)
        except Exception as e:
            print(f"CWT computation error: {e}")
            # Return zeros if computation fails
            return np.zeros((len(self.scales), len(signal_data)), dtype=np.float32)
    
    def create_combined_scalogram_safe(self, segment, patient_id, segment_idx, label):
        """
        Safe version of scalogram creation with proper error handling
        """
        try:
            # Compute CWT for all channels
            scalograms = []
            for ch in range(segment.shape[1]):
                cwt_result = self.compute_cwt_optimized(segment[:, ch])
                scalograms.append(cwt_result)
            
            # Average across all channels
            avg_scalogram = np.mean(scalograms, axis=0, dtype=np.float32)
            
            # Create time axis
            time_axis = np.linspace(0, self.segment_duration, segment.shape[0], dtype=np.float32)
            
            # Create figure with explicit cleanup
            plt.ioff()  # Turn off interactive mode
            fig, ax = plt.subplots(figsize=(10, 6), dpi=150)
            
            try:
                im = ax.imshow(
                    avg_scalogram, 
                    aspect='auto', 
                    origin='lower',
                    extent=[time_axis[0], time_axis[-1], self.frequencies[0], self.frequencies[-1]],
                    cmap='viridis',
                    interpolation='bilinear'
                )
                
                ax.set_title(f'Average EEG Scalogram (CWT) - {label} Patient {patient_id} - Segment {segment_idx}', 
                            fontsize=14, fontweight='bold')
                ax.set_xlabel('Time (s)', fontsize=12)
                ax.set_ylabel('Frequency (Hz)', fontsize=12)
                ax.set_yscale('log')
                ax.set_ylim(self.freq_min, self.freq_max)
                
                # Determine save directory and filename
                save_dir = self.sz_dir if label == 'SZ' else self.hc_dir
                filename = f'{label}_patient_{patient_id}_segment_{segment_idx:03d}_scalogram_combined.png'
                filepath = os.path.join(save_dir, filename)
                
                # Save with error handling (removed optimize parameter for compatibility)
                plt.savefig(filepath, dpi=150, bbox_inches='tight', facecolor='white', 
                           format='png')
                
                return filepath
                
            finally:
                plt.close(fig)  # Always close the figure
                plt.clf()       # Clear any remaining plots
                gc.collect()    # Force garbage collection
                
        except Exception as e:
            print(f"Error creating scalogram for {patient_id} segment {segment_idx}: {e}")
            return None
    
    def process_patient_file_safe(self, csv_file, patient_id, label):
        """
        Safe version of patient file processing with comprehensive error handling
        """
        print(f"Processing {label} Patient {patient_id}...")
        
        try:
            # Load EEG data with error handling
            try:
                data = pd.read_csv(csv_file, dtype=np.float32)
                eeg_data = data.values.astype(np.float32)
            except Exception as e:
                print(f"Error loading CSV file {csv_file}: {e}")
                return []
            
            # Handle channel count
            if eeg_data.shape[1] != 19:
                print(f"Warning: Expected 19 channels, got {eeg_data.shape[1]} for patient {patient_id}")
                if eeg_data.shape[1] > 19:
                    eeg_data = eeg_data[:, :19]
                else:
                    padding = np.zeros((eeg_data.shape[0], 19 - eeg_data.shape[1]), dtype=np.float32)
                    eeg_data = np.hstack([eeg_data, padding])
            
            print(f"  EEG data shape: {eeg_data.shape}")
            
            # Apply bandpass filter with error handling
            try:
                print(f"  Applying bandpass filter ({self.freq_min}-{self.freq_max} Hz)...")
                filtered_data = self.bandpass_filter_optimized(eeg_data)
            except Exception as e:
                print(f"Error in filtering for patient {patient_id}: {e}")
                return []
            
            # Segment the data
            print(f"  Segmenting data into {self.segment_duration}-second windows...")
            segments = self.segment_signal_vectorized(filtered_data)
            
            if not segments:
                print(f"  No segments generated for patient {patient_id}")
                return []
            
            print(f"  Generated {len(segments)} segments from patient {patient_id}")
            
            # Create scalogram images for each segment
            saved_files = []
            for idx, segment in enumerate(segments):
                print(f"  Processing segment {idx + 1}/{len(segments)}...")
                
                filepath = self.create_combined_scalogram_safe(segment, patient_id, idx, label)
                if filepath:
                    saved_files.append(filepath)
                    print(f"  Saved: {os.path.basename(filepath)}")
                else:
                    print(f"  Failed to save segment {idx}")
            
            return saved_files
            
        except Exception as e:
            print(f"Error processing patient {patient_id}: {str(e)}")
            traceback.print_exc()
            return []
    
    def get_patient_files(self, directory):
        """
        Get all CSV files from a directory
        """
        csv_files = []
        if os.path.exists(directory):
            for file in os.listdir(directory):
                if file.lower().endswith('.csv'):
                    csv_files.append(os.path.join(directory, file))
            csv_files.sort()
        else:
            print(f"Warning: Directory {directory} does not exist!")
        
        return csv_files
    
    def process_all_patients_threaded(self):
        """
        Process all patients using thread-based parallelism (safer than multiprocessing)
        """
        print("=== Optimized EEG Scalogram Generation (Thread-based Parallel Processing) ===")
        print("Generating scalograms with thread-based parallelism for stability\n")
        
        # Get patient files from both directories
        sz_patient_files = self.get_patient_files(self.sz_input_dir)
        hc_patient_files = self.get_patient_files(self.hc_input_dir)
        
        print(f"Found {len(sz_patient_files)} Schizophrenia patient files")
        print(f"Found {len(hc_patient_files)} Healthy Control patient files")
        print(f"Using {self.n_jobs} parallel threads")
        
        if len(sz_patient_files) == 0 and len(hc_patient_files) == 0:
            print("No CSV files found in the specified directories!")
            return
        
        # Prepare all patient processing tasks
        tasks = []
        
        # Add SZ patients
        for i, csv_file in enumerate(sz_patient_files, 1):
            patient_id = f"SZ_{i:02d}"
            tasks.append((csv_file, patient_id, 'SZ'))
        
        # Add HC patients
        for i, csv_file in enumerate(hc_patient_files, 1):
            patient_id = f"HC_{i:02d}"
            tasks.append((csv_file, patient_id, 'HC'))
        
        total_sz_images = 0
        total_hc_images = 0
        
        # Process patients using ThreadPoolExecutor
        with ThreadPoolExecutor(max_workers=self.n_jobs) as executor:
            # Submit all tasks
            future_to_task = {
                executor.submit(self.process_patient_file_safe, csv_file, patient_id, label): (patient_id, label)
                for csv_file, patient_id, label in tasks
            }
            
            # Process completed tasks
            for future in as_completed(future_to_task):
                patient_id, label = future_to_task[future]
                try:
                    saved_files = future.result(timeout=300)  # 5 minute timeout per patient
                    if label == 'SZ':
                        total_sz_images += len(saved_files)
                    else:
                        total_hc_images += len(saved_files)
                    print(f"  Completed {patient_id}: {len(saved_files)} scalograms generated\n")
                    
                except Exception as e:
                    print(f"Error processing {patient_id}: {e}")
                    traceback.print_exc()
        
        print("=== Summary ===")
        print(f"Total Schizophrenia scalogram images: {total_sz_images}")
        print(f"Total Healthy Control scalogram images: {total_hc_images}")
        print(f"Total scalogram images generated: {total_sz_images + total_hc_images}")
        print(f"Images saved in: {self.output_dir}")
        print(f"  - Schizophrenia scalograms: {self.sz_dir}")
        print(f"  - Healthy Control scalograms: {self.hc_dir}")
        print(f"Wavelet used: {self.wavelet}")
        print(f"Frequency range: {self.freq_min}-{self.freq_max} Hz")
        print(f"Processing completed successfully with {self.n_jobs} threads")

    def process_all_patients_sequential_optimized(self):
        """
        Optimized sequential processing as fallback option
        """
        print("=== Optimized EEG Scalogram Generation (Sequential Processing) ===")
        print("Using optimized sequential processing for maximum stability\n")
        
        # Get patient files from both directories
        sz_patient_files = self.get_patient_files(self.sz_input_dir)
        hc_patient_files = self.get_patient_files(self.hc_input_dir)
        
        print(f"Found {len(sz_patient_files)} Schizophrenia patient files")
        print(f"Found {len(hc_patient_files)} Healthy Control patient files")
        
        if len(sz_patient_files) == 0 and len(hc_patient_files) == 0:
            print("No CSV files found in the specified directories!")
            return
        
        total_sz_images = 0
        total_hc_images = 0
        
        print("\n=== Processing Schizophrenia Patients ===")
        for i, csv_file in enumerate(sz_patient_files, 1):
            patient_id = f"SZ_{i:02d}"
            saved_files = self.process_patient_file_safe(csv_file, patient_id, 'SZ')
            total_sz_images += len(saved_files)
            print(f"  Generated {len(saved_files)} scalograms for {patient_id}\n")
        
        print("=== Processing Healthy Control Patients ===")
        for i, csv_file in enumerate(hc_patient_files, 1):
            patient_id = f"HC_{i:02d}"
            saved_files = self.process_patient_file_safe(csv_file, patient_id, 'HC')
            total_hc_images += len(saved_files)
            print(f"  Generated {len(saved_files)} scalograms for {patient_id}\n")
        
        print("=== Summary ===")
        print(f"Total Schizophrenia scalogram images: {total_sz_images}")
        print(f"Total Healthy Control scalogram images: {total_hc_images}")
        print(f"Total scalogram images generated: {total_sz_images + total_hc_images}")
        print(f"Images saved in: {self.output_dir}")
        print(f"Sequential processing completed successfully")

# Example usage
if __name__ == "__main__":
    # Define paths
    sz_input_directory = "D:/Milon2/SD14/schizophrenia"
    hc_input_directory = "D:/Milon2/SD14/healthy" 
    output_directory = "D:/Milon2/scalogram/images"
    
    # Create optimized scalogram generator
    generator = OptimizedEEGScalogramGenerator(
        sz_input_dir=sz_input_directory,
        hc_input_dir=hc_input_directory,
        output_dir=output_directory,
        sampling_rate=256,
        n_jobs=4  # Use 4 threads for safer parallel processing
    )
    
    print("Choose processing method:")
    print("1. Thread-based parallel processing (recommended)")
    print("2. Sequential processing (most stable)")
    
    try:
        # Try thread-based parallel processing first
        print("\nStarting thread-based parallel processing...")
        generator.process_all_patients_threaded()
    except Exception as e:
        print(f"\nThread-based processing failed: {e}")
        print("Falling back to sequential processing...")
        generator.process_all_patients_sequential_optimized()
    
    print("\nOptimized scalogram generation completed!")
    print("Key improvements:")
    print("- Thread-based parallelism (safer than multiprocessing)")
    print("- Comprehensive error handling and recovery")
    print("- Memory management with garbage collection")
    print("- Vectorized operations where possible")
    print("- Pre-computed filter coefficients")
    print("- Optimized matplotlib settings")

Initialized with 4 parallel threads
Choose processing method:
1. Thread-based parallel processing (recommended)
2. Sequential processing (most stable)

Starting thread-based parallel processing...
=== Optimized EEG Scalogram Generation (Thread-based Parallel Processing) ===
Generating scalograms with thread-based parallelism for stability

Found 2 Schizophrenia patient files
Found 0 Healthy Control patient files
Using 4 parallel threads
Processing SZ Patient SZ_01...
Processing SZ Patient SZ_02...
  EEG data shape: (271750, 19)
  Applying bandpass filter (0.5-45 Hz)...
  Segmenting data into 3-second windows...
  Generated 353 segments from patient SZ_02
  Processing segment 1/353...
  EEG data shape: (340000, 19)
  Applying bandpass filter (0.5-45 Hz)...
  Segmenting data into 3-second windows...
  Generated 442 segments from patient SZ_01
  Processing segment 1/442...
  Saved: SZ_patient_SZ_01_segment_000_scalogram_combined.png  Saved: SZ_patient_SZ_02_segment_000_scalogram_combined.