In [None]:
# ================================
# FINAL WORKING PREPROCESSING SCRIPT
# ================================

import os
import numpy as np
import mne
import pandas as pd
from scipy import stats
import matplotlib.pyplot as plt
import json
from datetime import datetime
import re

def extract_image_number(filename):
    """Extract image number from filename"""
    if isinstance(filename, str):
        match = re.search(r'(\d+)\.bmp', filename)
        if match:
            return int(match.group(1))
    return None

def create_stimulus_mapping():
    """Create mapping for all 64 images with 4 dimensions based on actual stimulus descriptions"""
    mapping = {}
    
    # Groups of 8 images each follow this pattern:
    # First 4: neutral, Last 4: emotional
    # 8 groups total, ordered by: race-gender-age
    
    groups = [
        # (start, end, race, gender, age)
        (1, 8, 'asian', 'male', 'young'),
        (9, 16, 'asian', 'male', 'elderly'),
        (17, 24, 'asian', 'female', 'young'),
        (25, 32, 'asian', 'female', 'elderly'),
        (33, 40, 'caucasian', 'male', 'young'),
        (41, 48, 'caucasian', 'male', 'elderly'),
        (49, 56, 'caucasian', 'female', 'young'),
        (57, 64, 'caucasian', 'female', 'elderly')
    ]
    
    for group_start, group_end, race, gender, age in groups:
        for img_num in range(group_start, group_end + 1):
            # Determine expression: first half of group is neutral, second half is emotional
            group_position = img_num - group_start + 1  # 1 to 8
            expression = 'neutral' if group_position <= 4 else 'emotional'
            
            mapping[img_num] = {
                'race': race,
                'gender': gender,
                'age': age,
                'expression': expression
            }
    
    return mapping

