In [1]:
import numpy as np
import pandas as pd
import pathlib

from scipy.linalg import eigh, pinv
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.decomposition import PCA

from itertools import combinations_with_replacement
import h5io
import coffeine

from dameeg.procrustes import align_procrustes
from dameeg.recenter import align_recenter
from dameeg.recenter_rescale import align_recenter_rescale
import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib qt


In [31]:
DERIV_ROOT = pathlib.Path('/storage/store3/derivatives/biomag_hokuto_bids')
FEATURES_ROOT = DERIV_ROOT
BIDS_ROOT = pathlib.Path(
    '/storage/store/data/biomag_challenge/Biomag2022/biomag_hokuto_bids'
)
ROOT = pathlib.Path(
    '/storage/store/data/biomag_challenge/Biomag2022/biomag_hokuto'
)

RANDOM_STATE = 42

frequency_bands = {
    "low": (0.1, 1),
    "delta": (1, 4),
    "theta": (4.0, 8.0),
    "alpha": (8.0, 15.0),
    "beta_low": (15.0, 26.0),
    "beta_mid": (26.0, 35.0),
    "beta_high": (35.0, 49)
}
cross_frequency_bands = []
frequency_bands_ = list(frequency_bands)
for band1 in list(frequency_bands):
    for band2 in frequency_bands_:
        cross_frequency_bands.append(band1+band2)
    frequency_bands_.remove(band1)


def get_names(i, j):
    band1 = list(frequency_bands)[i]
    band2 = list(frequency_bands)[j]
    return(band1+band2)


def get_sub_cov(C, i, j):
    len_C = C.shape[1]
    n_fb = len(list(frequency_bands))
    len_sub_C = int(len_C / n_fb)
    sub_C = C[:, i*len_sub_C: (i+1)*len_sub_C, j*len_sub_C: (j+1)*len_sub_C]
    return list(sub_C)


def get_subjects_labels(all_subjects):
    train_subjects = []
    train_labels = []
    for subject in all_subjects:
        if subject.find('control') == 4:
            train_labels.append('control')
            train_subjects.append(subject)
        elif subject.find('mci') == 4:
            train_labels.append('mci')
            train_subjects.append(subject)
        elif subject.find('dementia') == 4:
            train_labels.append('dementia')
            train_subjects.append(subject)
    return train_subjects, train_labels


def get_site(labels, subjects):
    subjects_A = []
    subjects_B = []
    age_A = []
    age_B = []
    for label in labels:
        site_info = pd.read_excel(ROOT / 'hokuto_profile.xlsx', sheet_name=label)
        for i in range(site_info.shape[0]):
            subject = 'sub-' + site_info['ID'].iloc[i][7:]
            if site_info['Site'].iloc[i] == 'A' and subject in subjects:
                subjects_A.append(subject)
                age_A.append(site_info['Age'].iloc[i])
            if site_info['Site'].iloc[i] == 'B' and subject in subjects:
                subjects_B.append(subject)
                age_B.append(site_info['Age'].iloc[i])
    return subjects_A, subjects_B, age_A, age_B


def get_subjects_age(age, labels):
    subjects = []
    for label in labels:
        site_info = pd.read_excel(ROOT / 'hokuto_profile.xlsx', sheet_name=label)
        for i in range(site_info.shape[0]):
            if site_info['Age'].iloc[i]>= age:
                subjects.append('sub-' + site_info['ID'].iloc[i][7:])
    print(len(subjects))
    return subjects


class ProjCommonSpace(BaseEstimator, TransformerMixin):
    def __init__(self, n_compo='full', reg=1e-7):
        # self.scale = scale = 1 / np.mean([np.trace(x) for x in X])
        self.n_compo = n_compo
        self.reg = reg

    def fit(self, X, y=None):
        self.n_compo = len(X[0]) if self.n_compo == 'full' else self.n_compo
        self.scale_ = 1 / np.mean([np.trace(x) for x in X])
        self.filters_ = []
        self.patterns_ = []
        C = X.mean(axis=0)
        eigvals, eigvecs = eigh(C)
        ix = np.argsort(np.abs(eigvals))[::-1]
        evecs = eigvecs[:, ix]
        evecs = evecs[:, :self.n_compo].T
        self.filters_.append(evecs)  # (fb, compo, chan) row vec
        self.patterns_.append(pinv(evecs).T)  # (fb, compo, chan)
        return self

    def transform(self, X):
        n_sub, _, _ = X.shape
        self.n_compo = len(X[0]) if self.n_compo == 'full' else self.n_compo
        Xout = np.empty((n_sub, self.n_compo, self.n_compo))
        Xs = self.scale_ * X
        filters = self.filters_[0]  # (compo, chan)
        for sub in range(n_sub):
            Xout[sub] = filters @ Xs[sub] @ filters.T
            Xout[sub] += self.reg * np.eye(self.n_compo)
        return Xout  # (sub , compo, compo)


