In [1]:
import mne.io.eeglab.eeglab
import numpy as np
import pandas as pd
import functools
from neurokit2.stats.cluster_quality import _cluster_quality_distance
from filenames_and_paths import *
from helper import *

In [2]:
def cluster_taahc(
    data,
    n_clusters=2,
    gfp=None,
    gfp_peaks=None,
    gfp_sum_sq=None,
    random_state=None,
    use_peaks=False,
    **kwargs
):
    """Atomize and Agglomerative Hierarchical Clustering Algorithm, AAHC (Murray et al., Brain Topography, 2008),
    implemented by https://github.com/Frederic-vW/eeg_microstates/blob/master/eeg_microstates.py#L518

    Preprocessing steps of GFP computation are necessary for the algorithm to run. If gfp arguments are specified,
    data is assumed to have been filtered out based on gfp peaks (e.g., data[:, indices]), if not specified,
    gfp indices will be calculated in the algorithm and data is assumed to be the full un-preprocessed input.
    """

    # Internal functions for aahc
    def extract_row(A, k):
        v = A[k, :]
        A_ = np.vstack((A[:k, :], A[k + 1 :, :]))
        return A_, v

    def extract_item(A, k):
        a = A[k]
        A_ = A[:k] + A[k + 1 :]
        return A_, a

    def locmax(x):
        """Get local maxima of 1D-array
        Args:
            x: numeric sequence
        Returns:
            m: list, 1D-indices of local maxima
        """
        dx = np.diff(x)  # discrete 1st derivative
        zc = np.diff(np.sign(dx))  # zero-crossings of dx
        m = 1 + np.where(zc == -2)[0]  # indices of local max.  
        return m

    # Sanitize
    if isinstance(data, pd.DataFrame):
        data = np.array(data)
    _, nch = data.shape

    # If preprocessing is not Done already
    if gfp is None and gfp_peaks is None and gfp_sum_sq is None:
        gfp = data.std(axis=1)
        gfp_peaks = locmax(gfp)
        gfp_sum_sq = np.sum(gfp**2)  # normalizing constant in GEV
        if use_peaks:
            maps = data[gfp_peaks, :]  # initialize clusters
            cluster_data = data[gfp_peaks, :]  # store original gfp peak indices
        else:
            maps = data.copy()
            cluster_data = data.copy()
    else:
        maps = data.copy()
        cluster_data = data.copy()

    n_maps = maps.shape[0]

    # cluster indices w.r.t. original size, normalized GFP peak data
    Ci = [[k] for k in range(n_maps)]

    # Main loop: atomize + agglomerate
    while n_maps > n_clusters:

        # correlations of the data sequence with each cluster
        m_x, s_x = data.mean(axis=1, keepdims=True), data.std(axis=1)
        m_y, s_y = maps.mean(axis=1, keepdims=True), maps.std(axis=1)
        s_xy = 1.0 * nch * np.outer(s_x, s_y)
        C = np.dot(data - m_x, np.transpose(maps - m_y)) / s_xy

        # microstate sequence, ignore polarity
        L = np.argmax(C**2, axis=1)

        # GEV (global explained variance) of cluster k
        gev = np.zeros(n_maps)
        for k in range(n_maps):
            r = L == k
            gev[k] = np.sum(gfp[r] ** 2 * C[r, k] ** 2) / gfp_sum_sq

        # merge cluster with the minimum GEV
        imin = np.argmin(gev)

        # N => N-1
        maps, _ = extract_row(maps, imin)
        Ci, reC = extract_item(Ci, imin)
        re_cluster = []  # indices of updated clusters
        for k in reC:  # map index to re-assign
            c = cluster_data[k, :]
            m_x, s_x = maps.mean(axis=1, keepdims=True), maps.std(axis=1)
            m_y, s_y = c.mean(), c.std()
            s_xy = 1.0 * nch * s_x * s_y
            C = np.dot(maps - m_x, c - m_y) / s_xy
            inew = np.argmax(C**2)  # ignore polarity
            re_cluster.append(inew)
            Ci[inew].append(k)
        n_maps = len(Ci)

        # Update clusters
        re_cluster = list(set(re_cluster))  # unique list of updated clusters

        # re-clustering by eigenvector method
        for i in re_cluster:
            idx = Ci[i]
            Vt = cluster_data[idx, :]
            Sk = np.dot(Vt.T, Vt)
            evals, evecs = np.linalg.eig(Sk)
            c = evecs[:, np.argmax(np.abs(evals))]
            c = np.real(c)
            maps[i] = c / np.sqrt(np.sum(c**2))

    # Get distance
    prediction = _cluster_quality_distance(cluster_data, maps, to_dataframe=True)
    prediction["Cluster"] = prediction.abs().idxmax(axis=1).values
    prediction["Cluster"] = [
        np.where(prediction.columns == state)[0][0] for state in prediction["Cluster"]
    ]

    # Function
    clustering_function = functools.partial(
        cluster_taahc, n_clusters=n_clusters, random_state=random_state, **kwargs
    )

    # Info dump
    info = {
        "n_clusters": n_clusters,
        "clustering_function": clustering_function,
        "random_state": random_state,
    }

    return prediction, maps, info

