## Train linear model
This notebook fits linear encoding models to phoneme level data using different stimulus matrices, depending on the analysis in question. Here is a useful shorthand for the different model constructions:
* Model 1: 14 phonological features + 4 task features + normalized EMG
* Model 2: 42 phonological features (14 perception only, 14 production only, 14 combined) + 4 task features + normalized EMG
* Model 3: 14 phonological features + 2 task features (predictability contrast omitted) + normalized EMG
* Model 4: 14 phonological features + 2 task features (perception/production contrast omitted) + normalized EMG

Any model ending in `e` is the same as the above, but _without_ an EMG regressor.

**Warning: this notebook is very computationally intensive.** You may wish to run it on your lab server instead of a local machine. Another alternative is to split the subjects up into smaller subsets, then run those subsets instead of holding all the subjects' models in memory at the same time.

In [None]:
# Imports
import mne
import re
import numpy as np
from glob import glob
import os
import csv
import pandas as pd
from matplotlib import cm,rcParams
from matplotlib import pyplot as plt
import pickle
import h5py
from tqdm.notebook import tqdm
import sys
sys.path.append('./utils/')
import textgrid
import strf
import warnings

In [None]:
# Local paths
eeg_data_path = '/path/to/dataset/' # downloadable from OSF: https://doi.org/10.17605/OSF.IO/FNRD9
git_path  = '/path/to/git/speaker_induced_suppression_EEG/'
h5_path = '/path/to/h5/' # Where the model inputs and results will be saved

### Initialize model parameters

In [None]:
exclude = ['OP0001','OP0002','OP0020'] # OP1/2 don't have Aux EMG; OP20 had recording error/is excluded overall
subjs = np.sort([s[-6:] for s in glob(f'{git_path}eventfiles/*') if 'OP0' in s and s[-6:] not in exclude])
condition, level = 'all', 'phoneme'
model_number = 'model1'
tmin,tmax = -0.3,0.5
delays = np.arange(np.floor(tmin*128),np.ceil(tmax*128),dtype=int)
ndelays = len(delays)
nboots = 10
alphas=np.hstack((0, np.logspace(-4,4,20)))
features_dict = {
                'dorsal': ['y','w','k','kcl', 'g','gcl','eng','ng'],
                'coronal': ['ch','jh','sh','zh','s','z','t','tcl','d','dcl','n','th','dh','l','r'],
                'labial': ['f','v','p','pcl','b','bcl','m','em','w'],
                'high': ['uh','ux','uw','iy','ih','ix','ey','eh','oy'],
                'front': ['iy','ih','ix','ey','eh','ae','ay'],
                'low': ['aa','ao','ah','ax','ae','aw','ay','axr','ow','oy'],
                'back': ['aa','ao','ow','ah','ax','ax-h','uh','ux','uw','axr','aw'],
                'plosive': ['p','pcl','t','tcl','k','kcl','b','bcl','d','dcl','g','gcl','q'],
                'fricative': ['f','v','th','dh','s','sh','z','zh','hh','hv','ch','jh'],
                'syllabic': ['aa', 'ae', 'ah', 'ao', 'aw', 'ax', 'ax-h', 'axr', 'ay','eh','ey','ih', 'ix',
                             'iy','ow', 'oy','uh', 'uw', 'ux'],
                'nasal': ['m','em','n','en','ng','eng','nx'],
                'voiced': ['aa', 'ae', 'ah', 'ao', 'aw', 'ax', 'ax-h', 'axr', 'ay','eh','ey','ih', 'ix',
                             'iy','ow', 'oy','uh', 'uw', 'ux','w','y','el','l','r','dh','z','v','b','bcl','d',
                             'dcl','g','gcl','m','em','n','en','eng','ng','nx','q','jh','zh'],
                'obstruent': ['b', 'bcl', 'ch', 'd', 'dcl', 'dh', 'dx','f', 'g', 'gcl', 'hh', 'hv','jh', 'k',
                              'kcl', 'p', 'pcl', 'q', 's', 'sh','t', 'tcl', 'th','v','z', 'zh','q'],
                'sonorant': ['aa', 'ae', 'ah', 'ao', 'aw', 'ax', 'ax-h', 'axr', 'ay','eh','ey','ih', 'ix',
                             'iy','ow', 'oy','uh', 'uw', 'ux','w','y','el','l','r','m',
                             'n', 'ng', 'eng', 'nx','en','em'],
        }