Looking at covariances

In [155]:
all_subjects = get_subjects_age(50, ['control', 'dementia', 'mci'])
subjects_A, subjects_B, age_A, age_B = get_site(['control', 'dementia', 'mci'], all_subjects)
# subjects_B.remove('sub-dementia22')
# subjects_B.remove('sub-dementia25')
# subjects_B.remove('sub-mci3')
# subjects_B.remove('sub-mci5')
train_subjects, y = get_subjects_labels(subjects_A + subjects_B)
rank = 120
reg = 1

features = h5io.read_hdf5(DERIV_ROOT / 'features_cross_frequency_covs.h5')
covs_A = [features[sub]['cross_frequency_covs'] for sub in subjects_A]
covs_A = np.array(covs_A)
covs_B = [features[sub]['cross_frequency_covs']
            for sub in subjects_B]
covs_B = np.array(covs_B)
covs_A_aligned = np.zeros((covs_A.shape[0],
                            len(cross_frequency_bands),
                            rank, rank))
covs_B_aligned = np.zeros((covs_B.shape[0],
                            len(cross_frequency_bands),
                            rank, rank))
for k, (i, j) in enumerate(list(combinations_with_replacement(range(7), 2))):
    proj = ProjCommonSpace(n_compo=rank, reg=reg)
    all_covs = np.concatenate(
        (get_sub_cov(covs_A, i, j), get_sub_cov(covs_B, i, j))
    )
    all_covs = proj.fit_transform(all_covs)
    covs_A_pca = all_covs[:covs_A.shape[0]]
    covs_B_pca = all_covs[covs_A.shape[0]:]
    covs_A_aligned[:, k], covs_B_aligned[:, k] = covs_A_pca, covs_B_pca
    # (covs_A_aligned[:, i],
    #  covs_B_aligned[:, i]) = align_recenter_rescale(covs_A_pca, covs_B_pca)
X_A = pd.DataFrame(
    {band: list(covs_A_aligned[:, ii]) for ii, band in
        enumerate(cross_frequency_bands)})
X_B = pd.DataFrame(
    {band: list(covs_B_aligned[:, ii]) for ii, band in
        enumerate(cross_frequency_bands)})

X = pd.concat([X_A, X_B], axis=0)

101


In [157]:
filter_bank_transformer = coffeine.make_filter_bank_transformer(
    names=list(cross_frequency_bands),
    method='riemann',
    projection_params=dict(scale='auto', n_compo=rank, reg=reg)
)

X_covs = filter_bank_transformer.fit_transform(X)
X_covs.shape

(101, 203280)

In [158]:
pca = PCA(n_components=3)
X_PCA = pca.fit_transform(X_covs)
df = pd.DataFrame({
    'x': X_PCA[:, 0],
    'y': X_PCA[:, 1],
    'label': y,
    'site': ['A']*len(subjects_A) + ['B']*len(subjects_B),
    'age': age_A + age_B
})

In [159]:
plt.figure(0)
sns.scatterplot(x='x', y='y', data=df, hue='label', style='site', size='age', sizes=(40, 400), alpha=0.5)
plt.title('PCA of covariances')
plt.xlabel('1st component')
plt.ylabel('2nd component')
plt.xlim(-0.000153, -0.00009)
plt.ylim((-0.00008, 0))
plt.show()

In [136]:

fig = plt.figure(1)
ax = fig.add_subplot(projection = '3d')

x_pca = X_PCA[:, 0]
y_pca = X_PCA[:, 1]
z_pca = X_PCA[:, 2]
c = []
for label in y:
    if label == 'control':
        c.append('b')
    elif label == 'dementia':
        c.append('r')
    elif label == 'mci':
        c.append('g')
s = []
for a in age_A + age_B:
    s.append(100*((a-40)/(93-40))**2)
ax.scatter(x_pca, y_pca, z_pca, c=c, s=s)
plt.xlim(-0.00032, 0)
plt.ylim(-0.00015, 0.0002)
plt.show()

In [126]:
X_PCA[:, 1].max()

0.009553630348103425