# AMME Data Processing

This notebook loads, processes, and epochs the Emory AMME datasets...

---
> Martina Hollearn (martina.hollearn@psych.utah.edu)  
> 05/13/24

## 1. Import Libraries

In [1]:
import os
import mne
import csv
import numpy as np
import pandas as pd
from scipy.io import loadmat, savemat
from scipy.signal import filtfilt, firwin
import matplotlib.pyplot as plt


### Edit these paths for you PC:
Please add your directories as new comments

In [2]:
# For Alireza's PC (Please don't revome the comments)
# rdDir = r'C:\Users\alire\Box\InmanLab\AMME_Data_Emory\AMME_Data'
# wrDir = r'C:\Users\alire\Box\InmanLab\AMME_Data_Emory\AMME_Data'

rdDir = r'C:\Users\alire\Box\InmanLab\AMME_Data_Emory\AMME_Data'
wrDir = r'C:\Users\alire\Box\InmanLab\AMME_Data_Emory\AMME_Data'

rdDir = os.path.normpath(rdDir)
wrDir = os.path.normpath(wrDir)

The Function is not working now

## 2. Load Data

Loading All the data

In [None]:
def loadParticipant(pName, fileName, rdDir):
    
    #load Raw data
    subject = pName
    seeg_filename = fileName
    file_path = rdDir
    event_filename = f'{subject}_LFP_day2_trialtimes.mat'
    log_filename = f'{subject}_day2.log'
    data_path = os.path.join(file_path, subject, seeg_filename)
    events_path = os.path.join(file_path, subject, event_filename)
    logfile_path = os.path.join(file_path, subject, log_filename)

    # Create Preprocessed data folder
    preproc_datapath = os.path.join(file_path,subject,'PreprocessedData', 'Martinas_preprocessing')

    if not os.path.exists(preproc_datapath):
        os.makedirs(preproc_datapath)

    # Define your parameters
    lowcut = 1  # Lower cutoff frequency
    highcut = 119  # Upper cutoff frequency
    transition_bandwidth = 1  # Transition bandwidth in Hz

    # Dynamically get the subject's sampling rate
    samprate = raw.info['sfreq']  # This pulls the sampling rate for the current subject

    #Create a Copy and filter the data
    proc_data = raw.copy()
    proc_data.filter(lowcut, highcut, fir_design='firwin', filter_length='auto', phase='zero', fir_window='hamming', 
                    verbose=True,l_trans_bandwidth=1, h_trans_bandwidth=1)

    #filter out 42 Hz noise
    proc_data.notch_filter(freqs = [42], picks='seeg')# Joe had this in his code, not sure why it's there

    # Apply notch filter (60, 120, 180 Hz)
    proc_data.notch_filter(freqs = [60], picks='seeg')

In [3]:
subject = 'amyg003'
seeg_filename = f'{subject}_objectMemory_day2_05mA.edf'
file_path = rdDir
event_filename = f'{subject}_LFP_day2_trialtimes.mat'
log_filename = f'{subject}_day2.log'
data_path = os.path.join(file_path, subject, seeg_filename)
events_path = os.path.join(file_path, subject, event_filename)
logfile_path = os.path.join(file_path, subject, log_filename)

# Create Preprocessed data folder
preproc_datapath = os.path.join(file_path,subject,'PreprocessedData', 'Martinas_preprocessing')

if not os.path.exists(preproc_datapath):
    os.makedirs(preproc_datapath)

In [4]:
# Load SEEG data
raw = mne.io.read_raw_edf(data_path, preload=True)

# Get recording info
fs = int(np.round(raw.info['sfreq']))
ch_names = raw.info['ch_names']

# Set all channel types to SEEG
raw.set_channel_types({ch: 'seeg' for ch in ch_names})

# Display info
raw.info

Extracting EDF parameters from D:\Martina Test\Code\amyg003\amyg003_objectMemory_day2_05mA.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 1637201  =      0.000 ...  1638.164 secs...


Unnamed: 0,General,General.1
,MNE object type,Info
,Measurement date,2015-09-10 at 13:13:13 UTC
,Participant,"Hishida,"
,Experimenter,Unknown
,Acquisition,Acquisition
,Sampling frequency,999.41 Hz
,Channels,Channels
,sEEG,129
,Head & sensor digitization,Not available
,Filters,Filters


In [None]:
# Test and compare notch filter
# Define your parameters
lowcut = 1  # Lower cutoff frequency
highcut = 119  # Upper cutoff frequency
transition_bandwidth = 1  # Transition bandwidth in Hz

# Dynamically get the subject's sampling rate
samprate = raw.info['sfreq']  # This pulls the sampling rate for the current subject

