In [2]:
!pip install numpy pandas matplotlib scipy


Defaulting to user installation because normal site-packages is not writeable



[notice] A new release of pip is available: 24.0 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [4]:
pip install PyWavelets


Defaulting to user installation because normal site-packages is not writeable
Collecting PyWavelets
  Using cached pywavelets-1.8.0-cp311-cp311-win_amd64.whl.metadata (9.0 kB)
Using cached pywavelets-1.8.0-cp311-cp311-win_amd64.whl (4.2 MB)
Installing collected packages: PyWavelets
Successfully installed PyWavelets-1.8.0
Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.0 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import signal
from scipy.signal import butter, filtfilt
import pywt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import warnings
warnings.filterwarnings('ignore')

class EEGScalogramGenerator:
    def __init__(self, sz_input_dir, hc_input_dir, output_dir, sampling_rate=250):
        """
        Initialize the EEG Scalogram Generator using Continuous Wavelet Transform (CWT)
        
        Parameters:
        sz_input_dir (str): Directory containing CSV files with Schizophrenia EEG data
        hc_input_dir (str): Directory containing CSV files with Healthy Control EEG data
        output_dir (str): Root directory to save scalogram images
        sampling_rate (int): EEG sampling rate in Hz (default: 250 Hz)
        """
        self.sz_input_dir = sz_input_dir
        self.hc_input_dir = hc_input_dir
        self.output_dir = output_dir
        self.sampling_rate = sampling_rate
        self.segment_duration = 3  # 3 seconds
        self.segment_samples = self.sampling_rate * self.segment_duration  # 750 samples for 3 seconds
        
        # Wavelet parameters
        self.wavelet = 'cmor1.5-1.0'  # Complex Morlet wavelet
        self.freq_min = 0.5  # Minimum frequency (Hz)
        self.freq_max = 45   # Maximum frequency (Hz)
        self.num_scales = 100  # Number of scales/frequencies
        
        # Create frequency array
        self.frequencies = np.logspace(np.log10(self.freq_min), np.log10(self.freq_max), self.num_scales)
        
        # Convert frequencies to scales for the wavelet
        self.scales = pywt.frequency2scale(self.wavelet, self.frequencies / self.sampling_rate) * self.sampling_rate
        
        # Create output directories
        self.sz_dir = os.path.join(output_dir, 'Schizophrenia')
        self.hc_dir = os.path.join(output_dir, 'HealthyControl')
        os.makedirs(self.sz_dir, exist_ok=True)
        os.makedirs(self.hc_dir, exist_ok=True)
        
    def bandpass_filter(self, data, lowcut=0.5, highcut=45, order=4):
        """
        Apply bandpass filter to EEG data
        
        Parameters:
        data (array): EEG signal data
        lowcut (float): Low cutoff frequency
        highcut (float): High cutoff frequency
        order (int): Filter order
        """
        nyquist = 0.5 * self.sampling_rate
        low = lowcut / nyquist
        high = highcut / nyquist
        b, a = butter(order, [low, high], btype='band')
        filtered_data = filtfilt(b, a, data, axis=0)
        return filtered_data
    
    def segment_signal(self, data):
        """
        Segment EEG signal into 3-second windows
        
        Parameters:
        data (array): EEG data with shape (samples, channels)
        
        Returns:
        list: List of 3-second segments
        """
        segments = []
        n_samples, n_channels = data.shape
        
        # Calculate number of complete 3-second segments
        n_segments = n_samples // self.segment_samples
        
        for i in range(n_segments):
            start_idx = i * self.segment_samples
            end_idx = start_idx + self.segment_samples
            segment = data[start_idx:end_idx, :]
            segments.append(segment)
            
        return segments
    
    def compute_cwt(self, signal_data):
        """
        Compute Continuous Wavelet Transform for a single channel
        
        Parameters:
        signal_data (array): 1D EEG signal data
        
        Returns:
        tuple: (coefficients, frequencies)
        """
        # Compute CWT
        coefficients, frequencies = pywt.cwt(signal_data, self.scales, self.wavelet, 
                                           sampling_period=1/self.sampling_rate)
        
        # Convert complex coefficients to power (magnitude squared)
        power = np.abs(coefficients) ** 2
        
        # Convert to dB scale
        power_db = 10 * np.log10(power + 1e-12)  # Add small value to avoid log(0)
        
        return power_db, self.frequencies
    
    def create_combined_scalogram(self, segment, patient_id, segment_idx, label):
        """
        Create a single scalogram image by averaging across all 19 channels
        
        Parameters:
        segment (array): 3-second EEG segment with shape (samples, channels)
        patient_id (str): Patient identifier
        segment_idx (int): Segment index
        label (str): 'SZ' for Schizophrenia or 'HC' for Healthy Control
        """
        # Create time axis
        time_axis = np.linspace(0, self.segment_duration, segment.shape[0])
        
        # Compute CWT for all 19 channels
        scalograms = []
        print(f"  Computing CWT for all {segment.shape[1]} channels...")
        
        for ch in range(segment.shape[1]):
            cwt_power_db, frequencies = self.compute_cwt(segment[:, ch])
            scalograms.append(cwt_power_db)
        
        # Average across all channels
        avg_scalogram = np.mean(scalograms, axis=0)
        print(f"  Averaged scalogram across {len(scalograms)} channels")
        
        # Create and save the averaged scalogram
        plt.figure(figsize=(12, 8))  # Increased width to accommodate colorbar
        
        # Create main plot
        ax = plt.gca()
        im = ax.imshow(
            avg_scalogram, 
            aspect='auto', 
            origin='lower',
            extent=[time_axis[0], time_axis[-1], frequencies[0], frequencies[-1]],
            cmap='viridis'
        )
        
        # Add colorbar on the right side
        plt.subplots_adjust(right=0.85)  # Make room for colorbar
        cbar_ax = plt.gcf().add_axes([0.87, 0.15, 0.03, 0.7])  # [left, bottom, width, height]
        plt.colorbar(im, cax=cbar_ax, label='Power (dB)')
        
        ax.set_title(f'Average EEG Scalogram (CWT) - {label} Patient {patient_id} - Segment {segment_idx}', 
                 fontsize=14, fontweight='bold')
        ax.set_xlabel('Time (s)', fontsize=12)
        ax.set_ylabel('Frequency (Hz)', fontsize=12)
        ax.set_yscale('log')  # Log scale for frequency axis
        ax.set_ylim(self.freq_min, self.freq_max)
        
        # Save the scalogram image
        if label == 'SZ':
            save_dir = self.sz_dir
        else:
            save_dir = self.hc_dir
            
        filename = f'{label}_patient_{patient_id}_segment_{segment_idx:03d}_scalogram_combined.png'
        filepath = os.path.join(save_dir, filename)
        
        plt.savefig(filepath, dpi=150, bbox_inches='tight', facecolor='white')
        plt.close()
        
        return filepath
    
    def process_patient_file(self, csv_file, patient_id, label):
        """
        Process a single patient's EEG CSV file and generate combined scalograms
        
        Parameters:
        csv_file (str): Path to CSV file
        patient_id (str): Patient identifier
        label (str): 'SZ' for Schizophrenia or 'HC' for Healthy Control
        """
        print(f"Processing {label} Patient {patient_id}...")
        
        try:
            # Load EEG data
            data = pd.read_csv(csv_file)
            
            # Convert to numpy array (assuming CSV has 19 columns for 19 channels)
            eeg_data = data.values
            
            # Ensure we have the correct number of channels
            if eeg_data.shape[1] != 19:
                print(f"Warning: Expected 19 channels, got {eeg_data.shape[1]} for patient {patient_id}")
                # Take first 19 channels if more, or pad with zeros if less
                if eeg_data.shape[1] > 19:
                    eeg_data = eeg_data[:, :19]
                else:
                    # Pad with zeros if less than 19 channels
                    padding = np.zeros((eeg_data.shape[0], 19 - eeg_data.shape[1]))
                    eeg_data = np.hstack([eeg_data, padding])
            
            print(f"  EEG data shape: {eeg_data.shape}")
            
            # Apply bandpass filter
            print(f"  Applying bandpass filter ({self.freq_min}-{self.freq_max} Hz)...")
            filtered_data = self.bandpass_filter(eeg_data)
            
            # Segment the data into 3-second windows
            print(f"  Segmenting data into {self.segment_duration}-second windows...")
            segments = self.segment_signal(filtered_data)
            
            print(f"  Generated {len(segments)} segments from patient {patient_id}")
            
            # Create combined scalogram images for each segment
            saved_files = []
            for idx, segment in enumerate(segments):
                print(f"  Processing segment {idx + 1}/{len(segments)}...")
                filepath = self.create_combined_scalogram(segment, patient_id, idx, label)
                saved_files.append(filepath)
                print(f"  Saved: {os.path.basename(filepath)}")
            
            return saved_files
            
        except Exception as e:
            print(f"Error processing patient {patient_id}: {str(e)}")
            return []
    
    def get_patient_files(self, directory):
        """
        Get all CSV files from a directory
        
        Parameters:
        directory (str): Directory path containing CSV files
        
        Returns:
        list: List of CSV file paths
        """
        csv_files = []
        if os.path.exists(directory):
            for file in os.listdir(directory):
                if file.lower().endswith('.csv'):
                    csv_files.append(os.path.join(directory, file))
            csv_files.sort()  # Sort files for consistent ordering
        else:
            print(f"Warning: Directory {directory} does not exist!")
        
        return csv_files
    
    def process_all_patients(self):
        """
        Process all patients from both directories and generate combined scalogram images
        """
        print("=== EEG Scalogram Generation (Combined Mode) ===")
        print("Generating single scalogram per segment averaged across all 19 channels\n")
        
        # Get patient files from both directories
        sz_patient_files = self.get_patient_files(self.sz_input_dir)
        hc_patient_files = self.get_patient_files(self.hc_input_dir)
        
        print(f"Found {len(sz_patient_files)} Schizophrenia patient files")
        print(f"Found {len(hc_patient_files)} Healthy Control patient files")
        
        if len(sz_patient_files) == 0 and len(hc_patient_files) == 0:
            print("No CSV files found in the specified directories!")
            return
        
        total_sz_images = 0
        total_hc_images = 0
        
        print("\n=== Processing Schizophrenia Patients ===")
        for i, csv_file in enumerate(sz_patient_files, 1):
            patient_id = f"SZ_{i:02d}"
            saved_files = self.process_patient_file(csv_file, patient_id, 'SZ')
            total_sz_images += len(saved_files)
            print(f"  Generated {len(saved_files)} combined scalograms for {patient_id}\n")
        
        print("=== Processing Healthy Control Patients ===")
        for i, csv_file in enumerate(hc_patient_files, 1):
            patient_id = f"HC_{i:02d}"
            saved_files = self.process_patient_file(csv_file, patient_id, 'HC')
            total_hc_images += len(saved_files)
            print(f"  Generated {len(saved_files)} combined scalograms for {patient_id}\n")
        
        print("=== Summary ===")
        print(f"Total Schizophrenia combined scalogram images: {total_sz_images}")
        print(f"Total Healthy Control combined scalogram images: {total_hc_images}")
        print(f"Total combined scalogram images generated: {total_sz_images + total_hc_images}")
        print(f"Images saved in: {self.output_dir}")
        print(f"  - Schizophrenia scalograms: {self.sz_dir}")
        print(f"  - Healthy Control scalograms: {self.hc_dir}")
        print(f"Wavelet used: {self.wavelet}")
        print(f"Frequency range: {self.freq_min}-{self.freq_max} Hz")
        print(f"Each scalogram represents average across all 19 EEG channels")

