In [2]:
import mne
from src.data.utils import *
import os
from os.path import join, dirname
from pandas import read_csv
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import matplotlib.gridspec as gridspec
from scipy.io.wavfile import read as read_wav
from scipy.stats import sem, linregress
import git



In [3]:
base_dir = git.Repo('.', search_parent_directories=True).working_tree_dir

In [4]:
def get_experiment_data(subject, aux_channels_correct = True):
    """
    extracts important lag indices according to correlation, trigger, stimuli list, and stimulus data recorded through StimTrak
    Args:
        subject (int):                          subject identifier
        aux_channels_correct (bool, optional):  StimTrak connected to correct auxiliary channel. Defaults to True.

    Returns:
        _type_: corr_indices, trigger_indices, stimuli_list, stim_eeg_data_l_study, stim_eeg_data_r_study
    """
    header_file_study = get_eeg_header_file(base_dir, subject)
    montage_file = montage_file = join(base_dir, "data","CACS-32_NO_REF.bvef")
    montage = mne.channels.read_custom_montage(montage_file)

    csv_log_file = get_csv_log_file(base_dir, subject)
    psychopy_log_file = read_csv(csv_log_file, sep=',')
    stimuli_base_folder = join(base_dir, "data", "stimuli")


    if 'randomisation' in psychopy_log_file.columns.tolist():
        randomisation = psychopy_log_file.loc[0,'randomisation']
    else:
        randomisation = int(input('Specifiy randomisation (0 or 1), because it is missing in log file'))
    stimuli_list = generate_stimuli_list(stimuli_base_folder, randomisation)

    eeg_recording_study = get_brainvision_from_header(header_file_study, montage)

    triggers_study = get_triggers(eeg_recording_study)
    trigger_onsets_study = np.array([trigger['onset'] for trigger in triggers_study])

    #just swap left and right channels if set-up was wrong
    if aux_channels_correct:
        stim_eeg_data_l_study, _ = eeg_recording_study['Aux1']
        stim_eeg_data_r_study, _ = eeg_recording_study['Aux2']
    else:
        stim_eeg_data_l_study, _ = eeg_recording_study['Aux2']
        stim_eeg_data_r_study, _ = eeg_recording_study['Aux1']

    #translate trigger times into indices
    trigger_indices = np.array(trigger_onsets_study * 1000, dtype = int)
    corr_indices = get_lag_indicies_all_trials(stimuli_list, stim_eeg_data_l_study, stim_eeg_data_r_study)

    return corr_indices, trigger_indices, stimuli_list, stim_eeg_data_l_study, stim_eeg_data_r_study

In [5]:
def get_ten_second_windows(stimuli_list):
    """_summary_

    Args:
        stimuli_list (list):

    Returns:
        list: number of ten second windows that fit in the stimulu provided in the input list
    """
    ten_second_windows = []
    for stimulus in stimuli_list:
        _, stim_wav_0 = read_wav(stimulus['path'])
        stim_wav_0_l, stim_wav_0_r = downsample_wav(stim_wav_0[:,0], stim_wav_0[:,1])

        if stimulus['side'] == 'r':
            stim_wav = stim_wav_0_r
        elif stimulus['side'] == 'l':
            stim_wav = stim_wav_0_l
        ten_second_windows.append(int(((len(stim_wav) / 1000) - ((len(stim_wav) / 1000) % 10)) / 10))
    return ten_second_windows

