In [None]:
"""
EEG Person Identification - Preprocessing Pipeline
PhysioNet Motor Movement/Imagery Dataset
Author: [Your Name]
Date: 2025

This notebook handles:
1. Loading EDF files from PhysioNet dataset
2. Bandpass filtering (8-30 Hz)
3. Epoch extraction and segmentation
4. Artifact removal
5. Spectrogram generation
6. Data normalization and saving
"""

#%% Import Required Libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import signal
from scipy.signal import butter, filtfilt, stft
import mne
from mne.io import read_raw_edf
import h5py
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
np.random.seed(42)

print("Libraries imported successfully!")
print(f"MNE version: {mne.__version__}")

#%% Configuration and Parameters
class Config:
    """Configuration parameters for preprocessing"""
    # Paths
    RAW_DATA_DIR = './data/raw/files/'  # Path to extracted PhysioNet data
    PROCESSED_DATA_DIR = './data/processed/'
    SPECTROGRAM_DIR = './data/spectrograms/'
    
    # Signal processing parameters
    SAMPLING_RATE = 160  # Hz (PhysioNet dataset sampling rate)
    LOWCUT = 8.0  # Hz (lower bound of mu and beta bands)
    HIGHCUT = 30.0  # Hz (upper bound of beta band)
    FILTER_ORDER = 5
    
    # Epoch parameters
    EPOCH_DURATION = 3.0  # seconds
    EPOCH_SAMPLES = int(EPOCH_DURATION * SAMPLING_RATE)  # 480 samples
    
    # Spectrogram parameters
    NPERSEG = 64  # Window length for STFT
    NOVERLAP = 32  # Overlap between windows
    NFFT = 128  # Number of FFT points
    
    # Dataset parameters
    N_SUBJECTS = 109
    N_CHANNELS = 64
    
    # Task runs to use (motor imagery tasks)
    # Runs 4, 8, 12: Imagery left vs right fist
    # Runs 6, 10, 14: Imagery both fists vs both feet
    TASK_RUNS = [4, 6, 8, 10, 12, 14]
    
    # Create directories if they don't exist
    os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
    os.makedirs(SPECTROGRAM_DIR, exist_ok=True)

config = Config()
print("\nConfiguration loaded:")
print(f"  Sampling Rate: {config.SAMPLING_RATE} Hz")
print(f"  Filter Band: {config.LOWCUT}-{config.HIGHCUT} Hz")
print(f"  Epoch Duration: {config.EPOCH_DURATION} seconds ({config.EPOCH_SAMPLES} samples)")

#%% Helper Functions

def bandpass_filter(data, lowcut, highcut, fs, order=5):
    """
    Apply Butterworth bandpass filter to EEG data
    
    Parameters:
    -----------
    data : ndarray, shape (n_channels, n_samples)
        Raw EEG data
    lowcut : float
        Lower frequency bound (Hz)
    highcut : float
        Upper frequency bound (Hz)
    fs : float
        Sampling frequency (Hz)
    order : int
        Filter order
    
    Returns:
    --------
    filtered_data : ndarray
        Bandpass filtered EEG data
    """
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    filtered_data = filtfilt(b, a, data, axis=1)
    return filtered_data

def extract_epochs(raw, events, epoch_duration, event_ids):
    """
    Extract epochs around task events
    
    Parameters:
    -----------
    raw : mne.io.Raw
        Raw EEG object
    events : ndarray
        Event array from MNE
    epoch_duration : float
        Duration of each epoch in seconds
    event_ids : dict
        Event ID mapping
    
    Returns:
    --------
    epochs_data : ndarray, shape (n_epochs, n_channels, n_samples)
        Extracted epochs
    labels : ndarray
        Task labels for each epoch
    """
    tmin, tmax = 0.0, epoch_duration
    epochs = mne.Epochs(raw, events, event_id=event_ids, tmin=tmin, tmax=tmax,
                        baseline=None, preload=True, verbose=False)
    
    epochs_data = epochs.get_data()  # Shape: (n_epochs, n_channels, n_samples)
    labels = epochs.events[:, -1]  # Event IDs
    
    return epochs_data, labels

