# Source/Sink 
Integrates code from dsp.py to fit MVAR model (A_hat) and then compute source/sink values from the resultant model 

### Source/sink code

In [5]:
import numpy as np
from scipy import signal
from pathlib import Path


def compute_source_sink_index(A_hat, normalize=True):
    N = A_hat.shape[1]

    # Take absolute value and set diagonal to zero
    A_abs = np.abs(A_hat)
    for ii in range(N):
        A_abs[ii, ii] = 0

    # Compute row rank
    row_sums = np.sum(A_abs, axis=1)
    row_order = np.argsort(row_sums)
    row_rank = (np.argsort(row_order) + 1) / N

    # Compute column rank
    col_sums = np.sum(A_abs, axis=0)
    col_order = np.argsort(col_sums)
    col_rank = (np.argsort(col_order) + 1) / N
    
    # Compute sink, source indices
    
    #sink_idx = np.sqrt(2) - np.sqrt((row_rank - 1) ** 2 + (col_rank - 1/N)**2)
    sink_idx = np.sqrt(2) - np.sqrt((row_rank - 1) ** 2 + (col_rank - 1/N)**2)
    source_idx = np.sqrt(2) - np.sqrt((row_rank - 1/N) ** 2 + (col_rank - 1)**2)
    
    # Compute source influence
    source_influence = np.matmul(np.abs(A_hat), source_idx)
    source_influence /= np.max(source_influence)
    
    # Compute sink connectivity
    sink_connectivity = np.matmul(np.abs(A_hat), sink_idx)
    sink_connectivity /= np.max(sink_connectivity)
    
    # Normalize
    if normalize:
        source_idx /= np.sqrt(2)
        sink_idx /= np.sqrt(2)
    
    # Compute SSI
    ssi = sink_idx * source_influence * sink_connectivity
    
    return sink_idx, source_idx, source_influence, sink_connectivity, ssi


def computeA(data, alpha=0):
    """Compute the A transition matrix from a vector timeseries
    """
    
    nchns, T = data.shape
    
    Z = data[:, 0:T-1]
    Y = data[:, 1:T]
    D = np.linalg.inv(np.matmul(Z, Z.transpose()) + alpha*np.eye(nchns))
    Ahat = np.matmul(Y, np.matmul(Z.transpose(), D))
    return Ahat


def computeR1(data, alpha=0):
    """Compute the A transition matrix from a vector timeseries
    """
    
    nchns, T = data.shape
    
    Z = data[:, 0:T-1]
    Y = data[:, 1:T]
    R1 = np.matmul(Y, Z.transpose()) / (T - 1)
    return R1


def get_spectral_entropy(feature):
    ss = []
    for cc in range(18):
        f, t, Sxx = signal.spectrogram(feature[:, cc], nperseg=200, noverlap=100)
        Pxx = np.power(np.abs(Sxx[1:, :]), 2)
        se = -np.sum(np.log2(Pxx) * Pxx, axis=0)
        ss.append(se)
    return np.asarray(ss).transpose()


def spectral_entropy(data):
    D = np.fft.fft(data.transpose())
    D = np.abs(D[:, 1:])
    D = D / np.sum(D, axis=1)[:, np.newaxis]
    return np.sum(-D * np.log(D), axis=1)


### Preprocess


