In [1]:
%config Completer.use_jedi = False
%load_ext autoreload
# %reload_ext autoreload
%autoreload 2
%load_ext lab_black

In [2]:
import numpy as np
import scipy.linalg as la

In [3]:
fs = 250
low_freq = 4
high_freq = 38
bandwidth = 4
overlap = 2

freqs = np.arange(low_freq, high_freq - (bandwidth - overlap), overlap)
freqs / (fs / 2)

array([0.032, 0.048, 0.064, 0.08 , 0.096, 0.112, 0.128, 0.144, 0.16 ,
       0.176, 0.192, 0.208, 0.224, 0.24 , 0.256, 0.272])

In [4]:
signals = np.random.randint(1, 50, size=(10, 22, 1, 1, 769))
labels = np.random.randint(1, 4, size=(10))
signals.shape, labels

((10, 22, 1, 1, 769), array([2, 2, 3, 2, 3, 3, 3, 2, 3, 3]))

In [None]:
def make_binary(y, key: int):
    idx = y == key
    return idx, ~idx


def calc_cov(x):
    # x's shape will be (trials, channel, signals)
    n_trials, n_channels, n_samples = x.shape
    covariance_matrix = np.zeros((n_trials, n_channels, n_channels))

    for i in range(x.shape[0]):
        trial = x[i, :, :]
        # covariance_matrix[i, :, :] = np.cov(trial)
        # or
        covariance_matrix[i, :, :] = trial @ trial.T

    mean_cov = np.mean(covariance_matrix, 0)

    return mean_cov

In [None]:
def eval_W(signals, labels, freq_bands, n_windows: int = 5):
    num_bands = len(freq_bands)
    trials, channels, n_windows, num_bands, data = signals.shape
    template = "label-{}_band-{}_window-{}"

    w_dict = {}
    for label in np.unique(labels):
        trial_idx, rest_idx = make_binary(labels, label)

        x_trial = signals[trial_idx]
        x_rest = signals[rest_idx]

        for freq_band in range(num_bands):
            for window in range(n_windows):
                # (batch, num_channel, timepoint)
                R_a = calc_cov(x_trial[:, :, window, freq_band, :])
                R_b = calc_cov(x_rest[:, :, window, freq_band, :])
                R_c = R_a + R_b

                lda, V = la.eig(R_c)
                W = la.sqrtm(la.inv(np.diag(lda))) @ V.T
                T_a = W @ R_a @ W.T
                T_b = W @ R_b @ W.T
                T_c = W @ R_c @ W.T

                sai, E = la.eig(T_a, T_c)
                P = E  # or P = E.T @ W although the latter doesn't give
                # a coordinate system with orthogonal axis. we skip
                # whitning step

                sai_a, E_a = la.eig(T_a)
                sai_b, E_b = la.eig(T_b)
                assert np.allclose(E_a, E_b)

                print("*" * 30)
                print(
                    np.allclose(np.diag(sai_a) + np.diag(sai_b), np.eye(sai_b.shape[0]))
                )
                print(np.allclose(E @ E.T, np.eye(E.shape[0])))
                print(np.allclose(W @ W.T, np.eye(W.shape[0])))
                print(np.allclose(P @ P.T, np.eye(W.shape[0])))
                print("*" * 30)
                w_dict[template.format(label, freq_band, window)] = P

    return w_dict

In [None]:
weights = eval_W(signals=signals, labels=labels, freq_bands=freqs)

In [None]:
w = weights["label-1_band-0_window-0"]
print(weights["label-1_band-0_window-0"].shape)

np.allclose(w @ w.T, np.eye(w.shape[0])), np.allclose(w.T @ w, np.eye(w.shape[0]))

In [None]:
w = la.inv(weights["label-1_band-0_window-0"])
print(weights["label-1_band-0_window-0"].shape)

np.allclose(w @ w.T, np.eye(w.shape[0])), np.allclose(w.T @ w, np.eye(w.shape[0]))