In [None]:
# Cell 1: Import all required libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import signal
from scipy.io import loadmat
import pyedflib
import mne
from mne import filtering
import torch
from torch.utils.data import Dataset, DataLoader
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU device:", torch.cuda.get_device_name(0))
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [None]:
# Cell 2: Configuration and dataset exploration
class CHBMITConfig:
    # Dataset paths (adjust based on your actual path)
    DATA_PATH = "chb-mit-scalp-eeg-database-1.0.0"  # Update this path
    
    # Signal parameters
    TARGET_SAMPLING_RATE = 256  # Hz
    WINDOW_SIZE_PREDICTION = 10  # seconds
    WINDOW_SIZE_DETECTION = 4   # seconds
    WINDOW_STRIDE = 1           # seconds
    
    # Preprocessing parameters
    HIGH_PASS_FREQ = 0.5  # Hz
    LOW_PASS_FREQ = 70    # Hz
    NOTCH_FREQ = 60       # Hz (for US power line)
    
    # Pre-ictal period definition
    PRE_ICTAL_PERIOD = 300  # 5 minutes before seizure
    
    # Train/val/test split
    TRAIN_RATIO = 0.7
    VAL_RATIO = 0.15
    TEST_RATIO = 0.15

config = CHBMITConfig()

# Explore dataset structure
def explore_dataset_structure(data_path):
    """Explore the CHB-MIT dataset structure"""
    patients = []
    
    for item in os.listdir(data_path):
        item_path = os.path.join(data_path, item)
        if os.path.isdir(item_path):
            patients.append(item)
    
    print(f"Found {len(patients)} patients/subjects")
    print("Patients:", sorted(patients))
    
    # Explore one patient to understand file structure
    if patients:
        sample_patient = patients[0]
        sample_path = os.path.join(data_path, sample_patient)
        print(f"\nExploring {sample_patient}:")
        
        for file in os.listdir(sample_path):
            if file.endswith('.edf') or file.endswith('.seizures') or file.endswith('.txt'):
                print(f"  {file}")
    
    return patients

patients = explore_dataset_structure(config.DATA_PATH)

In [None]:
# Cell 3: Utility functions for data parsing
def parse_summary_file(patient_path):
    """Parse the summary file for seizure information"""
    summary_files = [f for f in os.listdir(patient_path) if 'summary' in f.lower() and f.endswith('.txt')]
    
    if not summary_files:
        print(f"No summary file found for {patient_path}")
        return None
    
    summary_file = os.path.join(patient_path, summary_files[0])
    seizure_info = {}
    
    try:
        with open(summary_file, 'r') as f:
            lines = f.readlines()
        
        current_file = None
        for line in lines:
            line = line.strip()
            if line.startswith('File Name:'):
                current_file = line.split(': ')[1]
                seizure_info[current_file] = []
            elif line.startswith('Seizure Start Time:'):
                parts = line.split()
                start_time = int(parts[3])
                end_time = int(parts[5]) if len(parts) > 5 else start_time + 1
                seizure_info[current_file].append((start_time, end_time))
                
    except Exception as e:
        print(f"Error parsing summary file {summary_file}: {e}")
    
    return seizure_info

def parse_seizure_file(seizure_file_path):
    """Parse individual seizure files"""
    seizures = []
    try:
        with open(seizure_file_path, 'r') as f:
            lines = f.readlines()
        
        for line in lines:
            if line.strip():
                parts = line.strip().split()
                if len(parts) >= 2:
                    start_time = int(parts[0])
                    end_time = int(parts[1])
                    seizures.append((start_time, end_time))
    except:
        pass
    
    return seizures

def get_all_seizure_info(patient_path):
    """Get all seizure information for a patient"""
    seizure_info = parse_summary_file(patient_path)
    
    # Also check individual seizure files
    seizure_files = [f for f in os.listdir(patient_path) if f.endswith('.seizures')]
    
    for seizure_file in seizure_files:
        file_base = seizure_file.replace('.seizures', '')
        seizures = parse_seizure_file(os.path.join(patient_path, seizure_file))
        if seizures:
            if file_base not in seizure_info:
                seizure_info[file_base] = []
            seizure_info[file_base].extend(seizures)
    
    return seizure_info

