<h2>Photodiode Analysis Code</h2>

Ansley Kunnath

Updated 04/15/24

In [1]:
# Load data
########## Run with Python 3.9.12 (for Ansley)

import mne
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use("TkAgg")
import scipy.stats as stats

########## You may need to change the path and file name:
eeg_path = "Data/"  
#file_name = "PhotoDiode_2024-03-22_10-24-46"
#file_name = "Subject_Example_2024-03-29_10-51-49" 
file_name = "Subject_Example_2024-03-29_10-59-19"
file_eeg = eeg_path + file_name + ".eeg"
file_vhdr = eeg_path + file_name + ".vhdr"
file_vmrk = eeg_path + file_name + ".vmrk"

# Load and plot the raw data
raw = mne.io.read_raw_brainvision(file_vhdr)
events, event_id = mne.events_from_annotations(raw)

#raw.crop(tmin=22, tmax=190)
#raw.plot()

Extracting parameters from Data/Subject_Example_2024-03-29_10-59-19.vhdr...
Setting channel info structure...
Used Annotations descriptions: ['Marker/Impedance', 'New Segment/', 'Stimulus/s1', 'Stimulus/s2', 'Stimulus/s3', 'Stimulus/s5']


  raw = mne.io.read_raw_brainvision(file_vhdr)


In [2]:
# Create epochs for checkerboard events

stimulus_s1_events = events[events[:, 2] == event_id['Stimulus/s1']]
stimulus_s1_events, event_id
tmin, tmax = 0, 0.250  
epochs = mne.Epochs(raw, events=stimulus_s1_events, event_id=event_id['Stimulus/s1'],
                    tmin=tmin, tmax=tmax, baseline=None, preload=True)
epochs = epochs.pick_channels(['BIP3'])

#epochs.plot() 

Not setting metadata
100 matching events found
No baseline correction applied
0 projection items activated
Loading data for 100 events and 126 original time points ...
0 bad epochs dropped


In [3]:
# Calculate the first latency that exceeds # of MADs for each epoch 

########## CHANGE MAD_FACTOR
mad_factor = 4 # or 2 to be less strict

median_amplitude = []
first_time = []
peak_latencies = []

for index, epoch in enumerate(epochs.get_data()):
    positive_epoch = abs(epoch)*1000
    median_amplitude = np.median(positive_epoch)
    mad = np.median(np.abs(positive_epoch - median_amplitude))
    threshold = median_amplitude + (mad_factor * mad)
    exceed_index = np.argmax(positive_epoch > threshold)
    if exceed_index > 0:
        first_time = epochs.times[exceed_index] 
    else:
        first_time = None 
        print(f"No exceedance found in epoch {index}") 
    peak_latencies.append(first_time)

# Calculate latencies in milliseconds
peak_latencies_ms = np.array([lat * 1000 if lat is not None else None for lat in peak_latencies])
valid_latencies = peak_latencies_ms[peak_latencies_ms != np.array(None)]
total_events = len(valid_latencies)

# Calculate latency CI in milliseconds
ncies_ms = np.array([lat * 1000 if lat is not None else None for lat in peak_latencies])
valid_latencies = peak_latencies_ms[peak_latencies_ms != np.array(None)]

# Calculate average and confidence interval only for valid latencies
if len(valid_latencies) > 0:
    median_latency_ms = np.median(valid_latencies)
    sem_latency = stats.sem(valid_latencies)  # SEM = std / sqrt(n)
    confidence_level = 0.95
    ci_width = sem_latency * stats.t.ppf((1 + confidence_level) / 2, len(valid_latencies) - 1)
    confidence_interval = (median_latency_ms - ci_width, median_latency_ms + ci_width)
else:
    median_latency_ms = None
    confidence_interval = (None, None)

print("")
print(f"Median Latency: {median_latency_ms:.0f} ms")
print(f"95% Confidence Interval: ({confidence_interval[0]:.3f}, {confidence_interval[1]:.3f})")
print(f"Total Events: {total_events} out of 100")


No exceedance found in epoch 19
No exceedance found in epoch 39
No exceedance found in epoch 59
No exceedance found in epoch 79
No exceedance found in epoch 99

Median Latency: 72 ms
95% Confidence Interval: (67.693, 76.307)
Total Events: 95 out of 100


In [6]:
# Plot individual epochs

########## CHANGE X_VALUES BASED ON WHICH EPOCHS DID NOT EXCEED THE THRESHOLD
x_values = [1, 2, 3, 4, 5]

fig, axes = plt.subplots(len(x_values), 1, figsize=(10, 10), sharex=True, sharey=False)
for i, x in enumerate(x_values):
    epoch = epochs.get_data()[x]
    abs_epoch = abs(epoch[0]) * 1000 
    median_amplitude = np.median(abs_epoch)
    mad = np.median(np.abs(abs_epoch - median_amplitude))
    threshold = median_amplitude + (mad_factor * mad)
    times_in_ms = epochs.times * 1000
    exceed_index = np.argmax(abs_epoch > threshold)
    if exceed_index > 0:
        first_time = epochs.times[exceed_index] * 1000 
    else:
        first_time = None  # No point exceeded the threshold
    
    min_time = min(times_in_ms)
    max_time = max(times_in_ms)
    vertical_lines = np.arange(min_time, max_time, 2)

    axes[i].plot(times_in_ms, abs_epoch)
    axes[i].axvline(x=first_time if first_time is not None else 0, color='black', 
        label=f"Screen Change: {first_time:.2f} ms" if first_time is not None else "Screen Change: None")
    axes[i].axhline(y=threshold, color='r', linestyle='--', label=f"Threshold: {threshold:.2f} mV")
    axes[i].axhline(y=median_amplitude, color='b', linestyle='--', label=f"Median: {median_amplitude:.2f} mV")
    for line in vertical_lines:
        axes[i].axvline(x=line, color='gray', linestyle='--', linewidth=0.5, alpha=0.5)
    axes[i].set_title(f"Epoch {x}")
    axes[i].legend()

plt.xlabel("Time (ms)")
plt.ylabel("Amplitude (mV)")
plt.savefig('Plot Epochs.png')
plt.show()


In [40]:
# Histogram of latency distributions

########## SET BIN SIZE
num_bins = 6 

plt.figure(figsize=(10, 6))
bin_edges = np.linspace(min(valid_latencies), max(valid_latencies), num_bins + 1)
rounded_bin_edges = np.round(bin_edges)
n, bins, patches = plt.hist(valid_latencies, bins=bin_edges, edgecolor='black', linewidth=1.5)
plt.xticks(rounded_bin_edges)
for count, x in zip(n, bins[:-1]):
    plt.text(x + (bins[1]-bins[0])/2, count, str(int(count)), ha='center', va='bottom')
plt.title('Histogram of Latencies')
plt.xlabel('Latency (ms)')
plt.ylabel('Frequency')
plt.savefig('Histogram.png')
plt.show()


In [41]:
# Scatter plot of the latencies

plt.figure(figsize=(10, 6))
plt.scatter(range(len(valid_latencies)), valid_latencies, color='black')
plt.axhline(median_latency_ms, color='red', linestyle='dashed', linewidth=2)
plt.title('Scatter Plot of Latencies')
plt.xlabel('Event')
plt.ylabel('Latency (ms)')
plt.grid(True)
plt.savefig('Scatter Plot.png')
plt.show()
