# SEMP Preprocessing Tutorial

This notebook walks you through adapting the `template` project folder into a working preprocessing pipeline for your own EEG-fMRI dataset.  
It mirrors the `projects/sr/` implementation and explains every design decision.

---

## Overview of steps

1. **Standardise the dataset** — ensure every raw EEG file follows a consistent naming pattern
2. **Write `pathfinder.py`** — teach semp how to find files on disk
3. **Write `initialize()` in `helpers.py`** — fill `dataset` with the metadata your pipeline needs
4. **Write the config dict in `1.prep.py`** — run the full preprocessing pipeline

---
## Step 1 — Standardise the dataset

semp's pathfinder works by matching filenames against a **pattern with named placeholders**, e.g.:

```
/data/eeg-fmri/sub-{subject}_ses-{session}_run-{run}_eeg.edf
```

This requires that **all** your raw EEG files follow the same naming convention.  
If your dataset is inconsistent (mixed separators, varying padding, missing fields), fix it before proceeding.

### 1a. Audit your raw files

In [None]:
from pathlib import Path

# Point this at the folder that contains all your raw EEG files (search recursively)
RAW_ROOT = Path("/path/to/your/raw/eeg/")

# List every .edf (or .fif / .set / .cdt) found under RAW_ROOT
raw_files = sorted(RAW_ROOT.rglob("*.edf"))
for f in raw_files[:20]:   # show first 20
    print(f)

### 1b. Rename files if needed

If file names are inconsistent, use `Path.rename()` to enforce a uniform pattern.  
The example below parses a legacy naming scheme and renames to BIDS-like names.

In [None]:
import re, shutil

DRY_RUN = True   # set to False to actually rename

LEGACY_PATTERN = re.compile(r"(\d+)_(\d+)_(\d+)\.edf")   # e.g. 01_02_1.edf → sub-01_ses-02_run-1

for f in raw_files:
    m = LEGACY_PATTERN.match(f.name)
    if m:
        subj, ses, run = m.groups()
        new_name = f"sub-{subj.zfill(2)}_ses-{ses.zfill(2)}_run-{run}_eeg.edf"
        new_path = f.parent / new_name
        print(f"{f.name}  →  {new_name}")
        if not DRY_RUN:
            f.rename(new_path)

---
## Step 2 — Write `pathfinder.py`

The pathfinder is a **frozen dataclass** that:
- Stores a `file_patterns` dict mapping *kind* → *glob pattern with `{placeholders}`*
- Scans the disk on construction and builds a lookup table: `file_id → kind → Path`
- Requires you to implement two methods: `dict2id` and `id2dict`

### Design decisions you must make

| Question | Example answer |
|---|---|
| What are the fields that uniquely identify a recording? | `subject`, `session`, `run` |
| What is a compact string identifier (`file_id`) for a recording? | `"0121"` for sub=01, ses=02, run=01 |
| What is the **anchor** kind — the file that *must* exist for a recording to be included? | `"raw"` |
| Which kinds are **required** (remove the file_id if missing)? | `{"raw", "polhemus"}` |

### Template `pathfinder.py`

Copy this into `projects/<your_project>/pathfinder.py` and customise it:

In [None]:
# ── pathfinder.py (copy to your project folder) ─────────────────────────────

from pathlib import Path
from typing import Dict, Optional
from semp.utils import BasePathfinder