features = [f for f in features_dict.keys()]
print(features)
print(len(features))

### Load raw data
(and aux EMG, assuming we want it in our specified model)

In [None]:
raws, emgs = dict(), dict()
for s in tqdm(subjs):
    blockid = f'{s}_B1'
    fname = f'{eeg_data_path}{s}/{blockid}/{blockid}_cca.vhdr'
    raws[s] = mne.io.read_raw_brainvision(fname,preload=True,verbose=False)
    raws[s].filter(l_freq=1, h_freq=30)
    if model_number[-1] != 'e': # if we want EMG, get it
        ica_fname = f'{eeg_data_path}{s}/{blockid}/{blockid}_ica.fif'
        ica_raw = mne.io.read_raw_fif(ica_fname, preload=True)
        if 'hEOG' in ica_raw.info['ch_names']:
            ica_raw.notch_filter(60,picks=['hEOG'],verbose=False)
            ica_raw.filter(l_freq=1,h_freq=None,picks=['hEOG'],verbose=False)
            emg = ica_raw.get_data(picks=['hEOG'])
            emgs[s] = emg/np.abs(emg).max() # Scale EMG
        else:
            raise Exception('EMG not found')

### Load events from eventfiles
Although only phoneme-level linear models are fit in the paper, this notebook can load sentence- and word-level stimuli as well, as it's possible to fit those.

Eventfiles are named according to channel, level, and condition:
* **Channel**: the mode of speech used during a specific trial.
    * `mic` eventfiles contain events from the production component of the task.
    * `spkr` eventfiles contain events from the perception component of the task.
* **Level**: the level of linguistic representation encoded in the eventfile
    * `ph`: individual phoneme events
    * `wr`: word-level events
    * `sn`: sentence-level events
* **Condition**: the type of playback utilized in the perception trials
    * `el` eventfiles contain the consistent playback trials or their preceding production trials (depends on specified channel).
    * `sh` eventfiles contain the inconsistent playback trials or their preceding production trials (depends on specified channel).
    * `all` eventfiles contain both the consistent and inconsistent playback trials, or their preceding production trials (depending on specified channel).