In [2]:
from scipy.signal import detrend
import mne
def outlier_repeat(data: np.ndarray, sd: float, rounds: int = np.inf,
                   axis: int = 0) -> tuple[tuple[int, int]]:
    """ Remove outliers from data and repeat until no outliers are left.

    This function removes outliers from data and repeats until no outliers are
    left. Outliers are defined as any data point that is more than sd standard
    deviations from the mean. The function returns a tuple of tuples containing
    the index of the outlier and the round in which it was removed.

    Parameters
    ----------
    data : np.ndarray
        Data to remove outliers from.
    sd : float
        Number of standard deviations from the mean to consider an outlier.
    rounds : int
        Number of times to repeat outlier removal. If None, the function will
        repeat until no outliers are left.
    axis : int
        Axis of data to remove outliers from.

    Returns
    -------
    tuple[tuple[int, int]]
        Tuple of tuples containing the index of the outlier and the round in
        which it was removed."""
    inds = list(range(data.shape[axis]))

    # Square the data and set zeros to small positive number
    R2 = np.square(data)
    R2[np.where(R2 == 0)] = 1e-9

    # find all axes that are not channels (example: time, trials)
    axes = tuple(i for i in range(data.ndim) if not i == axis)

    # Initialize stats loop
    sig = np.std(R2, axes)  # take standard deviation of each channel
    cutoff = (sd * np.std(sig)) + np.mean(sig)  # outlier cutoff
    i = 1

    # remove bad channels and re-calculate variance until no outliers are left
    while np.any(np.where(sig > cutoff)) and i <= rounds:

        # Pop out names to bads output using comprehension list
        for j, out in enumerate(np.where(sig > cutoff)[0]):
            yield inds.pop(out - j), i

        # re-calculate per channel variance
        R2 = R2[..., np.where(sig < cutoff)[0], :]
        sig = np.std(R2, axes)
        cutoff = (sd * np.std(sig)) + np.mean(sig)
        i += 1

def channel_outlier_marker(input_raw, outlier_sd=3,
                           max_rounds=np.inf, axis=0,
                            verbose=True
                           ) -> list[str]:
    """Identify bad channels by variance.

    Parameters
    ----------
    input_raw : Signal
        Raw data to be analyzed.
    outlier_sd : int, optional
        Number of standard deviations above the mean to be considered an
        outlier, by default 3
    max_rounds : int, optional
        Maximum number of variance estimations, by default runs until no
        more bad channels are found.
    axis : int, optional
        Axis to calculate variance over, by default 0

    Returns
    -------
    list[str]
        List of bad channel names.
    """ 

    names = input_raw.copy().pick('data').ch_names
    data = detrend(input_raw.get_data('data'))  # channels X time
    bads = []  # output for bad channel names
    desc = []  # output for bad channel descriptions

    # Pop out names to bads output using comprehension list
    for ind, i in outlier_repeat(data, outlier_sd, max_rounds, axis):
        bads.append(names[ind])
        desc.append(f'outlier round {i} more than {outlier_sd} SDs above mean')
        # log channels excluded per round
        if verbose:
            mne.utils.logger.info(f'outlier round {i} channels: {bads}')


    return bads

In [3]:
def preprocess(raw):
    # Drop bad channels
    raw.drop_channels(raw.info['bads'])
    raw.load_data()
    # Notch filter at 60 Hz to remove line noise
    notch_filtered = raw.notch_filter(60, notch_widths=2)
    # Extract bandapss 0.5 to 300 Hz using fourth order Butterworth filter
    final_raw = notch_filtered.filter(0.5, 300, 
                                    method='iir')

    outliers = channel_outlier_marker(final_raw, 3, 2)
    final_raw.drop_channels(outliers)
    # Set common average reference 
    final_raw.set_eeg_reference()

    return final_raw

### Run source/sink

In [75]:
def run_source_sink(raw, window_size=0.5):
    '''
    Runs source/sink code on given raw file. 

    -------------------------------------
    Input parameters:
    '''
    
    # raw.info['bads'] = channel_outlier_marker(raw, 3, 2)
    sfreq = raw.info["sfreq"]
    
    # Remove non-SEEG channels
    bad_words = ['C', 'TRIG', 'OSAT', 'PR', 'Pleth', 'EKG']
    for ch in raw.ch_names:
        if any(word in ch for word in bad_words):
            raw.info['bads'].append(ch)
        # Remove EEG channels based off length of string (EEG channels are no longer than 3 chars)
        if len(ch) < 4:
            raw.info['bads'].append(ch)

    # Do preprocessing on raw data
    final_raw = preprocess(raw)

    final_raw_epoch = mne.make_fixed_length_epochs(final_raw, duration=0.5)
    ch_names = final_raw.ch_names
    A_hats = []
    sink_indices = []
    source_indices = []
    source_infs = []
    sink_connectivity_list = []
    ssi_list = [] 

    for epoch in final_raw_epoch:
        A_hat = computeA(epoch)
        A_hats.append(A_hat)
        sink_idx, source_idx, source_influence, sink_connectivity, ssi = compute_source_sink_index(A_hat)
        sink_indices.append(sink_idx)
        source_indices.append(source_idx)
        source_infs.append(source_influence)
        sink_connectivity_list.append(sink_connectivity)
        ssi_list.append(ssi)

    return ch_names, A_hats, sink_indices, source_indices, source_infs, sink_connectivity_list, ssi_list
    