In [None]:
# Cell 4: EEG Data Loader and Preprocessor
class CHBMITDataLoader:
    def __init__(self, config):
        self.config = config
        self.standard_channels = None
        
    def load_edf_file(self, file_path):
        """Load EDF file and return data and info"""
        try:
            # Try using pyedflib first
            f = pyedflib.EdfReader(file_path)
            n_channels = f.signals_in_file
            channel_names = f.getSignalLabels()
            fs = f.getSampleFrequency(0)  # Assuming same sampling rate for all channels
            
            # Read data
            data = np.zeros((n_channels, f.getNSamples()[0]))
            for i in range(n_channels):
                data[i, :] = f.readSignal(i)
            
            f.close()
            return data, channel_names, fs
            
        except Exception as e:
            print(f"pyedflib failed for {file_path}: {e}. Trying MNE...")
            try:
                raw = mne.io.read_raw_edf(file_path, preload=True, verbose=False)
                data = raw.get_data()
                channel_names = raw.ch_names
                fs = raw.info['sfreq']
                return data, channel_names, fs
            except Exception as e2:
                print(f"MNE also failed for {file_path}: {e2}")
                return None, None, None
    
    def preprocess_signal(self, data, original_fs, channel_names):
        """Preprocess EEG signal: resample, filter, standardize channels"""
        n_channels, n_samples = data.shape
        
        # Step 1: Resample to target frequency if needed
        if original_fs != self.config.TARGET_SAMPLING_RATE:
            data_resampled = np.zeros((n_channels, 
                                     int(n_samples * self.config.TARGET_SAMPLING_RATE / original_fs)))
            for i in range(n_channels):
                data_resampled[i] = signal.resample(data[i], 
                                                  int(n_samples * self.config.TARGET_SAMPLING_RATE / original_fs))
            data = data_resampled
            fs = self.config.TARGET_SAMPLING_RATE
        else:
            fs = original_fs
        
        # Step 2: Standardize channel names and order
        data = self.standardize_channels(data, channel_names, fs)
        if data is None:
            return None, None
        
        # Step 3: Apply filters
        data_filtered = self.apply_filters(data, fs)
        
        return data_filtered, fs
    
    def standardize_channels(self, data, channel_names, fs):
        """Standardize channel names and select common channels"""
        if self.standard_channels is None:
            # Define standard channel set based on common EEG montage
            self.standard_channels = [
                'FP1-F7', 'F7-T7', 'T7-P7', 'P7-O1', 'FP1-F3', 'F3-C3', 'C3-P3', 'P3-O1',
                'FP2-F4', 'F4-C4', 'C4-P4', 'P4-O2', 'FP2-F8', 'F8-T8', 'T8-P8', 'P8-O2',
                'FZ-CZ', 'CZ-PZ'
            ]
        
        # Map actual channel names to standard names
        channel_mapping = {}
        for i, ch in enumerate(channel_names):
            ch_upper = ch.upper()
            # Remove spaces and special characters
            ch_clean = ch_upper.replace(' ', '').replace('-', '').replace('_', '')
            
            for std_ch in self.standard_channels:
                std_clean = std_ch.replace('-', '')
                if std_clean in ch_clean or ch_clean in std_clean:
                    channel_mapping[std_ch] = i
                    break
        
        # Create standardized data array
        standardized_data = np.zeros((len(self.standard_channels), data.shape[1]))
        channels_found = []
        
        for j, std_ch in enumerate(self.standard_channels):
            if std_ch in channel_mapping:
                standardized_data[j] = data[channel_mapping[std_ch]]
                channels_found.append(std_ch)
            else:
                # If channel not found, use zeros (will be handled later)
                standardized_data[j] = np.zeros(data.shape[1])
        
        print(f"Found {len(channels_found)}/{len(self.standard_channels)} standard channels")
        return standardized_data
    
    def apply_filters(self, data, fs):
        """Apply bandpass and notch filters"""
        # Bandpass filter
        nyquist = fs / 2
        low = self.config.HIGH_PASS_FREQ / nyquist
        high = self.config.LOW_PASS_FREQ / nyquist
        
        if low > 0 and high < 1:
            b, a = signal.butter(4, [low, high], btype='band')
            data_filtered = signal.filtfilt(b, a, data, axis=1)
        else:
            data_filtered = data
        
        # Notch filter for power line interference
        if self.config.NOTCH_FREQ > 0:
            notch_freq = self.config.NOTCH_FREQ
            quality = 30  # Quality factor
            b, a = signal.iirnotch(notch_freq, quality, fs)
            data_filtered = signal.filtfilt(b, a, data_filtered, axis=1)
        
        return data_filtered
    
    def normalize_signal(self, data):
        """Normalize signal using robust scaling"""
        # Use median and MAD for robust normalization
        median = np.median(data, axis=1, keepdims=True)
        mad = np.median(np.abs(data - median), axis=1, keepdims=True)
        
        # Avoid division by zero
        mad[mad == 0] = 1.0
        
        normalized = (data - median) / mad
        return normalized