In [None]:
condition, level = 'all', 'ph'
spkr_events, mic_events = dict(), dict()
pbar = tqdm(subjs)
for s in pbar:
    pbar.set_description(f"Loading events for {s}")
    blockid = f"{s}_B1"
    fs = raws[s].info['sfreq']
    # Spkr events
    event_fpath = f"{git_path}eventfiles/{s}/{blockid}/{blockid}_spkr_{level}_{condition}.txt"
    event_file = []
    with open(event_fpath,'r') as f:
        c = csv.reader(f,delimiter='\t')
        for row in c:
            if level == 'ph':
                event_file.append(row[:3]+row[4:])
            else:
                event_file.append(row[:3])
    event_file = np.array(event_file,dtype=float)
    event_file[:,:2] = np.round(event_file[:,:2]*fs)
    if s in ['OP0015','OP0016']:
        # These two subjects have multiple blocks and those needed to be added to the events list
        b2_blockid = f"{s}_B2"
        b2_event_fpath = f"{git_path}eventfiles/{s}/{b2_blockid}/{b2_blockid}_spkr_{level}_{condition}.txt"
        b2_events = []
        with open(b2_event_fpath, 'r') as f:
            c = csv.reader(f,delimiter='\t')
            for row in c:
                if level == 'ph':
                    b2_events.append(row[:3]+row[4:])
                else:
                    b2_events.append(row[:3])
        b2_events = np.array(b2_events,dtype=float)
        b2_events[:,:2] = np.round(b2_events[:,:2]*fs)
        last_samp = mne.io.read_raw_brainvision(
            f"{eeg_data_path}{s}/{blockid}/{blockid}_downsampled.vhdr",preload=False,verbose=False
        ).last_samp
        b2_events[:,0] = b2_events[:,0]+last_samp
        b2_events[:,1] = b2_events[:,1]+last_samp
        event_file = np.vstack((event_file,b2_events))
    spkr_events[s] = event_file.astype(int)
    # Mic events (need non-task times removed)
    event_fpath = f"{git_path}eventfiles/{s}/{blockid}/{blockid}_mic_{level}_{condition}.txt"
    event_file = []
    with open(event_fpath,'r') as f:
        c = csv.reader(f,delimiter='\t')
        for row in c:
            if level == 'ph':
                event_file.append(row[:3]+row[4:])
            else:
                event_file.append(row[:3])
    event_file = np.array(event_file,dtype=float)
    event_file[:,:2] = np.round(event_file[:,:2]*fs)
    if s in ['OP0015','OP0016']:
        # These two subjects have multiple blocks and those needed to be added to the events list
        b2_event_fpath = f"{git_path}eventfiles/{s}/{b2_blockid}/{b2_blockid}_mic_{level}_{condition}.txt"
        b2_events = []
        with open(b2_event_fpath, 'r') as f:
            c = csv.reader(f,delimiter='\t')
            for row in c:
                if level == 'ph':
                    b2_events.append(row[:3]+row[4:])
                else:
                    b2_events.append(row[:3])
        b2_events = np.array(b2_events,dtype=float)
        b2_events[:,:2] = np.round(b2_events[:,:2]*fs)
        b2_events[:,0] = b2_events[:,0]+last_samp
        b2_events[:,1] = b2_events[:,1]+last_samp
        event_file = np.vstack((event_file,b2_events))
    mic_tg_path = f"{git_path}textgrids/{s}/{blockid}/{blockid}_mic.textgrid"
    with open(mic_tg_path) as r:
        tg = textgrid.TextGrid(r.read())
    task_times = np.array([t*fs for t in gk.get_task_times(tg)],dtype=int)
    if s in ['OP0015','OP0016']:
        # Load textgrid from second block as well
        b2_mic_tg_path = f"{git_path}textgrids/{s}/{b2_blockid}/{b2_blockid}_mic.textgrid"
        with open(b2_mic_tg_path) as r:
            b2_tg = textgrid.TextGrid(r.read())
        b2_task_times = np.array([(t*fs)+last_samp for t in gk.get_task_times(tg)],dtype=int)
        task_times = np.vstack((task_times,b2_task_times))
    task_range = [np.arange(t[0],t[1],step=1) for t in task_times]
    trial_events = []
    for d in task_range:
        for ev in event_file:
            onset = ev[0]
            offset = ev[1]
            if onset in d and offset in d:
                trial_events.append(ev)
    mic_events[s] = np.array(trial_events).astype(int)

### Create a stimulus matrix and a response matrix to pass into linear encoding model
* Stimulus matrix has a shape of `n_eeg_samples` x `n_features`
* Response matrix has a shape of `n_eeg_channels` x `n_eeg_samples`

In [None]:
stims,resps = dict(),dict()
phonemes = np.loadtxt('./phonemes.txt',dtype=str)
nphones = len(phonemes)
nfeats = len(features)
pbar = tqdm(subjs)
for s in pbar:
    pbar.set_description(f"Creating stim/resp matrices for {s}")
    blockid = f'{s}_B1'
    stims[s],resps[s] = dict(),dict()
    # Make the resp
    picks = mne.pick_types(raws[s].info, meg=False,eeg=True,stim=False)
    resps[s] = raws[s].get_data(picks=picks)
    nsamps = resps[s].shape[1]
    # Make spkr stim
    phn_stim_spkr = np.zeros((nphones, nsamps))
    feat_stim_spkr = np.zeros((nfeats, nsamps))
    el_sh_stim = np.zeros((2, nsamps))
    el_times, sh_times = gk.get_el_sh_timing(git_path,s,'B1',mode='eeg')
    for ev in spkr_events[s]:
        onset = ev[0]
        for el_time in el_times:
            if onset >= el_time[0] and onset <= el_time[1]:
                el_sh_stim[0,onset] = 1
        for sh_time in sh_times:
            if onset >= sh_time[0] and onset <= sh_time[1]:
                el_sh_stim[1,onset] = 1
        phn_label = ev[2]
        phn_stim_spkr[phn_label,onset] = 1
        phn_stripped = re.sub(r'\d+', '', phonemes[phn_label].lower())
        for fi, f in enumerate(features):
            if phn_stripped in features_dict[f]:
                feat_stim_spkr[fi,onset] = 1
    # Make mic stim
    phn_stim_mic = np.zeros((nphones, nsamps))
    feat_stim_mic = np.zeros((nfeats, nsamps))
    for ev in mic_events[s]:
        onset = ev[0]
        phn_label = ev[2]
        phn_stim_mic[phn_label,onset] = 1
        phn_stripped = re.sub(r'\d+', '', phonemes[phn_label].lower())
        for fi, f in enumerate(features):
            if phn_stripped in features_dict[f]:
                feat_stim_mic[fi,onset] = 1
    # Concanate stimulus features 
    stims[s] = np.vstack((
        (feat_stim_spkr + feat_stim_mic), feat_stim_spkr, feat_stim_mic,
        np.atleast_2d(phn_stim_spkr.sum(0)), np.atleast_2d(phn_stim_mic.sum(0)),
        el_sh_stim
    )).T
    try:
        stims[s] = np.vstack((stims[s].T,emgs[s])).T
    except:
        warnings.warn(f'EMG for {s} not applied to model! Again, something is wrong...')

