# Import

In [1]:
import os
import mne
from mne import events_from_annotations, create_info, EpochsArray, concatenate_epochs, Epochs
import json
import pandas as pd
from pathlib import Path
import re
import matplotlib.pyplot as plt
import numpy as np
import warnings
from collections import defaultdict

warnings.filterwarnings("ignore", message=".*boundary.*data discontinuities.*")
warnings.filterwarnings("ignore", message="FigureCanvasAgg is non-interactive, and thus cannot be shown")

# Classes

### EEGTaskData

In [2]:
class EEGTaskData:
    def __init__(self, subject, task, run, data_dir):
        self.subject = subject
        self.task = task
        self.run = run
        self.data_dir = data_dir

        self._raw = None
        self._filtered_cache = {}  # key = (l_freq, h_freq)
        self.metadata = {}
        self.events = None
        self.channels = None
        self.electrodes = None

        self._epochs_cache = {}  # key: (l_freq, h_freq) → (epochs, labels)

        self._load()


    def _get_file(self, ext):
        base = f"{self.subject}_task-{self.task}"
        if self.run:
            base += f"_run-{self.run}"
        file_name = f"{base}_{ext}"

        return self.data_dir / f"{self.subject}" / "eeg" / file_name

    def _load(self):
        eeg_path = self._get_file("eeg.set")
        self._raw = mne.io.read_raw_eeglab(eeg_path, preload=True, montage_units='cm')
        montage = mne.channels.make_standard_montage("GSN-HydroCel-128")
        self._raw.drop_channels(['Cz'])
        self._raw.set_montage(montage, match_case=False)
        # self._raw.filter(l_freq=self.l_freq, h_freq=self.h_freq)

        json_path = self._get_file("eeg.json")
        if json_path.exists():
            with open(json_path) as f:
                self.metadata = json.load(f)

        event_path = self._get_file("events.tsv")
        if event_path.exists():
            self.events = pd.read_csv(event_path, sep='\t')

        channels_path = self._get_file("channels.tsv")
        if channels_path.exists():
            self.channels = pd.read_csv(channels_path, sep='\t')

        electrodes_path = self._get_file("electrodes.tsv")
        if electrodes_path.exists():
            self.electrodes = pd.read_csv(electrodes_path, sep='\t')

    def get_filtered_raw(self, l_freq=1, h_freq=50):
        key = (l_freq, h_freq)
        
        # Return cached version if available
        if key in self._filtered_cache:
            return self._filtered_cache[key]

        # Filter and cache
        raw_copy = self._raw.copy().load_data()
        raw_copy.filter(l_freq=l_freq, h_freq=h_freq, fir_design="firwin", skip_by_annotation="edge")
        
        self._filtered_cache[key] = raw_copy
        return raw_copy

    def get_epochs(self, l_freq=1, h_freq=50):
        key = (l_freq, h_freq)
        if key in self._epochs_cache:
            return self._epochs_cache[key]

        if self.task == 'RestingState':
            epochs, labels = self._resting_preprocess(l_freq, h_freq)
        elif self.task == 'surroundSupp':
            epochs, labels = self._Sus_preprocess(l_freq, h_freq)
        else:
            return None, None  # Unsupported task

        if epochs is not None:
            self._epochs_cache[key] = (epochs, labels)

        return epochs, labels

    def _Sus_preprocess(self, tmin=0.0, duration=2.4, l_freq=1, h_freq=50):
        """
        Preprocess surroundSupp task using 'stim_ON' events.
        Epochs are 2.4s long and labeled by background + foreground_contrast + stimulus_cond.
        """
        filtered_raw = self.get_filtered_raw(l_freq=l_freq, h_freq=h_freq)
        df = self.events
        stim_rows = df[df['value'] == 'stim_ON'].copy()

        stim_rows['label'] = stim_rows.apply(
            lambda row: f"bg{int(row['background'])}_fg{row['foreground_contrast']}_stim{int(row['stimulus_cond'])}",
            axis=1
        )

        # Create event_id mapping from existing unique labels
        unique_labels = sorted(stim_rows['label'].unique())
        event_id = {label: idx + 1 for idx, label in enumerate(unique_labels)}
        stim_rows['event_code'] = stim_rows['label'].map(event_id)

        # Build events array
        events_array = np.column_stack([
            stim_rows['sample'].astype(int),
            np.zeros(len(stim_rows), dtype=int),
            stim_rows['event_code'].astype(int)
        ])

        # Epoching
        tmax = tmin + duration
        epochs = Epochs(
            filtered_raw,
            events=events_array,
            event_id=event_id,
            tmin=tmin,
            tmax=tmax,
            baseline=None,
            proj=True,
            preload=True,
            detrend=1
        )
        labels = stim_rows['label'].values
        labels = labels[epochs.selection]
        return epochs, labels

    def _resting_preprocess(self, tmin=0.0, tmax=20.0, l_freq=1, h_freq=50):
        """
        Crop raw based on 'resting_start' to 'break cnt' in events.tsv,
        then epoch using eye condition annotations.
        """
        filtered_raw = self.get_filtered_raw(l_freq=l_freq, h_freq=h_freq)

        # Step 1: Find resting_start and break cnt from TSV
        df = self.events

        t_start = df[df['value'] == 'resting_start']['onset'].values[0]
        t_end = df[df['value'] == 'break cnt']['onset'].values[1]

        # Step 2: Crop raw to this resting window
        filtered_raw.crop(tmin=t_start, tmax=t_end)

        # Step 3: Extract new events from cropped raw's annotations
        events, event_id = events_from_annotations(self._raw)

        eye_event_id = {
            'open': event_id['instructed_toOpenEyes'],
            'close': event_id['instructed_toCloseEyes']
        }

        # Step 4: Create epochs based on eye condition labels
        epochs = Epochs(
            filtered_raw,
            events=events,
            event_id=eye_event_id,
            tmin=tmin,
            tmax=tmax,
            proj=True,
            baseline=None,
            preload=True
        )

        labels = epochs.events[:, -1] - eye_event_id['open']  # 0=open, 1=close
        
        return epochs, labels

    def show_annotations(self):
        return self.metadata if self.metadata else None

    def show_table(self, name='events', rows=10, l_freq=1, h_freq=50):
        df_map = {
            'events': self.events,
            'channels': self.channels,
            'electrodes': self.electrodes
        }

        if name == 'epochs':
            epochs, labels = self.get_epochs(l_freq=l_freq, h_freq=h_freq)
            if epochs is None:
                return None

            info = {
                'n_epochs': len(epochs),
                'n_channels': len(epochs.ch_names),
                'timespan_sec': epochs.times[-1] - epochs.times[0],
                'labels': np.unique(labels) if labels is not None else 'N/A',
                'sampling_rate': epochs.info['sfreq'],
                'duration_per_epoch_sec': epochs.get_data().shape[-1] / epochs.info['sfreq']
            }
            return pd.DataFrame([info])

        df = df_map.get(name)
        return df.head(rows) if df is not None else None

    def get_raw(self):
        return self._raw



