In [None]:
# Author : Jules GOMEL - 2025

# Practicing SSVEP SNR computation based on data from Ladouce et al 2022 
# "Improving user experience of SSVEP BCI through low amplitude depth and high frequency stimuli design"
# data found here : https://zenodo.org/record/5907009

# SNR computed as explained in Cohen et Gulbinaite, 2017
# "Rhythmic entrainment source separation: Optimizing analyses of neural responses to rhythmic sensory stimulation "


import mne

import numpy as np
import matplotlib.pyplot as plt
import scipy.signal as ss
import pyxdf
import pandas as pd
from meegkit import ress

data_folder = "./data/expe"
data_file = "./data/P1_low.set"
sfreq = 500

raw = mne.io.read_raw_eeglab(data_file, preload=True)
raw.info


In [None]:
def compute_snr(psd, bins, target):
    """
    Compute the Signal-to-Noise Ratio (SNR) at a target frequency based on the method 
    described in Mike X. Cohen's RESS paper.

    The SNR is computed by comparing the power at the target frequency with the average 
    power of neighboring frequencies, excluding a narrow band around the target.

    Args:
        psd (numpy.ndarray): Power spectral density values computed using the Welch method.
        bins (numpy.ndarray): Frequency bins corresponding to the PSD values.
        target (float): The target frequency (in Hz) at which to compute the SNR.

    Raises:
        ValueError: If the target frequency is not found in the bins.
        ValueError: If valid neighboring frequencies cannot be found.

    Returns:
        float: The computed SNR value for the target frequency.
    """
    
    
    df = np.mean(np.diff(bins))  # Frequency step
    
    # Index of the target frequency
    target_idx = np.where(np.isclose(bins, target, atol=df / 2))[0]
    
    if target_idx.size == 0:
        raise ValueError("Target frequency not found in bins")
    
    target_idx = target_idx[0]
    target_power = psd[target_idx]

    # Indices for the neighboring frequency range (±2 Hz)
    lower_bound = target - 1
    upper_bound = target + 1

    # Indices to exclude (±0.5 Hz around target)
    exclude_lower = target - 0.5
    exclude_upper = target + 0.5

    # Get indices of valid neighboring frequencies
    neighbor_idxs = np.where(
        (bins >= lower_bound) & (bins <= upper_bound) & 
        ~((bins >= exclude_lower) & (bins <= exclude_upper))
    )[0]

    if len(neighbor_idxs) == 0:
        raise ValueError("No valid neighboring frequencies found")

    # Compute mean power of neighbors
    neighbor_power = psd[neighbor_idxs].mean()

    # Compute SNR
    snr = target_power / neighbor_power
    return snr

In [None]:
# events extraction from annotation
events, event_id = mne.events_from_annotations(raw)

In [None]:
# basic preproc
raw.notch_filter(np.arange(50, 201, 50))

In [None]:
# Computing target freqs from events
peaks_freqs = [float(freq) for freq in list(event_id.keys())]
peaks_freqs

In [None]:
res = []
# Loop on freq and event id class
for id,target in enumerate(peaks_freqs):
    id +=1
    epochs = mne.Epochs(raw, events, event_id={'':id}, 
                        tmin=0, tmax=2.2, baseline=None,
                        reject=None, preload=True)

    data = epochs.get_data().T
    
    # Compute RESS filter and fit_transform data
    r = ress.RESS(sfreq=sfreq, peak_freq=target, compute_unmixing=True)
    out = r.fit_transform(data)

    # Compute PSD
    nfft = 1000
    df = sfreq / nfft  # frequency resolution
    bins, psd = ss.welch(np.squeeze(out), sfreq, window="hamming", nperseg=nfft,
                        noverlap=500, axis=0)
    psd = psd.mean(axis=1, keepdims=True)  # average over trials

    snr = compute_snr(psd,bins,target)
    res.append(snr)