### Spilt stimulus and response matrices into training and validation set
80% of the data are used for training, while the remaining 20% of data are held out for validating model performance. To mitigate any potential overfitting, data are split according to sentence boundary. Because there are 50 unique sentences in the task, that means 40 sentences are used in training while 10 are held out for validation.

This cell can take a while to run depending on your hardware. The outputs are saved to an `.hdf5` file so that repeated iterations of model fitting do not require one to re-make the split data. This is convenient in terms of raw computational resources and also security in case of a kernel crash due to RAM overflow.

In [None]:
random_seed = 6655321
tStims, vStims, tResps, vResps = dict(), dict(), dict(), dict()
pbar = tqdm(subjs)
for s in pbar:
    blockid = f"{s}_B1"
    # Update this path if you're saving/loading h5 files locally
    h5_fpath = f'{h5_path}{s}_model_inputs.hdf5'
    if os.path.isfile(h5_fpath):
        pbar.set_description(f"Stim/resp for {s} already split, skipping this subject...")
    else:
        pbar.set_description(f"Splitting stim/resp into training/validation sets for {s}")
        # Read event files
        onsets, offsets, ids, = [], [], []
        # Read event files (spkr)
        spkr_sn_ev_fpath = f"{git_path}eventfiles/{s}/{blockid}/{blockid}_spkr_sn_all.txt"
        with open(spkr_sn_ev_fpath,'r') as f:
            c = csv.reader(f,delimiter='\t')
            for row in c:
                onsets.append(int(float(row[0])*fs))
                offsets.append(int(float(row[1])*fs))
                ids.append(int(row[2]))
        if s in ['OP0015','OP0016']:
            # Load events from second block too
            last_samp = mne.io.read_raw_brainvision(
                f"{eeg_data_path}{s}/{blockid}/{blockid}_downsampled.vhdr",preload=False,verbose=False
            ).last_samp
            b2_blockid = f"{s}_B2"
            b2_spkr_sn_ev_fpath = f"{git_path}eventfiles/{s}/{b2_blockid}/{b2_blockid}_spkr_sn_all.txt"
            with open(b2_spkr_sn_ev_fpath,'r') as f:
                c = csv.reader(f,delimiter='\t')
                for row in c:
                    onsets.append(int(float(row[0])*fs)+last_samp)
                    offsets.append(int(float(row[1])*fs)+last_samp)
                    ids.append(int(row[2]))
        # Read event files (mic)
        mic_sn_ev_fpath = f"{git_path}eventfiles/{s}/{blockid}/{blockid}_mic_sn_all.txt"
        with open(mic_sn_ev_fpath,'r') as f:
            c = csv.reader(f,delimiter='\t')
            for row in c:
                onsets.append(int(float(row[0])*fs))
                offsets.append(int(float(row[1])*fs))
                ids.append(int(row[2]))
        if s in ['OP0015','OP0016']:
            # Load events from second block too
            b2_mic_sn_ev_fpath = f"{git_path}eventfiles/{s}/{b2_blockid}/{b2_blockid}_mic_sn_all.txt"
            with open(b2_mic_sn_ev_fpath,'r') as f:
                c = csv.reader(f,delimiter='\t')
                for row in c:
                    onsets.append(int(float(row[0])*fs)+last_samp)
                    offsets.append(int(float(row[1])*fs)+last_samp)
                    ids.append(int(row[2]))
        # Split events sentence-by-sentence
        sn_events = dict()
        for this_sentence in range(len(np.unique(ids))):
            sn_ranges = []
            for i, sn_id in enumerate(ids):
                if sn_id == this_sentence:
                    onset_samp = onsets[i]
                    offset_samp = offsets[i]
                    sn_ranges.append([onset_samp,offset_samp])
            sn_events[this_sentence] = sn_ranges
        # Split stim/resp sentence-by-sentence
        resp_dict, stim_dict = dict(), dict()
        for this_sentence in range(len(np.unique(ids))):
            sn_resps, sn_stims = [], []
            for i, ev in enumerate(sn_events[this_sentence]):
                onset = ev[0]
                offset = ev[1]
                for samp_idx in np.arange(resps[s].shape[1]):
                    if samp_idx >= onset and samp_idx <= offset:
                        sn_resps.append(resps[s][:,samp_idx])
                        sn_stims.append(stims[s][samp_idx])
            resp_dict[this_sentence] = np.array(sn_resps)
            stim_dict[this_sentence] = np.array(sn_stims)
        # Split stim/resp into training/validation sets along sentence boundaries
        nsentences = 50
        tv_split = 40 # 40 sentences IDs to train, remaining 10 to validate
        np.random.seed(random_seed)
        train_sn_ids = np.random.permutation(nsentences)[:tv_split]
        np.random.seed(random_seed)
        val_sn_ids = np.random.permutation(nsentences)[tv_split:]
        tStims_by_sn, vStims_by_sn, tResps_by_sn, vResps_by_sn = dict(), dict(), dict(), dict()
        for this_sentence in train_sn_ids:
            tResps_by_sn[this_sentence] = resp_dict[this_sentence]
            tStims_by_sn[this_sentence] = stim_dict[this_sentence]
        for this_sentence in val_sn_ids:
            vResps_by_sn[this_sentence] = resp_dict[this_sentence]
            vStims_by_sn[this_sentence] = stim_dict[this_sentence]
        tStims[s] = np.vstack(list(tStims_by_sn.values()))
        vStims[s] = np.vstack(list(vStims_by_sn.values()))
        tResps[s] = np.vstack(list(tResps_by_sn.values()))
        vResps[s] = np.vstack(list(vResps_by_sn.values()))
        pbar.set_description(
            f"{s}: training on {tStims[s].shape[0]} samps, validating on {vStims[s].shape[0]} samps. Raw contained {stims[s].shape[0]} samps total"
        )
        if tStims[s].shape[0] != tResps[s].shape[0]:
            raise Exception("Stim and resp do not have the same shape! (training)")
        if vStims[s].shape[0] != vResps[s].shape[0]:
            raise Exception("Stim and resp do not have the same shape! (validation)")
        # Save the split to hdf5 files
        with h5py.File(h5_fpath,'a') as f:
            f.create_dataset('tStim', data=tStims[s])
            f.create_dataset('tResp', data=tResps[s])
            f.create_dataset('vStim', data=vStims[s])
            f.create_dataset('vResp', data=vResps[s])