class MyDatasetPathfinder(BasePathfinder):
    """
    Pathfinder for <your dataset name>.

    File ID convention: zero-padded subject (2 digits) + session (1 digit) + run (1 digit).

    Examples
    --------
    sub-01, ses-02, run-1  →  file_id "0121"
    sub-11, ses-01, run-2  →  file_id "1112"
    """

    DEFAULT_FILE_PATTERNS = {
        # ── anchor: the raw EEG file ────────────────────────────────────────
        # Use {subject}, {session}, {run} as placeholders.
        # You can use {foo1}, {foo2}, … for parts of the name you don't care about.
        'raw': "/path/to/raw/sub-{subject}_ses-{session}_run-{run}_{foo1}_eeg.edf",

        # ── optional: polhemus digitisation file ────────────────────────────
        'polhemus': "/path/to/raw/sub-{subject}/ses-{session}/polhemus/{foo1}.pom",

        # ── output: filled in after preprocessing ───────────────────────────
        'preproc': "/path/to/output/{subject}{session}{run}/{subject}{session}{run}_preproc-raw.fif",
        ### Note: any string including }{ is an "ambiguous pattern" that cannot be used as an anchor. You need at least one unambiguous pattern (without }{) to serve as the anchor for scanning the disk.
    }
    DEFAULT_ANCHOR = "raw"      # scan disk using this kind's pattern
    REQUIRED_KEYS  = {"raw"}    # remove file_id from pathfinder if this kind is missing

    def __init__(self, file_patterns=None, anchor=None):
        patterns = file_patterns or self.DEFAULT_FILE_PATTERNS
        chosen_anchor = anchor or self.DEFAULT_ANCHOR
        super().__init__(file_patterns=patterns, anchor=chosen_anchor, required=self.REQUIRED_KEYS)

    # ── required: compact string identifier ─────────────────────────────────
    def dict2id(self, fields: Dict[str, Optional[str]]) -> str:
        """Fields dict  →  canonical file_id string."""
        subj = (fields.get('subject') or '1').lstrip('0') or '1'
        ses  = fields.get('session', '1')
        run  = fields.get('run',     '1')
        return f"{subj}{ses}{run}"

    # ── required: reverse mapping ────────────────────────────────────────────
    def id2dict(self, file_id: str) -> Dict[str, str]:
        """Canonical file_id string  →  fields dict."""
        if len(file_id) < 3:
            raise ValueError(f"Invalid file_id: '{file_id}'")
        return {
            'subject': file_id[:-2].zfill(2),   # zero-pad subject to 2 digits
            'session': file_id[-2],
            'run':     file_id[-1],
        }

### Verify the pathfinder

In [None]:
pf = MyDatasetPathfinder()

print(f"Found {len(pf)} recordings:\n")
for fid in list(pf.keys())[:5]:   # show first 5
    print(f"  {fid}:")
    for kind, path in pf[fid].items():
        status = "✓" if path is not None else "✗ (missing)"
        print(f"    {kind:<12} {status}  {path}")

---
## Step 3 — Write `initialize()` in `helpers.py`

`initialize()` is the **first step** in every `run_proc_chain` call.  
It populates `dataset` with the metadata that all later steps depend on.

The five keys you **must** set are described below.

---

### 3.1  `tr_interval` — the fMRI TR (repetition time)

This is stated in the MRI protocol (e.g. TR = 1.14 s for the Staresina dataset).  
It is used by `crop_TR` to align the EEG timeline with the fMRI volumes.

---

### 3.2  `slice_interval` — the fMRI slice timing

This is the time between consecutive slice acquisitions (TR / number-of-slices).  
It determines the dominant gradient artifact frequency: `1/slice_interval` Hz.

**If your MRI protocol sheet does not list it**, inspect the raw EEG PSD:  
gradient artifact peaks appear at harmonics of `1/slice_interval` Hz.

In [None]:
import mne
from semp.eeg import psd_plot

# Load one raw file (before any cleaning)
raw = mne.io.read_raw_edf("/path/to/one/raw_eeg.edf", preload=True)

# Plot PSD up to 30 Hz — the tallest regularly spaced peaks are GA harmonics.
# E.g. peaks at 14.3, 28.6 Hz  →  slice_interval = 1/14.3 ≈ 0.07 s
psd_plot(raw, picks='eeg', fmin=0, fmax=50, dB=False)

---

### 3.3  `tr_event_key` — the annotation label for TR triggers

The EEG amplifier records a trigger at every fMRI volume onset.  
Find the annotation label whose inter-event interval is consistently equal to `tr_interval`.

In [None]:
import numpy as np

# Get all events and their labels
events, event_id = mne.events_from_annotations(raw)
print("All annotation labels and their event codes:")
for label, code in event_id.items():
    print(f"  code={code:>8}  label='{label}'")

print()
TR = 1.14     # ← replace with your tr_interval
sfreq = raw.info['sfreq']
TR_samples = TR * sfreq

