In [None]:
import numpy as np
import pandas as pd
from scipy.io import loadmat
import os
import mne
from pathlib import Path

# ================= CONFIGURATION =================
base_path = r"D:\impress_project\eeg_signals\data\LRMI-21679035\organized_data_v2"
raw_data_dir = os.path.join(base_path, "raw_data")
preprocessed_dir = os.path.join(base_path, "paper_preprocessed")
output_dir = os.path.join(base_path, "labelled_paper_preprocessed")

# Create output directory
os.makedirs(output_dir, exist_ok=True)

# ================= CHANNEL INFORMATION FROM PAPER =================
# From paper Table 2 (0-based indexing for Python)
paper_channels = {
    0: 'Fp1', 1: 'Fp2', 2: 'Pz', 3: 'P3', 4: 'P4', 5: 'F7', 6: 'F8', 7: 'FC2', 
    8: 'FC3', 9: 'FC4', 10: 'FT7', 11: 'FT8', 12: 'Cz', 13: 'C3', 14: 'C4',
    15: 'T3', 16: 'T4', 17: 'CPz', 18: 'CP3', 19: 'CP4', 20: 'TP7', 21: 'TP8',
    22: 'Pz', 23: 'P3', 24: 'P4', 25: 'T5', 26: 'T6', 27: 'Oz', 28: 'O1',
    29: 'O2', 30: 'HEOL', 31: 'VEOR', 32: 'Marker'
}

# Classify channels
channels_all = [paper_channels[i] for i in range(33)]
channels_eeg = [paper_channels[i] for i in range(30)]  # First 30 are EEG
channels_eog = [paper_channels[30], paper_channels[31]]  # HEOL, VEOR
channels_marker = [paper_channels[32]]  # Marker channel

print(f"Channels All: {len(channels_all)}")
print(f"Channels EEG: {len(channels_eeg)}")
print(f"Channels EOG: {len(channels_eog)}")
print(f"Channels Marker: {len(channels_marker)}")