### Load training/validation stimulus/response matrices from `hdf5` file
The training/validation response matrices are saved to hdf5 and loaded here. Training/validation stimulus matrices are also loaded here, but an additional step is required while loading. To conserve disk space, model inputs for all models described in the first markdown cell of this notebook are saved to a singlular `hdf5` file. Assuming we do not want to maximum number of stimulus features, we need to clip the loaded matrix down to the appropriate specification of stimulus features. We use a function from `./utils/strf.py` to accomplish this.

In [None]:
tStims, vStims, tResps, vResps = dict(), dict(), dict(), dict()
for s in tqdm(subjs):
    model_input_h5_fpath = f"{h5_path}{s}_model_inputs.hdf5"
    tStims[s], tResps[s], vStims[s], vResps[s] = strf.load_model_inputs(model_input_h5_fpath, model_number)

### Fit linear encoding model
This cell takes a very long time to run depending on the number of stimulus features in the specified model and your local hardware. Go get a coffee (or several)!

In [None]:
# Run the STRF
corrs, wts, best_alphas = dict(),dict(),dict()
pbar = tqdm(subjs)
for s in pbar:
    pbar.set_description(f"Fitting STRF for {s} {model_number}")
    blockid = s + '_B1'
    subj_corrs, subj_wts, _, _, _, _, subj_best_alphas = strf.strf( # don't save stim/resp
        tResps[s], tStims[s],
        vResp = vResps[s], vStim = vStims[s],
        nboots=nboots, delay_min=delay_min, delay_max=delay_max, alphas=alphas,
        flip_resp = True
    )
    print(f"{s}: Best alpha: {subj_best_alphas[0]} (should be in between {alphas[0]} and {alphas[-1]})")
    # Write model output to dicts
    corrs[s] = subj_corrs[0] # R-value between predicted and actual EEG at each channel
    wts[s] = subj_wts[0].reshape((ndelays,tStims[s].shape[1],tResps[s].shape[1])) # delays x feats x chans
    best_alphas[s] = subj_best_alphas # Regularization parameter that yielded best model fit