print(f"Looking for events with inter-event spacing ≈ {TR_samples:.1f} samples ({TR} s):\n")
for label, code in event_id.items():
    mask = events[:, 2] == code
    times = events[mask, 0]
    if len(times) < 2:
        continue
    diffs = np.diff(times)
    mean_diff = diffs.mean()
    std_diff  = diffs.std()
    print(f"  '{label}': n={len(times)}, mean_interval={mean_diff/sfreq:.4f}s, std={std_diff/sfreq:.4f}s")
    # A good TR key has mean ≈ TR and very low std (< 1 sample jitter)
    if abs(mean_diff - TR_samples) < 2 and std_diff < 5:
        print(f"    ^^^ CANDIDATE TR event key: '{label}'")

---

### 3.4  `he_event_key` — helium pump trigger (optional)

Some sites record the helium pump cycle as a trigger.  
This is only needed if you plan to remove the helium pump artifact via OBS (mainly contaminating 30Hz+ data).  
**You can safely set this to `[]` or skip it entirely if you are not interested in >30 Hz information.**

---

### Full `initialize()` template

In [None]:
# ── helpers.py (copy to your project folder and fill in the constants) ────────

from pathlib import Path
from functools import partial
import numpy as np
from semp.utils import psd_band_ratio
from pathfinder import MyDatasetPathfinder


def initialize(dataset, userargs):
    """Populate dataset with all metadata needed by the preprocessing pipeline."""

    # ── 3.1  fMRI repetition time (seconds) ───────────────────────────────────
    # Read from your MRI protocol sheet.
    dataset['tr_interval']    = userargs.get('tr_interval',    1.14)

    # ── 3.2  fMRI slice interval (seconds) ────────────────────────────────────
    # = TR / n_slices, or inspected from PSD peaks (see cell above).
    dataset['slice_interval'] = userargs.get('slice_interval', 0.07)

    # ── 3.3  Annotation label(s) for TR trigger ────────────────────────────────
    # Found by running the cell above. Pass as a list of strings.
    dataset['tr_event_key']   = userargs.get('tr_event_key',   ['TR_LABEL'])

    # ── 3.4  Annotation label(s) for helium pump (optional) ───────────────────
    # Set to [] if you do not need OBS helium pump removal.
    dataset['he_event_key']   = userargs.get('he_event_key',   [])

    # ── 3.5  Output directory ─────────────────────────────────────────────────
    dataset['target_pth']     = userargs.get('target_pth',
                                             Path("/path/to/output/after_prep"))

    # ── Pathfinder & subject ID ───────────────────────────────────────────────
    dataset['pf']      = userargs['pf'] if 'pf' in userargs else MyDatasetPathfinder()
    dataset['subject'] = dataset['pf'].filename2id(dataset['raw'].filenames[0], kind='raw')

    # ── Sampling frequency (for reference) ────────────────────────────────────
    dataset['orig_sfreq'] = dataset['raw'].info['sfreq']

    # ── Slice-frequency tracers ────────────────────────────────────────────────
    ### some extra tracers you want to computer for init_tracer() and summary(). you can safely set it to {} and just use the default tracers provided by semp.
    si = dataset['slice_interval']
    dataset['tracer'] = {
        'psd_slice':  partial(psd_band_ratio,
                              band1=[1/si - 1, 1/si + 1], band2='beta', fn1=np.mean),
        'psd_2slice': partial(psd_band_ratio,
                              band1=[2/si - 1, 2/si + 1], band2=[20, 35], fn1=np.mean),
    }

    # ── Any dataset-specific fixes go here ────────────────────────────────────
    # Example: drop channels with no gel
    # dataset['raw'].drop_channels(['F11', 'F12'], on_missing='warn')
    
    # Example2: do some special custom fix for subject 05 session 2 run 1
    # if dataset['subject'] == '05' and dataset['session'] == '2' and dataset['run'] == '1':
    #     do_something_special(dataset['raw'])

    return dataset

---
## Step 4 — Write the preprocessing config in `1.prep.py`

The config is a list of steps passed to `osl_ephys.preprocessing.run_proc_batch`.  
Each step is a `{function_name: kwargs}` dict. Custom functions are passed via `extra_funcs`.

The pipeline below is the **standard semp EEG-fMRI preprocessing sequence**.

---

### 4.1  Initialisation, montage, notch, TR crop