### EEGSubjectData

In [3]:
class EEGSubjectData:
    def __init__(self, data_dir):
        self._data_dir = Path(data_dir)
        self._subject_ids = self._discover_subjects()
        self._task_index = self._discover_tasks()
        self._cache = {}  # (subj, task, run) → EEGTaskData

    def _discover_subjects(self):
        return sorted([p.name for p in self._data_dir.glob("sub-*") if p.is_dir()])

    def _discover_tasks(self):
        task_map = defaultdict(list)
        pattern = re.compile(
            r"(sub-(?P<subject>[^_]+))_task-(?P<task>[^_]+)(?:_run-(?P<run>\d+))?_eeg\.set"
        )

        for subj_dir in self._data_dir.glob("sub-*"):
            eeg_dir = subj_dir / "eeg"
            if not eeg_dir.exists():
                continue

            for eeg_file in eeg_dir.glob("sub-*_task-*_eeg.set"):
                match = pattern.match(eeg_file.name)
                if match:
                    full_subj = match.group(1) 
                    task = match.group("task")
                    run = match.group("run")
                    task_map[full_subj].append((task, run))

        return dict(task_map)

    def list_subjects(self):
        return self._subject_ids

    def list_tasks(self, subject):
        return sorted(self._task_index.get(subject, []))

    def get_task(self, subject, task, run=None):
        key = (subject, task, run)
        if key not in self._cache:
            task_data = EEGTaskData(
                subject=subject,
                task=task,
                run=run,
                data_dir=self._data_dir,
            )
            self._cache[key] = task_data
        return self._cache[key]

