In [62]:
import numpy
import torch

In [63]:
# synthetic data
torch.manual_seed(42)
n_samples = 5000
sample_rate = 256
time = torch.arange(0, n_samples) / sample_rate
base_signal = 0.1 * torch.randn(n_samples)
blink_indices = torch.randint(0, n_samples, size = (5,))
for idx in blink_indices:
    start = idx
    end = min(idx + 50, n_samples)
    base_signal[start:end] += torch.linspace(3, 0, end - start)
eeg_signal = base_signal

In [64]:
# step 1: SSA embedding of the raw EEG signal
def ssa_embedding(signal, window_size=256):
    x = torch.as_tensor(signal, dtype=torch.float32)
    N = x.shape[0]
    X = x.unfold(dimension=0, size=window_size, step=1)

    return X

In [65]:
X = ssa_embedding(eeg_signal, 256)
print("Trajectory matrix shape:", X.shape)

Trajectory matrix shape: torch.Size([4745, 256])


In [66]:
# step 2: feature extraction (four time domain features)
def extract_features(ssa_embedding: torch.Tensor) -> torch.Tensor:
    """
    ssa_embedding: a 2D tensor of shape (L, K),
                   where L is window_size and K is the number of columns
    Returns a 2D tensor of shape (4, K),
    with rows = [energy, mobility, kurtosis, peak2peak]
    and each column corresponds to the features of ssa_embedding[:, k].
    """
    def compute_energy(x: torch.Tensor) -> torch.Tensor:
        """
        energy feature
        """
        return torch.sum(x**2)

    def hjorth_mobility(x: torch.Tensor) -> torch.Tensor:
        """
        hjorty mobility
        """
        dx = x[1:] - x[:-1]
        var_x = torch.var(x, unbiased=True)
        var_dx = torch.var(dx, unbiased=True)
        if var_x == 0:
            return 0.0
        mobility = torch.sqrt(var_dx / var_x)
        return mobility

    def kurtosis(x: torch.Tensor, excess: bool = False) -> torch.Tensor:
        """
        kurtosis, could either use raw or adjusted
        """
        x = x.to(dtype=torch.float32)
        mean_x = torch.mean(x)
        var_x = torch.var(x, unbiased=True)

        if var_x == 0:
            return torch.tensor(0.0)
        
        fourth_moment = torch.mean((x - mean_x) **4)
        raw_kurtosis = fourth_moment / (var_x**2)
        
        if excess:
            return raw_kurtosis - 3
        else:
            return raw_kurtosis
        
    def peak_to_peak(x: torch.Tensor) -> torch.Tensor:
        """
        min max diff
        """
        return torch.max(x) - abs(torch.min(x))

    L, K = ssa_embedding.shape
    features = torch.zeros((4, K), dtype=torch.float32)
    for k in range(K):
        col = ssa_embedding[:, k]
        e = compute_energy(col)
        m = hjorth_mobility(col)
        ku = kurtosis(col)
        p2p = peak_to_peak(col)
        
        # save into resulting matrix
        features[0, k] = e
        features[1, k] = m
        features[2, k] = ku
        features[3, k] = p2p
    
    return features




In [67]:
feature_matrix = extract_features(X)

In [68]:
import numpy as np
from sklearn.cluster import KMeans


In [69]:
# Step 3: Perform k-means clustering on the columns of feature_matrix.
def k_means_clustering(feature_matrix: torch.Tensor, num_clusters: int):
    feature_matrix_np = feature_matrix.T.cpu().numpy().astype(np.float32)
    kmeans = KMeans(n_clusters=num_clusters, random_state=42)
    labels = kmeans.fit_predict(feature_matrix_np)
    # Centroids: shape (num_clusters, num_features)
    centers = kmeans.cluster_centers_
    return labels, centers

In [70]:
labels, centers = k_means_clustering(feature_matrix, 2)

In [71]:
type(labels)

numpy.ndarray

In [72]:
# Step 4: restructure the original ssa embedding matrix
def ssa_diagonal_average(X_bar: torch.Tensor) -> torch.Tensor:
    """
    Diagonal-average (Hankelize) the L x K matrix X_bar into a 1D signal s
    of length (L + K - 1).
    """
    L, K = X_bar.shape
    N = L + K - 1

    # prepare tensors
    s = torch.zeros(N, dtype=X_bar.dtype, device=X_bar.device)
    count = torch.zeros(N, dtype=X_bar.dtype, device=X_bar.device)

    for r in range(L):
        for c in range(K):
            idx = r + c
            s[idx] += X_bar[r, c]
            count[idx] += 1
    
    # only divide at non-zero position
    mask = (count != 0)
    s[mask] /= count[mask]
    return s

def reconstruct_signals(X:torch.Tensor, labels:np.ndarray, num_clusters:int):
    L, K = X.shape
    def create_cluster_matrix(X:torch.Tensor, cluster_idx:int):
        """
        Creates the cluster-specific matrix X̄ᵢ (shape L x K)
        by copying columns from X if labels[j] == cluster_idx,
        else putting 0 in that column.
        """
        X_i = torch.zeros_like(X)
        for j in range(K):
            if labels[j] == cluster_idx:
                X_i[:, j] = X[:, j]
        return X_i

        
    # Create and reconstruct each cluster
    signals = []
    for cluster_idx in range(num_clusters):
        # 1) Build X̄ᵢ
        X_i = create_cluster_matrix(X, cluster_idx)
        # 2) Diagonal-average => 1D signal
        s_i = ssa_diagonal_average(X_i)
        signals.append(s_i)

    return signals
signals = reconstruct_signals(X, labels, 2)



In [73]:
len(signals)

2