#Create a Copy and filter the data
proc_data = raw.copy()
# proc_data.filter(lowcut, highcut, fir_design='firwin', filter_length='auto', phase='zero', fir_window='hamming', 
#                  verbose=True,l_trans_bandwidth=1, h_trans_bandwidth=1)


data = notch_filter_moving_window(data = proc_data.get_data(), movingwin=[1.5, .5], Fs=samprate, freqs=[60], tau=10, mt_bandwidth=3)


savemat(os.path.join(wrDir,'NotchFilterTest.mat'), {'eeg_data': data})

## 3. Initial Preprocessing steps
- Extract events (trial times, stim types, responses)
- Filtering (lowpass-, highpass-, and notch w harmonics)
- Data cleaning by identifying bad channels and epochs


In [None]:
# Define your parameters
lowcut = 1  # Lower cutoff frequency
highcut = 119  # Upper cutoff frequency
transition_bandwidth = 1  # Transition bandwidth in Hz

# Dynamically get the subject's sampling rate
samprate = raw.info['sfreq']  # This pulls the sampling rate for the current subject

#Create a Copy and filter the data
proc_data = raw.copy()
proc_data.filter(lowcut, highcut, fir_design='firwin', filter_length='auto', phase='zero', fir_window='hamming', 
                 verbose=True,l_trans_bandwidth=1, h_trans_bandwidth=1)

#filter out 42 Hz noise
proc_data.notch_filter(freqs = [42], picks='seeg')# Joe had this in his code, not sure why it's there

# Apply notch filter (60, 120, 180 Hz)
proc_data.notch_filter(freqs = [60], picks='seeg')


In [None]:
savemat('PythonPreprocessOrder35.mat', {'eeg_data': proc_data.get_data()})
# print(filter_order)

## Visualize Plot Filter vs Raw data to ensure that the filtering works
### accoding to Joe it works fine

In [None]:
# Access the filtered data
filtered_data = proc_data.get_data(picks='seeg')

# Check if there are any NaN values in the data
nan_count = np.isnan(filtered_data).sum()
print(f"Number of NaN values: {nan_count}")

# Check if there are any Inf values in the data
inf_count = np.isinf(filtered_data).sum()
print(f"Number of Inf values: {inf_count}")

In [None]:
# Function to replace NaN or Inf values with 0
def clean_data(data):
    return np.nan_to_num(data, nan=0.0, posinf=0.0, neginf=0.0)

# Clean the data for raw, bandpass, and bandpass + notch
raw_data_clean = clean_data(raw.get_data())
bandpass_data_clean = clean_data(bandpass.get_data())
bandpass_and_notch_data_clean = clean_data(proc_data.get_data())


In [None]:
### Adjust here
time_start = 0  # in ms
time_end = 200  # in ms
###

# Convert time from ms to sample indices
time_start_idx = int(np.round(time_start * fs / 1000))  # Start sample index
time_end_idx = int(np.round(time_end * fs / 1000))      # End sample index
print(f'Time range: {time_start_idx} to {time_end_idx} samples')

# Clean the data to avoid NaN or Inf values
raw_data_clean = clean_data(raw.get_data())
bandpass_data_clean = clean_data(bandpass.get_data())
bandpass_and_notch_data_clean = clean_data(proc_data.get_data())

# Set up saving path
raw_vs_filter_path = os.path.join(preproc_datapath, 'raw_vs_filter_plots', f'{time_start}ms_to_{time_end}ms')
os.makedirs(raw_vs_filter_path, exist_ok=True)

num_channels = raw_data_clean.shape[0]
time_range = np.arange(time_start_idx, time_end_idx)  # X-axis for ms

# Time range in samples for X-axis
time_range = np.arange(time_start_idx, time_end_idx)