- `initialize` — fills `dataset` (Step 3 above)
- `set_channel_types` — mark EOG/ECG/EMG channels by name (check your channel list)
- `set_channel_montage` — apply digitised electrode positions
- `notch_filter` — remove line noise (50 Hz + harmonics in Europe)
- `crop_TR` — trim the recording to whole TR intervals using the TR trigger

### 4.2  Gradient artifact removal (AAS)

- `create_epoch` — epoch the raw data into TR-length windows
- `epoch_aas` — average artifact subtraction (AAS): subtracts the mean TR template
- `voltage_correction` — rescale EOG/ECG/EMG channels that are mislabelled as V instead of µV, this is a thing in Staresina dataset

### 4.3  Bandpass filter + mid-crop

- `filter` — 5th-order Butterworth IIR (0.5–125 Hz). **Use IIR, not FIR**: FIR does not fully attenuate out-of-band frequencies here.
- `mid_crop` — removes 5 s from each edge to prevent filter edge artefacts before resampling
- `resample` — downsample to 250 Hz

### 4.4  Bad segment and channel detection

- Three `bad_segments` passes: amplitude + diff mode for EEG, amplitude for EOG
- `bad_channels` — detect and mark globally bad EEG channels

### 4.5  Slice ICA (residual GA removal)

- `slice_ica` — ICA targeting `1/slice_interval` Hz and its harmonics.  
  AAS leaves a residual because it assumes a stationary template; slice_ica cleans the remainder.

### 4.6  General ICA (pulse & ocular artefacts)

- `ica_raw` — fit ICA on the cleaned broadband signal
- `ica_autoreject` — automatically label components as EOG (correlation), ECG (CTPS, threshold 0.1), and apply rejection
  - If automatic rejection misses artifacts or removes brain components, set `apply=False` here and inspect IC topos/time-series manually before applying.

### 4.7  Final bad channels, interpolation, re-reference

- `bad_channels` — second pass after ICA (ICA sometimes reveals previously hidden bad channels)
- `interpolate_bads` — spherical spline interpolation of marked bad channels
- `set_eeg_reference` — average reference (projection-based)

---

### Full `1.prep.py`

In [None]:
# ── 1.prep.py (copy to your project folder and adjust the constants) ─────────

import numpy as np
from pathlib import Path
from osl_ephys.preprocessing import run_proc_batch

from pathfinder import MyDatasetPathfinder
from helpers import initialize

from semp.eeg import (
    crop_TR, epoch_aas, epoch_obs, create_epoch,
    ckpt_report, slice_ica, init_tracer, summary,
    mid_crop, set_channel_type_raw, voltage_correction,
)

# ─────────────────────────────────────────────────────────────────────────────
# Global settings — change these for your dataset
# ─────────────────────────────────────────────────────────────────────────────
TR             = 1.14          # fMRI repetition time (seconds)
SLICE_INTERVAL = 0.07          # fMRI slice interval (seconds) — TR / n_slices
TR_EVENT_KEY   = ['TR_LABEL']  # annotation label(s) for TR trigger
HE_EVENT_KEY   = []            # helium pump label(s); set [] to skip
TARGET_PTH     = Path("/path/to/output/after_prep")

continue_interrupt = True      # skip subjects already finished or errored

pf = MyDatasetPathfinder()

