In [145]:
import os
import mne
import json
import pandas as pd
from pathlib import Path
import re
import matplotlib.pyplot as plt
import numpy as np
import warnings

warnings.filterwarnings("ignore", message=".*boundary.*data discontinuities.*")

In [None]:
class EEGSubject:
    def __init__(self, subject_id, data_directory, auto_load=True, l_freq=None, h_freq=60):
        """Initialize an EEGSubject with ID and data directory."""
        self.subject = subject_id
        self._data_dir = Path(data_directory)
        self._task_info = []
        self.eeg_metadata = {}
        self.events = None
        self.channels = None
        self.electrodes = None
        self._raw_data = {}
        self._l_freq = l_freq
        self._h_freq = h_freq
        if auto_load:
            self._find_tasks()
            for info in self._task_info:
                task = info['task']
                run = info['run']
                self._load_raw(task, run)
                self._load_metadata_and_annotations(task, run)
                if task == 'RestingState':
                    self._extract_eye_conditions(task=task, run=run)

        self.bands = {
            "Delta": (1, 4),
            "Theta": (4, 8),
            "Alpha": (8, 13),
            "Beta": (13, 30),
            "Gamma": (30, 50)
        }


    def _get_file(self, task, ext, run=None):
        name = f"sub-{self.subject}_task-{task}"
        if run:
            name += f"_run-{run}"
        return self._data_dir / f"{name}_{ext}"

    def _find_tasks(self):
        eeg_files = sorted(self._data_dir.glob(f"sub-{self.subject}_task-*_eeg.set"))
        pattern = re.compile(rf"sub-{self.subject}_task-(?P<task>.+?)(?:_run-(?P<run>\\d+))?_eeg.set")
        for file in eeg_files:
            match = pattern.match(file.name)
            if match:
                self._task_info.append({
                    "task": match.group("task"),
                    "run": match.group("run")
                })

    def _load_raw(self, task, run=None):
        eeg_path = self._get_file(task, 'eeg.set', run=run)
        raw = mne.io.read_raw_eeglab(eeg_path, preload=True, montage_units='cm')
        montage = mne.channels.make_standard_montage('GSN-HydroCel-128')
        raw.drop_channels(['Cz'])
        raw.set_montage(montage, match_case=False)
        # raw.filter(l_freq=self._l_freq, h_freq=self._h_freq)
        self._raw_data[(task, run)] = raw

    def _extract_eye_conditions(self, task='RestingState', run=None):
        """Extract and store concatenated eyes-open and eyes-closed segments."""
        raw = self._raw_data.get((task, run))
        if raw is None or self.events is None:
            return

        open_segments, closed_segments = [], []
        eyes_state = None
        last_onset = None

        for _, row in self.events.iterrows():
            if row['value'] == 'instructed_toOpenEyes':
                if eyes_state == 'closed' and last_onset is not None:
                    closed_segments.append(raw.copy().crop(tmin=last_onset, tmax=row['onset']))
                eyes_state = 'open'
                last_onset = row['onset']
            elif row['value'] == 'instructed_toCloseEyes':
                if eyes_state == 'open' and last_onset is not None:
                    open_segments.append(raw.copy().crop(tmin=last_onset, tmax=row['onset']))
                eyes_state = 'closed'
                last_onset = row['onset']

        if open_segments:
            self._raw_data[(task, run, 'open')] = mne.Epochs(open_segments)
        if closed_segments:
            self._raw_data[(task, run, 'closed')] = mne.concatenate_raws(closed_segments)

    def get_raw(self, task, run=None, condition=None):
        key = (task, run, condition) if condition else (task, run)
        return self._raw_data[key]

    def compute_band_psd(self, raw, fmin=1, fmax=50):
        psd = raw.compute_psd(fmin=fmin, fmax=fmax)
        psd_data, freqs = psd.get_data(return_freqs=True)
        band_powers = {
            band: psd_data[:, (freqs >= low) & (freqs < high)].mean(axis=1)
            for band, (low, high) in self.bands.items()
        }
        return psd, band_powers, freqs

    def _load_metadata_and_annotations(self, task, run=None):
        json_path = self._get_file(task, 'eeg.json', run=run)
        if json_path.exists():
            with open(json_path) as f:
                self.eeg_metadata = json.load(f)

        event_path = self._get_file(task, 'events.tsv', run=run)
        if event_path.exists():
            self.events = pd.read_csv(event_path, sep='\t')

        channels_path = self._get_file(task, 'channels.tsv', run=run)
        if channels_path.exists():
            self.channels = pd.read_csv(channels_path, sep='\t')

        electrodes_path = self._get_file(task, 'electrodes.tsv', run=run)
        if electrodes_path.exists():
            self.electrodes = pd.read_csv(electrodes_path, sep='\t')

    def show_annotations(self, task, run=None):
        print(f"\nMetadata for {self.subject} - {task}:")
        print(json.dumps(self.eeg_metadata, indent=2))

    def show_table(self, name='events', task=None, rows=10):
        df = None
        if name == 'events':
            df = self.events
        elif name == 'channels':
            df = self.channels
        elif name == 'electrodes':
            df = self.electrodes

        print(f"\n{name.capitalize()} for {self.subject} - {task if task else ''}:")
        if df is not None:
            print(df.head(rows))
        else:
            print(f"No {name} data available.")

    def visualize(self, subjects=None, tasks=None, plots=None):
        if subjects is not None and self.subject not in subjects:
            return

        if not self._task_info:
            print(f"No task info found for subject {self.subject}.")
            return

        selected_tasks = [t for t in self._task_info if (tasks is None or t['task'] in tasks)]
        for info in selected_tasks:
            task = info['task']
            run = info['run']
            print(f"\n--- Visualizing subject: {self.subject} | task: {task}" + (f"  Run: {run}" if run else "") + " ---")
            if plots is None or 'sensors' in plots:
                self.plot_sensors(task, run)
            if plots is None or 'time' in plots:
                self.plot_time_domain(task, run)
            if plots is None or 'frequency' in plots:
                self.plot_frequency_domain(task, run)
            if 'bandmap' in plots:
                self.plot_bandwise_psd_comparison(task1=task, run1=run)

    def plot_sensors(self, task, run=None):
        raw = self._raw_data.get((task, run))
        if raw:
            mne.viz.plot_sensors(raw.info, show_names=True)
            raw.plot_sensors(show_names=True)
            raw.plot_sensors('3d')

    def plot_time_domain(self, task, run=None, n_channels=10):
        raw = self._raw_data.get((task, run))
        if raw:
            plot_title = f"{self.subject} - {task}" + (f" (Run {run})" if run else "")
            raw.plot(n_channels=n_channels, title=plot_title, scalings='auto', show=True, block=True)

    def plot_frequency_domain(self, task, run=None, fmin=1, fmax=60, average=True, show=True):
        """Standalone PSD plot for a given task/run."""
        key = (task, run)
        raw = self._raw_data.get(key)
        psd = raw.compute_psd(fmin=fmin, fmax=fmax)
        psd.plot(average=average, spatial_colors=False, dB=True, show=show)
        return psd

    def plot_topomaps(self, task, run=None, condition=None, fmin=1, fmax=50, label=None):
        key = (task, run, condition) if condition else (task, run)
        raw = self._raw_data[key]

        psd = raw.compute_psd(fmin=fmin, fmax=fmax)
        psd_data, freqs = psd.get_data(return_freqs=True)
        band_powers = {
            band: psd_data[:, (freqs >= low) & (freqs < high)].mean(axis=1)
            for band, (low, high) in self.bands.items()
        }

        fig, axs = plt.subplots(1, len(self.bands), figsize=(15, 3))
        for ax, (band, data) in zip(axs, band_powers.items()):
            mne.viz.plot_topomap(data, raw.info, axes=ax, show=False, contours=0)
            ax.set_title(f"{band} {label or ''}")
        plt.suptitle(f"{self.subject} - {task} Topomaps")
        plt.tight_layout()
        plt.show()

    def plot_bandwise_psd_comparison(self, task1, task2=None, run1=None, run2=None,
                                    condition1=None, condition2=None, bands=None):
        """Compare PSD and topomaps of two tasks/conditions. Layout is separated for clarity."""
        import matplotlib.pyplot as plt
        import warnings

        if bands is None:
            bands = {
                "Delta": (1, 4),
                "Theta": (4, 8),
                "Alpha": (8, 13),
                "Beta": (13, 30),
                "Gamma": (30, 50)
            }

        def _compute(raw):
            psd = raw.compute_psd(fmin=1, fmax=50)
            psd_data, freqs = psd.get_data(return_freqs=True)
            band_powers = {
                band: psd_data[:, (freqs >= low) & (freqs < high)].mean(axis=1)
                for band, (low, high) in bands.items()
            }
            return psd, band_powers, freqs

        # --- Load Task 1 ---
        key1 = (task1, run1, condition1) if condition1 else (task1, run1)
        raw1 = self._raw_data.get(key1)
        if raw1 is None:
            print(f"[ERROR] No data for {self.subject} {key1}")
            return
        psd1, bands1, freqs = _compute(raw1)

        # --- Load Task 2 (optional) ---
        has_task2 = task2 is not None
        if has_task2:
            key2 = (task2, run2, condition2) if condition2 else (task2, run2)
            raw2 = self._raw_data.get(key2)
            if raw2 is None:
                print(f"[ERROR] No data for {self.subject} {key2}")
                return
            psd2, bands2, _ = _compute(raw2)

        # --- Setup figure ---
        n_rows = 2 if not has_task2 else 7
        fig = plt.figure(figsize=(16, 3 * n_rows), constrained_layout=True)
        gs = fig.add_gridspec(n_rows, len(bands))

        # Row 0: PSD Task 1
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UserWarning)
            ax = fig.add_subplot(gs[0, :])
            psd1.plot(axes=ax, show=False)
        ax.set_title(f"{task1} ({condition1 or 'full'}) - PSD")

        # Row 1: Topomaps Task 1
        for i, (band, data) in enumerate(bands1.items()):
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", category=UserWarning)
                mne.viz.plot_topomap(data, raw1.info, axes=fig.add_subplot(gs[1, i]), show=False, contours=0)
            fig.axes[-1].set_title(band)

        if not has_task2:
            fig.suptitle(f"{self.subject} - Bandwise PSD (Single Task)", fontsize=14)
            plt.show()
            return

        # Row 2: PSD Task 2
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UserWarning)
            ax = fig.add_subplot(gs[2, :])
            psd2.plot(axes=ax, show=False)
        ax.set_title(f"{task2} ({condition2 or 'full'}) - PSD")

        # Row 3: Topomaps Task 2
        for i, (band, data) in enumerate(bands2.items()):
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", category=UserWarning)
                mne.viz.plot_topomap(data, raw2.info, axes=fig.add_subplot(gs[3, i]), show=False, contours=0)
            fig.axes[-1].set_title(band)

        # Row 4: PSD Difference
        ax = fig.add_subplot(gs[4, :])
        diff_psd = psd1.get_data().mean(axis=0) - psd2.get_data().mean(axis=0)
        ax.plot(freqs, diff_psd, color="darkred")
        ax.axhline(0, color='black', linestyle='--', linewidth=0.5)
        ax.set_title("PSD Difference")
        ax.set_xlabel("Frequency (Hz)")
        ax.set_ylabel("Δ Power")

        # Row 5: Stacked PSD Difference by Channel
        ax = fig.add_subplot(gs[5, :])
        diff_data = psd1.get_data() - psd2.get_data()  # shape (n_channels, n_freqs)

        for ch_data in diff_data:
            ax.plot(freqs, ch_data, linewidth=0.5)

        ax.axhline(0, color='black', linestyle='--', linewidth=0.5)
        ax.set_title(f"Eyes-closed minus Eyes-open (PSD per channel)")
        ax.set_xlabel("Frequency (Hz)")
        ax.set_ylabel("Spectral amplitude (μV)")
        
        # Row 6: Topomap Difference
        for i, band in enumerate(bands):
            diff = bands1[band] - bands2[band]
            mne.viz.plot_topomap(diff, raw1.info, axes=fig.add_subplot(gs[6, i]), show=False, contours=0)
            fig.axes[-1].set_title(f"{band} Δ")

        fig.suptitle(f"{self.subject} - Bandwise PSD Comparison", fontsize=16)
        plt.show()