for channel in range(num_channels):
    plt.figure(figsize=(10, 5))

    # Plot raw, bandpass, and bandpass + notch filtered data
    plt.plot(time_range, raw_data_clean[channel, time_start_idx:time_end_idx], label='Raw')
    plt.plot(time_range, bandpass_data_clean[channel, time_start_idx:time_end_idx], label='Bandpass')
    plt.plot(time_range, bandpass_and_notch_data_clean[channel, time_start_idx:time_end_idx], label='Bandpass + Notch')

    # Adding labels and title
    plt.title(f'{subject} Channel {channel + 1} ({raw.ch_names[channel]}) from {time_start}ms to {time_end}ms')
    plt.xlabel('Time (ms)')
    plt.ylabel('Amplitude (µV)')

    plt.legend()
    plt.grid(True)

    # Adjust y-axis limits dynamically based on the cleaned data
    y_min = min(np.min(raw_data_clean[channel, time_start_idx:time_end_idx]),
                np.min(bandpass_data_clean[channel, time_start_idx:time_end_idx]),
                np.min(bandpass_and_notch_data_clean[channel, time_start_idx:time_end_idx]))
    y_max = max(np.max(raw_data_clean[channel, time_start_idx:time_end_idx]),
                np.max(bandpass_data_clean[channel, time_start_idx:time_end_idx]),
                np.max(bandpass_and_notch_data_clean[channel, time_start_idx:time_end_idx]))

    if np.isfinite(y_min) and np.isfinite(y_max):
        plt.ylim([y_min - abs(y_min * 0.1), y_max + abs(y_max * 0.1)])  # Add 10% margin
    else:
        plt.ylim([-100, 100])  # Fallback y-axis limit

    # Save the figure
    plt.savefig(os.path.join(raw_vs_filter_path, f'channel_{channel + 1}.png'))
    plt.close()


In [None]:
### Adjust here
time_start = 0 # in ms
time_end = 100 # in ms
###

# Convert time from ms to sample indices
time_start_idx = int(np.round(time_start * fs / 1000))  # Start sample index
time_end_idx = int(np.round(time_end * fs / 1000))      # End sample index
print(f'Time range: {time_start_idx} to {time_end_idx} samples')

raw_data = raw.get_data()
bandpass_data = bandpass.get_data()
bandpass_and_notch_data = proc_data.get_data()

# Set up saving path
raw_vs_filter_path = os.path.join(preproc_datapath, 'raw_vs_filter_plots', f'{time_start}ms_to_{time_end}ms')
os.makedirs(raw_vs_filter_path, exist_ok=True)

num_channels = raw_data.shape[0]
time_range = np.arange(time_start_idx, time_end_idx)  # X-axis for ms


for channel in range(num_channels):
    plt.figure(figsize=(10, 5))
    plt.plot(time_range, raw_data[channel, time_start_idx:time_end_idx], label='Raw')
    plt.plot(time_range, bandpass_data[channel, time_start_idx:time_end_idx], label='Bandpass')
    plt.plot(time_range, bandpass_and_notch_data[channel, time_start_idx:time_end_idx], label='Bandpass + Notch')

    # Adding labels and title
    plt.title(f'{subject} Channel {channel + 1} ({raw.ch_names[channel]}) from {time_start}ms to {time_end}ms')
    plt.xlabel('Time (ms)')
    plt.ylabel('Amplitude (uV)')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(raw_vs_filter_path, f'channel_{channel + 1}.png'))
    #plt.show()
    plt.close()


### Clean the data visually

1. Mark in the interactive plot all bad channels, then remove them
2. Mark in the interactive plot all bad epochs by annotations, then remove them

### If available, load in dropped channels from DroppedChans.csv

In [None]:
# Plot filtered data with event markers colored by event types
proc_data.plot(title='Filtered EEG Data', block=True, clipping = None)

In [None]:
# Call in previously dropped channels
file_path = os.path.join(preproc_datapath,'DroppedChans.csv')
dropped_chans = []

with open(file_path, "r") as file:
    reader = csv.reader(file)
    for row in reader:
        dropped_chans.append(row[1])
       
dropped_chans = dropped_chans[1:]
dropped_chans = list(filter(None, dropped_chans))  # Remove empty strings

print(dropped_chans)

In [None]:
# Drop channels (hand-picked or from previous analysis)
bads =['Event', 'L25d9', 'LSPs7', 'C128', 'C127', 'C126', 'C125', 'C123', 'C124', 'C122', 'C121', 'C120', 'C119', 'L17d1', 'L17d7', 'L13d5'] 

proc_data = proc_data.drop_channels(bads)
proc_data.info

## 4. Post Cleaning Preprocessing Steps
- Re-referencing (e.g., common median reference)
- Downsampling
- Epoching

In [None]:
# Re-referencing
proc_data_ref = proc_data.get_data()  # Convert to numpy array
median_lfp = np.median(proc_data_ref, axis=0)  # Calculate median LFP
proc_data_ref = proc_data_ref - median_lfp  # Subtract median LFP for re-referencing
proc_data = mne.io.RawArray(proc_data_ref, proc_data.info)  # Convert back to MNE object

# Downsampling
fs = 500
proc_data = proc_data.resample(sfreq=fs)

# Plot filtered data with event markers colored by event types
#proc_data.plot(title='Filtered EEG Data', block=True, clipping=None)

### Load in the log file and filter by stimulation types to find new images