def compute_spectrogram(epoch_data, fs, nperseg=64, noverlap=32, nfft=128):
    """
    Compute spectrogram for a single epoch
    
    Parameters:
    -----------
    epoch_data : ndarray, shape (n_channels, n_samples)
        Single epoch data
    fs : float
        Sampling frequency
    nperseg : int
        Length of each segment for STFT
    noverlap : int
        Number of overlapping samples
    nfft : int
        Number of FFT points
    
    Returns:
    --------
    spectrogram : ndarray, shape (n_channels, n_frequencies, n_time_bins)
        Time-frequency representation
    """
    n_channels = epoch_data.shape[0]
    
    # Compute STFT for first channel to get dimensions
    f, t, Zxx = stft(epoch_data[0], fs=fs, nperseg=nperseg, 
                     noverlap=noverlap, nfft=nfft)
    
    # Initialize spectrogram array
    spectrogram = np.zeros((n_channels, len(f), len(t)))
    
    # Compute STFT for all channels
    for ch in range(n_channels):
        f, t, Zxx = stft(epoch_data[ch], fs=fs, nperseg=nperseg,
                        noverlap=noverlap, nfft=nfft)
        spectrogram[ch] = np.abs(Zxx)
    
    return spectrogram

def normalize_data(data, method='zscore'):
    """
    Normalize data using z-score normalization
    
    Parameters:
    -----------
    data : ndarray
        Data to normalize
    method : str
        Normalization method ('zscore' or 'minmax')
    
    Returns:
    --------
    normalized_data : ndarray
        Normalized data
    """
    if method == 'zscore':
        mean = np.mean(data, axis=(0, 2), keepdims=True)
        std = np.std(data, axis=(0, 2), keepdims=True) + 1e-8
        normalized_data = (data - mean) / std
    elif method == 'minmax':
        min_val = np.min(data, axis=(0, 2), keepdims=True)
        max_val = np.max(data, axis=(0, 2), keepdims=True)
        normalized_data = (data - min_val) / (max_val - min_val + 1e-8)
    else:
        normalized_data = data
    
    return normalized_data

def remove_bad_epochs(epochs_data, threshold=5.0):
    """
    Remove epochs with extreme amplitudes (artifacts)
    
    Parameters:
    -----------
    epochs_data : ndarray, shape (n_epochs, n_channels, n_samples)
        Epoch data
    threshold : float
        Z-score threshold for artifact detection
    
    Returns:
    --------
    clean_indices : ndarray
        Indices of clean epochs
    """
    # Calculate peak-to-peak amplitude for each epoch
    ptp = np.ptp(epochs_data, axis=2)  # Shape: (n_epochs, n_channels)
    max_ptp = np.max(ptp, axis=1)  # Max across channels
    
    # Calculate z-scores
    z_scores = (max_ptp - np.mean(max_ptp)) / (np.std(max_ptp) + 1e-8)
    
    # Keep epochs below threshold
    clean_indices = np.where(np.abs(z_scores) < threshold)[0]
    
    return clean_indices

#%% Load and Process Single Subject

