In [None]:
import mne
import torch
import numpy as np
import matplotlib.pyplot as plt
import time
from tqdm import tqdm

import sys
from pathlib import Path
from collections import deque
from queue import Queue, Empty

from IPython.display import display, clear_output # for Jupyter updates

project_root = Path.cwd().parent
sys.path.append(str(project_root))

from pygedai.GEDAI_stream import gedai_stream, GEDAIStream

import torch
from typing import Optional

In [None]:
def _to_numpy(array_like):
    if array_like is None:
        return None
    if isinstance(array_like, torch.Tensor):
        return array_like.detach().cpu().float().numpy()
    return np.asarray(array_like, dtype=float)

In [None]:
class RollingEEGPlot:
    def __init__(self, fs: float, window_sec: float, n_channels: int, show_matlab: bool = True):
        self.fs = float(fs)
        self.window_sec = float(window_sec)
        self.n_channels = max(int(n_channels), 1)
        self.show_matlab = show_matlab

        # Create figure and axes
        self.fig, axes = plt.subplots(self.n_channels, 1, figsize=(10, 2.2 * self.n_channels), sharex=True)
        if self.n_channels == 1:
            axes = [axes]
        self.axes = axes
        self.lines_raw = []
        self.lines_clean = []
        self.lines_matlab = []

        # vertical-line artists for each axis
        self.vline_artists = [[] for _ in range(self.n_channels)]

        for idx, ax in enumerate(self.axes):
            raw_line, = ax.plot([], [], color="gray", label="Raw EEG", linewidth=1.5)
            clean_line, = ax.plot([], [], color="blue", label="Cleaned (Python)", linewidth=1.0, linestyle="--")
            if self.show_matlab:
                matlab_line, = ax.plot([], [], color="red", label="Cleaned (MATLAB)", linewidth=0.5, linestyle="-")
            else:
                matlab_line = None
            ax.set_ylabel(f"Ch {idx}")
            ax.set_xlim(0.0, max(self.window_sec, 1.0))
            ax.set_ylim(-1.0, 1.0)
            if idx == 0:
                ax.legend(loc="upper right")
            self.lines_raw.append(raw_line)
            self.lines_clean.append(clean_line)
            self.lines_matlab.append(matlab_line)

        self.axes[-1].set_xlabel("Time (s)")
        self.fig.tight_layout()

    def update(self,
               raw_window: torch.Tensor,
               cleaned_window: torch.Tensor,
               matlab_window: Optional[torch.Tensor],
               samples_processed: int,
               recalib_times_sec: Optional[list] = None):

        raw_np = _to_numpy(raw_window)[: self.n_channels]
        clean_np = _to_numpy(cleaned_window)[: self.n_channels]
        matlab_np = _to_numpy(matlab_window)[: self.n_channels] if (self.show_matlab and matlab_window is not None) else None

        num_samples = raw_np.shape[1]
        if num_samples == 0:
            return

        start_sample = max(int(samples_processed) - num_samples, 0)
        time_axis = (np.arange(num_samples) + start_sample) / self.fs
        right_edge = time_axis[-1]
        left_edge = max(right_edge - self.window_sec, 0.0)

        for idx, ax in enumerate(self.axes):
            self.lines_raw[idx].set_data(time_axis, raw_np[idx])
            self.lines_clean[idx].set_data(time_axis, clean_np[idx])

            y_values = [raw_np[idx], clean_np[idx]]

            if self.show_matlab:
                if matlab_np is not None:
                    self.lines_matlab[idx].set_data(time_axis, matlab_np[idx])
                    y_values.append(matlab_np[idx])
                else:
                    self.lines_matlab[idx].set_data([], [])
            elif self.lines_matlab[idx] is not None:
                self.lines_matlab[idx].set_data([], [])

            y_concat = np.concatenate(y_values)
            y_min = float(np.min(y_concat))
            y_max = float(np.max(y_concat))
            pad = max((y_max - y_min) * 0.1, 1e-6)

            ax.set_xlim(left_edge, max(right_edge, left_edge + 1.0 / self.fs))
            ax.set_ylim(y_min - pad, y_max + pad)

            # vertical green lines for threshold recalibrations
            for artist in self.vline_artists[idx]:
                artist.remove()
            self.vline_artists[idx] = []

            if recalib_times_sec is not None:
                visible_times = [t for t in recalib_times_sec if left_edge <= t <= right_edge]
                for t in visible_times:
                    v = ax.axvline(t, color="green", linestyle="--", linewidth=1.0)
                    self.vline_artists[idx].append(v)

        self.fig.tight_layout()

        clear_output(wait=True)
        display(self.fig)
        plt.pause(0.001)

