# Imports

In [1]:
import numpy as np
from numba import jit, njit, prange
import mne
from scipy.fftpack import fft, ifft, fftfreq
from scipy.stats import binned_statistic, entropy
from scipy import signal
import matplotlib.pyplot as plt
import time
import dill as pickle
from tqdm import tqdm

# Utility

In [2]:


def load_patient_from_pickle(filepath):
    with open(filepath, 'rb') as f:
        patient = pickle.load(f, pickle.HIGHEST_PROTOCOL)
    patient.pac = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
    return patient

    
def generate_coupled_signal(f_p, f_a, K_p, K_a, xi, timepoints, noise_level=0.1, phase=0, noise_type='pink', alpha=1):
    
    x_fp = K_p * np.sin(2 * np.pi * f_p * timepoints)
    A_fa = K_a/2 * ((1 - xi) * np.sin(2 * np.pi * f_p * timepoints + phase) + xi + 1)
    x_fa = A_fa * np.sin(2 * np.pi * f_a * timepoints)
    
    n = len(timepoints)
    
    if noise_type == 'white-gaussian':
        noise = np.random.normal(scale=noise_level, size=n)
    
    if noise_type == 'white-uniform':
        noise = np.random.uniform(low=-noise_level, high=noise_level, size=n)
        
    if noise_type == 'pink':
        noise = np.random.normal(scale=noise_level, size=n)
        
        noise_spectrum = fft(noise)
        freqs = fftfreq(n, timepoints[1] - timepoints[0])

        oneOverF = np.insert((1/(freqs[1:]**alpha)), 0, 0)
        new_spectrum = oneOverF * noise_spectrum                    
        noise = np.abs(ifft(new_spectrum))    
    
    x = x_fp + x_fa + noise
    
    return x

# Creating signal

In [3]:
sf = 2000
T = 180
timepoints = np.arange(0, T, 1/sf)

X = generate_coupled_signal(10, 150, 10, 1, 0.1, timepoints, phase=np.pi/4)

In [4]:
beta_phase = np.angle(signal.hilbert(mne.filter.filter_data(X, sf, 9, 11, )), deg=True)
hfo_amplitude = np.abs(signal.hilbert(mne.filter.filter_data(X, sf, 130, 150)))

Setting up band-pass filter from 9 - 11 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 9.00
- Lower transition bandwidth: 2.25 Hz (-6 dB cutoff frequency: 7.88 Hz)
- Upper passband edge: 11.00 Hz
- Upper transition bandwidth: 2.75 Hz (-6 dB cutoff frequency: 12.38 Hz)
- Filter length: 2935 samples (1.468 sec)