In [None]:
# Cell 5: Label Generation and Dataset Creation
class SeizureDatasetGenerator:
    def __init__(self, config):
        self.config = config
        self.data_loader = CHBMITDataLoader(config)
    
    def generate_labels_for_file(self, data_length, fs, seizure_intervals, file_duration):
        """Generate detection and prediction labels for a file"""
        n_samples = data_length
        total_seconds = n_samples / fs
        
        # Initialize labels
        detection_labels = np.zeros(n_samples)
        prediction_labels = np.zeros(n_samples)
        
        # Mark seizure periods for detection
        for start_sec, end_sec in seizure_intervals:
            start_sample = int(start_sec * fs)
            end_sample = min(int(end_sec * fs), n_samples)
            detection_labels[start_sample:end_sample] = 1
            
            # Mark pre-ictal period for prediction
            pre_ictal_start = max(0, start_sec - self.config.PRE_ICTAL_PERIOD)
            pre_ictal_end = start_sec
            pre_ictal_start_sample = int(pre_ictal_start * fs)
            pre_ictal_end_sample = int(pre_ictal_end * fs)
            prediction_labels[pre_ictal_start_sample:pre_ictal_end_sample] = 1
        
        return detection_labels, prediction_labels
    
    def create_windowed_dataset(self, data, detection_labels, prediction_labels, fs, window_type="prediction"):
        """Create windowed dataset from continuous data"""
        if window_type == "prediction":
            window_size = self.config.WINDOW_SIZE_PREDICTION
        else:  # detection
            window_size = self.config.WINDOW_SIZE_DETECTION
            
        window_samples = int(window_size * fs)
        stride_samples = int(self.config.WINDOW_STRIDE * fs)
        
        windows = []
        det_labels = []
        pred_labels = []
        metadata = []  # Store start time, file info, etc.
        
        n_samples = data.shape[1]
        
        for start_idx in range(0, n_samples - window_samples + 1, stride_samples):
            end_idx = start_idx + window_samples
            
            window_data = data[:, start_idx:end_idx]
            
            # Get majority label for the window
            det_label = np.mean(detection_labels[start_idx:end_idx]) > 0.5
            pred_label = np.mean(prediction_labels[start_idx:end_idx]) > 0.5
            
            # Only include windows with sufficient signal quality (optional)
            if self.is_valid_window(window_data):
                windows.append(window_data)
                det_labels.append(det_label)
                pred_labels.append(pred_label)
                metadata.append({
                    'start_sample': start_idx,
                    'end_sample': end_idx,
                    'start_time': start_idx / fs,
                    'end_time': end_idx / fs
                })
        
        return np.array(windows), np.array(det_labels), np.array(pred_labels), metadata
    
    def is_valid_window(self, window_data):
        """Check if window has valid EEG data"""
        # Check for flat lines (equipment failure)
        if np.max(window_data) - np.min(window_data) < 1e-6:
            return False
        
        # Check for excessive amplitude (artifacts)
        if np.max(np.abs(window_data)) > 1000:  # Adjust threshold as needed
            return False
            
        return True
    
    def process_patient(self, patient_path):
        """Process all files for a single patient"""
        all_windows = []
        all_det_labels = []
        all_pred_labels = []
        all_metadata = []
        
        seizure_info = get_all_seizure_info(patient_path)
        if not seizure_info:
            print(f"No seizure information found for {patient_path}")
            return None, None, None, None
        
        edf_files = [f for f in os.listdir(patient_path) if f.endswith('.edf')]
        
        for edf_file in sorted(edf_files):
            file_path = os.path.join(patient_path, edf_file)
            print(f"Processing {edf_file}...")
            
            # Load and preprocess data
            data, channel_names, original_fs = self.data_loader.load_edf_file(file_path)
            if data is None:
                continue
                
            processed_data, fs = self.data_loader.preprocess_signal(data, original_fs, channel_names)
            if processed_data is None:
                continue
            
            # Normalize data
            normalized_data = self.data_loader.normalize_signal(processed_data)
            
            # Get seizure intervals for this file
            file_key = edf_file.replace('.edf', '')
            seizure_intervals = seizure_info.get(file_key, [])
            
            # Generate labels
            detection_labels, prediction_labels = self.generate_labels_for_file(
                normalized_data.shape[1], fs, seizure_intervals, 
                normalized_data.shape[1] / fs
            )
            
            # Create windows for prediction task
            windows, det_labels, pred_labels, metadata = self.create_windowed_dataset(
                normalized_data, detection_labels, prediction_labels, fs, "prediction"
            )
            
            # Add file info to metadata
            for meta in metadata:
                meta['file'] = edf_file
                meta['patient'] = os.path.basename(patient_path)
            
            all_windows.append(windows)
            all_det_labels.append(det_labels)
            all_pred_labels.append(pred_labels)
            all_metadata.extend(metadata)
            
            print(f"  Created {len(windows)} windows, "
                  f"Seizure: {np.sum(det_labels)}, Pre-ictal: {np.sum(pred_labels)}")
        
        if all_windows:
            all_windows = np.vstack(all_windows)
            all_det_labels = np.concatenate(all_det_labels)
            all_pred_labels = np.concatenate(all_pred_labels)
            
            return all_windows, all_det_labels, all_pred_labels, all_metadata
        else:
            return None, None, None, None