def process_subject(subject, session='01', run='01'):
    """
    Main preprocessing function for one subject
    """
    print(f"\n{'='*80}")
    print(f"PROCESSING SUBJECT {subject}")
    print(f"{'='*80}")
    
    # =========================================================================
    # 1. LOAD DATA
    # =========================================================================
    print(f"\n1. Loading data...")
    
    # Load MEG data
    meg_path = f"../data/ds005107/sub-{subject}/ses-{session}/meg/sub-{subject}_ses-{session}_task-face_run-{run}_meg.fif"
    if not os.path.exists(meg_path):
        print(f"‚ùå MEG file not found: {meg_path}")
        return None
    
    raw = mne.io.read_raw_fif(meg_path, preload=True, verbose=False)
    print(f"‚úì Loaded MEG data: {len(raw.ch_names)} channels, {raw.info['sfreq']} Hz")
    
    # Load behavioral data
    beh_path = f"../data/ds005107/sub-{subject}/ses-{session}/beh/sub-{subject}_ses-{session}_task-face_run-{run}_events.tsv"
    beh_data = pd.read_csv(beh_path, sep='\t')
    print(f"‚úì Loaded behavioral data: {len(beh_data)} trials")
    
    # =========================================================================
    # 2. EXTRACT EVENTS
    # =========================================================================
    print(f"\n2. Extracting events...")
    
    events = mne.find_events(raw, stim_channel='STIM', shortest_event=1, verbose=False)
    print(f"‚úì Found {len(events)} events in STIM channel")
    
    # Filter for face presentation events (codes 1 and 55)
    face_events = [ev for ev in events if ev[2] in [1, 55]]
    
    # =========================================================================
    # 3. CREATE EVENT LIST
    # =========================================================================
    print(f"\n3. Creating event list...")
    
    stim_mapping = create_stimulus_mapping()
    mne_events = []
    event_info = []
    
    # Match behavioral trials to MEG events
    for i in range(min(len(face_events), len(beh_data))):
        sample, _, code = face_events[i]
        row = beh_data.iloc[i]
        
        img_num = extract_image_number(row['stim_file'])
        if img_num and img_num in stim_mapping:
            dims = stim_mapping[img_num]
            is_catch = (code == 55)
            
            # Event code: 1 for regular, 2 for catch
            event_code = 2 if is_catch else 1
            
            mne_events.append([sample, 0, event_code])
            event_info.append({
                'sample': sample,
                'event_code': event_code,
                'image_num': img_num,
                'is_catch': is_catch,
                'race': dims['race'],
                'gender': dims['gender'],
                'age': dims['age'],
                'expression': dims['expression']
            })
    
    mne_events = np.array(mne_events)
    print(f"‚úì Created {len(mne_events)} MNE events")
    
    # =========================================================================
    # 4. PREPROCESS MEG DATA
    # =========================================================================
    print(f"\n4. Preprocessing MEG data...")
    
    # Keep only magnetometers
    raw_mag = raw.copy().pick(picks='mag')
    
    # Apply filters (as in paper)
    # Notch filters
    notch_freqs = [44, 50, 100, 150, 200, 250]
    for freq in notch_freqs:
        raw_mag.notch_filter(freq, method='fir', phase='zero', verbose=False)
    
    # Bandpass filter
    raw_mag.filter(1, 100, method='fir', phase='zero', verbose=False)
    
    # Detect bad channels
    data = raw_mag.get_data()
    variances = np.var(data, axis=1)
    z_scores = np.abs(stats.zscore(variances))
    bad_idx = np.where(z_scores > 3)[0]
    bad_channels = [raw_mag.ch_names[i] for i in bad_idx]
    
    if bad_channels:
        raw_mag.info['bads'] = bad_channels
        print(f"‚úì Marked bad channels: {bad_channels}")
    
    # =========================================================================
    # 5. CREATE EPOCHS
    # =========================================================================
    print(f"\n5. Creating epochs...")
    
    event_dict = {'regular': 1, 'catch': 2}
    
    epochs = mne.Epochs(
        raw_mag,
        mne_events,
        event_id=event_dict,
        tmin=-0.2,
        tmax=0.8,
        baseline=(-0.2, 0),
        preload=True,
        reject=None,
        verbose=False
    )
    
    regular_epochs = epochs['regular']
    catch_epochs = epochs['catch']
    
    print(f"‚úì Created {len(epochs)} total epochs")
    print(f"‚úì Regular trials: {len(regular_epochs)}")
    print(f"‚úì Catch trials: {len(catch_epochs)}")
    
    # =========================================================================
    # 6. ORGANIZE BY DIMENSIONS
    # =========================================================================
    print(f"\n6. Organizing by face dimensions...")
    
    dim_groups = {}
    for dim_name in ['race', 'gender', 'age', 'expression']:
        dim_groups[dim_name] = {}
        
        # Get unique values for this dimension
        unique_vals = set([info[dim_name] for info in event_info if not info['is_catch']])
        
        for val in unique_vals:
            # Find event indices for this value
            event_indices = []
            for i, info in enumerate(event_info):
                if not info['is_catch'] and info[dim_name] == val:
                    # Find corresponding epoch
                    for j, ep_event in enumerate(regular_epochs.events):
                        if ep_event[0] == info['sample']:
                            event_indices.append(j)
                            break
            
            if event_indices:
                dim_epochs = regular_epochs[event_indices]
                dim_groups[dim_name][val] = dim_epochs
                print(f"  {dim_name}={val}: {len(dim_epochs)} trials")
    
    # =========================================================================
    # 7. SAVE DATA
    # =========================================================================
    print(f"\n7. Saving data...")
    
    output_dir = f"preprocessed/sub-{subject}"
    os.makedirs(output_dir, exist_ok=True)
    
    base_name = f"sub-{subject}_ses-{session}_run-{run}"
    
    # Save with correct MNE naming conventions
    raw_mag.save(f"{output_dir}/{base_name}_raw.fif", overwrite=True)
    epochs.save(f"{output_dir}/{base_name}_all-epo.fif", overwrite=True)
    regular_epochs.save(f"{output_dir}/{base_name}_regular-epo.fif", overwrite=True)
    
    # Save dimension groups
    for dim_name, groups in dim_groups.items():
        dim_dir = f"{output_dir}/dimensions/{dim_name}"
        os.makedirs(dim_dir, exist_ok=True)
        
        for val, ep in groups.items():
            ep.save(f"{dim_dir}/{base_name}_{dim_name}_{val}-epo.fif", overwrite=True)
    
    # Save metadata
    metadata = {
        'subject': subject,
        'session': session,
        'run': run,
        'total_epochs': len(epochs),
        'regular_epochs': len(regular_epochs),
        'catch_epochs': len(catch_epochs),
        'bad_channels': bad_channels,
        'sampling_rate': float(raw.info['sfreq']),  # Convert to Python float
        'dimension_counts': {
            dim_name: {val: len(ep) for val, ep in groups.items()}
            for dim_name, groups in dim_groups.items()
        }
    }
    
    # Convert numpy types to Python types for JSON serialization
    def convert_types(obj):
        if isinstance(obj, (np.integer, np.int64, np.int32)):
            return int(obj)
        elif isinstance(obj, (np.floating, np.float64, np.float32)):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, dict):
            return {k: convert_types(v) for k, v in obj.items()}
        elif isinstance(obj, (list, tuple)):
            return [convert_types(item) for item in obj]
        else:
            return obj
    
    metadata = convert_types(metadata)
    
    with open(f"{output_dir}/{base_name}_metadata.json", 'w') as f:
        json.dump(metadata, f, indent=2)
    
    print(f"‚úì Saved all data to {output_dir}")
    
    # =========================================================================
    # 8. CREATE SUMMARY
    # =========================================================================
    print(f"\n8. Creating summary...")
    
    print(f"\nüìä PREPROCESSING COMPLETE FOR SUBJECT {subject}")
    print(f"   Output directory: {output_dir}")
    print(f"   Regular trials: {len(regular_epochs)}")
    print(f"   Dimension breakdown:")
    
    for dim_name, groups in dim_groups.items():
        print(f"   {dim_name.upper()}:")
        for val, ep in groups.items():
            print(f"     {val}: {len(ep)} trials")
    
    return {
        'subject': subject,
        'raw': raw_mag,
        'epochs': epochs,
        'regular_epochs': regular_epochs,
        'dim_groups': dim_groups,
        'output_dir': output_dir
    }

