# Imports and setup

In [98]:
# Basic
import numpy as np
import scipy
import scipy.stats
import os
import itertools
import warnings
import sys
from copy import deepcopy

# Data Loading
import cmlreaders as cml #Penn Computational Memory Lab's library of data loading functions

# Data Handling
import os
from os import listdir as ld
import os.path as op
from os.path import join, exists as ex
import time
import datetime

# Data Analysis
import pandas as pd
import xarray as xr

# EEG & Signal Processing
import ptsa
from ptsa.data.readers import BaseEventReader, EEGReader, CMLEventReader, TalReader
from ptsa.data.filters import MonopolarToBipolarMapper, MorletWaveletFilter
from ptsa.data.timeseries import TimeSeries

# Data Visualization
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns

# Parallelization
import cmldask.CMLDask as da
from cmldask.CMLDask import new_dask_client_slurm as cl
from cluster import wait, get_exceptions_quiet as get_ex
import cmldask

# Custom
from cstat import * #circular statistics
from misc import * #helper functions for loading and saving data, and for other purposes
from matrix_operations import * #matrix operations

from helper import *

%load_ext autoreload

warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=FutureWarning)

beh_to_event_windows = {'en': [250-1000, 1250+1000],
                     'en_all': [250-1000, 1250+1000],
                     'rm': [-1000, 0],
                     'ri': [-1000, 0]}

beh_to_epochs = {'en': np.arange(250, 1250, 200),
              'en_all': np.arange(250, 1250, 200),
              'rm': np.arange(-1000, 0, 200),
              'ri': np.arange(-1000, 0, 200)}


from helper import root_dir, USERNAME as user
if not os.path.exists(root_dir):
    os.mkdir(root_dir)

from functools import partial
cluster_log_dir = 'cluster'
cl = partial(cl, log_directory=cluster_log_dir)
if not os.path.exists(cluster_log_dir):
    os.mkdir(cluster_log_dir)

font_dirs = ['fonts']


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Figure 1 (task schematic)

## Encoding

In [72]:
reader = cml.CMLReader(*dfrow)
events = reader.load('events').query('type == "WORD"').iloc[0:1]
pairs = reader.load('pairs')
eeg = reader.load_eeg(events, -2000, 3000, scheme=pairs)
sfreq = eeg.samplerate
eeg = eeg.data
i, j = 0, 10
signal1, signal2 = eeg[0, i, :].astype(float), eeg[0, j, :].astype(float)

NameError: name 'dfrow' is not defined

In [None]:
freq = 3
import mne
signal1_filt = mne.filter.filter_data(signal1, sfreq=sfreq, l_freq=freq-1, h_freq=freq+1, verbose=False)
signal2_filt = mne.filter.filter_data(signal2, sfreq=sfreq, l_freq=freq-1, h_freq=freq+1, verbose=False)

In [None]:
signal_lw = 6
vertical_lw = 7
fig, ax = plt.subplots(1, 1, figsize=(20, 5))
ax.plot(np.arange(len(signal1_filt)), signal1_filt, lw=signal_lw, color='blue')
ax.plot(np.arange(len(signal2_filt)), signal2_filt, lw=signal_lw, color='orange')
ymin, ymax = ax.get_ylim()
plt.axis('off')
line_color = 'k'
# for onset, line_color in zip([50, 2000], ['red', 'green']):
for onset in [50, 50+(1600+1000)*(sfreq/1000)]:
    highlight_begin = onset+250*(sfreq/1000)
    highlight_end = onset+1250*(sfreq/1000)
    offset = onset+1600*(sfreq/1000)
    ax.vlines(onset, ymin=ymin, ymax=ymax, linestyles='solid', color=line_color, lw=vertical_lw)
    ax.vlines(onset+10, ymin=ymin, ymax=ymax, linestyles=(0, (5, 1)), color=line_color, lw=vertical_lw)
    ax.axvspan(highlight_begin, highlight_end, color='yellow', alpha = 0.5)
    ax.vlines(offset, ymin=ymin, ymax=ymax, linestyles='solid', color=line_color, lw=vertical_lw)
    ax.vlines(offset-10, ymin=ymin, ymax=ymax, linestyles=(0, (5, 1)), color=line_color, lw=vertical_lw)

## Retrieval