In [68]:
from mne.io import read_raw_edf
# edf_fpath = Path.home() / 'Desktop' / 'DUKEDOCS' / 'DUKERESEARCH' / 'NETWORKSANDNMOD' / 'SRI' / 'NEURAL FRAGILITY' / 'TESTADULT' / 'CLICTALTEST1.EDF'
edf_fpath = Path.home() / 'Research' / 'fragility_code' / 'DATA' / 'MC01_1.edf'
raw = read_raw_edf(edf_fpath)

Extracting EDF parameters from /Users/dsexton/Research/fragility_code/DATA/MC01_1.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = read_raw_edf(edf_fpath)


In [74]:
raw

Unnamed: 0,General,General.1
,Filename(s),MC01_1.edf
,MNE object type,RawEDF
,Measurement date,2023-10-02 at 08:24:01 UTC
,Participant,X
,Experimenter,Unknown
,Acquisition,Acquisition
,Duration,00:02:38 (HH:MM:SS)
,Sampling frequency,2048.00 Hz
,Time points,322112
,Channels,Channels


In [76]:
ch_names, A_hat, sink_idx, source_idx, source_inf, sink_conn, ssi = run_source_sink(raw)

Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 58 - 62 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 58.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 58.25 Hz)
- Upper passband edge: 61.50 Hz
- Upper transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 61.75 Hz)
- Filter length: 13517 samples (6.600 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done  71 tasks      | elapsed:    0.5s


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 3e+02 Hz

IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 16 (effective, after forward-backward)
- Cutoffs at 0.50, 300.00 Hz: -6.02, -6.02 dB

outlier round 1 channels: ['RLS4']
outlier round 1 channels: ['RLS4', 'RLS7']
outlier round 2 channels: ['RLS4', 'RLS7', 'RLIA1']
outlier round 2 channels: ['RLS4', 'RLS7', 'RLIA1', 'RLIA2']
outlier round 2 channels: ['RLS4', 'RLS7', 'RLIA1', 'RLIA2', 'RPPS4']
outlier round 2 channels: ['RLS4', 'RLS7', 'RLIA1', 'RLIA2', 'RPPS4', 'RPPS5']
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Not setting metadata
314 matching events found
No baseline correction applied
0 projection items activated


### Visualize source/sink metrics

In [None]:
import seaborn as sns
import pandas as pd

df_dict = {}
df_dict['Channel'] = []
df_dict['Time'] = []
df_dict['SSI'] = []

time = 0

for ii in range(len(A_hat)):
    df_dict['Channel'].extend(ch_names)
    df_dict['Time'].extend([time] * len(ch_names))
    df_dict['SSI'].extend(list(ssi[ii]))
    time += 0.5


In [99]:
ssi_df = pd.DataFrame(df_dict)

mean_scores = ssi_df.groupby('Channel')['SSI'].mean()

# Sort and take top 20
top_channels = mean_scores.sort_values(ascending=False).head(20)

print(top_channels)

Channel
RLI4      0.288748
RPMS6     0.268945
RLS3      0.267136
RPMM2     0.262281
RLIA3     0.250959
RLI5      0.237933
RPMM1     0.234825
RPMS5     0.227158
RLI3      0.212558
RPPS6     0.200063
RPMM3     0.191795
RPMS4     0.191658
RLP5      0.183784
RLI2      0.181889
RPPI12    0.176283
RFPI3     0.172082
ROAS6     0.164540
RPMS7     0.154882
RLI6      0.152959
RFPS1     0.148174
Name: SSI, dtype: float64