In [147]:
class EEGStudy:
    def __init__(self, data_directory, subjects=None, auto_load=True):
        """Initialize an EEGStudy, auto-detecting subject IDs from filenames if not provided."""
        self.subjects = {}
        self._data_directory = Path(data_directory)

        # Auto-detect subjects from filenames if not provided
        if subjects is None:
            subject_ids = set()
            for file in self._data_directory.glob("sub-*_task-*_eeg.set"):
                match = re.match(r"sub-(?P<subject>[^_]+)_task-", file.name)
                if match:
                    subject_ids.add(match.group("subject"))
            subjects = sorted(subject_ids)

        for subj in subjects:
            self.subjects[subj] = EEGSubject(subj, self._data_directory, auto_load=auto_load)

    def get_subject(self, subject_id):
        return self.subjects.get(subject_id)

    def visualize(self, subjects=None, tasks=None, plots=None):
        for subj_id, subject in self.subjects.items():
            if subjects is None or subj_id in subjects:
                subject.visualize(subjects=subjects, tasks=tasks, plots=plots)

In [148]:
# # Define and load study
data_dir = '/mount/sub/cmi_bids_R1/eeg'
study = EEGStudy(data_directory=data_dir)

In [149]:
import ipywidgets as widgets
from IPython.display import display, clear_output

