# MNE Data Analysis Pipeline

This notebook contains tools for:
- Concatenating multiple FIF files from multiple sessions
- Loading and preprocessing EEG/SEEG data
- Filtering and artifact removal (ICA)
- Event extraction and analysis
- Reaction time analysis
- Visualization

## 1. Imports

In [None]:
import h5py
import numpy as np
import mne
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import os
from scipy import stats as scipy_stats

# Set plotting style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 8)

# For interactive plotting
matplotlib.use("Qt5Agg")

## 2. Utility Functions

### 2.1 File Concatenation Functions

In [None]:
def read_and_concatenate_fif(directory_path, preload=True, 
                              use_common_channels=True, verbose=False):
    """
    Read all FIF files from a directory and concatenate them using COMMON channels.
    
    Parameters
    ----------
    directory_path : str
        Path to directory containing FIF files
    preload : bool
        Whether to load data into memory (default: True)
    use_common_channels : bool
        If True, only keep channels common to all files (default: True)
    verbose : bool
        Whether to print verbose output (default: False)
    
    Returns
    -------
    raw_concat : mne.io.Raw
        Concatenated Raw object
    """
    raw_list = []
    filenames = []
    
    # Load all files
    for filename in sorted(os.listdir(directory_path)):
        if filename.endswith('.fif'):
            full_path = os.path.join(directory_path, filename)
            raw = mne.io.read_raw_fif(full_path, preload=False, verbose=verbose)
            raw_list.append(raw)
            filenames.append(filename)
    
    if len(raw_list) == 0:
        raise ValueError(f"No FIF files found in {directory_path}")
    
    print(f"Found {len(raw_list)} FIF files")
    
    # Find common channels if requested
    if use_common_channels and len(raw_list) > 1:
        common_channels = set(raw_list[0].ch_names)
        for raw in raw_list[1:]:
            common_channels &= set(raw.ch_names)
        
        common_channels = sorted(list(common_channels))
        print(f"Using {len(common_channels)} common channels")
        
        # Pick common channels from all files
        for i, raw in enumerate(raw_list):
            raw_list[i] = raw.copy().pick_channels(common_channels, ordered=True)
    
    # Concatenate
    print("Concatenating files...")
    raw_concat = mne.concatenate_raws(raw_list, preload=preload, verbose=verbose)
    
    print(f"Concatenation complete: {raw_concat.times[-1]:.2f} seconds total")
    
    return raw_concat