In [None]:
# Data loading and preprocessing
raw_noise = mne.io.read_raw_eeglab("./samples/with_artifacts/empirical_NOISE_EOG_EMG.set", preload=True)
matlab_cleaned_raw_noise = mne.io.read_raw_eeglab("./samples/matlab_cleaned/cleaned_empirical_NOISE_EOG_EMG.set", preload=True)

raw_noise.set_eeg_reference(ref_channels='average', projection=False, verbose=False)
matlab_cleaned_raw_noise.set_eeg_reference(ref_channels='average', projection=False, verbose=False)

device = "cpu"
noise_eeg = torch.from_numpy(raw_noise.get_data(picks="eeg")).float()
noise_matlab_cleaned = torch.from_numpy(matlab_cleaned_raw_noise.get_data(picks="eeg")).float()

leadfield_cov = torch.from_numpy(np.load("./leadfield_calibrated/leadfield4GEDAI_eeg_27ch.npy")).float()
fs = float(raw_noise.info["sfreq"])

In [None]:
chunk_duration_sec = 0.03
samples_per_chunk = max(int(round(fs * chunk_duration_sec)), 1)

total_samples = noise_eeg.shape[1]
num_chunks = total_samples // samples_per_chunk
if num_chunks == 0:
    raise ValueError("Not enough samples to form even a single 0.1 s chunk.")

usable_samples = num_chunks * samples_per_chunk

eeg_for_stream = noise_eeg[:, :usable_samples].contiguous()
matlab_for_stream = noise_matlab_cleaned[:, :usable_samples].contiguous()

samples_per_chunk

In [None]:
# Shape: (num_chunks, n_channels, samples_per_chunk)
eeg_stream = eeg_for_stream.view(noise_eeg.shape[0], num_chunks, samples_per_chunk).permute(1, 0, 2).contiguous()
matlab_stream = matlab_for_stream.view(noise_eeg.shape[0], num_chunks, samples_per_chunk).permute(1, 0, 2).contiguous()

# Blocking calls with next()

In [None]:
plt.close("all")
state = None
window_sec = 4.0
plot_channels = min(5, noise_eeg.shape[0])

# how many 0.1 s chunks to actually stream
max_chunks = eeg_stream.shape[0]

initial_threshold_delay = 10.0 # in seconds
threshold_update_interval = 10.0 # in seconds

plotter = RollingEEGPlot(fs=fs, window_sec=window_sec, n_channels=plot_channels, show_matlab=True)

# precompute threshold recalibration times (in seconds)
total_stream_time_sec = max_chunks * chunk_duration_sec
recalibration_times_sec = []
t = initial_threshold_delay
while t <= total_stream_time_sec:
    recalibration_times_sec.append(t)
    t += threshold_update_interval

window_capacity = max(int(round(window_sec / chunk_duration_sec)), 1)
raw_window_chunks = deque(maxlen=window_capacity)
cleaned_window_chunks = deque(maxlen=window_capacity)
matlab_window_chunks = deque(maxlen=window_capacity)

raw_history = []
cleaned_history = []
matlab_history = []

eeg_cleaning_stream = gedai_stream(
    sfreq=fs,
    leadfield=leadfield_cov,
    threshold_update_interval_sec=threshold_update_interval,
    initial_threshold_delay_sec=initial_threshold_delay,
    denoising_strength="auto",
    epoch_size_in_cycles=12.0,
    lowcut_frequency=0.5,
    max_concurrent_chunks=1
)

# Streaming loop: 0.1 second at a time
with eeg_cleaning_stream:
    for idx, (chunk, matlab_chunk) in enumerate(zip(eeg_stream[:max_chunks], matlab_stream[:max_chunks])):
        # simulate real-time
        time.sleep(chunk_duration_sec)
    
        cleaned = eeg_cleaning_stream.next(chunk)
    
        chunk_cpu = chunk.cpu()
        cleaned_cpu = cleaned.cpu()
        matlab_cpu = matlab_chunk.cpu()
    
        raw_history.append(chunk_cpu)
        cleaned_history.append(cleaned_cpu)
        matlab_history.append(matlab_cpu)
    
        raw_window_chunks.append(chunk_cpu)
        cleaned_window_chunks.append(cleaned_cpu)
        matlab_window_chunks.append(matlab_cpu)
    
        raw_window = torch.cat(list(raw_window_chunks), dim=1)
        cleaned_window = torch.cat(list(cleaned_window_chunks), dim=1)
        matlab_window = torch.cat(list(matlab_window_chunks), dim=1)
    
        samples_processed = (idx + 1) * samples_per_chunk
        plotter.update(
            raw_window,
            cleaned_window,
            matlab_window,
            samples_processed,
            recalib_times_sec=recalibration_times_sec
        )
    
        if (idx + 1) % 50 == 0 or (idx + 1) == max_chunks:
            print(f"Processed {idx + 1} / {max_chunks} chunks", end="\r")
    
    print("\nStreaming complete.")