# Widget container for dynamic UI
ui_output = widgets.Output()

# Create dropdowns and selectors
subject_dropdown = widgets.Dropdown(options=list(study.subjects.keys()), description='Subject:')
task_dropdown = widgets.Dropdown(description='Task:')
plot_checkboxes = widgets.SelectMultiple(
    options=['sensors', 'time', 'frequency', 'bandmap'],
    value=['sensors', 'time'],
    description='Plots:'
)

def update_tasks(*args):
    subject = subject_dropdown.value
    subj_obj = study.get_subject(subject)
    if subj_obj is not None and hasattr(subj_obj, '_task_info'):
        tasks = sorted(set([info['task'] for info in subj_obj._task_info]))
        task_dropdown.options = ['All'] + tasks  # prepend 'All'
        task_dropdown.value = 'All'  # set default to All

subject_dropdown.observe(update_tasks, names='value')
update_tasks()

def run_visualization(b):
    with ui_output:
        clear_output()
        selected_task = None if task_dropdown.value == 'All' else [task_dropdown.value]
        study.visualize(
            subjects=[subject_dropdown.value],
            tasks=selected_task,
            plots=list(plot_checkboxes.value)
        )
        
run_button = widgets.Button(description='Run Visualization')
run_button.on_click(run_visualization)

