<h2>Photodiode Analysis Code</h2>

Adam Tiesman, Ansley Kunnath

Updated 07/10/24

In [None]:
# Define variables
baseline_apply = 'True'
electrode_of_interest = 'Z13'
photodiode_electrode = 'EKG'

In [None]:
# Load libraries and data
import mne
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use("TkAgg")
import scipy.stats as stats
#json import needed
import json

## Define directory and file path:
# Change eeg_path if using different computer, change file_name for each new individual file
# EEG data should be stored into 3 file types, .eeg (raw EEG), .vhdr (electrode naming library), and .vmrk (markers recorded in raw EEG)
eeg_path = "C://Users//neuro//Documents//Git_EEG_Workshop//EEG_Workshop//Data//" 
file_name = "VEP_Test2"
file_eeg = eeg_path + file_name + ".eeg"
file_vhdr = eeg_path + file_name + ".vhdr"
file_vmrk = eeg_path + file_name + ".vmrk"

# Load and plot the raw data
raw = mne.io.read_raw_brainvision(file_vhdr)
events, event_id = mne.events_from_annotations(raw)
raw2 = raw.load_data().filter(l_freq=None, h_freq=40) # applying a low pass filter @ 40 Hz

raw2.plot(picks=photodiode_electrode)

In [None]:
# Create photodiode epochs for checkerboard events given a window, timelocked to EEG marker

stimulus_s1_events = events[events[:, 2] == event_id['Stimulus/s2']]
stimulus_s1_events, event_id
tmin, tmax = 0, 0.500 
epochs = mne.Epochs(raw2, events=stimulus_s1_events, event_id=event_id['Stimulus/s2'],
                    tmin=tmin, tmax=tmax, baseline=None, preload=True)
epochs = epochs.pick_channels(['EKG'])

epochs.plot() 

In [None]:
# Calculate the first latency that exceeds # of MADs for each epoch 

mad_factor = 4 # or 2 to be less strict

median_amplitude = []
first_time = []
peak_latencies = []

for index, epoch in enumerate(epochs.get_data()):
    positive_epoch = abs(epoch)*1000
    median_amplitude = np.median(positive_epoch)
    mad = np.median(np.abs(positive_epoch - median_amplitude))
    threshold = median_amplitude + (mad_factor * mad)
    exceed_index = np.argmax(positive_epoch > threshold)
    if exceed_index > 0:
        first_time = epochs.times[exceed_index] 
    else:
        first_time = None 
        print(f"No exceedance found in epoch {index}") 
    peak_latencies.append(first_time)

# Calculate latencies in milliseconds
peak_latencies_ms = np.array([lat * 1000 if lat is not None else None for lat in peak_latencies])
valid_latencies = peak_latencies_ms[peak_latencies_ms != np.array(None)]
total_events = len(valid_latencies)

# Calculate average and confidence interval only for valid latencies
if len(valid_latencies) > 0:
    median_latency_ms = np.median(valid_latencies)
    sem_latency = stats.sem(valid_latencies)  # SEM = std / sqrt(n)
    confidence_level = 0.95
    ci_width = sem_latency * stats.t.ppf((1 + confidence_level) / 2, len(valid_latencies) - 1)
    confidence_interval = (median_latency_ms - ci_width, median_latency_ms + ci_width)
else:
    median_latency_ms = None
    confidence_interval = (None, None)

print(f"Median Latency: {median_latency_ms:.0f} ms")
print(f"95% Confidence Interval: ({confidence_interval[0]:.3f}, {confidence_interval[1]:.3f})")
print(f"Total Events: {total_events} out of 50")

#save the latency values separately
peak_latencies_ms = [lat * 1000 if lat is not None else None for lat in peak_latencies]
with open('photodiode_latencies.json', 'w') as f:
    json.dump(peak_latencies_ms, f)

In [None]:
# Plot individual photodiode epochs. These should produce a positive deflection from baseline and screen change line should be right at start of deflection.

# These x values can be changed to show epochs of interest. For example, can change to x values that did not exceed MAD factor
x_values = [1, 2, 3, 4, 5]

fig, axes = plt.subplots(len(x_values), 1, figsize=(10, 10), sharex=True, sharey=False)
for i, x in enumerate(x_values):
    epoch = epochs.get_data()[x]
    abs_epoch = abs(epoch[0]) * 1000 
    median_amplitude = np.median(abs_epoch)
    mad = np.median(np.abs(abs_epoch - median_amplitude))
    threshold = median_amplitude + (mad_factor * mad)
    times_in_ms = epochs.times * 1000
    exceed_index = np.argmax(abs_epoch > threshold)
    if exceed_index > 0:
        first_time = epochs.times[exceed_index] * 1000 
    else:
        first_time = None  # No point exceeded the threshold
    
    min_time = min(times_in_ms)
    max_time = max(times_in_ms)
    vertical_lines = np.arange(min_time, max_time, 2)

    axes[i].plot(times_in_ms, abs_epoch)
    axes[i].axvline(x=first_time if first_time is not None else 0, color='black', 
        label=f"Screen Change: {first_time:.2f} ms" if first_time is not None else "Screen Change: None")
    axes[i].axhline(y=threshold, color='r', linestyle='--', label=f"Threshold: {threshold:.2f} mV")
    axes[i].axhline(y=median_amplitude, color='b', linestyle='--', label=f"Median: {median_amplitude:.2f} mV")
    for line in vertical_lines:
        axes[i].axvline(x=line, color='gray', linestyle='--', linewidth=0.5, alpha=0.5)
    axes[i].set_title(f"Epoch {x}")
    axes[i].legend()