In [74]:
# Step 5-7
def fractal_sevcik(signal:torch.Tensor) -> float:
    """Sevcik Fractal Dimension (SFD) referred by the original paper, 
        adapted from 
        https://github.com/neuropsychology/NeuroKit/blob/master/neurokit2/complexity/fractal_sevcik.py
    """
    n = signal.shape[0]
    s_min = torch.min(signal)
    s_max = torch.max(signal)

    # 1) Normalize the signal (new range to [0, 1])
    y_ = (signal - s_min) / (s_max - s_min)
    # 2) Derive x* and y* (y* is actually the normalized signal)
    x_ = torch.linspace(0, 1, steps=n, dtype=torch.float32, device=signal.device)
    # 3) Compute L (because we use np.diff, hence n-1 below)
    dy = y_[1:] - y_[:-1]
    dx = x_[1:] - x_[:-1]
    dist = torch.sqrt(dx**2 + dy**2)
    L = torch.sum(dist)
    # 4. Compute the fractal dimension (approximation)
    sfd = 1.0 + torch.log(L) / torch.log(torch.tensor(2.0 * (n - 1), device=signal.device))
    return float(sfd.item())

def ssa_decomposition(A_hat:torch.Tensor):
    # 1) SVD of A_hat
    # U: (M, M), S: (min(M,K),), Vt: (K, K)
    U, S, Vt = torch.linalg.svd(A_hat, full_matrices=False)
    A_list = []
    for i in range(S.shape[0]):
        sigma_i = S[i]
        u_i = U[:, i]
        v_i = Vt[i,:]
        A_i = sigma_i * (u_i.unsqueeze(1) @ v_i.unsqueeze(0))
        A_list.append(A_i)
    
    return A_list, S
def ssa_grouping(A_list, singular_values, threshold=0.01):
    sigmas = singular_values.float()
    lambdas = sigmas**2

    total_lambda = torch.sum(lambdas)
    if total_lambda <= 0:
        return torch.zeros_like(A_list[0]), []
    
    ratios = lambdas / total_lambda

    kept_indices = []
    for i in range(len(A_list)):
        if ratios[i].item() > threshold:
            kept_indices.append(i)
        
    if len(kept_indices) == 0:
        A_sum = torch.zeros_like(A_list[0])
    else:
        A_sum = A_list[kept_indices[0]].clone()
        for idx in kept_indices[1:]:
            A_sum += A_list[idx]
    return A_sum, kept_indices

def refine_artifact_with_ssa(
        artifact_signal: torch.Tensor,
        window_size: int,
        grouping_threshold: float):
    """
    Steps 8 & 9:
      Step 8: Apply SSA to the artifact_signal to remove EEG remnants.
              (pipeline: embed -> decompose -> group -> FD)
    """
    artifact_signal = ssa_embedding(artifact_signal, window_size)
    decomposed_artifact_signals, singular_values = ssa_decomposition(artifact_signal)
    A_sum, _ = ssa_grouping(decomposed_artifact_signals, singular_values, grouping_threshold)
    blink_artifact = ssa_diagonal_average(A_sum)
    return blink_artifact




In [75]:

def remove_blink_artifact(cluster_signals,
                          eeg_signal: torch.Tensor,
                          fd_threshold:float):
    """
    Identify and remove eye-blink artifact from EEG by:
      1) Computing FD for each cluster signal,
      2) Selecting those with FD <= fd_threshold (blink-like),
      3) Summing those to form the blink estimate,
      4) Subtracting from the original EEG.

    Parameters
    ----------
    cluster_signals : list[torch.Tensor]
        A list of length L, where each element is a 1D PyTorch tensor
        of the same length as eeg_signal. These are the signals reconstructed
        from Step 4 (i.e., from diagonal averaging of each cluster).
    eeg_signal : torch.Tensor
        The original 1D EEG recording (contaminated by blink).
    fd_threshold : float
        The preset fractal dimension threshold. Any signal with FD <= this
        is considered blink-artifact.

    Returns
    -------
    cleaned_eeg : torch.Tensor
        The artifact-corrected EEG (same shape as eeg_signal).
    blink_estimate : torch.Tensor
        The summed blink-artifact signal (same shape as eeg_signal).
    fd_values : list[float]
        The FD values for each cluster signal, in the same order as cluster_signals.
    """
    # Step 5: Compute FD & sum blink components
    eeg_signal = eeg_signal.flatten()
    L = eeg_signal.shape[0]
    fd_values = []
    blink_components = []
    for signal in cluster_signals:
        fd_val = fractal_sevcik(signal)
        fd_values.append(fd_val)
        if fd_val <= fd_threshold:
            blink_components.append(signal)
    if len(blink_components) == 0:
        a_r = torch.zeros_like(eeg_signal)
    else:
        a_r = blink_components[0].clone()
        for b_idx in range(1, len(blink_components)):
            a_r += blink_components[b_idx]

    # Step 6: Convert a_r to a binary template a_b (nonzero -> 1, zero -> 0)
    a_b = (a_r != 0).float()
    # Step 7: Multiply the binary artifact template by the original EEG x
    #         => blink_artifact = a_b * x
    blink_artifact = a_b * eeg_signal
    # Step 8: refine the artifact with SSA
    blink_artifact = refine_artifact_with_ssa(blink_artifact, 256, 0.01)
    # Step 9: subtract the artifact from the contaminated signal
    cleaned_eeg = eeg_signal - blink_artifact
    
    return cleaned_eeg, blink_artifact, fd_values

cleaned_eeg, blink_est, fd_vals = remove_blink_artifact(signals,
                                                        eeg_signal,
                                                        1.4)

In [76]:
cleaned_eeg

tensor([ 0.1927,  0.1487,  0.0901,  ..., -0.1071,  0.0778, -0.1770])