cleaned_full = torch.cat(cleaned_history, dim=1)
raw_full = torch.cat(raw_history, dim=1)
matlab_full = torch.cat(matlab_history, dim=1)

# Non blocking threaded calls

In [None]:
plt.close("all")
state = None
window_sec = 4.0
plot_channels = min(5, noise_eeg.shape[0])

# how many 0.1 s chunks to actually stream
max_chunks = eeg_stream.shape[0]
chunk_duration_sec = samples_per_chunk / fs  # make sure this is defined once

processing_window_sec_target = 2.0  # pass the same value to gedai_stream below
chunks_per_block = max(int(round(processing_window_sec_target / chunk_duration_sec)), 1)
samples_per_block = samples_per_chunk * chunks_per_block
processing_window_sec = samples_per_block / fs
if abs(processing_window_sec - processing_window_sec_target) > 1e-9:
    print(
        f"Adjusted processing window to {processing_window_sec:.6f}s so it aligns with the chunk size "
        f"(target was {processing_window_sec_target:.6f}s)"
    )

initial_threshold_delay = 20.0  # in seconds
threshold_update_interval = 20.0  # in seconds

plotter = RollingEEGPlot(fs=fs, window_sec=window_sec, n_channels=plot_channels, show_matlab=True)

# precompute threshold recalibration times (in seconds)
total_stream_time_sec = max_chunks * chunk_duration_sec
recalibration_times_sec = []
t = initial_threshold_delay
while t <= total_stream_time_sec:
    recalibration_times_sec.append(t)
    t += threshold_update_interval

window_capacity = max(int(round(window_sec / chunk_duration_sec)), 1)
raw_window_chunks = deque(maxlen=window_capacity)
cleaned_window_chunks = deque(maxlen=window_capacity)
matlab_window_chunks = deque(maxlen=window_capacity)

raw_history = []
cleaned_history = []
matlab_history = []

matlab_cache: dict[int, torch.Tensor] = {}
result_queue: Queue[tuple[int, torch.Tensor, torch.Tensor]] = Queue()

def handle_cleaned_chunk(cleaned_chunk: torch.Tensor, chunk_index: int, raw_chunk: torch.Tensor) -> None:
    result_queue.put(
        (
            chunk_index,
            cleaned_chunk.detach().cpu(),
            raw_chunk.detach().cpu(),
        )
    )

def process_ready_results(block: bool = False) -> None:
    while True:
        try:
            idx, cleaned_cpu, raw_cpu = result_queue.get(block=block, timeout=0.05 if block else 0.0)
        except Empty:
            break

        if cleaned_cpu.size(1) != samples_per_block or raw_cpu.size(1) != samples_per_block:
            raise ValueError(
                "Processing window returned unexpected sample counts; ensure processing_window_sec aligns with chunk size"
            )

        block_start = idx * chunks_per_block
        # gather the matching MATLAB slices for this block
        matlab_parts = []
        for offset in range(chunks_per_block):
            matlab_parts.append(matlab_cache.pop(block_start + offset).cpu())
        matlab_cpu_full = torch.cat(matlab_parts, dim=1)

        # split the aggregated tensors back into original 0.1 s segments
        cleaned_splits = cleaned_cpu.split(samples_per_chunk, dim=1)
        raw_splits = raw_cpu.split(samples_per_chunk, dim=1)
        matlab_splits = matlab_cpu_full.split(samples_per_chunk, dim=1)

        for i, (clean_seg, raw_seg, matlab_seg) in enumerate(zip(cleaned_splits, raw_splits, matlab_splits)):
            sample_idx = block_start + i

            raw_history.append(raw_seg)
            cleaned_history.append(clean_seg)
            matlab_history.append(matlab_seg)

            raw_window_chunks.append(raw_seg)
            cleaned_window_chunks.append(clean_seg)
            matlab_window_chunks.append(matlab_seg)

            raw_window = torch.cat(list(raw_window_chunks), dim=1)
            cleaned_window = torch.cat(list(cleaned_window_chunks), dim=1)
            matlab_window = torch.cat(list(matlab_window_chunks), dim=1)

            samples_processed = (sample_idx + 1) * samples_per_chunk
            plotter.update(
                raw_window,
                cleaned_window,
                matlab_window,
                samples_processed,
                recalib_times_sec=recalibration_times_sec,
            )

            if (sample_idx + 1) % 50 == 0 or (sample_idx + 1) == max_chunks:
                print(f"Processed {sample_idx + 1} / {max_chunks} chunks", end="\r")
                if (sample_idx + 1) == max_chunks:
                    print("\nStreaming complete.")

        result_queue.task_done()

