<h2>Photodiode Analysis Code</h2>

Ansley Kunnath & Andrew Kim

Updated 07/05/24

In [1]:
# Load data
########## Run with Python 3.9.12 (for Ansley)

import mne
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use("TkAgg")
import scipy.stats as stats
#json import needed
import json

########## You may need to change the path and file name:
eeg_path = "C://Users//neuro//Documents//Git_EEG_Workshop//EEG_Workshop//Data//" 
file_name = "VEP_Test2"
file_eeg = eeg_path + file_name + ".eeg"
file_vhdr = eeg_path + file_name + ".vhdr"
file_vmrk = eeg_path + file_name + ".vmrk"

# Load and plot the raw data
raw = mne.io.read_raw_brainvision(file_vhdr)
events, event_id = mne.events_from_annotations(raw)
raw2 = raw.load_data().filter(l_freq=None, h_freq=40)

#raw.crop(tmin=22, tmax=190)
raw2.plot(picks="EKG")

Extracting parameters from C://Users//neuro//Documents//Git_EEG_Workshop//EEG_Workshop//Data//VEP_Test2.vhdr...
Setting channel info structure...
Used Annotations descriptions: ['Marker/Impedance', 'New Segment/', 'Stimulus/s1', 'Stimulus/s2', 'Stimulus/s5']
Reading 0 ... 178208  =      0.000 ...   178.208 secs...
Filtering raw data in 1 contiguous segment
Setting up low-pass filter at 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal lowpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 331 samples (0.331 s)



  raw = mne.io.read_raw_brainvision(file_vhdr)
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done  71 tasks      | elapsed:    0.1s


Using matplotlib as 2D backend.


<MNEBrowseFigure size 1605x800 with 4 Axes>

Channels marked as bad:
none


In [2]:
# Create epochs for checkerboard events

stimulus_s1_events = events[events[:, 2] == event_id['Stimulus/s2']]
stimulus_s1_events, event_id
tmin, tmax = 0, 0.500 
epochs = mne.Epochs(raw2, events=stimulus_s1_events, event_id=event_id['Stimulus/s2'],
                    tmin=tmin, tmax=tmax, baseline=None, preload=True)
epochs = epochs.pick_channels(['EKG'])

#epochs.plot() 

Not setting metadata
50 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 50 events and 501 original time points ...
0 bad epochs dropped
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


In [5]:
# Calculate the first latency that exceeds # of MADs for each epoch 

########## CHANGE MAD_FACTOR
mad_factor = 4 # or 2 to be less strict

median_amplitude = []
first_time = []
peak_latencies = []

for index, epoch in enumerate(epochs.get_data()):
    positive_epoch = abs(epoch)*1000
    median_amplitude = np.median(positive_epoch)
    mad = np.median(np.abs(positive_epoch - median_amplitude))
    threshold = median_amplitude + (mad_factor * mad)
    exceed_index = np.argmax(positive_epoch > threshold)
    if exceed_index > 0:
        first_time = epochs.times[exceed_index] 
    else:
        first_time = None 
        print(f"No exceedance found in epoch {index}") 
    peak_latencies.append(first_time)

# Calculate latencies in milliseconds
peak_latencies_ms = np.array([lat * 1000 if lat is not None else None for lat in peak_latencies])
valid_latencies = peak_latencies_ms[peak_latencies_ms != np.array(None)]
total_events = len(valid_latencies)

# Calculate average and confidence interval only for valid latencies
if len(valid_latencies) > 0:
    median_latency_ms = np.median(valid_latencies)
    sem_latency = stats.sem(valid_latencies)  # SEM = std / sqrt(n)
    confidence_level = 0.95
    ci_width = sem_latency * stats.t.ppf((1 + confidence_level) / 2, len(valid_latencies) - 1)
    confidence_interval = (median_latency_ms - ci_width, median_latency_ms + ci_width)
else:
    median_latency_ms = None
    confidence_interval = (None, None)

print(f"Median Latency: {median_latency_ms:.0f} ms")
print(f"95% Confidence Interval: ({confidence_interval[0]:.3f}, {confidence_interval[1]:.3f})")
print(f"Total Events: {total_events} out of 50")

#save the latency values separately
peak_latencies_ms = [lat * 1000 if lat is not None else None for lat in peak_latencies]
with open('photodiode_latencies.json', 'w') as f:
    json.dump(peak_latencies_ms, f)

print("Photodiode latencies saved to photodiode_latencies.json")

Median Latency: 309 ms
95% Confidence Interval: (300.351, 317.649)
Total Events: 50 out of 50
Photodiode latencies saved to photodiode_latencies.json


  for index, epoch in enumerate(epochs.get_data()):


In [4]:
# Plot individual epochs

########## CHANGE X_VALUES BASED ON WHICH EPOCHS DID NOT EXCEED THE THRESHOLD
x_values = [1, 2, 3, 4, 5]