In [6]:
def calc_drift_analysis(corr_indices, trigger_indices, stimuli_list, stim_eeg_data_l_study, stim_eeg_data_r_study):
    """Performs analysis of drift within EEG measurement

    Args:
        corr_indices (list):        alignment indices according to cross-correlation
        trigger_indices (list):     alignment indices according to trigger
        stimuli_list (list):        list of audio stimuli
        stim_eeg_data_l_study ():   left channel stimulus data recorded through SimTrak and EEG amp
        stim_eeg_data_r_study ():   right channel stimulus data recorded through SimTrak and EEG amp

    Returns:
        (drifts_subject_corr:   np.array, drifts_subject_trigger:np.array): calculated drifts on ten second windows relative to the global alignment,
                                a row represents one trial
    """
    ten_second_windows = get_ten_second_windows(stimuli_list)
    drifts_subject_corr = np.empty((20,max(ten_second_windows)))
    drifts_subject_corr[:] = np.nan

    drifts_subject_trigger = drifts_subject_corr.copy()

    for corr_index, trigger_index, stimulus, i in zip(corr_indices, trigger_indices, stimuli_list, range(0,20)):
        _, stim_wav_0 = read_wav(stimulus['path'])
        stim_wav_0_l, stim_wav_0_r = downsample_wav(stim_wav_0[:,0], stim_wav_0[:,1])

        if stimulus['side'] == 'r':
            stim_wav = stim_wav_0_r
            stim_eeg = stim_eeg_data_r_study[0,:]
        elif stimulus['side'] == 'l':
            stim_wav = stim_wav_0_l
            stim_eeg = stim_eeg_data_l_study[0,:]

        stim_eeg_snippet_corr = stim_eeg[corr_index: corr_index + len(stim_wav)]
        stim_eeg_snippet_trigger = stim_eeg[trigger_index: trigger_index + len(stim_wav)]

        for j in range(0,ten_second_windows[i]):
            start_second, end_second = j * 10, j * 10 + 10

            #correlation
            corr = correlate(stim_eeg_snippet_corr[start_second * 1000 : end_second * 1000], stim_wav[start_second * 1000 : end_second * 1000], mode = 'full')
            lags = correlation_lags(stim_eeg_snippet_corr[start_second * 1000: end_second * 1000].size, stim_wav[start_second * 1000: end_second * 1000].size, mode='full')
            lag_index = int(lags[np.argmax(np.abs(corr))])
            drifts_subject_corr[i,j] = lag_index

            #trigger
            corr = correlate(stim_eeg_snippet_trigger[start_second * 1000 : end_second * 1000], stim_wav[start_second * 1000 : end_second * 1000], mode = 'full')
            lags = correlation_lags(stim_eeg_snippet_trigger[start_second * 1000: end_second * 1000].size, stim_wav[start_second * 1000: end_second * 1000].size, mode='full')
            lag_index = int(lags[np.argmax(np.abs(corr))])
            drifts_subject_trigger[i,j] = lag_index
            
    return drifts_subject_corr, drifts_subject_trigger