# ─────────────────────────────────────────────────────────────────────────────
# Preprocessing config
# ─────────────────────────────────────────────────────────────────────────────
config = {
    'preproc': [

        # ── 4.1  Init, montage, notch, TR crop ────────────────────────────────
        {'initialize': {
            'target_pth':     TARGET_PTH,
            'pf':             pf,
            'tr_interval':    TR,
            'slice_interval': SLICE_INTERVAL,
            'tr_event_key':   TR_EVENT_KEY,
            'he_event_key':   HE_EVENT_KEY,
        }},
        {'init_tracer': {}},
        # Rename EOG/ECG/EMG channels to their correct types:
        {'set_channel_types': {'VEOG': 'eog', 'HEOG': 'eog', 'EKG': 'ecg', 'EMG': 'emg'}},
        {'notch_filter': {'freqs': '50 100'}},
        {'crop_TR': {'tmin': 0, 'TR': TR}},
        {'ckpt_report': {'ckpt_name': 'raw', 'focus_range': [0, 10], 'dB': False}},

        # ── 4.2  Gradient artifact removal (AAS) ──────────────────────────────
        {'create_epoch': {'event': 'TR', 'tmin': 0, 'tmax': TR, 'correct_trig': True}},
        {'epoch_aas': {'epoch_key': 'tr_ep', 'overwrite': 'new',
                       'picks': 'all', 'window_length': 30, 'fit': False}},
        {'voltage_correction': {}},
        {'ckpt_report': {'ckpt_name': 'after_aas_removal',
                         'key_to_print': 'tr_ep', 'dB': False}},

        # ── 4.3  Bandpass + mid-crop + resample ───────────────────────────────
        # IIR Butterworth preferred over FIR: better out-of-band attenuation here.
        {'filter': {'l_freq': 0.5, 'h_freq': 125,
                    'method': 'iir', 'iir_params': {'order': 5, 'ftype': 'butter'}}},
        # Remove 5 s edges to prevent filter ringing before downsampling.
        {'mid_crop': {'edge': 5}},
        {'resample': {'sfreq': 250}},
        {'ckpt_report': {'ckpt_name': 'after_filt', 'dB': False}},

        # ── 4.4  Bad segment & channel detection ──────────────────────────────
        {'bad_segments': {'segment_len': 500, 'picks': 'eeg',
                          'significance_level': 0.1, 'detect_zeros': False}},
        {'bad_segments': {'segment_len': 500, 'picks': 'eeg', 'mode': 'diff',
                          'significance_level': 0.1, 'detect_zeros': False}},
        {'bad_channels': {'picks': 'eeg', 'significance_level': 0.1}},
        {'bad_segments': {'segment_len': 2500, 'picks': 'eog', 'detect_zeros': False}},

        # ── 4.5  Slice ICA — residual gradient artifact ────────────────────────
        # Targets 1/slice_interval Hz and its harmonics.
        # Required because AAS assumes a stationary artifact template.
        {'slice_ica': {}},
        {'ckpt_report': {'ckpt_name': 'after_bads_trica', 'dB': False}},

        # ── 4.6  General ICA — pulse artifact, EOG, ECG ───────────────────────
        # n_components=0.999 retains 99.9 % of EEG variance.
        # ica_autoreject: EOG by correlation (0.35), ECG by CTPS (0.1).
        # For a harder dataset or finer control, set apply=False and inspect
        # component topographies and time-series before deciding what to reject.
        {'ica_raw': {'n_components': 0.999, 'picks': 'eeg', 'l_freq': 1}},
        {'ica_autoreject': {'eogmeasure': 'correlation', 'eogthreshold': 0.35,
                            'ecgmethod': 'ctps', 'ecgthreshold': 0.1, 'apply': True}},

        # ── 4.7  Final bad channels, interpolation, re-reference ──────────────
        {'bad_channels': {'picks': 'eeg', 'significance_level': 0.1}},
        {'interpolate_bads': {}},
        {'ckpt_report': {'ckpt_name': 'after_ica', 'dB': False}},
        {'set_eeg_reference': {'projection': True}},
        {'summary': {}},
    ]
}

# ─────────────────────────────────────────────────────────────────────────────
# Build subject/file lists (skip already-finished or errored subjects)
# ─────────────────────────────────────────────────────────────────────────────
subject_list = list(pf.keys())
file_list    = [pf[s]['raw'] for s in subject_list]

if continue_interrupt:
    finished = {p.parts[-2] for p in TARGET_PTH.glob('*/*_preproc-raw.fif')}
    errored  = {p.parts[-1].split('_')[0] for p in TARGET_PTH.glob('logs/*.error.log')}

    pairs = [(s, f) for s, f in zip(subject_list, file_list)
             if s not in finished and s not in errored]
    if pairs:
        subject_list, file_list = zip(*pairs)
    else:
        subject_list, file_list = [], []

# ─────────────────────────────────────────────────────────────────────────────
# Run
# ─────────────────────────────────────────────────────────────────────────────
run_proc_batch(
    config, list(file_list),
    subjects=list(subject_list),
    outdir=str(TARGET_PTH),
    gen_report=False,
    overwrite=True,
    extra_funcs=[
        initialize, init_tracer, crop_TR, set_channel_type_raw,
        create_epoch, epoch_aas, epoch_obs, voltage_correction,
        ckpt_report, mid_crop, slice_ica, summary,
    ],
)