In [None]:
# Cell 6: Complete Dataset Preprocessing Pipeline
class CHBMITPreprocessor:
    def __init__(self, config):
        self.config = config
        self.dataset_generator = SeizureDatasetGenerator(config)
    
    def preprocess_complete_dataset(self):
        """Preprocess entire CHB-MIT dataset"""
        patients = []
        for item in os.listdir(self.config.DATA_PATH):
            item_path = os.path.join(self.config.DATA_PATH, item)
            if os.path.isdir(item_path):
                patients.append(item_path)
        
        all_patient_data = {}
        
        for patient_path in sorted(patients):
            patient_id = os.path.basename(patient_path)
            print(f"\n{'='*50}")
            print(f"Processing patient: {patient_id}")
            print(f"{'='*50}")
            
            windows, det_labels, pred_labels, metadata = self.dataset_generator.process_patient(patient_path)
            
            if windows is not None:
                all_patient_data[patient_id] = {
                    'windows': windows,
                    'detection_labels': det_labels,
                    'prediction_labels': pred_labels,
                    'metadata': metadata
                }
                
                print(f"Patient {patient_id}: {windows.shape[0]} windows, "
                      f"{np.sum(det_labels)} seizure windows, "
                      f"{np.sum(pred_labels)} pre-ictal windows")
                
                # Save individual patient data
                self.save_patient_data(patient_id, windows, det_labels, pred_labels, metadata)
            else:
                print(f"Skipping patient {patient_id} - no valid data")
        
        return all_patient_data
    
    def save_patient_data(self, patient_id, windows, det_labels, pred_labels, metadata):
        """Save preprocessed data for a patient"""
        save_dir = f"./preprocessed_data"
        os.makedirs(save_dir, exist_ok=True)
        
        save_path = f"{save_dir}/{patient_id}_preprocessed.npz"
        
        np.savez_compressed(
            save_path,
            windows=windows.astype(np.float32),
            detection_labels=det_labels.astype(np.int8),
            prediction_labels=pred_labels.astype(np.int8),
            metadata=metadata
        )
        
        print(f"Saved preprocessed data for {patient_id} to {save_path}")
    
    def load_preprocessed_data(self, patient_id):
        """Load preprocessed data for a patient"""
        load_path = f"./preprocessed_data/{patient_id}_preprocessed.npz"
        
        if os.path.exists(load_path):
            data = np.load(load_path, allow_pickle=True)
            return {
                'windows': data['windows'],
                'detection_labels': data['detection_labels'],
                'prediction_labels': data['prediction_labels'],
                'metadata': data['metadata']
            }
        else:
            return None
    
    def create_train_val_test_split(self, all_patient_data, strategy='patient_wise'):
        """Create dataset splits"""
        if strategy == 'patient_wise':
            patient_ids = list(all_patient_data.keys())
            np.random.shuffle(patient_ids)
            
            n_train = int(len(patient_ids) * self.config.TRAIN_RATIO)
            n_val = int(len(patient_ids) * self.config.VAL_RATIO)
            
            train_patients = patient_ids[:n_train]
            val_patients = patient_ids[n_train:n_train + n_val]
            test_patients = patient_ids[n_train + n_val:]
            
            train_data = {pid: all_patient_data[pid] for pid in train_patients}
            val_data = {pid: all_patient_data[pid] for pid in val_patients}
            test_data = {pid: all_patient_data[pid] for pid in test_patients}
            
            return train_data, val_data, test_data
        
        else:
            # Implement other splitting strategies if needed
            pass