# ================= FUNCTION TO PROCESS ONE SUBJECT =================
def process_preprocessed_edf_file(subject_id, raw_data_dir, preprocessed_dir, output_dir, sampling_rate=500):
    """
    Process preprocessed .edf file for one subject and save in the specified format
    
    Parameters:
    -----------
    subject_id : str
        Subject ID (e.g., '01', '02')
    raw_data_dir : str
        Directory containing raw .mat files
    preprocessed_dir : str
        Directory containing preprocessed .edf files
    output_dir : str
        Output directory for saved results
    sampling_rate : int
        Sampling rate in Hz (500 from paper)
    
    Returns:
    --------
    result_dict : dict
        Dictionary with all required keys
    """
    
    print(f"\nProcessing preprocessed subject: {subject_id}")
    
    # Construct file path for raw .mat file to get labels
    mat_filename = f"sub-{subject_id}_task-motor-imagery_eeg.mat"
    mat_path = os.path.join(raw_data_dir, mat_filename)
    
    if not os.path.exists(mat_path):
        print(f"  ‚ùå Raw .mat file not found: {mat_path}")
        return None
    
    try:
        # Load .mat file to get labels
        data = loadmat(mat_path)
        eeg_struct = data['eeg'][0, 0]
        labels = eeg_struct['label'].flatten()  # Shape: (40,)
        
        print(f"  ‚úì Loaded labels from .mat file: {labels.shape}")
        print(f"    Left hand (1): {np.sum(labels == 1)}, Right hand (2): {np.sum(labels == 2)}")
        
    except Exception as e:
        print(f"  ‚ùå Error loading .mat file for {subject_id}: {e}")
        return None
    
    # Construct preprocessed .edf file path
    edf_filename = f"sub-{subject_id}_task-motor-imagery_eeg.edf"
    edf_path = os.path.join(preprocessed_dir, edf_filename)
    
    if not os.path.exists(edf_path):
        print(f"  ‚ùå Preprocessed .edf file not found: {edf_path}")
        return None
    
    try:
        # Load .edf file
        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
        data_edf, times = raw[:, :]
        
        print(f"  ‚úì Loaded .edf data: {data_edf.shape}")
        print(f"    Sampling rate: {raw.info['sfreq']} Hz")
        print(f"    Duration: {raw.times[-1]:.2f} seconds")
        
        # ================= SEGMENT THE CONTINUOUS DATA =================
        # IMPORTANT: Based on debugging, the .edf files contain CONTINUOUS data
        # Shape is (33, 160000) which is 320 seconds at 500Hz
        # This is 40 trials √ó 8 seconds (2s instruction + 4s MI + 2s break)
        # We need to extract only the 4s MI portion from each 8s trial
        
        file_sampling_rate = raw.info['sfreq']
        n_trials = 40
        
        # Time parameters (in seconds)
        instruction_duration = 2.0  # First 2 seconds to remove
        mi_duration = 4.0  # Middle 4 seconds to keep
        break_duration = 2.0  # Last 2 seconds to remove
        trial_duration = instruction_duration + mi_duration + break_duration  # 8 seconds
        
        # Convert to samples
        instruction_samples = int(instruction_duration * file_sampling_rate)
        mi_samples = int(mi_duration * file_sampling_rate)
        trial_samples = int(trial_duration * file_sampling_rate)
        
        print(f"  ‚úì Segmenting continuous data:")
        print(f"    - Total trials: {n_trials}")
        print(f"    - Trial duration: {trial_duration}s = {trial_samples} samples")
        print(f"    - MI portion: {mi_duration}s = {mi_samples} samples")
        print(f"    - Total data available: {data_edf.shape[1]} samples")
        
        # Check if we have enough data
        total_samples_needed = n_trials * trial_samples
        if data_edf.shape[1] < total_samples_needed:
            print(f"  ‚ùå Not enough data for {n_trials} trials")
            print(f"    Need: {total_samples_needed} samples")
            print(f"    Have: {data_edf.shape[1]} samples")
            return None
        
        # Extract MI segments from each trial
        segments_mi_all = []
        
        for trial_idx in range(n_trials):
            # Calculate sample indices for this trial's MI segment
            trial_start = trial_idx * trial_samples
            mi_start = trial_start + instruction_samples  # Skip 2s instruction
            mi_end = mi_start + mi_samples  # 4s MI
            
            # Extract the MI segment (all channels)
            mi_segment = data_edf[:, mi_start:mi_end]
            segments_mi_all.append(mi_segment)
        
        # Convert to numpy array
        segments_mi_all = np.array(segments_mi_all)  # Shape: (n_trials, n_channels, mi_samples)
        
        # Extract EEG channels only (first 30 channels)
        n_eeg_channels = min(30, segments_mi_all.shape[1])
        segments_mi_eeg = segments_mi_all[:, :n_eeg_channels, :]
        
        print(f"  ‚úì Segmented data:")
        print(f"    - All channels shape: {segments_mi_all.shape}")
        print(f"    - EEG only shape: {segments_mi_eeg.shape}")
        
        # Create event_info DataFrame for Motor Imagery only
        event_info = []
        for trial_idx in range(n_trials):
            label = labels[trial_idx]
            hand = "left" if label == 1 else "right"
            
            # Calculate actual timing
            trial_start_time = trial_idx * trial_duration
            mi_start_time = trial_start_time + instruction_duration
            
            event_info.append({
                'trial_idx': trial_idx,
                'onset_sec': mi_start_time,
                'duration_sec': mi_duration,
                'label': label,
                'hand': hand,
                'condition': 'motor_imagery',
                'sample_start': trial_idx * trial_samples + instruction_samples,
                'sample_end': trial_idx * trial_samples + instruction_samples + mi_samples
            })
        
        event_info_df = pd.DataFrame(event_info)
        
        # Create result dictionary with MI data only
        result_dict = {
            'segments_mi_all': segments_mi_all.astype(np.float32),  # Motor Imagery only, all channels
            'segments_mi_eeg': segments_mi_eeg.astype(np.float32),  # Motor Imagery only, EEG channels
            'labels': labels.astype(np.int32),
            'channels_all': channels_all,
            'channels_eeg': channels_eeg,
            'channels_eog': channels_eog,
            'channels_marker': channels_marker,
            'sampling_rate': file_sampling_rate,
            'subject_id': f"sub-{subject_id}",
            'event_info': event_info_df.to_dict('records'),  # Save as list of dicts for JSON compatibility
            'mi_info': {
                'mi_duration_samples': mi_samples,
                'mi_duration_seconds': mi_duration,
                'trial_duration_seconds': trial_duration,
                'n_trials': n_trials
            }
        }
        
        # Save to .npz file
        output_path = os.path.join(output_dir, f"sub-{subject_id}_preprocessed_motor_imagery.npz")
        np.savez_compressed(output_path, **result_dict)
        
        print(f"  ‚úÖ Saved to: {output_path}")
        
        # Print summary
        left_count = np.sum(labels == 1)
        right_count = np.sum(labels == 2)
        print(f"  üìä Summary: {left_count} left hand, {right_count} right hand trials")
        
        return result_dict
        
    except Exception as e:
        print(f"  ‚ùå Error processing preprocessed file for {subject_id}: {e}")
        import traceback
        traceback.print_exc()
        return None

