In [None]:
import mne
from pathlib import Path
import sys
import numpy as np
import torch
import time
project_root = Path.cwd().parent
sys.path.append(str(project_root))

import matplotlib.pyplot as plt
import seaborn as sns

plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)

from pygedai import gedai, batch_gedai

In [None]:
raw = mne.io.read_raw_eeglab("./samples/with_artifacts/artifact_jumps.set", preload=True)
raw_noise = mne.io.read_raw_eeglab("./samples/with_artifacts/empirical_NOISE_EOG_EMG.set", preload=True)
raw_bad_ch = mne.io.read_raw_eeglab("./samples/with_artifacts/synthetic_bad_channels.set", preload=True)
raw_bad_hbn = mne.io.read_raw_bdf("./samples/with_artifacts/contrastChangeDetection_run_eeg.bdf", preload=True)

matlab_cleaned_raw = mne.io.read_raw_eeglab("./samples/matlab_cleaned/cleaned_artifact_jumps.set", preload=True)
matlab_cleaned_raw_noise = mne.io.read_raw_eeglab("./samples/matlab_cleaned/cleaned_empirical_NOISE_EOG_EMG.set", preload=True)
matlab_cleaned_raw_bad_ch = mne.io.read_raw_eeglab("./samples/matlab_cleaned/cleaned_synthetic_bad_channels.set", preload=True)

# force average reference in both (keeps behavior consistent)
for r in (raw, matlab_cleaned_raw, matlab_cleaned_raw_noise, matlab_cleaned_raw_bad_ch, raw_noise, raw_bad_ch, raw_bad_hbn):
    r.set_eeg_reference(ref_channels='average', projection=False, verbose=False)  # average ref, in-place

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu" 
device

In [None]:
eeg = torch.from_numpy(raw.get_data(picks="eeg"))
matlab_cleaned = matlab_cleaned_raw.get_data(picks="eeg")

sfreq = raw.info["sfreq"]
denoising_strength = "auto"
epoch_size = 1.0
leadfield = torch.from_numpy(np.load("./leadfield_calibrated/leadfield4GEDAI_eeg_61ch.npy")).to(device)

In [None]:
start_time = time.time()
results = gedai(eeg, sfreq, denoising_strength, epoch_size, leadfield, device=device)
end_time = time.time()
print(results["cleaned"].shape)
end_time - start_time

In [None]:
cleaned = results["cleaned"].cpu().numpy()
cleaned.shape, matlab_cleaned.shape

In [None]:
# minimal align: crop to the common length
n = min(cleaned.shape[1], matlab_cleaned.shape[1])

ok = np.allclose(cleaned, matlab_cleaned, rtol=1e-6, atol=1e-8, equal_nan=True)
print("allclose:", ok, "shapes:", cleaned.shape, matlab_cleaned.shape)
print("max abs diff:", float(np.max(np.abs(cleaned - matlab_cleaned))))
assert cleaned.shape == matlab_cleaned.shape

In [None]:
def plot_eeg(
    signal,
    matlab_clean_signal,
    raw_signal,
    title: str,
    fs = 100,
    L = 10
):
    try:
        y = signal.detach().cpu().float().numpy()
    except:
        y = signal
    if matlab_clean_signal is not None:
        try:
            y_ = matlab_clean_signal.detach().cpu().float().numpy()
        except:
            y_ = matlab_clean_signal
    try:
        x = raw_signal.detach().cpu().float().numpy()
    except:
        x = raw_signal
    L = 5
    t = np.arange(y.shape[1]) / float(fs)

    fig, axes = plt.subplots(L, 1, figsize=(10, 1.8*L*3), sharex=True)

    for i in range(L):
        axes[i].plot(t, x[i], color='gray', label='Raw EEG', linewidth=1)
        axes[i].plot(t, y[i], color='blue', label='Cleaned By Python', linewidth=1)
        if matlab_clean_signal is not None:
            axes[i].plot(t, y_[i], color='red', label='Cleaned By Matlab', linewidth=1)
        axes[i].set_ylabel(f"Channel {i}")
        if i == 0:
            axes[i].legend(loc='upper right')

    axes[-1].set_xlabel("Time (s)")
    if title:
        fig.suptitle(title, y=0.995)
    fig.tight_layout()
    plt.show()