### EEGVisualization

In [4]:
class EEGVisualization:
    def __init__(self, subject_data: EEGSubjectData):
        self.data = subject_data

    def _validate_and_crop(self, epochs, tmin, tmax):
        start = epochs.tmin
        end = epochs.tmax

        tmin_valid = max(start, tmin) if tmin is not None else start
        tmax_valid = min(end, tmax) if tmax is not None else end

        if tmin_valid >= tmax_valid:
            return None
            
        cropped = epochs.copy().crop(tmin=tmin_valid, tmax=tmax_valid)
        return cropped

    def _finalize_figure(self, fig, subject, task, run=None, stimulus=None, caption: dict = None, plot_name="EEG Plot"):
        if not isinstance(fig, plt.Figure):
            return

        fig.set_size_inches(15, 12)

        subject_line = f"{subject} - {task}" + (f" - {stimulus}" if stimulus else "") + (f" (Run {run})" if run else "")
        
        if caption:
            caption_line = ", ".join(f"{k} = {v:.1f}" if isinstance(v, (float, int)) else f"{k} = {v}" for k, v in caption.items())
        else:
            caption_line = ""

        fig.text(0.5, 0.96, plot_name.title(), ha='center', fontsize=18, weight='bold')
        fig.text(0.5, 0.94, subject_line, ha='center', fontsize=14)
        if caption_line:
            fig.text(0.5, 0.92, caption_line, ha='center', fontsize=11)

        fig.subplots_adjust(top=0.90)
        plt.show()

    def plot_sensors(self, subject, task, run=None, **kwargs):
        l_freq = kwargs.get("l_freq", 1)
        h_freq = kwargs.get("h_freq", 50)

        task_data = self.data.get_task(subject, task, run)
        raw = task_data.get_filtered_raw(l_freq, h_freq)
        raw.plot_sensors(show_names=True)

    def plot_time(self, subject, task, run=None, **kwargs):
        l_freq = kwargs.get("l_freq", 1)
        h_freq = kwargs.get("h_freq", 50)
        duration = kwargs.get('duration', 10.0)
        start = kwargs.get('start', 0.0)
        n_channels = kwargs.get('n_channels', 10)

        task_data = self.data.get_task(subject, task, run)
        raw = task_data.get_filtered_raw(l_freq, h_freq)

        fig = raw.plot(
            duration=duration,
            start=start,
            n_channels=n_channels,
            scalings='auto',
            show=False,
            block=True
        )

        caption_dict = {"start": start, "duration": duration}
        self._finalize_figure(fig, subject, task, run, caption=caption_dict, plot_name="Time Domain")

    def plot_frequency(self, subject, task, run=None, **kwargs):
        l_freq = kwargs.get("l_freq", 1)
        h_freq = kwargs.get("h_freq", 50)
        fmin = kwargs.get("fmin", 1)
        fmax = kwargs.get("fmax", 60)
        average = kwargs.get("average", True)
        dB = kwargs.get("dB", True)
        spatial_colors = kwargs.get("spatial_colors", False)

        task_data = self.data.get_task(subject, task, run)
        raw = task_data.get_filtered_raw(l_freq, h_freq)

        psd = raw.compute_psd(fmin=fmin, fmax=fmax)
        fig = psd.plot(
            average=average,
            spatial_colors=spatial_colors,
            dB=dB,
            show=False
        )

        caption_dict = {"l_freq": l_freq, "h_freq": h_freq,"fmin": fmin, "fmax": fmax}
        self._finalize_figure(fig, subject, task, run, caption=caption_dict, plot_name="Frequency Domain")

    def plot_conditionwise_psd(self, subject, task, run=None, **kwargs):
        fmin = kwargs.get("fmin", 1)
        fmax = kwargs.get("fmax", 50)
        tmin = kwargs.get("tmin", None)
        tmax = kwargs.get("tmax", None)
        average = kwargs.get("average", True)
        dB = kwargs.get("dB", True)
        l_freq = kwargs.get("l_freq", 1)
        h_freq = kwargs.get("h_freq", 50)

        task_data = self.data.get_task(subject, task, run)
        epochs, labels = task_data.get_epochs(l_freq=l_freq, h_freq=h_freq)

        if epochs is None:
            print(f"No epochs available for {subject} - {task}" + (f" (Run {run})" if run else ""))
            return

        for condition_name in epochs.event_id:
            condition_epochs = epochs[condition_name]
            if len(condition_epochs) == 0:
                print(f"Skipping condition '{condition_name}' — no valid epochs.")
                continue
            cropped_epochs = self._validate_and_crop(condition_epochs, tmin, tmax)
            
            if cropped_epochs is None:
                print(f"Skipping {condition_name} — Invalid crop range: tmin={tmin}, tmax={tmax}")
                continue

            psd = cropped_epochs.compute_psd(fmin=fmin, fmax=fmax)
            fig = psd.plot(spatial_colors=True, average=average, dB=dB, show=False)

            caption_dict = {"l_freq": l_freq, "h_freq": h_freq,"tmin": tmin, "tmax": tmax}
            self._finalize_figure(fig, subject, task, run, condition_name, caption=caption_dict, plot_name="Condition-wise PSD")

    def plot_epochs_or_evoked(self, subject, task, run=None, mode='epochs', **kwargs):
        l_freq = kwargs.get("l_freq", 1)
        h_freq = kwargs.get("h_freq", 50)
        stimulus = kwargs.get("stimulus", None)
        n_channels = kwargs.get("n_channels", 20)
        tmin = kwargs.get("tmin", None)
        tmax = kwargs.get("tmax", None)

        task_data = self.data.get_task(subject, task, run)
        epochs, _ = task_data.get_epochs(l_freq=l_freq, h_freq=h_freq)

        if epochs is None:
            print(f"No epochs available for {subject} - {task}" + (f" (Run {run})" if run else ""))
            return

        if stimulus:
            if stimulus not in epochs.event_id:
                print(f"Stimulus '{stimulus}' not found in event_id.")
                return
            epochs = epochs[stimulus]

        cropped_epochs = self._validate_and_crop(epochs, tmin, tmax)
        if cropped_epochs is None:
            print(f"Invalid crop window: tmin={tmin}, tmax={tmax}")
            return

        if mode == 'evoked':
            evoked = cropped_epochs.average()
            fig = evoked.plot(show=False)
        else:
            fig = cropped_epochs.plot(events=False, n_channels=n_channels, show=False)

        caption_dict = {"l_freq": l_freq, "h_freq": h_freq, "tmin": tmin, "tmax": tmax}
        self._finalize_figure(fig, subject, task, run, stimulus, caption=caption_dict, plot_name=mode)