### Save results
Model weights are saved to an `.hdf5` file stored locally, while correlations, best alpha, and p-values (calculated in `git/stats/bootstrap_lme.ipynb`) are saved to a `.csv` file stored on GitHub (`git/stats/lme_results.csv`).

In [None]:
# Save weights to hdf5 file
force_overwrite = True
pbar = tqdm(subjs)
for s in pbar:
    pbar.set_description(f"Saving model results for {s} {model_number}")
    # Update this file location accordingly on your local machine!
    model_output_h5_fpath = f"{h5_path}{s}_weights.hdf5"
    if os.path.isfile(model_output_h5_fpath):
        with h5py.File(model_output_h5_fpath,'a') as f:
            if model_number not in f.keys():
                f.create_dataset(model_number, data = wts[s])
            elif force_overwrite:
                f[model_number][:] = wts[s]
            else:
                print(
                    f"{s} {model_number} weights already saved to hdf5. To force an overwrite, set 'force_overwrite' to True."
                )
    else:
        with h5py.File(model_output_h5_fpath,'w') as f:
            f.create_dataset(model_number, data = wts[s])    

In [None]:
# Save corrs, alphas to csv file
force_overwrite = True
results_csv_fpath = f"{git_path}stats/lme_results.csv"
df = pd.read_csv(results_csv_fpath)
pbar = tqdm(subjs)
for s in pbar:
    pbar.set_description(f"Saving alphas/corrs for {s} {model_number} to csv")
    blockid = f"{s}_B1"
    ch_names = mne.io.read_raw_brainvision(f"{eeg_data_path}{s}/{blockid}/{blockid}_cca.vhdr",
                                           preload=False,verbose=False).info['ch_names']
    for i,ch in enumerate(ch_names):
        if len(df[(df['subject']==s) & (df['model']==model_number) & (df['channel']==ch)]) == 0:
            # Data does not exist in dataframe so let's add it
            # using nan as pval as a placeholder as we still need to bootstrap
            new_row = pd.DataFrame({
                'subject':[s], 'model':[model_number], 'channel':[ch],
                'r_value':[corrs[s][i]], 'p_value':["nan"], "best_alpha":[best_alphas[s][0]]
            })
            df = df.append(new_row,ignore_index=True)
        elif force_overwrite:
            tgt_row = df[(df['subject']==s) & (df['model']==model_number) & (df['channel']==ch)]
            df.loc[tgt_row.index, 'r_value'] = corrs[s][i]
            df.loc[tgt_row.index, 'best_alpha'] = best_alphas[s][0]
            df.loc[tgt_row.index, 'p_value'] = "nan"
        else:
            print(
                f"data for {s} {model_number} {ch} already exists. To overwrite, set 'force_overwrite' to True"
            )
df.to_csv(results_csv_fpath,index=False)