In [None]:
plot_eeg(cleaned[:, :250], matlab_cleaned[:, :250], raw.get_data(picks="eeg")[:, :250], title="EEG")

In [None]:
noise_eeg = torch.from_numpy(raw_noise.get_data(picks="eeg"))
noise_matlab_cleaned = matlab_cleaned_raw_noise.get_data(picks="eeg")

sfreq_noise = raw_noise.info["sfreq"]
leadfield = "./leadfield_calibrated/leadfield4GEDAI_eeg_27ch.npy"
noise_cleaned_bad_channels = gedai(noise_eeg, sfreq_noise, denoising_strength, epoch_size, leadfield, device=device)["cleaned"]
plot_eeg(noise_cleaned_bad_channels[:, :250], noise_matlab_cleaned[:, :250], noise_eeg[:, :250], title="EEG")

In [None]:
bad_ch_eeg = torch.from_numpy(raw_bad_ch.get_data(picks="eeg"))
bad_ch_matlab_cleaned = matlab_cleaned_raw_bad_ch.get_data(picks="eeg")

sfreq_bad_ch = raw_bad_ch.info["sfreq"]
leadfield = "./leadfield_calibrated/leadfield4GEDAI_eeg_67ch.npy"
bad_ch_cleaned_bad_channels = gedai(bad_ch_eeg, sfreq_bad_ch, denoising_strength, epoch_size, leadfield, device=device)["cleaned"]
plot_eeg(bad_ch_cleaned_bad_channels[:, :250], bad_ch_matlab_cleaned[:, :250], bad_ch_eeg[:, :250], title="EEG")

In [None]:
hbn_eeg = torch.from_numpy(raw_bad_hbn.get_data(picks="eeg"))

sfreq_hbn = raw_bad_hbn.info["sfreq"]
leadfield = torch.load("./leadfield_calibrated/leadfield_129ch.pt")
hbn_cleaned = gedai(hbn_eeg, sfreq_hbn, denoising_strength, epoch_size, leadfield, device=device)["cleaned"]
plot_eeg(hbn_cleaned[:, :250], None, raw_bad_hbn.get_data(picks="eeg")[:, :250], title="EEG")

In [None]:
plot_eeg(hbn_cleaned[:, :1250], None, raw_bad_hbn.get_data(picks="eeg")[:, :1250], title="EEG")

In [None]:
start_time = time.time()
raw_bad_hbn_no_ref = mne.io.read_raw_bdf("./samples/with_artifacts/contrastChangeDetection_run_eeg.bdf", preload=True)
tiny_eeg = torch.from_numpy(raw_bad_hbn_no_ref.get_data(picks="eeg")[:, :200])
x_ref = tiny_eeg - torch.median(tiny_eeg, dim=1, keepdim=True)[0]
x_gedai = gedai(x_ref, sfreq_hbn, denoising_strength, epoch_size, leadfield, device=device)
hbn_cleaned_short = x_gedai["cleaned"]
end_time = time.time()
end_time - start_time

plot_eeg(hbn_cleaned_short, None, raw_bad_hbn.get_data(picks="eeg")[:, :200], title="EEG")

In [None]:
sfreq = raw.info["sfreq"]
leadfield = torch.from_numpy(np.load("./leadfield_calibrated/leadfield4GEDAI_eeg_61ch.npy")).to(device)


start_time = time.time()
batch = torch.stack([eeg.detach().clone().to(device)[:, :200] for i in range(128)], dim=0)
print(batch.shape)
out_batch = batch_gedai(batch, sfreq, denoising_strength, epoch_size, leadfield, device=device)
end_time = time.time()
print(end_time - start_time)
assert torch.allclose(out_batch[0], out_batch[1], rtol=1e-10, atol=1e-12)
plot_eeg(out_batch[0], None, out_batch[1], title="EEG")
out_batch.shape