plt.xlabel("Time (ms)")
plt.ylabel("Amplitude (mV)")
plt.savefig('Plot Epochs.png')
plt.show()


In [None]:
# Histogram of latency distributions

# Bin size can change depending on how precise you are looking for your latencies
num_bins = 6 

plt.figure(figsize=(10, 6))
bin_edges = np.linspace(min(valid_latencies), max(valid_latencies), num_bins + 1)
rounded_bin_edges = np.round(bin_edges)
n, bins, patches = plt.hist(valid_latencies, bins=bin_edges, edgecolor='black', linewidth=1.5)
plt.xticks(rounded_bin_edges)
for count, x in zip(n, bins[:-1]):
    plt.text(x + (bins[1]-bins[0])/2, count, str(int(count)), ha='center', va='bottom')
plt.title('Histogram of Latencies')
plt.xlabel('Latency (ms)')
plt.ylabel('Frequency')
plt.savefig('Histogram.png')



In [None]:
# Scatter plot of the latencies, better graphical depiction of latency precision

plt.figure(figsize=(10, 6))
plt.scatter(range(len(valid_latencies)), valid_latencies, color='black')
plt.axhline(median_latency_ms, color='red', linestyle='dashed', linewidth=2)
plt.title('Scatter Plot of Latencies')
plt.xlabel('Event')
plt.ylabel('Latency (ms)')
plt.grid(True)
plt.savefig('Scatter Plot.png')
plt.show()

In [None]:
from mne.preprocessing import (ICA)

#for storing and exchanging data
import json
matplotlib.use("TkAgg")

#load the photodiode latency data
with open('photodiode_latencies.json', 'r') as f:
    photodiode_latencies = json.load(f)


adjusted_events = []
for event, latency in zip(stimulus_s1_events, photodiode_latencies):
    if latency is not None:
        adjusted_event = np.copy(event)  # Create a copy of the individual event row
        adjusted_event[0] += int(latency)  # Adjusting event start time by photodiode latency
        adjusted_events.append(adjusted_event)


# Convert adjusted_events back to a NumPy array
adjusted_events = np.array(adjusted_events, dtype=int)
tmin, tmax = -0.100, 0.500 # Set ERP window

epochs = mne.Epochs(raw2, events=adjusted_events, event_id=event_id['Stimulus/s2'], tmin=tmin, tmax=tmax, baseline=None, preload=True)

baseline_tmin, baseline_tmax = -0.050, 0
baseline = (baseline_tmin, baseline_tmax)
## electrode_of_interest ERP
if baseline_apply:
    VEP = epochs.apply_baseline(baseline).average(picks=electrode_of_interest, method='mean', by_event_type=False)
    electrode_of_interest_savename = "{}_VEP_{}_baselinecorrect".format(file_name, electrode_of_interest)
else:
    VEP = epochs.average(picks=electrode_of_interest, method='mean', by_event_type=False)
    electrode_of_interest_savename = "{}_VEP_{}".format(file_name, electrode_of_interest)

fig = mne.viz.plot_evoked(VEP, picks=[electrode_of_interest], time_unit="ms")
fig.savefig(electrode_of_interest_savename)
plt.show()

## photodiode_electrode ERP
if baseline_apply:
    blank = epochs.apply_baseline(baseline).average(picks=photodiode_electrode, method='mean', by_event_type=False)
    photodiode_savename = "{}_VEP_{}_baselinecorrect".format(file_name, photodiode_electrode)
else:
    blank = epochs.average(picks=photodiode_electrode, method='mean', by_event_type=False)
    photodiode_savename = "{}_VEP_{}".format(file_name, photodiode_electrode)

fig = mne.viz.plot_evoked(blank, picks=[photodiode_electrode], time_unit="ms")
fig.savefig(photodiode_savename)
plt.show()

In [None]:
## TESTING AND WIP DO NOT LOOK AT THIS CODE ENDS ABOVE #################################

# Create epochs with adjusted events
#highpass = 0.2
#lowpass = 40
#notch = 60
#raw_filtered = raw.load_data().filter(highpass, lowpass).notch_filter(np.arange(notch, (notch * 3), notch))

#events = mne.make_fixed_length_events(raw, id=1, start=0, stop=None, duration=1.0, first_samp=True, overlap=0.0)
#events = mne.find_events(raw2, stim_channel=None, output='onset', consecutive='increasing', min_duration=0, shortest_event=2, mask=None, uint_cast=False, mask_type='and', initial_event=False, verbose=None)
#events = mne.events_from_annotations(raw2, event_id='auto', regexp='None', use_rounding=True, chunk_duration=None, tol=1e-08, verbose=None)

#epochs_adjusted = epochs_adjusted.pick.channels(['EKG'])

#epochs = mne.Epochs(raw, events=photodiode_latencies, event_id=event_id, tmin=tmin, tmax=tmax, baseline=None, preload=True)

# Continue with VEP analysis as usual
#evoked = epochs.average()
#evoked.plot()