### EEGController

In [5]:
class EEGController:
    def __init__(self, subject_data: 'EEGSubjectData', visualizer: 'EEGVisualization'):
        self.subject_data = subject_data
        self.visualizer = visualizer

    def list_subjects(self):
        return self.subject_data.list_subjects()

    def list_tasks(self, subject):
        return self.subject_data.list_tasks(subject)
    
    def get_event_ids(self, subject, task,l_freq, h_freq, run=None):
        task_data = self.subject_data.get_task(subject, task, run)
        epochs, _ = task_data.get_epochs(l_freq=l_freq, h_freq=h_freq)
        return list(epochs.event_id.keys()) if epochs else []

    def show(self, subject, task, run=None, plot_type='time', **kwargs):
        if plot_type == 'time':
            self.visualizer.plot_time(subject, task, run, **kwargs)
        elif plot_type == 'sensors':
            self.visualizer.plot_sensors(subject, task, run)
        elif plot_type == 'frequency':
            self.visualizer.plot_frequency(subject, task, run)
        elif plot_type == 'conditionwise psd':
            self.visualizer.plot_conditionwise_psd(subject, task, run, **kwargs)
        elif plot_type == 'epochs':
            self.visualizer.plot_epochs_or_evoked(subject, task, run, mode='epochs', **kwargs)
        elif plot_type == 'evoked':
            self.visualizer.plot_epochs_or_evoked(subject, task, run, mode='evoked', **kwargs)

    def show_annotations(self, subject, task, run=None):
        """Return metadata dict or None."""
        task_data = self.subject_data.get_task(subject, task, run)
        return task_data.show_annotations() if task_data else None

    def show_table(self, subject, task, run=None, name='events', rows=10, l_freq=1, h_freq=50):
        """Return DataFrame or None."""
        task_data = self.subject_data.get_task(subject, task, run)
        return task_data.show_table(name=name, rows=rows, l_freq=l_freq, h_freq=h_freq)

    def get_annotation_df(self, subject, task, run=None):
        task_data = self.subject_data.get_task(subject, task, run)

        raw = task_data.get_filtered_raw()
        
        annots = raw.annotations
        df = pd.DataFrame({
            "onset": annots.onset,
            "duration": annots.duration,
            "description": annots.description
        })
        return df