In [None]:
reader = cml.CMLReader(*dfrow)
events = reader.load('events').query('type == "REC_WORD"').iloc[0:1]
pairs = reader.load('pairs')
eeg = reader.load_eeg(events, -2000, 3000, scheme=pairs)
sfreq = eeg.samplerate
eeg = eeg.data
i, j = 0, 10
signal1, signal2 = eeg[0, i, :].astype(float), eeg[0, j, :].astype(float)

In [None]:
freq = 3
import mne
signal1_filt = mne.filter.filter_data(signal1, sfreq=sfreq, l_freq=freq-1, h_freq=freq+1, verbose=False)
signal2_filt = mne.filter.filter_data(signal2, sfreq=sfreq, l_freq=freq-1, h_freq=freq+1, verbose=False)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(20, 5))
ax.plot(np.arange(len(signal1_filt)), signal1_filt, linewidth=5, color='blue')
ax.plot(np.arange(len(signal2_filt)), signal2_filt, linewidth=5, color='orange')
ymin, ymax = ax.get_ylim()
plt.axis('off')
first_recall_time = 100+1000*(sfreq/1000)
silence_onset = first_recall_time+1000*(sfreq/1000)
silence_offset = silence_onset+1000*(sfreq/1000)
second_recall_time = silence_offset+1500*(sfreq/1000)
for line_time, line_color, linestyle in zip([first_recall_time, silence_onset, silence_offset, second_recall_time], ['red', 'gray', 'gray', 'green'], ['solid', 'dashed', 'dashed', 'solid']):
    ax.vlines(line_time, ymin=ymin, ymax=ymax, linestyles=linestyle, color=line_color, lw=lw)
ax.axvspan(first_recall_time-1000*(sfreq/1000), first_recall_time, color='yellow', alpha = 0.5)
ax.axvspan(silence_onset, silence_offset, color='yellow', alpha = 0.5)
ax.axvspan(second_recall_time-1000*(sfreq/1000), second_recall_time, color='yellow', alpha = 0.5)
# ax.vlines(offset, ymin=ymin, ymax=ymax, linestyles='solid', color=line_color, lw=lw)

# Get subject sex

In [394]:
sess_list_df = pd.read_json(join(root_dir, 'sess_list_df.json')).query('include == True').set_index(['sub', 'exp', 'sess', 'loc', 'mon'], drop=False)
print(f'Length of sess_list_df: {len(sess_list_df)}')
sess_list_df.head()

Length of sess_list_df: 980


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,sub,exp,sess,loc,mon,atlas,contacts_source,eeg,eeg_data_source,eeg_error,...,no_matches_rm,mean_succ_times_rm,mean_unsucc_times_rm,no_matches_ri,mean_succ_times_ri,mean_unsucc_times_ri,recall_rate,en_match_rate,rm_match_rate,ri_match_rate
sub,exp,sess,loc,mon,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
R1001P,FR1,0,0,0,R1001P,FR1,0,0,0,avg,contacts,True,cmlreaders,,...,19.0,13339.947368,12390.0,17.0,13803.588235,14292.176471,0.206667,0.176667,0.306452,0.274194
R1001P,FR1,1,0,0,R1001P,FR1,1,0,0,avg,contacts,True,cmlreaders,,...,15.0,9179.2,9117.4,13.0,9279.384615,10227.769231,0.21,0.206667,0.238095,0.206349
R1002P,FR1,0,0,0,R1002P,FR1,0,0,0,avg,contacts,True,cmlreaders,,...,46.0,13129.23913,12938.630435,,,,0.39,0.356667,0.393162,
R1002P,FR1,1,0,0,R1002P,FR1,1,0,0,avg,contacts,True,cmlreaders,,...,23.0,11955.956522,11924.826087,,,,0.4,0.36,0.191667,
R1003P,FR1,0,0,0,R1003P,FR1,0,0,0,avg,contacts,True,cmlreaders,,...,23.0,13344.086957,12047.043478,23.0,13344.086957,13791.652174,0.325758,0.30303,0.267442,0.267442


In [395]:
sublist = sess_list_df['sub'].unique() # Get list of unique subject codes
print(f'There are {len(sublist)} analyzed subjects.')
print(sublist) # Print the list of unique subject codes