In [13]:
def plot_drift_analysis(drifts_subject_corr, drifts_subject_trigger, subject, base_dir):
    """
    Plots the results of analysing the drifts within an EEG measurement

    Args:
        drifts_subject_corr (np.array):     drifts on ten second windows based on correlation analysis
        drifts_subject_trigger (np.array):  drifts on ten second windows based on triggers
        subject (int):                      subject identifier
        base_dir (string):                  directory where git repository is located
    """

    drifts_mean_corr, drifts_variance_corr = np.mean(drifts_subject_corr, axis=0, where= np.isfinite(drifts_subject_corr)), np.var(drifts_subject_corr, axis = 0, where = np.isfinite(drifts_subject_corr))
    drifts_mean_trigger, drifts_variance_trigger = np.mean(drifts_subject_trigger, axis=0, where= np.isfinite(drifts_subject_trigger)), np.var(drifts_subject_trigger, axis = 0, where = np.isfinite(drifts_subject_trigger))


    x = np.linspace(0, len(drifts_mean_corr) - 1, len(drifts_mean_corr))
    fig, ax = plt.subplots(2,2,figsize = (14,8))
    fig.tight_layout(pad=3.0)
    ax[0,0].plot(x, drifts_mean_corr, label = 'mean', linewidth = 3)
    ax[0,0].fill_between(x, drifts_mean_corr - drifts_variance_corr, drifts_mean_corr + drifts_variance_corr, alpha = .2, label = 'variance')
    ax[0,0].set_ylabel('ms')
    ax[0,0].set_xticks(x)
    ax[0,0].grid()
    ax[0,0].legend()
    ax[0,0].set_title(f'Drift Subject {str(subject)} mean over all trials (StimTrak)')

    ax[0,1].plot(x, drifts_mean_trigger, label = 'mean', linewidth = 3)
    ax[0,1].fill_between(x, drifts_mean_trigger - drifts_variance_trigger, drifts_mean_trigger + drifts_variance_trigger, alpha = .2, label = 'variance')
    ax[0,1].set_ylabel('ms')
    ax[0,1].set_xticks(x)
    ax[0,1].grid()
    ax[0,1].legend()
    ax[0,1].set_title(f'Drift Subject {str(subject)} mean over all trials (Trigger)')

    colors = cm.rainbow(np.linspace(0, 1, 20))

    trial_corr = drifts_subject_corr[0,:]
    trial_corr = trial_corr[np.isfinite(trial_corr)]
    x_reg = np.linspace(0, len(trial_corr) - 1, len(trial_corr))
    res = linregress(x_reg, trial_corr)
    ax[1,0].plot(x_reg, res.intercept + res.slope*x_reg, linewidth = 1.5, color = colors[0], label = 'linear regression')
    ax[1,0].plot(x_reg, trial_corr, 'o', color = colors[0], markersize = 3, label = 'data points')

    for trial_corr, trial_trigger, c, trial in zip(drifts_subject_corr, drifts_subject_trigger, colors, range(1,21)):
        #StimTrak
        trial_corr = trial_corr[np.isfinite(trial_corr)]
        x_reg = np.linspace(0, len(trial_corr) - 1, len(trial_corr))
        res = linregress(x_reg, trial_corr)
        ax[1,0].plot(x_reg, res.intercept + res.slope*x_reg, linewidth = 1.5, color = c)
        ax[1,0].plot(x_reg, trial_corr, 'o', color = c, markersize = 3)
        #Trigger
        trial_trigger = trial_trigger[np.isfinite(trial_trigger)]
        x_reg = np.linspace(0, len(trial_corr) - 1, len(trial_corr))
        res = linregress(x_reg, trial_trigger)
        ax[1,1].plot(x_reg, res.intercept + res.slope*x_reg, linewidth = 1.5, color = c, label = str(trial))
        ax[1,1].plot(x_reg, trial_trigger, 'o', color = c, markersize = 3)

    ax[1,0].set_ylabel('ms')
    ax[1,0].set_xticks(x)
    ax[1,0].set_xlabel('ten second window')
    ax[1,0].grid()
    ax[1,0].set_title(f'Drift Subject {str(subject)} individual trials (StimTrak)')
    ax[1,0].legend()

    ax[1,1].set_ylabel('ms')
    ax[1,1].set_xticks(x)
    ax[1,1].set_xlabel('ten second window')
    ax[1,1].grid()
    ax[1,1].set_title(f'Drift Subject {str(subject)} individual trials (Trigger)')
    ax[1,1].legend(loc='upper center', bbox_to_anchor=(0.0, -0.13),
          fancybox=True, shadow=True, ncol=10, title = "Trial")

    plt.savefig(join(base_dir, "reports", "figures", "drift",str(subject) + ".pdf"), bbox_inches='tight')

In [14]:
#Set false if the standard setting (Aux1 - left, Aux2 right) is violated
#Set file paths. Montage File is different from the standard file from brainproducts
subjects = list(range(110,112))
aux_channels_correct = np.ones((len(subjects)), dtype=int).tolist()
#aux_channels_correct[2] = 0

In [None]:
for subject, aux in zip(subjects, aux_channels_correct):
    corr_indices, trigger_indices, stimuli_list, stim_eeg_l, stim_eeg_r = get_experiment_data(subject, aux)
    drifts_subject_corr, drifts_subject_trigger = calc_drift_analysis(corr_indices,trigger_indices, stimuli_list, stim_eeg_l, stim_eeg_r)
    plot_drift_analysis(drifts_subject_corr, drifts_subject_trigger, subject, base_dir)