Setting up band-pass filter from 1.3e+02 - 1.5e+02 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 130.00
- Lower transition bandwidth: 32.50 Hz (-6 dB cutoff frequency: 113.75 Hz)
- Upper passband edge: 150.00 Hz
- Upper transition bandwidth: 37.50 Hz (-6 dB cutoff freque

In [5]:
n_beta, n_hfo = 44, 24

hfo_matrix = np.array([[hfo_amplitude for i in range(n_beta)] for j in range(n_hfo)])
beta_matrix = np.array([[beta_phase for i in range(n_beta)] for j in range(n_hfo)])

print(beta_matrix.shape)
print(hfo_matrix.shape)

print(beta_matrix.dtype)

(24, 44, 360000)
(24, 44, 360000)
float64


# Optimization

## calculate_single_PAC

In [6]:
# my binned statistic - takes phase and amplitude

@njit(fastmath=True)
def calculate_single_PAC_njit(beta_phase, hfo_amplitude):
    
    bins = np.arange(-180, 200, 20)
    nbins = 18
    
    ind = np.searchsorted(bins, beta_phase, side='left') # this returns indexes of bins to which phase values correspond to
    count = np.bincount(ind, None)
    sum_value = np.bincount(ind, hfo_amplitude)
    mean_value = sum_value[1:] / count[1:]
    
    pk = mean_value / mean_value.sum()
    qk = np.ones(nbins) / nbins
    KL = (pk * np.log(pk / qk)).sum()
    return KL
# kl = S = sum(pk * log(pk / qk)

def calculate_single_PAC(beta_phase, hfo_amplitude):
    bins = np.arange(-180, 200, 20)
    nbins = len(bins) -1
    PA_distr = binned_statistic(beta_phase, hfo_amplitude, statistic='mean', bins=bins).statistic

    PA_distr = PA_distr/np.sum(PA_distr)

    uniform_distr = np.ones(nbins) / nbins

    MI = entropy(PA_distr, uniform_distr)
    return MI

In [7]:
t0 = time.perf_counter()

res_njit = calculate_single_PAC_njit(beta_phase, hfo_amplitude)
    
t_njit = time.perf_counter() - t0
print(t_njit)

t0 = time.perf_counter()

res_njit = calculate_single_PAC_njit(beta_phase, hfo_amplitude)
    
t_njit = time.perf_counter() - t0
print(t_njit)


1.924154399999999
0.004201900000001757


In [8]:
t0 = time.perf_counter()

res_py = calculate_single_PAC(beta_phase, hfo_amplitude)
    
t_python = time.perf_counter() - t0
print(t_python)

0.03143559999999823


## calculate_PAC_matrix

In [9]:
@njit(fastmath=True, parallel=False)
def calculate_PAC_matrix_njit(beta_matrix, hfo_matrix):
    n_hfo, n_beta = np.shape(beta_matrix)[0], np.shape(beta_matrix)[1]
    pac_matrix = np.zeros((n_hfo, n_beta))
    for i in prange(n_hfo):
        for j in prange(n_beta):
            pac_matrix[i, j] = calculate_single_PAC_njit(beta_matrix[i, j], hfo_matrix[i, j])
    return pac_matrix


def calculate_PAC_matrix(beta_matrix, hfo_matrix):
    n_hfo, n_beta = np.shape(beta_matrix)[0], np.shape(beta_matrix)[1]
    pac_matrix = np.zeros((n_hfo, n_beta))
    for i in range(n_hfo):
        for j in range(n_beta):
            pac_matrix[i, j] = calculate_single_PAC(beta_matrix[i, j], hfo_matrix[i, j])
    return pac_matrix

In [10]:
t0 = time.perf_counter()

pac_matrix_njit = calculate_PAC_matrix_njit(beta_matrix, hfo_matrix)
    
t_njit = time.perf_counter() - t0
print(t_njit)

t0 = time.perf_counter()

pac_matrix_njit = calculate_PAC_matrix_njit(beta_matrix, hfo_matrix)
    
t_njit = time.perf_counter() - t0
print(t_njit)

print(pac_matrix_njit.dtype)

4.665914299999997
4.023082299999999
float64


In [19]:
t0 = time.perf_counter()

pac_matrix_python = calculate_PAC_matrix(beta_matrix, hfo_matrix)
    
t_python = time.perf_counter() - t0
print(t_python)

36.53650974499999


## shuffle_hfo_matrix

In [10]:

def shuffle_hfo_matrix(hfo_matrix, n_splits=3):
    n_hfo, n_beta, n_times = hfo_matrix.shape
    rand_idx = list(np.sort(np.random.randint(low=200, high=n_times-200, size=n_splits)))
    rand_idx.insert(0, 0)
    rand_idx.append(n_times)
    subarrays = []
    for i in range(n_splits+1):
        subarrays.append(hfo_matrix[:, :, rand_idx[i]:rand_idx[i+1]])
    np.random.shuffle(subarrays)
    shuffled_hfo_matrix = np.concatenate(subarrays, axis=2)
    return shuffled_hfo_matrix

@njit(fastmath=True)
def shuffle_hfo_matrix_njit(hfo_matrix, n_splits):   
    n_hfo, n_beta, n_times = np.shape(hfo_matrix)
    rand_idx = np.zeros(n_splits + 2)
    rand_idx[1:n_splits+1] = np.sort(np.random.randint(200, n_times-200, n_splits))
    rand_idx[n_splits] = n_times
    shuffled_idx = np.random.permutation(np.arange(n_splits+1)) # e.g. [3, 1, 0, 2] - this is index of list
    
    # make sure shuffled array is not the same as input
    while not np.abs(shuffled_idx - np.arange(n_splits+1)).sum():
        shuffled_idx = np.random.permutation(np.arange(n_splits+1))
        
    j_left = shuffled_idx[0]
    j_right = j_left + 1 
    shuff = hfo_matrix[:, :, rand_idx[j_left]:rand_idx[j_right]]
    for i in range(1, n_splits+1):
        j_left  = shuffled_idx[i]
        j_right = j_left + 1 
        chunk = hfo_matrix[:, :, rand_idx[j_left]:rand_idx[j_right]]
        shuff = np.concatenate((shuff, chunk), axis=2)
    return shuff


@njit(fastmath=True)
def shuffle_hfo_matrix_njit_1split_concat(hfo_matrix, n_splits):
    n_hfo, n_beta, n_times = np.shape(hfo_matrix)

    rand_idx = np.random.randint(200, n_times-200)
    chunk1 = hfo_matrix[:, :, :rand_idx] # len = rand_idx
    chunk2 = hfo_matrix[:, :, rand_idx:] # len = n - rand_idx
    
    #shuff = np.zeros((n_hfo, n_beta, n_times), dtype=np.float64)
    
    #shuff[:, :, :n_times - rand_idx] = hfo_matrix[:, :, rand_idx:]
    #shuff[:, :, n_times - rand_idx:] = hfo_matrix[:, :, :rand_idx]
        
    shuff = np.concatenate((chunk2, chunk1), axis=2)

    return shuff


@njit(fastmath=True)
def shuffle_hfo_matrix_njit_1split_roll(hfo_matrix, n_splits):
    n_hfo, n_beta, n_times = np.shape(hfo_matrix)

    rand_idx = np.random.randint(200, n_times-200)
    shuff = np.roll()
    return shuff


In [25]:
t0 = time.perf_counter()

res_njit = shuffle_hfo_matrix_njit_1split(hfo_matrix, 1)
    
t_njit = time.perf_counter() - t0
print(t_njit)

t0 = time.perf_counter()

res_njit = shuffle_hfo_matrix_njit_1split(hfo_matrix, 1)
    
t_njit = time.perf_counter() - t0
print(t_njit)

1.950736811000013
1.5052925090000144


In [34]:
t0 = time.perf_counter()

res_njit = shuffle_hfo_matrix_njit(hfo_matrix, 1)
    
t_njit = time.perf_counter() - t0
print(t_njit)

t0 = time.perf_counter()

res_njit = shuffle_hfo_matrix_njit(hfo_matrix, 1)
    
t_njit = time.perf_counter() - t0
print(t_njit)

2.461872783999979
1.9223293369998373


In [28]:
t0 = time.perf_counter()

res_py = shuffle_hfo_matrix(hfo_matrix, 1)
    
t_python = time.perf_counter() - t0
print(t_python)

2.132703412000012


## calculate_surrogate_matrix

In [14]:
@njit(fastmath=True, parallel=True)
def calculate_surrogates_njit(beta_matrix, hfo_matrix, n_surrogates, n_splits=1):
    n_hfo, n_beta = np.shape(beta_matrix)[0], np.shape(beta_matrix)[1]
    surrogate_pac_matrices = np.zeros((n_surrogates, n_hfo, n_beta), dtype=np.float64)
    for k in prange(n_surrogates):
        if n_splits == 1:
            shuffled_hfo_matrix = shuffle_hfo_matrix_njit_1split(hfo_matrix, 1)
        else:
            shuffled_hfo_matrix = shuffle_hfo_matrix_njit(hfo_matrix, n_splits)
        #shuffled_hfo_matrix = hfo_matrix
        surrogate_pac_matrices[k, :, :] = calculate_PAC_matrix_njit(beta_matrix, shuffled_hfo_matrix)
    return surrogate_pac_matrices


def calculate_surrogates(beta_matrix, hfo_matrix, n_surrogates, n_splits=1):
    n_hfo, n_beta = np.shape(beta_matrix)[0], np.shape(beta_matrix)[1]
    surrogate_pac_matrices = np.zeros((n_surrogates, n_hfo, n_beta))
    for k in range(n_surrogates):
        shuffled_hfo_matrix = shuffle_hfo_matrix(hfo_matrix, n_splits)
        surrogate_pac_matrices[k, :, :] = calculate_PAC_matrix_njit(beta_matrix, shuffled_hfo_matrix)
    return surrogate_pac_matrices

In [None]:
t0 = time.perf_counter()

res_njit = calculate_surrogates_njit(beta_matrix, hfo_matrix, 10)
    
t_njit = time.perf_counter() - t0
print(t_njit)

# t0 = time.perf_counter()

# res_njit = calculate_surrogates_njit(beta_matrix, hfo_matrix, 10)
    
# t_njit = time.perf_counter() - t0
# print(t_njit)

In [None]:
t0 = time.perf_counter()

res_py = calculate_surrogates(beta_matrix, hfo_matrix, 20)
    
t_python = time.perf_counter() - t0
print(t_python)