In [None]:
# Cell 7: Visualization and Analysis Functions
class DataAnalyzer:
    def __init__(self, config):
        self.config = config
    
    def plot_data_distribution(self, all_patient_data):
        """Plot distribution of data across patients"""
        patient_stats = []
        
        for patient_id, data in all_patient_data.items():
            n_windows = len(data['detection_labels'])
            n_seizure = np.sum(data['detection_labels'])
            n_preictal = np.sum(data['prediction_labels'])
            
            patient_stats.append({
                'patient': patient_id,
                'total_windows': n_windows,
                'seizure_windows': n_seizure,
                'preictal_windows': n_preictal,
                'seizure_ratio': n_seizure / n_windows * 100,
                'preictal_ratio': n_preictal / n_windows * 100
            })
        
        stats_df = pd.DataFrame(patient_stats)
        
        # Create visualization
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # Total windows per patient
        axes[0,0].bar(stats_df['patient'], stats_df['total_windows'])
        axes[0,0].set_title('Total Windows per Patient')
        axes[0,0].tick_params(axis='x', rotation=45)
        
        # Seizure windows
        axes[0,1].bar(stats_df['patient'], stats_df['seizure_windows'], color='red')
        axes[0,1].set_title('Seizure Windows per Patient')
        axes[0,1].tick_params(axis='x', rotation=45)
        
        # Pre-ictal windows
        axes[1,0].bar(stats_df['patient'], stats_df['preictal_windows'], color='orange')
        axes[1,0].set_title('Pre-ictal Windows per Patient')
        axes[1,0].tick_params(axis='x', rotation=45)
        
        # Class ratios
        axes[1,1].bar(stats_df['patient'], stats_df['seizure_ratio'], alpha=0.7, label='Seizure %')
        axes[1,1].bar(stats_df['patient'], stats_df['preictal_ratio'], alpha=0.7, label='Pre-ictal %')
        axes[1,1].set_title('Class Distribution (%)')
        axes[1,1].tick_params(axis='x', rotation=45)
        axes[1,1].legend()
        
        plt.tight_layout()
        plt.show()
        
        return stats_df
    
    def plot_sample_windows(self, patient_data, patient_id, n_samples=5):
        """Plot sample EEG windows"""
        windows = patient_data['windows']
        det_labels = patient_data['detection_labels']
        pred_labels = patient_data['prediction_labels']
        
        # Find samples of each class
        seizure_idx = np.where(det_labels == 1)[0]
        preictal_idx = np.where(pred_labels == 1)[0]
        normal_idx = np.where((det_labels == 0) & (pred_labels == 0))[0]
        
        fig, axes = plt.subplots(3, n_samples, figsize=(15, 9))
        
        # Plot seizure windows
        if len(seizure_idx) > 0:
            for i, idx in enumerate(seizure_idx[:n_samples]):
                if i < n_samples:
                    axes[0,i].plot(windows[idx, 0, :])  # Plot first channel
                    axes[0,i].set_title(f'Seizure Window {i+1}')
        
        # Plot pre-ictal windows
        if len(preictal_idx) > 0:
            for i, idx in enumerate(preictal_idx[:n_samples]):
                if i < n_samples:
                    axes[1,i].plot(windows[idx, 0, :])
                    axes[1,i].set_title(f'Pre-ictal Window {i+1}')
        
        # Plot normal windows
        if len(normal_idx) > 0:
            for i, idx in enumerate(normal_idx[:n_samples]):
                if i < n_samples:
                    axes[2,i].plot(windows[idx, 0, :])
                    axes[2,i].set_title(f'Normal Window {i+1}')
        
        plt.tight_layout()
        plt.show()
    
    def compute_dataset_statistics(self, all_patient_data):
        """Compute comprehensive dataset statistics"""
        total_windows = 0
        total_seizure = 0
        total_preictal = 0
        
        for patient_id, data in all_patient_data.items():
            total_windows += len(data['detection_labels'])
            total_seizure += np.sum(data['detection_labels'])
            total_preictal += np.sum(data['prediction_labels'])
        
        stats = {
            'total_patients': len(all_patient_data),
            'total_windows': total_windows,
            'total_seizure_windows': total_seizure,
            'total_preictal_windows': total_preictal,
            'seizure_ratio': total_seizure / total_windows * 100,
            'preictal_ratio': total_preictal / total_windows * 100,
            'class_imbalance_ratio': (total_windows - total_seizure) / total_seizure if total_seizure > 0 else 0
        }
        
        print("Dataset Statistics:")
        print(f"Total Patients: {stats['total_patients']}")
        print(f"Total Windows: {stats['total_windows']:,}")
        print(f"Seizure Windows: {stats['total_seizure_windows']:,} ({stats['seizure_ratio']:.2f}%)")
        print(f"Pre-ictal Windows: {stats['total_preictal_windows']:,} ({stats['preictal_ratio']:.2f}%)")
        print(f"Class Imbalance Ratio: {stats['class_imbalance_ratio']:.2f}:1")
        
        return stats

