# Fit model.

In [None]:
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
import mne

### Parameters.

In [None]:
reject = None 
tmin, tmax = (-1.25, 0.55)
sfreq = 128
n_channels = 2 # We use the MCCA output
colors = ['#ff7f0e', '#1f77b4']

### Import data.

In [None]:
# Behavior (sequences)
dname_seq = '/media/jacques/DATA1/2019_MusicPred/experimentContinuousMatrix/data/behavior/'
DEC = np.load(dname_seq + 'DEC.npy')
MUSICIANS = np.load(dname_seq + 'MUSICIANS.npy')
SEQS = np.load(dname_seq + 'SEQS.npy')
SEQS_TRAINING = np.load(dname_seq + 'SEQS_TRAINING.npy')
sujs = np.load(dname_seq + 'sujs.npy', allow_pickle=True)
n_sujs = len(sujs)
n_dec = SEQS_TRAINING.shape[1]
SEQS.shape

In [None]:
# Model data
SURPRISE_FITTED = np.load(dname_seq + 'SURPRISE_FITTED.npy')
SURPRISE_FITTED = np.reshape(SURPRISE_FITTED, (n_sujs, -1))

In [None]:
# EEG files
data_dir = '/media/jacques/DATA1/2019_MusicPred/experimentContinuousMatrix/data/eeg/ERP/'
ch_names = list(np.load(data_dir + 'ch_names.npy'))
ERP = np.load(data_dir + 'aSC_clean.npy')
info = np.load(data_dir + 'info.npy', allow_pickle=True, encoding='latin1').item()
_, n_seqs, n_tones, n_channels, n_times = ERP.shape
ERP.shape

### Build a function to perform linear regression.

In [None]:
def regressModel(erp, predictors):
    '''
    Regress model surprise on all time points and channels. 
    Return regression coefficients and -log10(p-values)
    '''
    
    import numpy as np
    import statsmodels.api as sm
    
    # Get dimensions
    _, n_channels, n_times = erp.shape
    predictors = np.atleast_2d(predictors).T
    n_predictors = predictors.shape[1]
    
    # Initiate arrays
    BETA = np.zeros((n_predictors, n_channels, n_times))
    P_VALUES = np.zeros((n_predictors, n_channels, n_times))
    R2 = np.zeros((n_channels, n_times))

    # Predictor
    X = sm.add_constant(predictors)

    # Fit model for every point in time and space
    for i in range(n_times):
        for j in range(n_channels):

            # Fit and summarize OLS model
            y = erp[:, j, i]
            mod = sm.OLS(y, X)
            res = mod.fit()
            BETA[:, j, i] = res.params[1:] # Remove intercept
            P_VALUES[:, j, i] = res.pvalues[1:] # Remove intercept
            R2[j, i] = res.rsquared
    
    # Return
    return np.squeeze(BETA), np.squeeze(P_VALUES), np.squeeze(R2)

### Perform linear regression for all subjects.

In [None]:
BETAS = np.full((n_sujs, n_channels, n_times), np.nan)
R2 = np.full((n_sujs, n_channels, n_times), np.nan)
BETAS.shape

In [None]:
for i, suj in enumerate(sujs):
    
    # Load data
    _erp = np.reshape(ERP[i], (-1, n_channels, n_times))
    _surprise = np.reshape(SURPRISE_FITTED[i], (-1))
    
    # Remove NaNs
    idx = np.logical_and(np.logical_not(np.isnan(_surprise)), 
                         np.logical_not(np.any(np.isnan(_erp), axis=(1, 2))))
    _erp = _erp[idx]
    _surprise = _surprise[idx]
    if _erp.size == 0: print('error ', i)
    
    # Compute LinearRegression
    BETAS[i], _, R2[i] = regressModel(_erp, _surprise)

In [None]:
np.save('BETAS_FITTED', BETAS)
np.save('R2_FITTED', R2)

### Regress all K0, K1 and K2.

In [None]:
# Model data
SURPRISE = np.load(dname_seq + 'SURPRISE.npy')
SURPRISE = np.reshape(SURPRISE, (3, n_sujs, -1))
SURPRISE.shape

In [None]:
BETAS = np.full((3, n_sujs, n_channels, n_times), np.nan)
R2 = np.full((3, n_sujs, n_channels, n_times), np.nan)
BETAS.shape

In [None]:
for k in range(3):
    for i, suj in enumerate(sujs):

        # Load data
        _erp = np.reshape(ERP[i], (-1, n_channels, n_times))
        _surprise = np.reshape(SURPRISE[k, i], (-1))

        # Remove NaNs
        idx = np.logical_and(np.logical_not(np.isnan(_surprise)), 
                             np.logical_not(np.any(np.isnan(_erp), axis=(1, 2))))
        _erp = _erp[idx]
        _surprise = _surprise[idx]
        if _erp.size == 0: print('error ', i)

        # Compute LinearRegression
        BETAS[k, i], _, R2[k, i] = regressModel(_erp, _surprise)

In [None]:
np.save('BETAS_SURPRISE', BETAS)
np.save('R2_SURPRISE', R2)