In [None]:
def read_and_concatenate_fif_all_channels(directory_path, 
                                          preload=True, 
                                          fill_value=0.0,
                                          verbose=False):
    """
    Read all FIF files from a directory and concatenate them using ALL channels.
    Missing channels in individual files are filled with specified value.
    
    Parameters
    ----------
    directory_path : str
        Path to directory containing FIF files
    preload : bool
        Whether to load data into memory (default: True)
    fill_value : float
        Value to use for missing channels (default: 0.0)
    verbose : bool
        Whether to print verbose output (default: False)
    
    Returns
    -------
    raw_concat : mne.io.Raw
        Concatenated Raw object with all channels
    channel_info : dict
        Information about which channels were present in which files
    """
    
    if fill_value is None:
        fill_value = 0.0
    
    print(f"Loading FIF files from: {directory_path}")
    print(f"Missing channels will be filled with: {fill_value}")
    print("=" * 80)
    
    # Step 1: Load all files and collect channel information
    raw_list = []
    filenames = []
    all_channels_sets = []
    
    for filename in sorted(os.listdir(directory_path)):
        if filename.endswith('.fif'):
            full_path = os.path.join(directory_path, filename)
            raw = mne.io.read_raw_fif(full_path, preload=False, verbose=verbose)
            raw_list.append(raw)
            filenames.append(filename)
            all_channels_sets.append(set(raw.ch_names))
            
            print(f"Loaded: {filename}")
            print(f"  Channels: {len(raw.ch_names)}")
    
    if len(raw_list) == 0:
        raise ValueError(f"No FIF files found in {directory_path}")
    
    print(f"\nTotal files loaded: {len(raw_list)}")
    
    # Step 2: Find union of all channels
    all_channels_union = set()
    for ch_set in all_channels_sets:
        all_channels_union |= ch_set
    
    all_channels_sorted = sorted(list(all_channels_union))
    
    print(f"\n" + "=" * 80)
    print(f"Total unique channels across all files: {len(all_channels_sorted)}")
    
    # Step 3: Analyze channel presence across files
    channel_presence = {ch: [] for ch in all_channels_sorted}
    for i, ch_set in enumerate(all_channels_sets):
        for ch in all_channels_sorted:
            channel_presence[ch].append(ch in ch_set)
    
    # Report channels that are missing from some files
    missing_in_some = {ch: sum([not present for present in presences]) 
                       for ch, presences in channel_presence.items()}
    channels_missing_somewhere = {ch: count for ch, count in missing_in_some.items() if count > 0}
    
    if channels_missing_somewhere:
        print(f"\nChannels missing from some files:")
        for ch, count in sorted(channels_missing_somewhere.items(), key=lambda x: x[1], reverse=True):
            print(f"  {ch}: missing from {count}/{len(raw_list)} files")
    else:
        print(f"\nAll channels present in all files!")
    
    # Step 4: Process each file to add missing channels
    print(f"\n" + "=" * 80)
    print("Processing files and adding missing channels...")
    print("=" * 80)
    
    processed_raws = []
    
    for i, (raw, filename) in enumerate(zip(raw_list, filenames)):
        print(f"\nProcessing: {filename}")
        
        # Find missing channels for this file
        missing_channels = [ch for ch in all_channels_sorted if ch not in raw.ch_names]
        
        if missing_channels:
            print(f"  Adding {len(missing_channels)} missing channels filled with {fill_value}")
            
            # Load data if not already loaded
            if not raw.preload:
                raw.load_data()
            
            # Create info for the missing channels
            info_missing = mne.create_info(
                ch_names=missing_channels,
                sfreq=raw.info['sfreq'],
                ch_types='eeg'
            )
            
            # Try to infer channel type from name patterns
            for ch_name in missing_channels:
                ch_idx = info_missing.ch_names.index(ch_name)
                
                if 'STI' in ch_name or 'TRIG' in ch_name or ch_name.startswith('STI'):
                    info_missing['chs'][ch_idx]['kind'] = mne.io.constants.FIFF.FIFFV_STIM_CH
                elif any(seeg_prefix in ch_name for seeg_prefix in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']):
                    info_missing['chs'][ch_idx]['kind'] = mne.io.constants.FIFF.FIFFV_SEEG_CH
            
            # Create data for missing channels
            n_samples = raw.n_times
            missing_data = np.full((len(missing_channels), n_samples), fill_value, dtype=np.float64)
            
            # Create Raw object for missing channels
            raw_missing = mne.io.RawArray(missing_data, info_missing)
            
            # Combine with original raw
            raw = raw.add_channels([raw_missing], force_update_info=True)
            
            # Reorder channels to match all_channels_sorted
            raw = raw.reorder_channels(all_channels_sorted)
        else:
            print("  No missing channels")
            if not raw.preload:
                raw.load_data()
        
        processed_raws.append(raw)
    
    # Step 5: Concatenate all processed files
    print(f"\n" + "=" * 80)
    print("Concatenating all files...")
    print("=" * 80)
    
    raw_concat = mne.concatenate_raws(processed_raws, preload=preload, verbose=verbose)
    
    print(f"\nConcatenation complete!")
    print(f"  Total channels: {len(raw_concat.ch_names)}")
    print(f"  Total duration: {raw_concat.times[-1]:.2f} seconds")
    print(f"  Sampling frequency: {raw_concat.info['sfreq']} Hz")
    
    return raw_concat, channel_presence

### 2.2 Reaction Time Analysis Function

In [None]:
def extract_reaction_times_from_fif(fif_file_path, probe_event_id=None, response_event_id=None):
    """
    Extract reaction times from MNE FIF format file.
    RT = Response time - Probe onset time
    
    Parameters
    ----------
    fif_file_path : str or Path
        Path to the FIF format file
    probe_event_id : int or str or list, optional
        Event ID(s) for probe/retrieval onset
    response_event_id : int or str or list, optional
        Event ID(s) for response/button press
        
    Returns
    -------
    dict : Dictionary containing RT data and metadata
    """
    data = {
        'reaction_times': [],
        'trial_numbers': [],
        'probe_times': [],
        'response_times': [],
        'probe_event_ids': [],
        'response_event_ids': []
    }
    
    print(f"Loading FIF file: {fif_file_path}")
    
    # Load the raw file
    file_path = str(fif_file_path)
    
    try:
        raw = mne.io.read_raw_fif(file_path, preload=False, verbose=False)
        events = mne.find_events(raw, verbose=False)
        sfreq = raw.info['sfreq']
        
        # Try to get event_id from annotations
        if raw.annotations:
            event_id = {}
            for desc in set(raw.annotations.description):
                if desc.isdigit():
                    event_id[desc] = int(desc)
                else:
                    unique_events = np.unique(events[:, 2])
                    if len(event_id) < len(unique_events):
                        event_id[desc] = unique_events[len(event_id)]
        else:
            unique_event_codes = np.unique(events[:, 2])
            event_id = {f"Event_{code}": code for code in unique_event_codes}
        
        print(f"Found {len(events)} events")
        print(f"Sampling frequency: {sfreq} Hz")
        print(f"Event IDs: {event_id}")
        
        # Determine probe and response events
        if probe_event_id is None:
            probe_keys = [k for k in event_id.keys() if any(
                word in str(k).lower() for word in ['probe', 'retrieval', 'test', 'cue']
            )]
            if probe_keys:
                probe_event_id = [event_id[k] for k in probe_keys]
                print(f"Auto-detected probe events: {probe_keys} -> {probe_event_id}")
            else:
                print("WARNING: Could not auto-detect probe events.")
                return data
        elif isinstance(probe_event_id, str):
            probe_event_id = [event_id[probe_event_id]]
        elif isinstance(probe_event_id, int):
            probe_event_id = [probe_event_id]
        
        if response_event_id is None:
            response_keys = [k for k in event_id.keys() if any(
                word in str(k).lower() for word in ['response', 'button', 'answer', 'resp']
            )]
            if response_keys:
                response_event_id = [event_id[k] for k in response_keys]
                print(f"Auto-detected response events: {response_keys} -> {response_event_id}")
            else:
                print("WARNING: Could not auto-detect response events.")
                return data
        elif isinstance(response_event_id, str):
            response_event_id = [event_id[response_event_id]]
        elif isinstance(response_event_id, int):
            response_event_id = [response_event_id]
        
        # Extract probe and response times
        probe_mask = np.isin(events[:, 2], probe_event_id)
        response_mask = np.isin(events[:, 2], response_event_id)
        
        probe_events = events[probe_mask]
        response_events = events[response_mask]
        
        print(f"Found {len(probe_events)} probe events")
        print(f"Found {len(response_events)} response events")
        
        # Convert sample indices to time (in seconds)
        probe_times = probe_events[:, 0] / sfreq
        response_times = response_events[:, 0] / sfreq
        
        # Match each probe to the next response
        for i, (probe_sample, probe_time, probe_code) in enumerate(zip(
            probe_events[:, 0], probe_times, probe_events[:, 2]
        )):
            # Find the next response after this probe
            future_response_mask = response_events[:, 0] > probe_sample
            
            if np.any(future_response_mask):
                next_response_idx = np.where(future_response_mask)[0][0]
                response_time = response_times[next_response_idx]
                response_code = response_events[next_response_idx, 2]
                
                rt = response_time - probe_time
                
                data['reaction_times'].append(rt)
                data['trial_numbers'].append(i + 1)
                data['probe_times'].append(probe_time)
                data['response_times'].append(response_time)
                data['probe_event_ids'].append(probe_code)
                data['response_event_ids'].append(response_code)
        
        print(f"\nExtracted {len(data['reaction_times'])} reaction times")
        
    except Exception as e:
        print(f"Error loading file: {e}")
        return data
    
    return data

## 3. Data Loading and Preprocessing

### 3.1 Load Subject Data

In [None]:
# Load data from subject 6
subject_6, s6_channel_info = read_and_concatenate_fif_all_channels("Data_converted/Subject_06")

# Extract events
events_6 = mne.find_events(
    subject_6, 
    stim_channel='STI', 
    consecutive=True,
    shortest_event=0,
    min_duration=0,
    initial_event=True
)

# Define event IDs
event_id = {
    'fixation': 1,
    'encoding': 2,
    'maintenance': 3,
    'retrieval': 4,
    'response': 5
}

print(f"\nLoaded Subject 6: {len(subject_6.ch_names)} channels, {subject_6.times[-1]:.1f}s duration")
print(f"Found {len(events_6)} events")

In [None]:
# Load data from subject 8
subject_8, s8_channel_info = read_and_concatenate_fif_all_channels("Data_converted/Subject_08")

# Extract events
events_8 = mne.find_events(
    subject_8, 
    stim_channel='STI', 
    consecutive=True,
    shortest_event=0,
    min_duration=0,
    initial_event=True
)

print(f"\nLoaded Subject 8: {len(subject_8.ch_names)} channels, {subject_8.times[-1]:.1f}s duration")
print(f"Found {len(events_8)} events")

# Plot raw data
subject_8.plot()

### 3.2 Filtering

In [None]:
# Filter data for analysis
# 0.1-40 Hz for general analysis
seeg_raw_filtered_8 = subject_8.copy().filter(l_freq=0.1, h_freq=40)

# 1-40 Hz for ICA (higher low-freq cutoff removes slow drifts)
seeg_raw_filtered_ica_8 = subject_8.copy().filter(l_freq=1, h_freq=40)

# Plot sensor positions
seeg_raw_filtered_ica_8.plot_sensors(kind="3d")

### 3.3 ICA for Artifact Removal

In [None]:
# Fit ICA
n_components = 18  # Adjust based on number of channels
method = 'fastica'
random_state = 42

ica = mne.preprocessing.ICA(
    n_components=n_components,
    method=method,
    random_state=random_state
)

# Fit ICA on filtered data
ica.fit(seeg_raw_filtered_ica_8)

print(f"\nICA fitted with {ica.n_components_} components")

In [None]:
# Plot ICA sources for inspection
ica.plot_sources(seeg_raw_filtered_ica_8, show_scrollbars=True)

In [None]:
# Plot component properties: topography, time series, PSD, and variance
ica.plot_properties(seeg_raw_filtered_ica_8, picks=range(ica.n_components_))

In [None]:
# Exclude bad components (manual selection after inspection)
# Example: ica.exclude = [0, 1, 5]  # Component indices to exclude

# Apply ICA to remove artifacts
# seeg_raw_cleaned = ica.apply(seeg_raw_filtered_8.copy())

## 4. Epoching and ERP Analysis

In [None]:
# Create epochs
tmin, tmax = -0.2, 0.8  # Time window around events

epochs = mne.Epochs(
    seeg_raw_filtered_8,
    events_8,
    event_id=event_id,
    tmin=tmin,
    tmax=tmax,
    baseline=(None, 0),
    preload=True
)

print(f"Created {len(epochs)} epochs")
print(f"Epoch duration: {tmax - tmin}s")

In [None]:
# Compute evoked responses (ERPs)
evoked_retrieval = epochs['retrieval'].average()
evoked_response = epochs['response'].average()

# Plot evoked responses
evoked_retrieval.plot()
evoked_response.plot()

In [None]:
# Plot topographic maps at different time points
times = np.arange(0.1, 0.5, 0.1)
evoked_retrieval.plot_topomap(times=times)

In [None]:
# Plot joint ERP and topography
evoked_response.plot_joint()

## 5. Reaction Time Analysis

In [None]:
# Extract reaction times
# Using the raw object and specifying probe and response event IDs
rt_data = extract_reaction_times_from_fif(
    subject_8,
    probe_event_id=4,  # retrieval
    response_event_id=5  # response
)

# Convert to DataFrame for easier analysis
rt_df = pd.DataFrame(rt_data)

if len(rt_df) > 0:
    print("\nReaction Time Statistics:")
    print(f"Mean RT: {rt_df['reaction_times'].mean():.3f}s")
    print(f"Median RT: {rt_df['reaction_times'].median():.3f}s")
    print(f"SD RT: {rt_df['reaction_times'].std():.3f}s")
    print(f"Min RT: {rt_df['reaction_times'].min():.3f}s")
    print(f"Max RT: {rt_df['reaction_times'].max():.3f}s")

In [None]:
# Plot RT distribution
if len(rt_df) > 0:
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # Histogram
    axes[0, 0].hist(rt_df['reaction_times'], bins=30, edgecolor='black')
    axes[0, 0].set_xlabel('Reaction Time (s)')
    axes[0, 0].set_ylabel('Count')
    axes[0, 0].set_title('RT Distribution')
    
    # Box plot
    axes[0, 1].boxplot(rt_df['reaction_times'])
    axes[0, 1].set_ylabel('Reaction Time (s)')
    axes[0, 1].set_title('RT Box Plot')
    
    # Time series
    axes[1, 0].plot(rt_df['trial_numbers'], rt_df['reaction_times'], 'o-')
    axes[1, 0].set_xlabel('Trial Number')
    axes[1, 0].set_ylabel('Reaction Time (s)')
    axes[1, 0].set_title('RT Over Trials')
    
    # Q-Q plot
    scipy_stats.probplot(rt_df['reaction_times'], dist="norm", plot=axes[1, 1])
    axes[1, 1].set_title('Q-Q Plot')
    
    plt.tight_layout()
    plt.show()

## 6. Additional Visualization

In [None]:
# Plot raw data with events
subject_8.plot(
    events=events_8,
    event_id=event_id,
    scalings='auto',
    duration=10.0
)

In [None]:
# Plot power spectral density
seeg_raw_filtered_8.compute_psd(fmax=50).plot()

In [None]:
# Plot epochs as image
epochs['retrieval'].plot_image(picks='eeg')