## Compute permutations

This script reads the epoched data already imported from BVA, resampled, preprocessed and merged across subjects
and compute the both surrogate distribution of connectivity matrices and the true connectivity matrix.

No orthogonalization is used for surrogate distribution (which drastically increases the speed) because the signals are 
randomly shifted in time with respect to each other during each iteration

### Parameters to consider:
    numPerm = 500 - a number of permutations
    shiftEpch = 20 - a max number of epochs shift (measured in epochs, the overlap is not taken into acount)

### Inputs:
    pipePath  - path to a pipe folder containing epoch data
    fileName  - name of the epoch data (condition specific)
    
### Outputs:
    corrMat - the data correlation matrix (lower triangle)
    corrMatSurr - surrogate distribution of correlation matrices

### Define a working directory and load data
- define a number of conditions nC
- define a number of groups nG

In [31]:
import numpy as np

def init_list(dims, val):
    if len(dims) == 0:
        raise ValueError("Requires at least 1 dimension.")
    if len(dims) == 1:
        return [val for _ in range(dims[0])]
    return [init_list(dims[1:], val=val) for _ in range(dims[0])]


#Load the data:
pipePath = '/home/koudelka/Projects/LSD_FilipTrybusek/pipe/'
nC = 2
nG = 1

epochData = init_list((nC,nG),val=0)

for cIdx in range(0,nC):
    for gIdx in range(0,nG):
        epochData[cIdx][gIdx] = np.load(pipePath + 'c' + str(cIdx) + 'g' + str(gIdx) + '.npz')["epochData"] 
        

nEpoch, nChan, nSamp = epochData[0][0].shape
print('nEpochs: ' + str(nEpoch) + ' nChan: ' + str(nChan) + ' nSamp: ' + str(nSamp))

nEpochs: 238 nChan: 64 nSamp: 128


### Exclude or reduce some epochs if needed - skip 

In [11]:
#decim = 5
#epochData = epochData[0:-1:decim,:,:]
#newNEpoch = epochData[0][0].shape[0]
#newNEpoch

### Define functions

In [32]:
def pairwise_correlation(A, B):
    am = A - np.mean(A, axis=0, keepdims=True)
    bm = B - np.mean(B, axis=0, keepdims=True)
    return am.T @ bm /  (np.sqrt(
        np.sum(am**2, axis=0,
               keepdims=True)).T * np.sqrt(
        np.sum(bm**2, axis=0, keepdims=True)))

def atin_powCorr_compute(epochsData):    
    nEpoch, nChan, nSamp = epochsData.shape
#prepare two arrays for orthogonalization and correlation
    x = np.zeros((nEpoch,nSamp))
    y = np.zeros((nEpoch,nSamp))
#prepare a nChanXnChan correlation matrix
    corrMat = np.zeros((nChan,nChan))
    for chan1Idx in range(0,nChan):
        for chan2Idx in range(0,chan1Idx):
#a pair of two channels and all Epochs
            x = epochsData[:,chan1Idx,:]
            y = epochsData[:,chan2Idx,:]
#across all epochs orthogonalize
            x = np.reshape(x, (nEpoch*nSamp,1),order='C')
            y = np.reshape(y, (nEpoch*nSamp,1),order='C')
#get the regression coeficient from the pseudoinverse
            beta = np.real(np.dot(np.linalg.pinv(x),y))
            y= y - beta*x
#rehape back to the epochXsamples matrix
            x = np.reshape(x, (nEpoch,nSamp),order='C')
            y = np.reshape(y, (nEpoch,nSamp),order='C')
#compute RMS
            xRMS = np.power(x,2)
            yRMS = np.power(y,2)
            xRMS = np.sqrt(np.mean(xRMS,axis=1))
            yRMS = np.sqrt(np.mean(yRMS,axis=1))
            corrMat[chan1Idx,chan2Idx] = pairwise_correlation(xRMS, yRMS)
    return corrMat

def atin_powCorr_compute_surr(epochsData):
    nEpoch, nChan = epochsData.shape
#prepare a nChanXnChan correlation matrix    
    corrMat = np.zeros((nChan,nChan))
    for chan1Idx in range(0,nChan):
        for chan2Idx in range(0,chan1Idx):            
#compute RMS
            xRMS = epochsData[:,chan1Idx]
            yRMS = epochsData[:,chan2Idx]
            corrMat[chan1Idx,chan2Idx] = pairwise_correlation(xRMS, yRMS)
    return corrMat          

### Compute the observed correlation matrix
- consider the lower triangular matrix is computed to save time
  

In [34]:
corrMat = init_list((nC,nG),val=0)
for cIdx, condIns in enumerate(epochData):
    for gIdx, groupIns in enumerate(condIns):
        corrMat[cIdx][gIdx] = atin_powCorr_compute(groupIns)

  corrMat[chan1Idx,chan2Idx] = pairwise_correlation(xRMS, yRMS)


### and save the results...

In [35]:
for cIdx, condIns in enumerate(corrMat):
    for gIdx, groupIns in enumerate(condIns):
        np.save(pipePath + 'c' + str(cIdx) + 'g' + str(gIdx) + '_observed', groupIns)

### Slightly change the correlation matrix computation:
- since we do a random circular shifts in time we do not need to orthogonalize
- then we can compute RMS within each epoch before itarating across electrode pairs
- this will save us a lot of time and this is implemented in the **corrMatSurr** function

In [36]:
numPerm = 500
shiftEpch = 20

#prepare RMS (the computationally demanding step)
rmsMat = init_list((nC,nG),val=0)
for cIdx in range(0,nC):
    for gIdx in range(0,nG):
        rmsMat[cIdx][gIdx] =  np.sqrt(np.mean(np.power(epochData[cIdx][gIdx],2),2))



corrMatSurr = init_list((nC,nG),val=0)

for cIdx, condIns in enumerate(rmsMat):
    for gIdx, groupIns in enumerate(condIns):
        npCorrMat = []
        for iterIdx in range(0,numPerm):
            print('cond: ' + str(cIdx) + ' group: ' + str(gIdx) + ' Iter: ' + str(iterIdx) + '/' + str(numPerm), end='\r')
            epochShift = np.copy(groupIns)
            for chanIdx in range(0,nChan): 
                randShift = np.random.randint(-shiftEpch,shiftEpch,1)
                epochShift[:,chanIdx] = np.roll(epochShift[:,chanIdx],randShift,axis=0)
            npCorrMat.append(atin_powCorr_compute_surr(epochShift))
        corrMatSurr[cIdx][gIdx] = np.array(npCorrMat)

for cIdx, condIns in enumerate(corrMatSurr):
    for gIdx, groupIns in enumerate(condIns):
        np.save(pipePath + 'c' + str(cIdx) + 'g' + str(gIdx) + '_surrogate', groupIns)

cond: 0 group: 0 Iter: 1/500

  corrMat[chan1Idx,chan2Idx] = pairwise_correlation(xRMS, yRMS)


cond: 1 group: 0 Iter: 499/500

### and save the results again...

In [37]:
for cIdx, condIns in enumerate(corrMatSurr):
    for gIdx, groupIns in enumerate(condIns):
        np.save(pipePath + 'c' + str(cIdx) + 'g' + str(gIdx) + '_surrogate', groupIns)