In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mayavi import mlab
import mne
from colorednoise import powerlaw_psd_gaussian as pink
from scipy.stats import pearsonr
import sys
sys.path.insert(1, '../')

from ESINet.forward import create_forward_model
from ESINet.simulations import run_simulations, create_eeg
from ESINet.util import *
from ESINet.ann import *
from ESINet.viz import splot


pth_fwd = 'forward_models/ico3/'

In [None]:
create_forward_model(pth_fwd, sampling='ico3')

In [None]:
leadfield = load_leadfield(pth_fwd)
n_elec, n_dipoles = leadfield.shape

# Source Parameters
np.random.seed(seed=None)
beta = 1.5
n_source_clusters = 2
n_timepoints = 1000
sfreq = 1000  # Hz
time_step = 1/sfreq
trial_duration = 1  # second

# Spatial dimension (dipole selection)
seeds = np.random.choice(np.arange(n_dipoles), size=n_source_clusters)
# seeds = [10, 1000]

# Temporal dimension
correlatedness = 0.5
source_time_course = np.zeros((n_source_clusters, n_timepoints))
source_time_course[0] = pink(beta, n_timepoints)
source_time_course[1] = pink(beta, n_timepoints)
source_time_course[1] = source_time_course[1]*correlatedness*source_time_course[0] * 1/correlatedness


source = np.zeros((n_dipoles, n_timepoints))
EEG = np.zeros((n_elec, n_timepoints))


In [None]:
r = pearsonr(source_time_course[0], source_time_course[1])[0]
title = f'r={r:.2f}'
plt.figure(1)
[plt.plot(stc, label=i) for i, stc in enumerate(source_time_course)]
plt.title(title)
plt.legend()

In [None]:
stc

In [None]:
%matplotlib qt
for seed in seeds:
    source[seed, :] = 1
splot(source, pth_fwd, figure=mlab.figure(''))


## Autoregressive Model
https://www.frontiersin.org/articles/10.3389/fnhum.2014.00448/full#h3

https://en.wikipedia.org/wiki/Autoregressive_model


In [None]:
from random import uniform
def autoregressive_model(size=100, order=1, coeff_range=(0, 1)):
    
    c = np.random.randn(1)
    coefficients = [uniform(*coeff_range) for _ in range(order)]
    print(f'Coefficients: {coefficients}')
    time_series = np.zeros((size))
    time_series[0] = c
    for t in range(1, size):
        noise = np.random.randn(1)*1
        ar_term = np.sum([coefficients[i]*time_series[t-i] for i in range(1, order)])
        time_series[t] = c + noise + ar_term
        
    return time_series
signal = autoregressive_model(order=2, coeff_range=(0.9, 1))
plt.figure(2)
plt.plot(signal)

In [None]:
from numpy.random import uniform

def self_similarity_function(n_timeseries, param):
    if isinstance(param, (tuple, list)):
        out = np.identity(n_timeseries) * uniform(*param, size=n_timeseries)
    else:
        out = np.identity(n_timeseries) * param
    return out
def autoregressive_model_mv(size=(3, 100), order=1, coeff_range=(0.2, 1), 
    p_connected=.25, self_similarity_parameter=0., redo_crit=200):
    ''' Generate Time series that are stochastically correlated using a 
    multi-variate autoregressive model (MAR).

    Paramters:
    ----------
    size : tuple/list, channels x time points, determines the final dimension 
        of the signal
    order : int > 0, order of the autoregressive model
    coeff_range : tuple/list, the range that the coefficients of the MAR can take.
    p_connected : float, between 0 and 1 that determines the likelihood that two 
        signals covary at all.
    self_similarity_parameter : float, between 0 and 1 indicating the average 
        autocorrelation of each signal. Use 0.5 for more 1/f - like signals. 
        If two floats are given then a random number is drawn in the given range 
        of floats.
    redo_crit : int/float, if the max(abs()) of the resulting signal is higher 
        than redo_crit the function is called again. This prevents exploding 
        signals.
    Return:
    -------
    time_series : numpy.ndarray, The resulting MAR time series
    coefficients : list, the coefficient matrices for the MAR

    '''

    n_timeseries, n_timepoints = size
    c = np.random.randn(1)
    

    # Create coefficients
    coefficients = [list() for _ in range(order)]
    for i in range(order):
        coefficients[i] = np.array([uniform(*coeff_range)
            if np.random.rand()<p_connected else 0
            for _ in range(n_timeseries**2)] ).reshape(n_timeseries,n_timeseries) + self_similarity_function(n_timeseries, self_similarity_parameter)
        # coefficients = [np.array([uniform(*coeff_range) for _ in range(n_timeseries**2)]).reshape(n_timeseries,n_timeseries) for _ in range(1, order+1)]


    # print(f'Coefficients: {coefficients[0]}')
    # print(f'len(coefficients) = {len(coefficients)}')

    time_series = np.zeros((size))
    for t in range( n_timepoints):
        noise = np.random.randn(n_timeseries)
        ar_term = np.zeros((1, n_timeseries))
        for k in range(1, order):
            ar_term += np.matmul(coefficients[k], time_series[:, t-k])
            
        time_series[:, t] = ar_term + noise
    while np.max(np.abs(time_series)) > redo_crit:
        print(f'redoing')
        time_series, coefficients = autoregressive_model_mv(size=size, order=order, 
            coeff_range=coeff_range, p_connected=p_connected, 
            self_similarity_parameter=self_similarity_parameter, 
            redo_crit=redo_crit)
    return time_series, coefficients

coeff_range = (-.05, .05) 
order = 2
p_connected = 0.25
self_similarity_parameter = (0.25, 0.75)
signal, coefficients = autoregressive_model_mv(size=(5, 100), order=order, coeff_range=coeff_range, 
    p_connected=p_connected, self_similarity_parameter=self_similarity_parameter)
%matplotlib qt
plt.close()
plt.figure(2)

plt.subplot(311)
plt.imshow(signal)
plt.subplot(312)
[plt.plot(signal[i, :], label=i) for i in range(signal.shape[0])]
plt.legend()
plt.subplot(313)
plt.imshow(np.mean(coefficients, axis=0), cmap='bwr')
plt.clim(coeff_range)
plt.colorbar()
plt.tight_layout()

In [None]:
import time
start = time.time()
signal, coefficients = autoregressive_model_mv(size=(10, 100), order=order, coeff_range=coeff_range, p_connected=p_connected)
end = time.time()
print(f'{end-start}')


In [None]:
self_similarity_function(3, (0.5, 0.75))