# Example usage
if __name__ == "__main__":
    # Define paths - separate directories for each group
    sz_input_directory = "D:/Milon2/SD14/schizophrenia"  # Directory with SZ patient CSV files
    hc_input_directory = "D:/Milon2/SD14/healthy"  # Directory with HC patient CSV files
    output_directory = "D:/Milon2/scalogram/images"  # Output directory for scalogram images
    
    # Create scalogram generator
    generator = EEGScalogramGenerator(
        sz_input_dir=sz_input_directory,
        hc_input_dir=hc_input_directory,
        output_dir=output_directory,
        sampling_rate=256  # Adjust based on your EEG sampling rate
    )
    
    # Process all patients and generate ONLY combined scalograms
    print("Generating combined scalograms (averaged across all 19 channels)...")
    generator.process_all_patients()
    
    print("\nCombined scalogram generation completed!")
    print(f"Check the output directory: {output_directory}")
    print("- Schizophrenia combined scalograms in: Schizophrenia/")
    print("- Healthy Control combined scalograms in: HealthyControl/")
    print("\nEach scalogram image represents the average time-frequency representation")
    print("computed across all 19 EEG channels using Continuous Wavelet Transform (CWT).")

Generating combined scalograms (averaged across all 19 channels)...
=== EEG Scalogram Generation (Combined Mode) ===
Generating single scalogram per segment averaged across all 19 channels

Found 14 Schizophrenia patient files
Found 14 Healthy Control patient files

=== Processing Schizophrenia Patients ===
Processing SZ Patient SZ_01...
  EEG data shape: (211250, 19)
  Applying bandpass filter (0.5-45 Hz)...
  Segmenting data into 3-second windows...
  Generated 275 segments from patient SZ_01
  Processing segment 1/275...
  Computing CWT for all 19 channels...