def process_subject(subject_id, config):
    """
    Process all runs for a single subject
    
    Parameters:
    -----------
    subject_id : int
        Subject ID (1-109)
    config : Config
        Configuration object
    
    Returns:
    --------
    subject_data : dict
        Dictionary containing processed data for the subject
    """
    print(f"\nProcessing Subject {subject_id:03d}...")
    
    all_epochs = []
    all_spectrograms = []
    all_tasks = []
    
    for run in config.TASK_RUNS:
        # Construct filename (format: S001R04.edf)
        filename = f"S{subject_id:03d}R{run:02d}.edf"
        filepath = os.path.join(config.RAW_DATA_DIR, filename)
        
        if not os.path.exists(filepath):
            print(f"  Warning: {filename} not found, skipping...")
            continue
        
        try:
            # Load EDF file
            raw = read_raw_edf(filepath, preload=True, verbose=False)
            
            # Get data and apply bandpass filter
            data = raw.get_data()  # Shape: (n_channels, n_samples)
            filtered_data = bandpass_filter(data, config.LOWCUT, config.HIGHCUT, 
                                          config.SAMPLING_RATE, config.FILTER_ORDER)
            
            # Create new raw object with filtered data
            info = raw.info
            raw_filtered = mne.io.RawArray(filtered_data, info, verbose=False)
            
            # Find events
            events = mne.find_events(raw_filtered, stim_channel='STI 014', 
                                    shortest_event=1, verbose=False)
            
            if len(events) == 0:
                print(f"  Warning: No events found in {filename}, skipping...")
                continue
            
            # Event IDs: T0=rest, T1=left, T2=right (or T1=hands, T2=feet)
            event_ids = {'T1': 1, 'T2': 2}
            
            # Extract epochs
            epochs_data, labels = extract_epochs(raw_filtered, events, 
                                                config.EPOCH_DURATION, event_ids)
            
            # Remove bad epochs
            clean_indices = remove_bad_epochs(epochs_data, threshold=5.0)
            epochs_data = epochs_data[clean_indices]
            labels = labels[clean_indices]
            
            print(f"  Run {run:02d}: {len(epochs_data)} clean epochs extracted")
            
            # Compute spectrograms for each epoch
            for epoch in epochs_data:
                spectrogram = compute_spectrogram(epoch, config.SAMPLING_RATE,
                                                 config.NPERSEG, config.NOVERLAP,
                                                 config.NFFT)
                all_spectrograms.append(spectrogram)
            
            all_epochs.append(epochs_data)
            all_tasks.extend(labels)
            
        except Exception as e:
            print(f"  Error processing {filename}: {str(e)}")
            continue
    
    if len(all_epochs) == 0:
        return None
    
    # Concatenate all epochs
    all_epochs = np.concatenate(all_epochs, axis=0)
    all_spectrograms = np.array(all_spectrograms)
    all_tasks = np.array(all_tasks)
    
    # Normalize data
    all_epochs = normalize_data(all_epochs, method='zscore')
    all_spectrograms = normalize_data(all_spectrograms, method='zscore')
    
    print(f"  Total: {len(all_epochs)} epochs, {len(all_spectrograms)} spectrograms")
    print(f"  Epoch shape: {all_epochs.shape}")
    print(f"  Spectrogram shape: {all_spectrograms.shape}")
    
    return {
        'subject_id': subject_id,
        'epochs': all_epochs,
        'spectrograms': all_spectrograms,
        'tasks': all_tasks,
        'n_epochs': len(all_epochs)
    }

#%% Process All Subjects

def process_all_subjects(config, start_subject=1, end_subject=109):
    """
    Process all subjects in the dataset
    
    Parameters:
    -----------
    config : Config
        Configuration object
    start_subject : int
        First subject ID to process
    end_subject : int
        Last subject ID to process
    
    Returns:
    --------
    dataset : dict
        Complete processed dataset
    """
    print("\n" + "="*60)
    print("PROCESSING ALL SUBJECTS")
    print("="*60)
    
    all_data = []
    failed_subjects = []
    
    for subject_id in tqdm(range(start_subject, end_subject + 1), 
                          desc="Processing subjects"):
        subject_data = process_subject(subject_id, config)
        
        if subject_data is not None:
            all_data.append(subject_data)
        else:
            failed_subjects.append(subject_id)
    
    print(f"\n{'='*60}")
    print(f"Processing Complete!")
    print(f"  Successful: {len(all_data)} subjects")
    print(f"  Failed: {len(failed_subjects)} subjects")
    if failed_subjects:
        print(f"  Failed IDs: {failed_subjects}")
    
    return all_data, failed_subjects

#%% Save Processed Data