In [None]:
#Load logfile
logfile = pd.read_csv(logfile_path, delimiter='\t', skiprows=2, skipfooter=1) #tab delimited csv file is our log file format, reject first 2 rows and the last row

# Filter out the NaN or New response condition rows
#enumerate stim categories from log files to numbers
#ORIGINAL
logfile.loc[logfile['CONDITION'] == 'nostim', 'CONDITION'] = 0
logfile.loc[logfile['CONDITION'] == 'stim', 'CONDITION'] = 1
logfile.loc[logfile['CONDITION'] == 'new', 'CONDITION'] = 9
logfile.loc[logfile['CONDITION'] == 'None', 'CONDITION'] = np.nan

# Identify NaN rows
dropped_nan_indices = logfile[logfile['CONDITION'].isna()].index
print(f'Indices of NaN rows: {dropped_nan_indices.tolist()}')

# Drop NaN rows and reorder the index
logfile = logfile.dropna(subset=['CONDITION']).reset_index(drop=True)

# Save CONDITION column to a .npy file
np.save(os.path.join(preproc_datapath, subject + '_stimcondition'), logfile['CONDITION'].values)
np.save(os.path.join(preproc_datapath, subject + '_dropped_nan_indices'), dropped_nan_indices)#save dropped nan indices

# Test loading the saved file
test = np.load(os.path.join(preproc_datapath, subject + '_stimcondition.npy'), allow_pickle=True)
print('length of stim condition',len(test))

### From the same log file find response types and perform Signal Detection analysis

In [None]:
# Extract responses from log file for remembered vs forgotten analysis
logfile.loc[logfile['YES/NO'] == 'yes', 'YES/NO'] = 1 # 1 for yes
logfile.loc[logfile['YES/NO'] == 'no', 'YES/NO'] = 0 # 0 for no

# Ensure no NaN values in YES/NO column
logfile['YES/NO'] = pd.to_numeric(logfile['YES/NO'], errors='coerce')
logfile = logfile.dropna(subset=['YES/NO']).reset_index(drop=True)
logfile['YES/NO'] = logfile['YES/NO'].astype(int) #convert to integer

# Calculate remembered vs forgotten data from responses: from 'Condition' column we can extract image condition as 'new' and anything that's not 'new' as 'target'
hit = logfile[(logfile['CONDITION']!=9) & (logfile['YES/NO'] == 1)]
miss = logfile[(logfile['CONDITION']!=9) & (logfile['YES/NO'] == 0)]
fa= logfile[(logfile['CONDITION'] ==9) & (logfile['YES/NO'] == 1)] 
cr = logfile[(logfile['CONDITION'] ==9) & (logfile['YES/NO'] == 0)]

print('hit:', hit.shape)
print('miss:', miss.shape)
print('fa:', fa.shape)
print('cr:', cr.shape)

nhits = hit.shape[0]
nmiss = miss.shape[0]
nfa= fa.shape[0] # -----> check this with a new subject bc we got 40/40 FA/CR for amyg030
ncr = cr.shape[0]

remembered = hit # where the subject responded YES,and accurately recognized the image
forgotten = miss # where the subject responded NO, and did NOT accurately recognize the image

n_remembered = remembered.shape[0]
n_forgotten = forgotten.shape[0]

hitrate = nhits/(nhits+nmiss)
hits_index = hit.index
miss_index = miss.index
cr_index = cr.index
fa_index = fa.index

print("remembered: ", n_remembered)
print("forgotten: ", n_forgotten)
print("fa:", nfa)
print("cr:", ncr)

### Save Signal Detection data info

In [None]:
# Create a dataframe to store the indices and their corresponding response types
response_data = pd.DataFrame({
    'Index': hit.index.tolist() + miss.index.tolist() + fa.index.tolist() + cr.index.tolist(),
    'Response': ['hit'] * len(hit.index) + ['miss'] * len(miss.index) + ['fa'] * len(fa.index) + ['cr'] * len(cr.index)
})

# Sort the dataframe by the index to maintain the order of the original logfile
response_data = response_data.sort_values(by='Index').reset_index(drop=True)

# Save the dataframe to a CSV file
response_data.to_csv(os.path.join(preproc_datapath, subject + '_SignalDetection_ResponseData.csv'))

#Write signal detection ratios into a txt file
with open(os.path.join(preproc_datapath, subject + '_SignalDetection_Ratios_Before_Dropping_Epochs.txt'), 'w') as f:
    f.write(f'Hits: {nhits}\n')
    f.write(f'Misses: {nmiss}\n')
    f.write(f'False Alarms: {nfa}\n')
    f.write(f'Correct Rejections: {ncr}\n')
    f.write(f'Hit Rate: {hitrate}\n')
    f.write(f'Total Trials: {nhits + nmiss + nfa + ncr}\n')
    f.write(f'Remembered: {n_remembered}\n')
    f.write(f'Forgotten: {n_forgotten}\n')
    f.write(f'Hit Rate: {hitrate}\n')