In [None]:
# Cell 8: Main Preprocessing Execution
def main():
    print("Starting CHB-MIT Scalp EEG Dataset Preprocessing")
    print("=" * 60)
    
    # Initialize preprocessor
    preprocessor = CHBMITPreprocessor(config)
    analyzer = DataAnalyzer(config)
    
    # Check if preprocessed data already exists
    preprocessed_dir = "./preprocessed_data"
    if os.path.exists(preprocessed_dir) and len(os.listdir(preprocessed_dir)) > 0:
        print("Loading existing preprocessed data...")
        all_patient_data = {}
        for file in os.listdir(preprocessed_dir):
            if file.endswith('_preprocessed.npz'):
                patient_id = file.replace('_preprocessed.npz', '')
                patient_data = preprocessor.load_preprocessed_data(patient_id)
                if patient_data is not None:
                    all_patient_data[patient_id] = patient_data
        
        print(f"Loaded data for {len(all_patient_data)} patients")
    else:
        print("Preprocessing dataset from scratch...")
        all_patient_data = preprocessor.preprocess_complete_dataset()
    
    if not all_patient_data:
        print("No data processed. Please check dataset path and structure.")
        return
    
    # Analyze and visualize data
    print("\nAnalyzing dataset...")
    stats_df = analyzer.plot_data_distribution(all_patient_data)
    overall_stats = analyzer.compute_dataset_statistics(all_patient_data)
    
    # Plot sample windows from first patient
    first_patient = list(all_patient_data.keys())[0]
    print(f"\nPlotting sample windows for {first_patient}...")
    analyzer.plot_sample_windows(all_patient_data[first_patient], first_patient)
    
    # Create dataset splits
    print("\nCreating dataset splits...")
    train_data, val_data, test_data = preprocessor.create_train_val_test_split(all_patient_data)
    
    print(f"Training patients: {len(train_data)}")
    print(f"Validation patients: {len(val_data)}")
    print(f"Test patients: {len(test_data)}")
    
    # Save dataset splits
    split_info = {
        'train_patients': list(train_data.keys()),
        'val_patients': list(val_data.keys()),
        'test_patients': list(test_data.keys())
    }
    
    np.savez('./preprocessed_data/dataset_splits.npz', **split_info)
    print("Dataset splits saved to ./preprocessed_data/dataset_splits.npz")
    
    # Print final summary
    print("\n" + "=" * 60)
    print("PREPROCESSING COMPLETED SUCCESSFULLY!")
    print("=" * 60)
    print(f"Total patients processed: {len(all_patient_data)}")
    print(f"Total windows: {overall_stats['total_windows']:,}")
    print(f"Dataset ready for model training!")
    
    return all_patient_data, train_data, val_data, test_data

# Execute main preprocessing
if __name__ == "__main__":
    all_data, train_data, val_data, test_data = main()