There are 378 analyzed subjects.
['R1001P' 'R1002P' 'R1003P' 'R1006P' 'R1010J' 'R1013E' 'R1016M' 'R1018P'
 'R1020J' 'R1021D' 'R1022J' 'R1024E' 'R1026D' 'R1028M' 'R1030J' 'R1031M'
 'R1032D' 'R1033D' 'R1035M' 'R1036M' 'R1039M' 'R1042M' 'R1045E' 'R1048E'
 'R1050M' 'R1051J' 'R1052E' 'R1053M' 'R1054J' 'R1056M' 'R1060M' 'R1061T'
 'R1062J' 'R1063C' 'R1065J' 'R1066P' 'R1067P' 'R1068J' 'R1069M' 'R1070T'
 'R1074M' 'R1075J' 'R1076D' 'R1077T' 'R1080E' 'R1083J' 'R1084T' 'R1089P'
 'R1092J' 'R1093J' 'R1094T' 'R1096E' 'R1098D' 'R1101T' 'R1102P' 'R1104D'
 'R1105E' 'R1106M' 'R1107J' 'R1108J' 'R1111M' 'R1112M' 'R1113T' 'R1114C'
 'R1115T' 'R1118N' 'R1120E' 'R1121M' 'R1122E' 'R1123C' 'R1124J' 'R1125T'
 'R1127P' 'R1128E' 'R1129D' 'R1130M' 'R1131M' 'R1134T' 'R1135E' 'R1136N'
 'R1137E' 'R1138T' 'R1141T' 'R1144E' 'R1145J' 'R1146E' 'R1147P' 'R1148P'
 'R1149N' 'R1150J' 'R1151E' 'R1153T' 'R1154D' 'R1156D' 'R1157C' 'R1161E'
 'R1162N' 'R1163T' 'R1164E' 'R1166D' 'R1167M' 'R1168T' 'R1169P' 'R1170J'
 'R1172E' 'R1173J'

In [396]:
def get_sex(sub):
    
    for fname in [f'/data/eeg/{sub}/docs/readme.txt',
                  f'/data/eeg/{sub}/docs/patient_info.txt']:
        if not ex(fname): continue
        with open(fname, 'r') as file:
            content = file.read().lower()
            
        if ( ('male' in content) | ('gender: m' in content) ):
            return 'M'
        elif ( ('female' in content) | ('gender: f' in content) ):
            return 'F'

    return 'nan'

pyFR_sublist = sess_list_df.query('exp=="pyFR"')['sub'].unique()
pyFR_subjects_sex = pd.DataFrame({'sub': pyFR_sublist})
pyFR_subjects_sex['sex'] = pyFR_subjects_sex.apply(lambda r: get_sex(r['sub']), axis=1)

In [397]:
ram_subjects_sex_path = '/home1/amrao/ConnectivityProject/ram_subjects_sex.csv'
ram_subjects_sex = pd.read_csv(ram_subjects_sex_path)

study_site_codes = {'University of Pennsylvania': 'P',
                    'Dartmouth University': 'D',
                    'Jefferson Hospital': 'J',
                    'University of Washington': '',
                    'Emory University': 'E',
                    'Mayo Clinic': 'M',
                    'UT Southwestern': 'T',
                    'Columbia University': 'C',
                    'NINDS': 'N',
                    'UTHSC San Antonio': 'S',
                    'CU Anschutz': 'A',
                    'Harvard': 'H'}
ram_subjects_sex['sub'] = ram_subjects_sex.apply(lambda r: 'R1' + str(r['Subject Number']).zfill(3) + study_site_codes[r['Study site']], axis=1)
ram_subjects_sex = ram_subjects_sex.query('sub in @sublist')
def get_ram_sex(r):
    
    if r['Gender'] == 'Female': return 'F'
    elif r['Gender'] == 'Male': return 'M'
    else: return 'nan'
ram_subjects_sex['sex'] = ram_subjects_sex.apply(lambda r: get_ram_sex(r), axis=1)
ram_subjects_sex = ram_subjects_sex[['sub', 'sex']]

In [398]:
subjects_sex = pd.concat([ram_subjects_sex, pyFR_subjects_sex])
no_males = np.sum(subjects_sex['sex'] == 'M')
no_females = np.sum(subjects_sex['sex'] == 'F')
no_unavailable = np.sum(subjects_sex['sex'] == 'nan')
print(f'There are\n {no_males} male subjects,\n {no_females} female subjects, and\n {no_unavailable} subject(s) for whom sex information is unavailable.')

There are
 212 male subjects,
 165 female subjects, and
 1 subject(s) for whom sex information is unavailable.