# ================= PROCESS ALL SUBJECTS =================
def process_all_preprocessed_subjects(raw_data_dir, preprocessed_dir, output_dir, max_subjects=None):
    """
    Process all subjects in the preprocessed directory
    
    Parameters:
    -----------
    raw_data_dir : str
        Directory containing raw .mat files
    preprocessed_dir : str
        Directory containing preprocessed .edf files
    output_dir : str
        Output directory for saved results
    max_subjects : int or None
        Maximum number of subjects to process (for testing)
    """
    
    # Get all .edf files
    edf_files = [f for f in os.listdir(preprocessed_dir) if f.endswith('.edf')]
    
    # Sort by subject number
    edf_files.sort(key=lambda x: int(x.split('-')[1].split('_')[0]))
    
    if max_subjects:
        edf_files = edf_files[:max_subjects]
    
    print(f"Found {len(edf_files)} preprocessed .edf files")
    print(f"Output directory: {output_dir}")
    
    all_results = {}
    
    for edf_file in edf_files:
        # Extract subject ID
        subject_id = edf_file.split('-')[1].split('_')[0]  # Get '01' from 'sub-01_task-...'
        
        # Process this subject
        result = process_preprocessed_edf_file(subject_id, raw_data_dir, preprocessed_dir, output_dir)
        
        if result:
            all_results[subject_id] = result
    
    return all_results

# ================= VERIFICATION FUNCTION =================
def verify_saved_preprocessed_file(subject_id, output_dir):
    """
    Verify that the saved file contains all required keys
    """
    file_path = os.path.join(output_dir, f"sub-{subject_id}_preprocessed_motor_imagery.npz")
    
    if not os.path.exists(file_path):
        print(f"File not found: {file_path}")
        return False
    
    try:
        # Load the file
        data = np.load(file_path, allow_pickle=True)
        
        print(f"\nVerifying {os.path.basename(file_path)}:")
        print("-" * 50)
        
        # Check all required keys
        required_keys = [
            'segments_mi_all', 'segments_mi_eeg', 'labels', 'channels_all',
            'channels_eeg', 'channels_eog', 'channels_marker',
            'sampling_rate', 'subject_id', 'event_info', 'mi_info'
        ]
        
        all_present = True
        for key in required_keys:
            if key in data:
                value = data[key]
                if hasattr(value, 'shape'):
                    print(f"‚úì {key}: shape={value.shape}, dtype={value.dtype}")
                elif isinstance(value, dict):
                    print(f"‚úì {key}: dict")
                else:
                    print(f"‚úì {key}: {type(value)}")
            else:
                print(f"‚úó {key}: MISSING!")
                all_present = False
        
        # Print sample information
        if 'labels' in data:
            labels = data['labels']
            print(f"\nLabel distribution:")
            print(f"  Left hand (1): {np.sum(labels == 1)}")
            print(f"  Right hand (2): {np.sum(labels == 2)}")
        
        if 'segments_mi_eeg' in data:
            segments = data['segments_mi_eeg']
            print(f"\nMotor Imagery Segment information:")
            print(f"  Shape: {segments.shape}")
            print(f"  Trials: {segments.shape[0]}")
            print(f"  Channels: {segments.shape[1]}")
            print(f"  Samples per trial: {segments.shape[2]}")
            print(f"  Duration per trial: {segments.shape[2] / data['sampling_rate']:.1f}s")
        
        data.close()
        return all_present
        
    except Exception as e:
        print(f"Error verifying file: {e}")
        return False

# ================= MAIN EXECUTION =================
if __name__ == "__main__":
    print("=" * 60)
    print("PROCESSING PREPROCESSED .EDF FILES - MOTOR IMAGERY DATA ONLY")
    print("=" * 60)
    
    # Test with first 2 subjects
    print("\nTesting with first 2 subjects...")
    test_results = process_all_preprocessed_subjects(raw_data_dir, preprocessed_dir, output_dir, max_subjects=2)
    
    # Verify the saved files
    if test_results:
        print("\n" + "=" * 60)
        print("VERIFYING SAVED FILES")
        print("=" * 60)
        
        for subject_id in list(test_results.keys())[:2]:
            verify_saved_preprocessed_file(subject_id, output_dir)
    
    # Ask if user wants to process all subjects
    print("\n" + "=" * 60)
    response = input("Process ALL preprocessed subjects? (y/n): ")
    
    if response.lower() == 'y':
        print("\nProcessing all preprocessed subjects...")
        all_results = process_all_preprocessed_subjects(raw_data_dir, preprocessed_dir, output_dir)
        
        print("\n" + "=" * 60)
        print("PROCESSING COMPLETE!")
        print("=" * 60)
        print(f"Results saved in: {output_dir}")
        print(f"Files created:")
        
        # List created files
        created_files = os.listdir(output_dir)
        for file in sorted(created_files):
            if file.endswith('.npz'):
                file_path = os.path.join(output_dir, file)
                file_size = os.path.getsize(file_path) / 1024 / 1024
                print(f"  {file} ({file_size:.2f} MB)")
    
    else:
        print("Processing stopped. Only test files were created.")