def main():
    """Main function to run preprocessing"""
    print(f"{'='*80}")
    print(f"OPM-MEG FACE PERCEPTION - PREPROCESSING")
    print(f"{'='*80}")
    
    # Test with subject 01
    print(f"\nTesting with subject 01...")
    results = process_subject('01')
    
    if results:
        print(f"\n‚úÖ Successfully processed subject 01!")
        
        # Ask about batch processing
        response = input("\nProcess all subjects? (y/n): ")
        
        if response.lower() == 'y':
            # Process all subjects from the dataset
            subjects = ['01', '02', '03', '04', '06', '07', '08', '09', '10',
                       '11', '13', '14', '15', '16', '17', '18', '19', '20',
                       '21', '22', '23']
            
            all_results = {}
            for subject in subjects:
                print(f"\n{'='*80}")
                print(f"Processing subject {subject}...")
                try:
                    result = process_subject(subject)
                    if result:
                        all_results[subject] = result
                        print(f"‚úì Done with subject {subject}")
                except Exception as e:
                    print(f"‚ùå Error with subject {subject}: {e}")
            
            print(f"\n{'='*80}")
            print(f"BATCH PROCESSING COMPLETE")
            print(f"{'='*80}")
            print(f"Processed {len(all_results)} subjects successfully")
            print(f"Data saved in: preprocessed/")
            
        else:
            print(f"\n‚è∏Ô∏è  Only processed subject 01")
            
    else:
        print(f"\n‚ùå Failed to process subject 01")

if __name__ == "__main__":
    main()