### EEGUI

In [23]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import json

def parse_time_input(text_value):
    text_value = text_value.strip()
    return None if text_value == "" or text_value.lower() == "none" else float(text_value)

class EEGUI:
    def __init__(self, controller: 'EEGController'):
        self.controller = controller
        self._init_widgets()
        self._build_ui()
        self._connect_events()
        self._initialize_state()

    def _init_widgets(self):
        subjects = sorted(self.controller.list_subjects())

        self.mode_toggle = widgets.ToggleButtons(
            options=['Plot', 'Table'], description='Mode:', layout=widgets.Layout(width='300px')
        )
        self.subject_dropdown = widgets.Dropdown(
            options=subjects, description='Subject:', layout=widgets.Layout(width='250px')
        )
        self.task_dropdown = widgets.Dropdown(description='Task:', layout=widgets.Layout(width='250px'))

        self.plot_type = widgets.ToggleButtons(
            options=['time', 'sensors', 'frequency', 'conditionwise psd', 'epochs', 'evoked'],
            description='Plot:',
            layout=widgets.Layout(width='600px')
        )
        self.stimulus_dropdown = widgets.Dropdown(
            options=[],
            description='Stimulus:',
            layout=widgets.Layout(width='250px')
        )

        # Filter and PSD options
        self.lfreq_float = widgets.FloatText(value=3.0, description='l_freq:', layout=widgets.Layout(width='200px'))
        self.hfreq_float = widgets.FloatText(value=35.0, description='h_freq:', layout=widgets.Layout(width='200px'))
        self.average_check = widgets.Checkbox(value=True, description='Average', indent=False)
        self.db_check = widgets.Checkbox(value=True, description='dB', indent=False)

        # Time-domain controls
        self.duration_float = widgets.FloatText(value=10.0, description='duration:', layout=widgets.Layout(width='200px'))
        self.start_float = widgets.FloatText(value=0.0, description='start:', layout=widgets.Layout(width='200px'))
        self.nchan_int = widgets.IntText(value=10, description='n_channels:', layout=widgets.Layout(width='200px'))

        # Epoch cropping and frequency bounds
        self.tmin_text = widgets.Text(value="1.0", description='tmin:', layout=widgets.Layout(width='200px'))
        self.tmax_text = widgets.Text(value="2.4", description='tmax:', layout=widgets.Layout(width='200px'))
        self.fmin_float = widgets.FloatText(value=1.0, description='fmin:', layout=widgets.Layout(width='200px'))
        self.fmax_float = widgets.FloatText(value=50.0, description='fmax:', layout=widgets.Layout(width='200px'))

        self.plot_button = widgets.Button(description='Plot', button_style='success')
        self.info_button = widgets.Button(description='Show Info', button_style='info')

        self.table_type = widgets.Dropdown(
            options=['events', 'channels', 'electrodes', 'epochs'],
            description='Table:',
            layout=widgets.Layout(width='250px')
        )
        self.rows_int = widgets.IntText(
            value=10,
            description='Rows:',
            layout=widgets.Layout(width='200px')
        )

        self.output = widgets.Output()

        # Containers
        self.filter_controls = widgets.HBox([self.lfreq_float, self.hfreq_float])
        self.time_controls = widgets.HBox([self.duration_float, self.start_float, self.nchan_int])
        self.psd_options = widgets.HBox([self.average_check, self.db_check])
        self.t_controls = widgets.HBox([self.tmin_text, self.tmax_text])
        self.f_controls = widgets.HBox([self.fmin_float, self.fmax_float])
        self.param_box = widgets.VBox([])
        self.table_controls = widgets.HBox([self.table_type, self.rows_int, self.info_button])

    def _build_ui(self):
        self.ui = widgets.VBox([
            self.mode_toggle,
            self.subject_dropdown,
            self.task_dropdown,
            self.plot_type,
            self.param_box,
            self.plot_button,
            self.table_controls, 
            self.output
        ])

    def _connect_events(self):
        self.mode_toggle.observe(self.update_mode_ui, names='value')
        self.subject_dropdown.observe(self.update_tasks, names='value')
        self.plot_type.observe(self.update_param_inputs, names='value')
        self.plot_button.on_click(self.do_plot)
        self.info_button.on_click(self.do_show_info)

    def _initialize_state(self):
        if self.subject_dropdown.options:
            self.subject_dropdown.value = self.subject_dropdown.options[0]
            self.update_tasks()
        self.update_param_inputs()
        self.update_mode_ui()

    def update_mode_ui(self, *args):
        is_plot = self.mode_toggle.value == 'Plot'
        self.plot_type.layout.display = 'block' if is_plot else 'none'
        self.param_box.layout.display = 'block' if is_plot else 'none'
        self.plot_button.layout.display = 'inline-block' if is_plot else 'none'
        self.table_type.layout.display = 'block' if not is_plot else 'none'
        self.info_button.layout.display = 'inline-block' if not is_plot else 'none'
        self.rows_int.layout.display = 'block' if not is_plot else 'none'


    def update_tasks(self, *args):
        subject = self.subject_dropdown.value
        task_keys = sorted(self.controller.list_tasks(subject))
        formatted = [(f"{t} (run {r})" if r else t, (t, r)) for t, r in task_keys]
        self.task_dropdown.options = formatted
        if formatted:
            self.task_dropdown.value = formatted[0][1]
            self.update_stimulus_options()

    def update_param_inputs(self, *args):
        plot_mode = self.plot_type.value
        if plot_mode == 'conditionwise psd':
            self.param_box.children = [self.t_controls, self.f_controls, self.filter_controls, self.psd_options]
        elif plot_mode == 'frequency':
            self.param_box.children = [self.f_controls, self.filter_controls, self.psd_options]
        elif plot_mode == 'time':
            self.param_box.children = [self.time_controls, self.filter_controls]
        elif plot_mode == 'epochs':
            self.param_box.children = [self.nchan_int,self.t_controls, self.filter_controls, self.stimulus_dropdown]
            self.update_stimulus_options()
        elif plot_mode == 'evoked':
            self.param_box.children = [self.t_controls, self.filter_controls, self.stimulus_dropdown]
            self.update_stimulus_options()
        else:
            self.param_box.children = []

    def update_stimulus_options(self):
        subject = self.subject_dropdown.value
        task, run = self.task_dropdown.value
        run = run if run else None       
        l_freq = self.lfreq_float.value
        h_freq = self.hfreq_float.value

        event_ids = self.controller.get_event_ids(subject, task, l_freq, h_freq, run)
        self.stimulus_dropdown.options = sorted(event_ids)

    def do_plot(self, _):
        with self.output:
            clear_output(wait=True)
            subject = self.subject_dropdown.value
            task, run = self.task_dropdown.value

            tmin = parse_time_input(self.tmin_text.value)
            tmax = parse_time_input(self.tmax_text.value)

            kwargs = {
                'tmin': tmin,
                'tmax': tmax,
                'fmin': self.fmin_float.value,
                'fmax': self.fmax_float.value,
                'l_freq': self.lfreq_float.value,
                'h_freq': self.hfreq_float.value,
                'duration': self.duration_float.value,
                'start': self.start_float.value,
                'n_channels': self.nchan_int.value,
                'average': self.average_check.value,
                'dB': self.db_check.value,
                'stimulus': self.stimulus_dropdown.value,
            }
            self.controller.show(subject, task, run, plot_type=self.plot_type.value, **kwargs)

    def do_show_info(self, _):
        with self.output:
            clear_output(wait=True)
            subject = self.subject_dropdown.value
            task, run = self.task_dropdown.value
            l_freq = self.lfreq_float.value
            h_freq = self.hfreq_float.value

            metadata = self.controller.show_annotations(subject, task, run)
            print(f"Metadata for {subject} - {task}" + (f" (Run {run})" if run else "") + ":")
            print(json.dumps(metadata, indent=2) if metadata else "No metadata available.")

            table_name = self.table_type.value
            rows = self.rows_int.value
            df = self.controller.show_table(subject, task, run, name=table_name, l_freq=l_freq, h_freq=h_freq, rows=rows)
            print(f"\nTable: {table_name}")
            if df is not None:
                display(df)
            else:
                print("No table data available.")

    def show(self):
        display(self.ui)