### After filtering the logfile, drop rows to keep only remembered and forgotten items (no new items)

In [None]:
# Keep track of the original indices before dropping rows
original_indices = logfile.index.to_list()
print("Original length of rows:",len(original_indices))

#Drop rows where condition is 'new', so we can analyze remembered vs forgotten
logfile = logfile[logfile['CONDITION']!=9].reset_index(drop=True)
print("New number of rows:", logfile.shape[0])
print(logfile['YES/NO'])

### Load in trial times from matlab file, filter by response types for remembered and forgotten images (filter out new images)

In [None]:
# Load in event times from mat file, these are the trial times in seconds
events = loadmat(events_path, simplify_cells=True)
day2_trial_times = events['day2_trial_times']
day2_trial_times = day2_trial_times * fs # convert to samples
day2_trial_times = np.array([int(np.round(x)) for x in day2_trial_times]) # round to nearest integer
print('lenght of day2_trial_times:',len(day2_trial_times))
print("length of dropped nan indices:",len(dropped_nan_indices))
print(dropped_nan_indices.shape)   

# Determine the indices to drop in day2_trial_times
indices_to_drop = list(fa.index) + list(cr.index)
print("Number of rows to drop:",len(indices_to_drop))

# Drop the corresponding trial times
day2_trial_times = np.delete(day2_trial_times, indices_to_drop)
print("Length of FA trials:", len(fa.index), "FA indices:", fa.index)
print("Length of CR trials:", len(cr.index),"CR indices:", cr.index)

# Verify the indices match between day2_trial_times and logfile
assert len(day2_trial_times) == logfile.shape[0], "Mismatch between day2_trial_times and logfile"

# Now, day2_trial_times should match the filtered logfile indices
print("Indices match successfully!")
print("Remaining length of day2_trial_times:", len(day2_trial_times))

In [None]:
# Create event array based on response types (remembered vs forgotten)
n_events = len(day2_trial_times)
(print(n_events))
events_array = np.zeros((n_events, 3), dtype=int)
print(events_array.shape)
events_array[:, 0] = day2_trial_times
print(logfile['YES/NO'].values)
print(len(logfile['YES/NO'].values))

# Count the number of zeros
zero_count = np.count_nonzero(logfile['YES/NO'].values ==0)
print("Number of zeros:", zero_count)
events_array[:, 2] = logfile['YES/NO'].values # set event IDs here based on log file info

#Check if events_array[3] is the same as logfile['YES/NO']
if events_array[:, 2].all() == logfile['YES/NO'].values.all():
    print("Event IDs match successfully!")
else:
    print("Event IDs do not match!")

### Epoch data based on response types with events marking remembered (1) versus forgotten (0) items

In [None]:
# Epoching
epochs = mne.Epochs(proc_data, events_array, tmin = -5, tmax = 5, baseline = None, reject=None) # 5s before to 5s after event onset
epochs.plot(title='Epoched EEG Data', block=True, events=events_array)

print('num events',events_array.shape[0])
print('num epochs',len(epochs))

# Drop epochs if needed which will also drop that trial
epochs.info['bads']

## 5. Export the data into numpy arrays for analysis

In [None]:
# Get raw data, list of rejected epochs, list of bad_chans
epoch_data = epochs.get_data()
drop_epochs = [n for n, dl in enumerate(epochs.drop_log) if len(dl)]
events_mask = np.ones(events_array.shape[0], dtype = bool)
events_mask[drop_epochs] = False #drop epochs from events array
keep_events = events_array[events_mask] 

# Export dropped epochs, dropped chans, events, and channel labels to .csv files
np.save(os.path.join(preproc_datapath, ('PreprocessedData')), epoch_data)
np.save(os.path.join(preproc_datapath, ('Events')), keep_events) #saves only the epochs that were not dropped manually
pd.DataFrame(drop_epochs, columns = ['Dropped Epochs']).to_csv(os.path.join(preproc_datapath,'DroppedEpochs.csv'))
pd.DataFrame(bads, columns = ['Dropped Chans']).to_csv(os.path.join(preproc_datapath,'DroppedChans.csv'))
pd.DataFrame(epochs.ch_names, columns = ['Chan']).to_csv(os.path.join(preproc_datapath,'ChanLabels.csv'))

In [None]:
a = np.load(os.path.join(preproc_datapath, ('Events.npy')))