In [1]:
import matplotlib.pyplot as plt
import mne
import mne_icalabel
import numpy as np
import torch
from braindecode.models import EEGInceptionERP, EEGNet
from mne.preprocessing import ICA
from scipy.fft import rfft, rfftfreq
from sklearn.cluster import AgglomerativeClustering
from sklearn.decomposition import PCA
from mne.filter import filter_data
from scipy.signal import hilbert
from sklearn.preprocessing import StandardScaler
from scipy.spatial.distance import cdist

Get filename

In [2]:
filename = "user_05_session_1"

Open BDF file

In [3]:
raw = mne.io.read_raw_bdf(f"/Users/hrakol/Desktop/Thesis EEG/{filename}.bdf", preload = True)

Extracting EDF parameters from /Users/hrakol/Desktop/Thesis EEG/user_05_session_1.bdf...
BDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 483327  =      0.000 ...   236.000 secs...


Get Events through trigger(Status) channel

In [4]:
events = mne.find_events(raw, stim_channel = "Status", shortest_event = 1)

Finding events on: Status
Trigger channel Status has a non-zero initial value of 65536 (consider using initial_event=True to detect this event)
Removing orphaned offset at the beginning of the file.
15 events found on stim channel Status
Event IDs: [2]


Remove unused BioSemi channels

In [5]:
keep_ch = raw.ch_names[:32] + ["Status"]
raw = raw.pick(keep_ch)

Deleta data recorded before the experiment starts

In [6]:
first_stim = events[events[:, 2] == 2][0, 0] #find the first code 2 event in frames
last_stim = events[events[:, 2] == 2][-1, 0]
first_stim_time = raw.times[first_stim]
last_stim_time = raw.times[last_stim]
start_time = max(2.0, first_stim_time - 2.0)
end_time = min(raw.times[-1], last_stim_time  + 1.5  # stim duration
                                              + 3.0  # extra time until resting state screen closes
                                              + 0.3) # extra time to be sure    

raw.crop(tmin = start_time, tmax = end_time)

Unnamed: 0,General,General.1
,Filename(s),user_05_session_1.bdf
,MNE object type,RawEDF
,Measurement date,2025-09-15 at 14:02:44 UTC
,Participant,
,Experimenter,Unknown
,Acquisition,Acquisition
,Duration,00:02:54 (HH:MM:SS)
,Sampling frequency,2048.00 Hz
,Time points,354330
,Channels,Channels


Recalculate events

In [7]:
events = mne.find_events(raw, stim_channel = "Status", shortest_event = 1)

Finding events on: Status
15 events found on stim channel Status
Event IDs: [2]


Label Acoustic Stimuli

Create epochs

In [8]:
#Reject = None, to label all the epochs 
epochs = mne.Epochs(raw, events = events, tmin = -0.3, tmax = 2.5, baseline = (None, 0), preload = True, reject = None)
sfreq = raw.info["sfreq"]