# Main

### Load

In [7]:
release = 1
data_dir = f'/mount/NAS-public-dataset/HBN-EEG/cmi_bids_R{release}'
subject_data = EEGSubjectData(data_dir)

### UI

In [None]:
%matplotlib inline
visualizer = EEGVisualization(subject_data)
controller = EEGController(subject_data, visualizer)
ui = EEGUI(controller)
ui.show()

VBox(children=(ToggleButtons(description='Mode:', layout=Layout(width='300px'), options=('Plot', 'Table'), val…

IndexError: index 0 is out of bounds for axis 0 with size 0

<MNEBrowseFigure size 800x800 with 4 Axes>

# Debug

In [9]:
subjects = sorted(controller.list_subjects())
subject = subjects[0]
subject

'sub-NDARAC904DMU'

In [10]:
task_keys = sorted(controller.list_tasks(subject))
for (i, item) in enumerate(task_keys):
    print(i, item)

0 ('DespicableMe', None)
1 ('DiaryOfAWimpyKid', None)
2 ('FunwithFractals', None)
3 ('RestingState', None)
4 ('ThePresent', None)
5 ('contrastChangeDetection', '1')
6 ('contrastChangeDetection', '2')
7 ('contrastChangeDetection', '3')
8 ('seqLearning8target', None)
9 ('surroundSupp', '1')
10 ('surroundSupp', '2')
11 ('symbolSearch', None)


In [11]:
confirmation = input("Type 'yes' to continue: ")

if confirmation.lower() != 'yes':
    raise RuntimeError("Cell must be confirmed manually.")

RuntimeError: Cell must be confirmed manually.

In [None]:
task, run = task_keys[9]
task_data = controller.subject_data.get_task(subject,task,run=run)
task_data.show_annotations()

{'PowerLineFrequency': 60,
 'TaskName': 'surroundSupp',
 'EEGChannelCount': 129,
 'EEGReference': 'Cz',
 'RecordingType': 'continuous',
 'RecordingDuration': 293.42,
 'SamplingFrequency': 500,
 'SoftwareFilters': 'n/a'}

In [None]:
epochs, labels = task_data.get_epochs()
sorted(labels)[:10]

['bg0_fg0.0_stim2',
 'bg0_fg0.0_stim3',
 'bg0_fg0.0_stim3',
 'bg0_fg0.3_stim1',
 'bg0_fg0.3_stim1',
 'bg0_fg0.3_stim1',
 'bg0_fg0.6_stim2',
 'bg0_fg0.6_stim2',
 'bg0_fg0.6_stim3',
 'bg0_fg1.0_stim2']