# Layout
control_panel = widgets.VBox([subject_dropdown, task_dropdown, plot_checkboxes, run_button])

# Initial display
display(control_panel, ui_output)


VBox(children=(Dropdown(description='Subject:', options=('NDARAC904DMU',), value='NDARAC904DMU'), Dropdown(des…

Output()

In [150]:
%matplotlib inline

ui_out = widgets.Output()

subject_dropdown = widgets.Dropdown(options=list(study.subjects.keys()), description="Subject")
task1_dropdown = widgets.Dropdown(description="Task 1")
task2_dropdown = widgets.Dropdown(description="Task 2")
cond1_dropdown = widgets.Dropdown(description="Cond 1", options=["", "open", "closed"])
cond2_dropdown = widgets.Dropdown(description="Cond 2", options=["", "open", "closed"])
run1_input = widgets.Text(description="Run 1", placeholder="Optional")
run2_input = widgets.Text(description="Run 2", placeholder="Optional")

def update_tasks(*args):
    subject = study.get_subject(subject_dropdown.value)
    if subject:
        tasks = sorted({info["task"] for info in getattr(subject, "_task_info", [])})
        task1_dropdown.options = tasks
        task2_dropdown.options = tasks

subject_dropdown.observe(update_tasks, names="value")
update_tasks()

def run_comparison(b):
    with ui_out:
        clear_output()
        subj = study.get_subject(subject_dropdown.value)
        run1 = run1_input.value if run1_input.value else None
        run2 = run2_input.value if run2_input.value else None
        cond1 = cond1_dropdown.value or None
        cond2 = cond2_dropdown.value or None
        subj.plot_bandwise_psd_comparison(
            task1=task1_dropdown.value, task2=task2_dropdown.value,
            run1=run1, run2=run2,
            condition1=cond1, condition2=cond2
        )

compare_btn = widgets.Button(description="Compare PSDs")
compare_btn.on_click(run_comparison)

ui_box = widgets.VBox([
    subject_dropdown,
    widgets.HBox([task1_dropdown, cond1_dropdown, run1_input]),
    widgets.HBox([task2_dropdown, cond2_dropdown, run2_input]),
    compare_btn,
    ui_out
])

display(ui_box)

VBox(children=(Dropdown(description='Subject', options=('NDARAC904DMU',), value='NDARAC904DMU'), HBox(children…