eeg_cleaning_stream = gedai_stream(
    sfreq=fs,
    leadfield=leadfield_cov,
    threshold_update_interval_sec=threshold_update_interval,
    initial_threshold_delay_sec=initial_threshold_delay,
    denoising_strength="auto",
    epoch_size_in_cycles=12.0,
    lowcut_frequency=0.5,
    max_concurrent_chunks=-1,
    num_workers=4,
    processing_window_sec=processing_window_sec,
    moving_window_chunk_sec=10.0
)

# Streaming loop: 0.1 second at a time
with eeg_cleaning_stream:
    for idx, (chunk, matlab_chunk) in enumerate(zip(eeg_stream[:max_chunks], matlab_stream[:max_chunks])):
        matlab_cache[idx] = matlab_chunk.detach().clone()
        eeg_cleaning_stream.next(chunk, callback=handle_cleaned_chunk)
        process_ready_results(block=False)

# Finish whatever is still in flight
while len(cleaned_history) < max_chunks:
    process_ready_results(block=True)

cleaned_full_concurrent = torch.cat(cleaned_history, dim=1)
raw_full_concurrent = torch.cat(raw_history, dim=1)
matlab_full_concurrent = torch.cat(matlab_history, dim=1)

In [None]:
plt.close("all")
state = None
window_sec = 4.0
plot_channels = min(5, noise_eeg.shape[0])

max_chunks = eeg_stream.shape[0]
initial_threshold_delay = 20.0
threshold_update_interval = 20.0

total_stream_time_sec = max_chunks * chunk_duration_sec
recalibration_times_sec = []
t = initial_threshold_delay
while t <= total_stream_time_sec:
    recalibration_times_sec.append(t)
    t += threshold_update_interval

window_capacity = max(int(round(window_sec / chunk_duration_sec)), 1)
raw_window_chunks = deque(maxlen=window_capacity)
cleaned_window_chunks = deque(maxlen=window_capacity)
matlab_window_chunks = deque(maxlen=window_capacity)

cleaning_times_ms = []

eeg_cleaning_stream = gedai_stream(
    sfreq=fs,
    leadfield=leadfield_cov,
    threshold_update_interval_sec=threshold_update_interval,
    initial_threshold_delay_sec=initial_threshold_delay,
    denoising_strength="auto",
    epoch_size_in_cycles=12.0,
    lowcut_frequency=0.5,
    max_concurrent_chunks=1,
    moving_window_chunk_sec=1.0,
    verbose_timing=False
)

with eeg_cleaning_stream:
    for idx, (chunk, matlab_chunk) in tqdm(enumerate(
        zip(eeg_stream[:max_chunks], matlab_stream[:max_chunks])
    )):
        start_t = time.time()
        cleaned = eeg_cleaning_stream.next(chunk)
        end_t = time.time()
        
        state = eeg_cleaning_stream.state
        if (
            state["initial_threshold_computed"]
            and state["samples_seen"] > state["last_threshold_update_sample"]
        ):
            cleaning_times_ms.append((end_t - start_t) * 1000.0)
            
        if (idx + 1) % 50 == 0 or (idx + 1) == max_chunks:
            print(f"Processed {idx + 1} / {max_chunks} chunks", end="\r")
        

print("\nStreaming complete.")

if cleaning_times_ms:
    mean_ms = float(np.mean(cleaning_times_ms))
    median_ms = float(np.median(cleaning_times_ms))
    fig, ax = plt.subplots()
    ax.bar(["mean", "median"], [mean_ms, median_ms])
    ax.set_ylabel("Cleaning time (ms)")
    ax.set_title("GEDAI chunk cleaning latency")
    plt.show()