In [3]:
folders.end_folder = path014
mhw: MicrostateHelperWrapper = MicrostateHelperWrapper.static_load(folders=folders, raw_filename=filenames014[0]+"_th")

Loading MHW object ACP_INP0014_REST1_1pnt_1vis_th


In [4]:
raw: mne.io.eeglab.eeglab.RawEEGLAB = mhw.raw

In [22]:
eeg = raw.get_data()
data_df = pd.DataFrame(eeg)

In [23]:
from neurokit2 import microstates_clean

data, indices, gfp, info_mne = microstates_clean(
    eeg,
    train='gfp',
    sampling_rate=2048,
    standardize_eeg=False,
    gfp_method='l1'
)

array([[ 1.41895771e-05,  1.47679081e-05,  1.31899948e-05, ...,
        -3.18051529e-06, -4.59211636e-06, -1.54073524e-05],
       [ 2.09563160e-05,  1.99793587e-05,  1.20101690e-05, ...,
        -1.07384453e-05,  1.33070278e-06, -8.84633923e-06],
       [ 1.59974420e-06,  8.43425560e-06,  1.96164112e-05, ...,
        -2.41871147e-05, -5.17581320e-06, -2.27316036e-05],
       ...,
       [-8.42091370e-06, -6.77572966e-06, -5.80944586e-06, ...,
         1.67502327e-05, -3.84942150e-06,  1.84001064e-05],
       [-6.90112925e-06, -7.47818470e-06, -5.12855387e-06, ...,
         1.03668213e-05,  6.81083584e-06,  1.61063118e-05],
       [-8.07153225e-06, -7.70825267e-07,  5.23290396e-06, ...,
        -4.84503126e-06, -2.92906451e-06, -3.50755787e-06]])

In [25]:
cluster_taahc(data[:, indices].T)

KeyboardInterrupt: 

In [20]:
nk.microstates_segment(raw, 4, method="aahc")

{'Microstates': array([[ 5.49325516e-06,  6.09980506e-06,  7.46148731e-06,
          4.22756705e-06,  4.95822544e-06,  6.27991790e-06,
          8.84467965e-06,  9.08036237e-06,  9.92061111e-07,
          3.21764185e-06,  4.25011295e-06,  8.55842387e-06,
         -3.73392241e-06, -3.01713902e-06,  2.63394367e-06,
          5.27455415e-06,  6.10772989e-06, -9.66563611e-06,
         -4.29985090e-06,  1.31196159e-06,  2.35864530e-06,
         -1.08981310e-05, -1.11216304e-05, -7.81334124e-06,
         -1.45672889e-06, -6.00542742e-08, -7.93352851e-06,
         -1.25131581e-05, -3.43219351e-06,  4.70816269e-06,
          4.97552062e-06,  8.87985615e-06,  8.69532523e-06,
          3.54824375e-06,  5.17282168e-06,  8.59569852e-06,
          8.90302413e-06,  6.32918139e-07,  3.29678196e-06,
          6.33983899e-06, -4.31829911e-06, -7.79885922e-07,
          2.22155691e-06,  6.75055268e-06, -7.36403212e-06,
          1.49972839e-06, -1.11082572e-05, -1.11938961e-05,
         -1.42333672e-06,