def save_processed_data(all_data, config):
    """
    Save processed data to HDF5 file for efficient loading
    
    Parameters:
    -----------
    all_data : list
        List of subject data dictionaries
    config : Config
        Configuration object
    """
    print("\n" + "="*60)
    print("SAVING PROCESSED DATA")
    print("="*60)
    
    # Prepare data for saving
    X_epochs = []
    X_spectrograms = []
    y_subjects = []
    y_tasks = []
    
    for subject_data in all_data:
        n_epochs = subject_data['n_epochs']
        subject_id = subject_data['subject_id']
        
        X_epochs.append(subject_data['epochs'])
        X_spectrograms.append(subject_data['spectrograms'])
        y_subjects.extend([subject_id] * n_epochs)
        y_tasks.extend(subject_data['tasks'])
    
    # Concatenate all data
    X_epochs = np.concatenate(X_epochs, axis=0)
    X_spectrograms = np.concatenate(X_spectrograms, axis=0)
    y_subjects = np.array(y_subjects)
    y_tasks = np.array(y_tasks)
    
    print(f"\nDataset Summary:")
    print(f"  Total epochs: {len(X_epochs)}")
    print(f"  Epochs shape: {X_epochs.shape}")
    print(f"  Spectrograms shape: {X_spectrograms.shape}")
    print(f"  Unique subjects: {len(np.unique(y_subjects))}")
    
    # Save to HDF5
    output_file = os.path.join(config.PROCESSED_DATA_DIR, 'eeg_processed_data.h5')
    
    with h5py.File(output_file, 'w') as hf:
        hf.create_dataset('X_epochs', data=X_epochs, compression='gzip')
        hf.create_dataset('X_spectrograms', data=X_spectrograms, compression='gzip')
        hf.create_dataset('y_subjects', data=y_subjects, compression='gzip')
        hf.create_dataset('y_tasks', data=y_tasks, compression='gzip')
        
        # Save metadata
        hf.attrs['n_subjects'] = len(np.unique(y_subjects))
        hf.attrs['n_epochs'] = len(X_epochs)
        hf.attrs['n_channels'] = config.N_CHANNELS
        hf.attrs['sampling_rate'] = config.SAMPLING_RATE
        hf.attrs['epoch_duration'] = config.EPOCH_DURATION
        hf.attrs['lowcut'] = config.LOWCUT
        hf.attrs['highcut'] = config.HIGHCUT
    
    print(f"\n✓ Data saved to: {output_file}")
    print(f"  File size: {os.path.getsize(output_file) / (1024**3):.2f} GB")
    
    return X_epochs, X_spectrograms, y_subjects, y_tasks

#%% Visualization Functions