fig, axes = plt.subplots(len(x_values), 1, figsize=(10, 10), sharex=True, sharey=False)
for i, x in enumerate(x_values):
    epoch = epochs.get_data()[x]
    abs_epoch = abs(epoch[0]) * 1000 
    median_amplitude = np.median(abs_epoch)
    mad = np.median(np.abs(abs_epoch - median_amplitude))
    threshold = median_amplitude + (mad_factor * mad)
    times_in_ms = epochs.times * 1000
    exceed_index = np.argmax(abs_epoch > threshold)
    if exceed_index > 0:
        first_time = epochs.times[exceed_index] * 1000 
    else:
        first_time = None  # No point exceeded the threshold
    
    min_time = min(times_in_ms)
    max_time = max(times_in_ms)
    vertical_lines = np.arange(min_time, max_time, 2)

    axes[i].plot(times_in_ms, abs_epoch)
    axes[i].axvline(x=first_time if first_time is not None else 0, color='black', 
        label=f"Screen Change: {first_time:.2f} ms" if first_time is not None else "Screen Change: None")
    axes[i].axhline(y=threshold, color='r', linestyle='--', label=f"Threshold: {threshold:.2f} mV")
    axes[i].axhline(y=median_amplitude, color='b', linestyle='--', label=f"Median: {median_amplitude:.2f} mV")
    for line in vertical_lines:
        axes[i].axvline(x=line, color='gray', linestyle='--', linewidth=0.5, alpha=0.5)
    axes[i].set_title(f"Epoch {x}")
    axes[i].legend()

plt.xlabel("Time (ms)")
plt.ylabel("Amplitude (mV)")
plt.savefig('Plot Epochs.png')
plt.show()


  epoch = epochs.get_data()[x]
  epoch = epochs.get_data()[x]
  epoch = epochs.get_data()[x]
  epoch = epochs.get_data()[x]
  epoch = epochs.get_data()[x]


In [None]:
# Histogram of latency distributions

########## SET BIN SIZE
num_bins = 6 

plt.figure(figsize=(10, 6))
bin_edges = np.linspace(min(valid_latencies), max(valid_latencies), num_bins + 1)
rounded_bin_edges = np.round(bin_edges)
n, bins, patches = plt.hist(valid_latencies, bins=bin_edges, edgecolor='black', linewidth=1.5)
plt.xticks(rounded_bin_edges)
for count, x in zip(n, bins[:-1]):
    plt.text(x + (bins[1]-bins[0])/2, count, str(int(count)), ha='center', va='bottom')
plt.title('Histogram of Latencies')
plt.xlabel('Latency (ms)')
plt.ylabel('Frequency')
plt.savefig('Histogram.png')
plt.show()


In [6]:
# Scatter plot of the latencies

plt.figure(figsize=(10, 6))
plt.scatter(range(len(valid_latencies)), valid_latencies, color='black')
plt.axhline(median_latency_ms, color='red', linestyle='dashed', linewidth=2)
plt.title('Scatter Plot of Latencies')
plt.xlabel('Event')
plt.ylabel('Latency (ms)')
plt.grid(True)
plt.savefig('Scatter Plot.png')
plt.show()

In [12]:
import mne
import numpy as np
from mne.preprocessing import (ICA)
#from autoreject import AutoReject
import matplotlib

#for storing and exchanging data
import json
matplotlib.use("TkAgg")

#load the photodiode latency data
with open('photodiode_latencies.json', 'r') as f:
    photodiode_latencies = json.load(f)

# Adjust epoch start times based on photodiode latencies
adjusted_events = []
for event, latency in zip(events, photodiode_latencies):
    if latency is not None:
        adjusted_event = event.copy()
        adjusted_event[0] += int(latency)  # Adjusting event start time by photodiode latency
        adjusted_events.append(adjusted_event)

adjusted_events = np.array(adjusted_events)

# Create epochs with adjusted events
#highpass = 0.2
#lowpass = 40
#notch = 60
#raw_filtered = raw.load_data().filter(highpass, lowpass).notch_filter(np.arange(notch, (notch * 3), notch))

tmin, tmax = -0.100, 0.500
epochs = mne.Epochs(raw, events=adjusted_events, event_id=event_id, tmin=tmin, tmax=tmax, baseline=None, preload=True)

# Continue with VEP analysis as usual
#evoked = epochs.average()
#evoked.plot()

#baseline_tmin, baseline_tmax = -0.050, 0
#baseline = (baseline_tmin, baseline_tmax)
#VEP = epochs['Stimulus/s2'].apply_baseline(baseline).average()
#blank = epochs['Stimulus/s1'].apply_baseline(baseline).average()

VEP = epochs['Stimulus/s2'].average()
blank = epochs['Stimulus/s1'].average()

fig = mne.viz.plot_compare_evokeds(VEP, picks=['Z13'], combine="mean", show=False, time_unit="ms")
fig[0].savefig("VEP_Z13")
fig = mne.viz.plot_compare_evokeds(VEP, picks=['EKG'], combine="mean", show=False, time_unit="ms")
fig[0].savefig("VEP_EKG")

Not setting metadata
50 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 50 events and 601 original time points ...
0 bad epochs dropped
combining channels using "mean"
combining channels using "mean"


  epochs = mne.Epochs(raw, events=adjusted_events, event_id=event_id, tmin=tmin, tmax=tmax, baseline=None, preload=True)
  fig = mne.viz.plot_compare_evokeds(VEP, picks=['Z13'], combine="mean", show=False, time_unit="ms")
  fig = mne.viz.plot_compare_evokeds(VEP, picks=['Z13'], combine="mean", show=False, time_unit="ms")
  fig = mne.viz.plot_compare_evokeds(VEP, picks=['EKG'], combine="mean", show=False, time_unit="ms")
  fig = mne.viz.plot_compare_evokeds(VEP, picks=['EKG'], combine="mean", show=False, time_unit="ms")
