# Microstate analysis with Pycrostates

In [None]:
import mne
import numpy as np
import os
import pickle
import pycrostates
import seaborn as sns
import time

from matplotlib import pyplot as plt
from mne.io import read_raw_edf

from pycrostates.cluster import ModKMeans
from pycrostates.io import ChData
from pycrostates.preprocessing import extract_gfp_peaks

In [None]:
sleep_stages = ['W', 'N3']

In [None]:
data_dir = "../data"
all_files = os.listdir(data_dir)
print(f"All files n = {len(all_files):d}")
results_dir = "./results"
if not os.path.isdir(results_dir):
    os.mkdir(results_dir)
    print(f"Results folder created: {results_dir:s}")

In [None]:
files = {
    stage : [f"{data_dir:s}/{f:s}" for f in all_files if f.endswith(f'_{stage:s}.edf')] 
    for stage in sleep_stages
}
for stage in sleep_stages:
    print(f"\nSleep stage: {stage:s}")
    for i, f in enumerate(files[stage]):
        print(f"{i:d}: {f:s}")

## Group-level microstate clustering

In [None]:
# number of microstates
K = 5
# band-pass filter settings
bp_lo, bp_hi = (1, 30)

In [None]:
# this cell will take a while...
tic = time.time()
montage = mne.channels.make_standard_montage("standard_1005")
subject_level_maps = list()
for stage in sleep_stages:
    print(f"\nSleep stage: {stage:s}")
    n_files = len(files[stage])
    for i, f in enumerate(files[stage]):
        print(f"{i+1:d}/{n_files:d}: {f:s}")
        raw = read_raw_edf(f, preload=True, verbose=False)
        raw.set_montage(montage)
        raw.pick("eeg")
        raw.filter(l_freq=bp_lo, h_freq=bp_hi)
        raw.set_eeg_reference("average")
        gfp_peaks = extract_gfp_peaks(raw) # extract GFP peaks
        # subject level clustering
        ModK = ModKMeans(n_clusters=K, random_state=42)
        ModK.fit(gfp_peaks, n_jobs=2)
        subject_level_maps.append(ModK.cluster_centers_)
# combine maps across all subjects and stages to obtain group maps
group_maps = np.vstack(subject_level_maps).T
group_maps = ChData(group_maps, ModK.info)
# group level clustering
ModK = ModKMeans(n_clusters=K, random_state=42)
ModK.fit(group_maps, n_jobs=2)
toc = time.time()
print(f"[+] Computation time: {toc-tic:.1f} seconds.")
ModK.plot()
plt.show()

## Sort group-level microstate template maps

In [None]:
# only once
ModK.invert_polarity([False, True, True, True, True])
ModK.reorder_clusters(order=[2, 1, 4, 3, 0])
ModK.rename_clusters(new_names=["A", "B", "C", "D", "E"])
ModK.plot()
plt.show()

## Save group-level microstate template maps

In [None]:
# save the whole ModK object as pickle
with open(f"./results/sleep_group_maps_K{K:d}.pkl", 'wb') as fp:
    pickle.dump(ModK, fp, protocol=pickle.HIGHEST_PROTOCOL)
# save array data only in NumPy format
np.save(f"./results/sleep_group_maps_K{K:d}_numpyndarray.npy", ModK.cluster_centers_)
del ModK

## Fit microstate sequences from group-level maps

In [None]:
# re-load group-level maps
ModK = np.load(f"./results/sleep_group_maps_K{K:d}.pkl", allow_pickle=True)
print(ModK)

In [None]:
for stage in sleep_stages:
    print(f"\nSleep stage: {stage:s}")
    n_files = len(files[stage])
    for i, f in enumerate(files[stage]):
        print(f"{i+1:d}/{n_files:d}: {f:s}")
        # get subject ID, e.g. f=".../data/S00_W.edf" --> S00
        subj_id = f.split('/')[-1].split('.')[0].split('_')[0]
        raw = read_raw_edf(f, preload=True, verbose=False)
        #raw.set_montage(montage)
        raw.pick("eeg")
        raw.filter(l_freq=bp_lo, h_freq=bp_hi)
        raw.set_eeg_reference("average")
        # half window size b=3, lambda factor 5, as in Pascual-Marqui et al. IEEE TBME 1995
        segmentation = ModK.predict(
            raw,
            reject_by_annotation=True,
            factor=5,
            half_window_size=3,
            min_segment_length=3,
            reject_edges=True,
        )
        # save microstate sequence
        f_ms = f"{results_dir:s}/{subj_id:s}_{stage:s}_ms_K{K:d}.npy"
        print(f"Save as: {f_ms:s}")
        np.save(f_ms, segmentation.labels)
print("DONE.")

## Analyze microstate sequences

In [None]:
all_result_files = os.listdir(results_dir)
result_files = {
    stage : [f for f in all_result_files if f"_{stage:s}_" in f]
    for stage in sleep_stages
}
for stage in sleep_stages:
    for f in result_files[stage]:
        print(f"\nFile: {f:s}")
        ms = np.load(f"{results_dir:s}/{f:s}")
        print(ms.shape, np.unique(ms))
        # to be continued...
print("DONE.")