def visualize_sample_data(X_epochs, X_spectrograms, y_subjects, config):
    """
    Create visualizations of sample processed data
    """
    print("\n" + "="*60)
    print("CREATING VISUALIZATIONS")
    print("="*60)
    
    # Select a random sample
    sample_idx = np.random.randint(0, len(X_epochs))
    sample_epoch = X_epochs[sample_idx]
    sample_spec = X_spectrograms[sample_idx]
    sample_subject = y_subjects[sample_idx]
    
    fig = plt.figure(figsize=(16, 10))
    
    # Plot 1: Raw EEG signals (first 8 channels)
    ax1 = plt.subplot(3, 2, 1)
    time = np.arange(config.EPOCH_SAMPLES) / config.SAMPLING_RATE
    for ch in range(8):
        plt.plot(time, sample_epoch[ch] + ch*2, label=f'Ch{ch+1}', alpha=0.7)
    plt.xlabel('Time (s)')
    plt.ylabel('Channel')
    plt.title(f'Sample EEG Epoch (Subject {sample_subject})')
    plt.grid(True, alpha=0.3)
    
    # Plot 2: Spectrogram for one channel
    ax2 = plt.subplot(3, 2, 2)
    plt.imshow(sample_spec[0], aspect='auto', origin='lower', cmap='viridis')
    plt.colorbar(label='Power')
    plt.xlabel('Time bins')
    plt.ylabel('Frequency bins')
    plt.title('Sample Spectrogram (Channel 1)')
    
    # Plot 3: Distribution of subjects
    ax3 = plt.subplot(3, 2, 3)
    subject_counts = pd.Series(y_subjects).value_counts().sort_index()
    plt.bar(subject_counts.index, subject_counts.values, alpha=0.7, color='steelblue')
    plt.xlabel('Subject ID')
    plt.ylabel('Number of Epochs')
    plt.title('Epochs per Subject')
    plt.grid(True, alpha=0.3)
    
    # Plot 4: Frequency spectrum
    ax4 = plt.subplot(3, 2, 4)
    freqs = np.fft.rfftfreq(config.EPOCH_SAMPLES, 1/config.SAMPLING_RATE)
    spectrum = np.abs(np.fft.rfft(sample_epoch[0]))
    plt.plot(freqs, spectrum, color='coral', linewidth=2)
    plt.xlim(0, 50)
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Power')
    plt.title('Frequency Spectrum (Channel 1)')
    plt.grid(True, alpha=0.3)
    plt.axvline(config.LOWCUT, color='red', linestyle='--', alpha=0.5, label='Filter cutoff')
    plt.axvline(config.HIGHCUT, color='red', linestyle='--', alpha=0.5)
    plt.legend()
    
    # Plot 5: Average spectrogram across channels
    ax5 = plt.subplot(3, 2, 5)
    avg_spec = np.mean(sample_spec, axis=0)
    plt.imshow(avg_spec, aspect='auto', origin='lower', cmap='plasma')
    plt.colorbar(label='Averaged Power')
    plt.xlabel('Time bins')
    plt.ylabel('Frequency bins')
    plt.title('Average Spectrogram (All Channels)')
    
    # Plot 6: Channel correlation
    ax6 = plt.subplot(3, 2, 6)
    corr_matrix = np.corrcoef(sample_epoch)
    plt.imshow(corr_matrix, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto')
    plt.colorbar(label='Correlation')
    plt.xlabel('Channel')
    plt.ylabel('Channel')
    plt.title('Channel Correlation Matrix')
    
    plt.tight_layout()
    plt.savefig(os.path.join(config.PROCESSED_DATA_DIR, 'preprocessing_visualization.png'), 
                dpi=300, bbox_inches='tight')
    print("\n✓ Visualization saved!")
    plt.show()

#%% Main Execution

if __name__ == "__main__":
    print("\n" + "="*60)
    print("EEG PERSON IDENTIFICATION - PREPROCESSING PIPELINE")
    print("="*60)
    
    # Check if raw data directory exists
    if not os.path.exists(config.RAW_DATA_DIR):
        print(f"\n❌ Error: Raw data directory not found: {config.RAW_DATA_DIR}")
        print("Please download the PhysioNet dataset and extract it to this location.")
        print("Download from: https://physionet.org/content/eegmmidb/get-zip/1.0.0/")
    else:
        print(f"\n✓ Raw data directory found: {config.RAW_DATA_DIR}")
        
        # Process all subjects
        all_data, failed_subjects = process_all_subjects(config)
        
        if len(all_data) > 0:
            # Save processed data
            X_epochs, X_spectrograms, y_subjects, y_tasks = save_processed_data(all_data, config)
            
            # Create visualizations
            visualize_sample_data(X_epochs, X_spectrograms, y_subjects, config)
            
            print("\n" + "="*60)
            print("PREPROCESSING COMPLETE!")
            print("="*60)
            print("\nNext step: Run 03_model_training.ipynb")
        else:
            print("\n❌ Error: No subjects were successfully processed!")
            print("Please check the raw data directory and file formats.")

#%% Summary Statistics

"""
PREPROCESSING SUMMARY
=====================

This notebook preprocessed the PhysioNet EEG Motor Movement/Imagery Dataset:

1. Loaded 64-channel EEG recordings from 109 subjects
2. Applied 8-30 Hz bandpass filter (mu and beta bands)
3. Extracted 3-second epochs aligned to motor imagery tasks
4. Removed artifacts using amplitude thresholding
5. Generated time-frequency spectrograms using STFT
6. Normalized data using z-score normalization
7. Saved processed data in HDF5 format

Output Files:
- eeg_processed_data.h5: Contains all preprocessed data
- preprocessing_visualization.png: Sample visualizations

The processed data is now ready for model training!
"""