Not setting metadata
15 matching events found
Setting baseline interval to [-0.2998046875, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 15 events and 5735 original time points ...
0 bad epochs dropped


Check Magnitute for the frequencies of the Acoustic Stimuli

In [9]:
def check_freq_magnitude(target_freq, epochs, sfreq, number_of_epochs = 5):    
    power = []

    for ep_idx, epoch_data in enumerate(epochs.get_data()): 
        # enumerate adds index 
        # epoch.get_data() : (n_channels, n_times)
        fft_vals = np.fft.rfft(epoch_data, axis = 1)  # FFT along time axis for each channel
    
        freqs = np.fft.rfftfreq(epoch_data.shape[1], d = 1/sfreq)
        # Computes frequency values for the rfft bins above
        
        magnitude = np.abs(fft_vals)
        # |x|

        # Find closest bin to 500 Hz
        bin_idx = np.argmin(np.abs(freqs - target_freq))

        # Average magnitude at 500 Hz across channels
        mean_power = magnitude[:, bin_idx].mean()
        power.append(mean_power)

    power = np.array(power)

    #Select top 5 epochs with biggest magnitude 
    top_5 = np.argsort(power)[-number_of_epochs:][::-1]
    
    return top_5

Hierarchical Clustering Algorithm (Central-Region-Based)

In [10]:
# create the central region 
def central_region_area(curves):
    lower = np.percentile(curves, 25, axis = 0)
    upper = np.percentile(curves, 75, axis = 0)
    return np.sum(upper - lower)  #subtract from the bins under 75% the bins  
                                  #the bins under 25% to find the central region
# compute the distance of the central regions
def cr_distance(X, Y):
    merged = np.vstack([X, Y])
    return central_region_area(merged)

In [11]:
# compute fft curves 
def CR_Hier_clustering(epochs, n_clusters = 3):
    fft_curves = []
    freqs = None

    #Loop through each epoch to get amplitude of fft and the frequency bins
    for epoch_data in epochs.get_data():  # (n_channels, n_times)
        fft_vals = np.fft.rfft(epoch_data, axis = 1)
        magnitude = np.abs(fft_vals)

        # Store frequency bins (same for all epochs)
        if freqs is None:
            freqs = np.fft.rfftfreq(epoch_data.shape[1], d = 1/sfreq)

        # Average across channels to get a single curve
        mean_curve = magnitude.mean(axis = 0)  # (n_freqs,)
        fft_curves.append(mean_curve)

    #All epochs stored together
    fft_curves = np.array(fft_curves)  # shape: (n_epochs, n_freqs)

    #Distance matrix for central area 
    n = len(fft_curves)
    dist_matrix = np.zeros((n, n))
    for i in range(n):
        for j in range(i+1, n):
            dist = cr_distance(fft_curves[i:i+1], fft_curves[j:j+1])
            dist_matrix[i, j] = dist_matrix[j, i] = dist

    clustering = AgglomerativeClustering(
        n_clusters = n_clusters,
        metric = 'precomputed',
        linkage = 'average'
    )
    labels = clustering.fit_predict(dist_matrix)
    
    return labels, fft_curves, freqs

Find which epoch is in each cluster

In [12]:
#Get cluster labels, data after fft and the frequency bins
labels, fft_curves , freqs = CR_Hier_clustering(epochs, 3)

In [13]:
#Dictionary with every epoch in each respective cluster
epochs_in_clusters = {}
for cluster_id in np.unique(labels):
    epochs_in_clusters[cluster_id] = np.where(labels == cluster_id)[0].tolist() # indices of epochs

for cluster_id, epoch_indices in epochs_in_clusters.items():
    print(f"Cluster {cluster_id}: epochs {epoch_indices}")

Cluster 0: epochs [0, 5, 6, 9, 10]
Cluster 1: epochs [1, 3, 4, 7, 8, 11, 12, 13, 14]
Cluster 2: epochs [2]


Find which frequency each cluster corresponds to

In [14]:
# Store powers per cluster in dictonary
mean_power = {cluster_id: {} for cluster_id in np.unique(labels)}

# Get FFT for each cluster from ff_curves
for cluster_id in np.unique(labels):
    idx = np.where(labels == cluster_id)[0]
    cluster_data = fft_curves[idx]  # (n_epochs_in_cluster, n_freqs)

    print(f"Cluster {cluster_id}: epochs {idx.tolist()}")

    for target_freq in [500, 48, 96]:  
        # Find closest bin to target frequencies
        bin_idx = np.argmin(np.abs(freqs - target_freq))
        power_val = cluster_data[:, bin_idx].mean()
        #Mean power of each cluster at each frequency
        mean_power[cluster_id][target_freq] = power_val
        print(f"  Power {target_freq} Hz : {power_val:.6f}")

Cluster 0: epochs [0, 5, 6, 9, 10]
  Power 500 Hz : 0.043568
  Power 48 Hz : 0.483100
  Power 96 Hz : 0.174871
Cluster 1: epochs [1, 3, 4, 7, 8, 11, 12, 13, 14]
  Power 500 Hz : 0.021469
  Power 48 Hz : 0.530994
  Power 96 Hz : 0.118115
Cluster 2: epochs [2]
  Power 500 Hz : 0.021452
  Power 48 Hz : 0.531110
  Power 96 Hz : 0.118078


Assign based on power the 500Hz label and check if any other cluster has almost the same power 

In [15]:
top5_500 = check_freq_magnitude(500, epochs, sfreq)
maximum = mean_power[0][500]
cl_500 = 0
for cluster_id in np.unique(labels):
    if mean_power[cluster_id][500] > maximum:
        maximum = mean_power[cluster_id][500]
        cl_500 = cluster_id
for cluster_id in np.unique(labels):
    if maximum / mean_power[cluster_id][500] < 1.2 and cl_500 != cluster_id: 
        # chose top5_500 as one cluster and recluster the rest of the data
        
        cl_500 = 2

        #all epochs
        all_indices = np.arange(len(epochs))
        # Keep the rest
        keep_indices = np.setdiff1d(all_indices, top5_500)
        epochs_to_cluster = epochs[keep_indices]

        labels, fft_curves , freqs = CR_Hier_clustering(epochs_to_cluster, 2)

        #Dictionary with every epoch in each respective cluster
        epochs_in_clusters = {}
        for cluster_id in np.unique(labels):
            epochs_in_clusters[cluster_id] = np.where(labels == cluster_id)[0].tolist()  # indices of epochs

        for cluster_id, epoch_indices in epochs_in_clusters.items():
            print(f"Cluster {cluster_id}: epochs {epoch_indices}")

        epochs_in_clusters[2] = top5_500
        

Transform all dictionary inputs to ints

In [16]:
epochs_in_clusters = {y: [int(x) for x in v] for y, v in epochs_in_clusters.items()}

Find the frequencies for the rest of the clusters

In [17]:
#Check to find which other cluster contain which frequency stimulis
remaining_cluster = list(set(np.unique(labels)) - {cl_500}) 
cl_other = remaining_cluster[0] #random cluster from the remaining

maximum = mean_power[cl_other][96]
cl_4000 = cl_other
for cluster_id in set(np.unique(labels)) - {cl_500}:
    if mean_power[cluster_id][96] > maximum:
        maximum = mean_power[cluster_id][96]
        cl_4000 = cluster_id

maximum = mean_power[cl_other][48]
cl_2000 = cl_other
for cluster_id in set(np.unique(labels)) - {cl_500}:
    if mean_power[cluster_id][48] > maximum:
        maximum = mean_power[cluster_id][48]
        cl_2000 = cluster_id


#Check which has more power between one another in case they are assigned the same cluster name 
if cl_2000 == cl_4000:
    remaining_cluster = list(set(np.unique(labels)) - {cl_500, cl_2000}) 
    #need cl_other because I can't pass a list to a dict (mean_power[remaining_cluster])
    cl_other = remaining_cluster[0] 

    #The one that drops more if it changes stays the same
    if mean_power[cl_2000][48]/ mean_power[cl_other][48] > mean_power[cl_4000][96]/mean_power[cl_other][96]:
        cl_4000 = list(set(np.unique(labels)) - {cl_500, cl_2000})[0]

    else:
        cl_2000 = list(set(np.unique(labels)) - {cl_500, cl_2000})[0]

Find excess epochs

In [18]:
def excess_epochs(freq, cl_number, epochs_in_clusters, epochs, sfreq, check = "under"):
    #Gets the fist or last from the list until 5 remain
    if check == "under":
        #Get the correct order of the epochs(in terms of power) 
        common = [x for x in check_freq_magnitude(freq, epochs, sfreq, 15) if x in epochs_in_clusters[cl_number]]
        #Select from the bottom so that when I remove them from the list the sum will be 5
        removed_list = common[-(len(epochs_in_clusters[cl_number]) - 5):] 
        removed_list = [int(x) for x in removed_list]
        #Remove them from the list
        epochs_in_clusters[cl_number] = list(set(epochs_in_clusters[cl_number]) - set(removed_list))
    else: 
        #Get the correct order of the epochs(in terms of power)
        common = [x for x in check_freq_magnitude(freq, epochs, sfreq, 15) if x not in epochs_in_clusters[cl_number]]
        #Select from the top so that when I add them to the list the sum will be 5
        removed_list = common[:(5 - len(epochs_in_clusters[cl_number]))]
        removed_list = [int(x) for x in removed_list]
        #Add them to the list
        epochs_in_clusters[cl_number] = list(epochs_in_clusters[cl_number]) + list(removed_list)

    return removed_list

Make each cluster have 5 epochs

In [19]:
#Check cl_500 to get which has the epochs with the epochs with the most 500Hz power
removed_list = []
if len(epochs_in_clusters[cl_500]) > 5:

    removed_list = excess_epochs(500, cl_500, epochs_in_clusters, epochs, sfreq)
    #Check the contents of the rest of the lists
    if len(epochs_in_clusters[cl_2000]) > len(epochs_in_clusters[cl_4000]):
        epochs_in_clusters[cl_4000] = list(epochs_in_clusters[cl_4000]) + list(removed_list) 
        if len(epochs_in_clusters[cl_4000]) > 5:
            removed_list = excess_epochs(96, cl_4000, epochs_in_clusters, epochs, sfreq)
            epochs_in_clusters[cl_2000] = list(epochs_in_clusters[cl_2000]) + list(removed_list)
        elif len(epochs_in_clusters[cl_4000]) < 5:
            removed_list = excess_epochs(48, cl_2000, epochs_in_clusters, epochs, sfreq)
            epochs_in_clusters[cl_4000] = list(epochs_in_clusters[cl_4000]) + list(removed_list)
    #The opposite if len(cl_2000)<len(cl_4000)
    else:
        epochs_in_clusters[cl_2000] = list(epochs_in_clusters[cl_2000]) + list(removed_list)  
        if len(epochs_in_clusters[cl_2000]) > 5:
            removed_list = excess_epochs(48, cl_2000, epochs_in_clusters, epochs, sfreq)
            epochs_in_clusters[cl_4000] = list(epochs_in_clusters[cl_4000]) + list(removed_list)
        elif len(epochs_in_clusters[cl_2000]) < 5:
            removed_list = excess_epochs(96, cl_4000, epochs_in_clusters, epochs, sfreq)
            epochs_in_clusters[cl_2000] = list(epochs_in_clusters[cl_2000]) + list(removed_list)

elif len(epochs_in_clusters[cl_500]) < 5 :
    #Get the epochs with the most power at 500Hz
    removed_list = excess_epochs(500, cl_500, epochs_in_clusters, epochs, sfreq, check = "over")
    #Check where the epochs are from and remove them 
    for ep in removed_list:
        if ep in epochs_in_clusters[cl_2000]:
            epochs_in_clusters[cl_2000].remove(ep)
        else:
            epochs_in_clusters[cl_4000].remove(ep)
    
    #Then check the lengths of the rest of clusters an fix them
    if len(epochs_in_clusters[cl_2000]) > len(epochs_in_clusters[cl_4000]):
        removed_list = excess_epochs(48, cl_2000, epochs_in_clusters, epochs, sfreq)
        epochs_in_clusters[cl_4000] = list(epochs_in_clusters[cl_4000]) + list(removed_list)
    
    #The opposite if len(cl_2000)<len(cl_4000)
    elif len(epochs_in_clusters[cl_2000]) < len(epochs_in_clusters[cl_4000]):
        removed_list = excess_epochs(96, cl_4000, epochs_in_clusters, epochs, sfreq)
        epochs_in_clusters[cl_2000] = list(epochs_in_clusters[cl_2000]) + list(removed_list)

else:
    #Then check the lengths of the rest of clusters an fix them
    if len(epochs_in_clusters[cl_2000]) > len(epochs_in_clusters[cl_4000]):
        removed_list = excess_epochs(48, cl_2000, epochs_in_clusters, epochs, sfreq)
        epochs_in_clusters[cl_4000] = list(epochs_in_clusters[cl_4000]) + list(removed_list)
    
    #The opposite if len(cl_2000) < len(cl_4000)
    elif len(epochs_in_clusters[cl_2000]) < len(epochs_in_clusters[cl_4000]):
        removed_list = excess_epochs(96, cl_4000, epochs_in_clusters, epochs, sfreq)
        epochs_in_clusters[cl_2000] = list(epochs_in_clusters[cl_2000]) + list(removed_list)

In [20]:
print(epochs_in_clusters[cl_500])
print(epochs_in_clusters[cl_2000])
print(epochs_in_clusters[cl_4000])

[0, 5, 6, 9, 10]
[2, 12, 14, 11, 7]
[1, 3, 4, 8, 13]


Change the cluster sequence

In [21]:
epochs_in_clusters[cl_500], epochs_in_clusters[cl_2000], epochs_in_clusters[cl_4000] = epochs_in_clusters[0], epochs_in_clusters[1], epochs_in_clusters[2]
cl_500 = 0
cl_2000 = 1
cl_4000 = 2

Morlet approach 

In [22]:
def extract_freq_neighborhood_features(epochs, center_freqs, neighbor_offset=2, n_neighbors=1, 
                                       t_point=0.05, window=0.02, bw=2.0):
    """
    Extract per-epoch amplitude & cross-channel PLV for stimulus frequencies
    and their neighbors.

    Parameters
    ----------
    epochs : mne.Epochs
        MNE epochs object (n_epochs, n_channels, n_times)
    center_freqs : list of float
        Central stimulus frequencies (e.g., [20, 30, 40])
    neighbor_offset : float
        Distance (Hz) between neighbor frequencies (e.g., 2 Hz)
    n_neighbors : int
        Number of neighbors on each side (1 → ±1 neighbor, 2 → ±2)
    t_point : float
        Time (s) to extract around
    window : float
        Averaging window half-width (seconds)
    bw : float
        Bandwidth (Hz) for bandpass filter around each frequency

    Returns
    -------
    features : ndarray (n_epochs, n_features)
    freq_list : list of float
        Frequencies actually analyzed (centers + neighbors)
    """
    sfreq = epochs.info['sfreq']
    data = epochs.get_data()
    times = epochs.times
    n_epochs, n_channels, n_times = data.shape

    # build frequency list: each central + neighbors
    freq_list = []
    for f in center_freqs:
        for k in range(-n_neighbors, n_neighbors + 1):
            freq_list.append(f + k * neighbor_offset)


    idx = np.argmin(np.abs(times - t_point))
    hw = int(np.round(window * sfreq))
    inds = np.arange(max(0, idx - hw), min(n_times, idx + hw + 1))

    all_feats = []

    for f in freq_list:
        low, high = f - bw/2, f + bw/2

        # filter + hilbert
        filtered = filter_data(data, sfreq, low, high, verbose=False)
        analytic = hilbert(filtered, axis=-1)
        amp = np.abs(analytic)
        phase = np.angle(analytic)

        # average amplitude over time window (keep per-channel)
        amp_win = amp[:, :, inds].mean(axis=2)

        # compute per-epoch PLV across channels in that window
        plv = np.zeros(n_epochs)
        for e in range(n_epochs):
            plv_time = np.abs(np.mean(np.exp(1j * phase[e, :, inds]), axis=0))
            plv[e] = plv_time.mean()
        plv = plv.reshape(-1, 1)

        # concatenate per-frequency features
        feats = np.concatenate([amp_win, plv], axis=1)  # per-epoch features
        all_feats.append(feats)

    features = np.concatenate(all_feats, axis=1)
    return features, freq_list

def select_top_epochs_per_cluster(X, labels, top_n=5):
    """Select top N most representative epochs (closest to centroid) per cluster."""
    selected_indices = []
    unique_labels = np.unique(labels)
    for lbl in unique_labels:
        cluster_indices = np.where(labels == lbl)[0]
        cluster_feats = X[cluster_indices]
        centroid = cluster_feats.mean(axis=0, keepdims=True)
        dists = cdist(cluster_feats, centroid)
        top_idxs = cluster_indices[np.argsort(dists.flatten())[:top_n]]
        selected_indices.extend(top_idxs)
    return np.array(selected_indices)


def cluster_epochs_by_freq_features(epochs, center_freqs=[20, 30, 40],
                                    n_clusters=3, neighbor_offset=2, n_neighbors=1,
                                    t_point=0.05, window=0.02, bw=2.0,
                                    top_n=5, keep_top=True):
    # --- Extract features
    X, freq_list = extract_freq_neighborhood_features(
        epochs, center_freqs=center_freqs,
        neighbor_offset=neighbor_offset, n_neighbors=n_neighbors,
        t_point=t_point, window=window, bw=bw
    )

    # --- Normalize & cluster
    X_scaled = StandardScaler().fit_transform(X)
    clustering = AgglomerativeClustering(n_clusters=n_clusters, linkage='ward')
    labels = clustering.fit_predict(X_scaled)

    # --- Optionally keep top N per cluster
    if keep_top:
        selected_indices = select_top_epochs_per_cluster(X_scaled, labels, top_n=top_n)
        X_top = X[selected_indices]
        labels_top = labels[selected_indices]
    else:
        X_top = X
        labels_top = labels

    return labels_top, X_top, freq_list, selected_indices


In [23]:
labels_new, X_top, freqs_used, selected_idx = cluster_epochs_by_freq_features(
    epochs,
    center_freqs=[500, 96, 48],  # this order will be preserved
    neighbor_offset=2,
    n_neighbors=1,
    t_point=0.05,
    window=0.02,
    bw=2.0,
    top_n=5,      # keep 5 per cluster
    keep_top=True
)

print("Frequencies analyzed:", freqs_used)
print("Selected epoch indices (top 5 per cluster):", selected_idx)


Frequencies analyzed: [498, 500, 502, 94, 96, 98, 46, 48, 50]
Selected epoch indices (top 5 per cluster): [ 3  4 10 13  9  5 11  7  1]


In [24]:
print(labels_new)

[0 0 0 1 1 1 1 1 2]


Phase-locked and Power 

In [25]:
def detect_phase_locked_epochs(epochs, freqs, bw=2.0, t_min=0.0, t_max=0.3,
                               amp_thresh='zscore', plv_thresh=0.7):
    """
    Detect epochs with high amplitude and strong phase-locking at given frequencies.

    Parameters
    ----------
    epochs : mne.Epochs
    freqs : list of float
        Target frequencies (e.g., [48, 96, 500])
    bw : float
        Bandwidth (Hz) for bandpass around each frequency.
    t_min, t_max : float
        Time window (s) post-stimulus to analyze.
    amp_thresh : 'zscore' or float
        Threshold rule for amplitude selection.
    plv_thresh : float
        Minimum PLV to keep epoch.

    Returns
    -------
    selected_epochs : dict
        {freq: indices of epochs passing criteria}
    metrics : dict
        {freq: (amplitudes, plvs)} arrays for inspection
    """
    data = epochs.get_data()
    times = epochs.times
    sfreq = epochs.info['sfreq']
    n_epochs, n_ch, n_times = data.shape

    t_inds = np.where((times >= t_min) & (times <= t_max))[0]

    selected_epochs = {}
    metrics = {}

    for f in freqs:
        # --- narrowband filter around frequency
        filtered = filter_data(data, sfreq, f - bw/2, f + bw/2, verbose=False)

        # --- analytic signal → amplitude + phase
        analytic = hilbert(filtered, axis=-1)
        amp = np.abs(analytic)
        phase = np.angle(analytic)

        # --- average amplitude over time window
        amp_mean = amp[:, :, t_inds].mean(axis=(1, 2))  # per epoch mean

        # --- compute PLV across channels in that window
        plv = np.abs(np.mean(np.exp(1j * phase[:, :, t_inds]), axis=1)).mean(axis=1)

        # --- thresholds
        if amp_thresh == 'zscore':
            thr = amp_mean.mean() + amp_mean.std()
        else:
            thr = amp_thresh

        sel_idx = np.where((amp_mean >= thr) & (plv >= plv_thresh))[0]
        selected_epochs[f] = sel_idx
        metrics[f] = (amp_mean, plv)

        print(f"\nFreq {f} Hz: {len(sel_idx)} epochs selected (thr={thr:.2f}, PLV>{plv_thresh})")

    return selected_epochs, metrics


In [28]:
selected_epochs, metrics = detect_phase_locked_epochs(
    epochs,
    freqs=[500, 96, 48],
    bw=2.0,
    t_min=0.05, t_max=0.25,
    amp_thresh='zscore',
    plv_thresh=0.3
)

for f, idx in selected_epochs.items():
    print(f"{f} Hz → Epoch indices: {idx}")


Freq 500 Hz: 2 epochs selected (thr=0.00, PLV>0.3)

Freq 96 Hz: 3 epochs selected (thr=0.00, PLV>0.3)

Freq 48 Hz: 4 epochs selected (thr=0.00, PLV>0.3)
500 Hz → Epoch indices: [5 9]
96 Hz → Epoch indices: [ 0  1 14]
48 Hz → Epoch indices: [ 1  3  4 14]


In [35]:
import numpy as np
from mne.filter import filter_data
from scipy.signal import hilbert

def find_stimulus_frequency_per_epoch(
    epochs,
    freqs=[48, 96, 500],
    bw=3.0,
    t_min=0.05,
    t_max=0.4,
    plv_weight=0.5,
    n_per_cluster=5
):
    """
    Identify which stimulus frequency best matches each epoch
    based on spectral amplitude + within-epoch phase-locking,
    and keep only the top N per frequency.

    Returns
    -------
    selected_indices : dict
        Mapping from frequency -> list of selected epoch indices
    labels : ndarray (n_epochs,)
        Best-matching frequency index per epoch
    scores : ndarray (n_epochs, n_freqs)
        Combined scores per epoch and frequency
    metrics : dict
        Raw amplitude and PLV arrays for inspection
    """

    data = epochs.get_data()          # (n_epochs, n_channels, n_times)
    times = epochs.times
    sfreq = epochs.info["sfreq"]

    mask = (times >= t_min) & (times <= t_max)
    data = data[:, :, mask]
    n_epochs, n_channels, n_times = data.shape

    amp_all = np.zeros((n_epochs, len(freqs)))
    plv_all = np.zeros_like(amp_all)
    scores = np.zeros_like(amp_all)

    # --- main frequency loop ---
    for f_idx, f0 in enumerate(freqs):
        low, high = f0 - bw, f0 + bw
        filtered = filter_data(data, sfreq, low, high, verbose=False)
        analytic = hilbert(filtered, axis=-1)
        amp = np.abs(analytic)
        phase = np.angle(analytic)

        # mean amplitude over time and channels per epoch
        amp_mean = amp.mean(axis=(1, 2))

        # phase-locking across channels (within-epoch PLV)
        plv = np.abs(np.mean(np.exp(1j * phase), axis=1)).mean(axis=1)

        amp_all[:, f_idx] = amp_mean
        plv_all[:, f_idx] = plv

    # --- normalization ---
    amp_norm = (amp_all - amp_all.min(axis=0)) / (np.ptp(amp_all, axis=0) + 1e-12)
    plv_norm = (plv_all - plv_all.min(axis=0)) / (np.ptp(plv_all, axis=0) + 1e-12)

    # --- combine metrics ---
    scores = (1 - plv_weight) * amp_norm + plv_weight * plv_norm

    # --- get best label for each epoch ---
    labels = np.argmax(scores, axis=1)

    # --- keep only top N per frequency ---
    selected_indices = {}
    for f_idx, f in enumerate(freqs):
        freq_scores = scores[:, f_idx]
        top_indices = np.argsort(freq_scores)[::-1][:n_per_cluster]  # top N
        selected_indices[f] = top_indices.tolist()
        print(f"Top {n_per_cluster} epochs for {f} Hz: {top_indices.tolist()}")

    metrics = {"amplitude": amp_all, "plv": plv_all}

    return selected_indices, labels, scores, metrics


In [37]:
selected, labels, scores, metrics = find_stimulus_frequency_per_epoch(
    epochs,
    freqs=[500, 96, 48],
    bw=3.0,
    t_min=0.05,
    t_max=0.4,
    plv_weight=0.5,
    n_per_cluster=5
)

print("\nSelected top 5 indices per frequency:")
for f, idxs in selected.items():
    print(f"{f} Hz → {idxs}")


Top 5 epochs for 500 Hz: [9, 5, 10, 7, 6]
Top 5 epochs for 96 Hz: [0, 11, 13, 1, 14]
Top 5 epochs for 48 Hz: [1, 14, 3, 0, 6]

Selected top 5 indices per frequency:
500 Hz → [9, 5, 10, 7, 6]
96 Hz → [0, 11, 13, 1, 14]
48 Hz → [1, 14, 3, 0, 6]
