In [None]:
import pyabf
from scipy.signal import butter, sosfiltfilt
import matplotlib.pyplot as plt
import numpy as np
from gcr_utils import *

%load_ext autoreload
%autoreload 2

In [None]:
# Load Data
data_loader = DataLoader()
file_path='07102024_post/2024_07_10_0014_ well 1 C1-C8_ 2 week_ 0.1mil at center_ no BOE_ 1000mV at mV1.abf'
raw_data = data_loader.load_data(file_path, exclude=True, channels=[1])#, start_time=0, end_time=50
stimulation_electrode=1

In [None]:
# Visualization
visualizer = Visualizer()

# Filter data
filter = Filter()
# Bandpass filter
data = filter.bandpass_filter(raw_data, lowcut=300, highcut=2500, order=4)
visualizer.multi_channel_plot(data, exclude=True, channels=[], spikes=None, stimulation_times=None)

spike_detector = SpikeDetector()
stimulation_dict = spike_detector.threshold_detection(data=data, thresholds=[100,100], spike_time = 1000, min_consecutive_time = 0)
stimulation_times = find_stimulation_times(stimulation_dict, tolerance=0.01, channel_fraction=0.33)
print( "Stimulation times:", stimulation_times)

In [None]:
# Filter stimulation
data = filter.stimulation_filter(data, stimulation_times)
visualizer.multi_channel_plot(data, exclude=True, channels=[], spikes=None, stimulation_times=stimulation_times)
# Filter to mean zero across time
data = filter.temporal_zeroing(data)
visualizer.multi_channel_plot(data, exclude=True, channels=[], spikes=None, stimulation_times=stimulation_times)
# Filter to mean zero across channels (eliminate errant fluctuations)
#data = filter.interchannel_zeroing(data)
#visualizer.multi_channel_plot(data, exclude=True, channels=[], spikes=None, stimulation_times=stimulation_times)

In [None]:
# Detect spikes using threshold method
spike_detector = SpikeDetector()
spikes = spike_detector.threshold_detection(data=data, thresholds=[4,4], spike_time = 2, min_consecutive_time = .2)
visualizer.multi_channel_plot(data, exclude=True, channels=[], spikes=spikes, stimulation_times=stimulation_times)
# Apply basic post detection processing
pdp = PostDetectionProcessing()
# Eliminate obvious statistical anomalties
spikes = pdp.statistical_elimination(data=data, spike_train = spikes, window=2)
visualizer.multi_channel_plot(data, exclude=True, channels=[], spikes=spikes, stimulation_times=stimulation_times)
# Align spikes
spikes = pdp.align_spikes(data=data, spike_train=spikes, alignment = 'min', window=2, interpolation_factor = 4)
#spikes = pdp.multi_point_align_spikes(data=data, spike_train=spikes, window=2, interpolation_factor = 4, max_shift=10, n_peaks=3, n_bins=50)
visualizer.multi_channel_overlay_spikes(data, spikes=spikes, window=2)


In [None]:
# Spike sorting
sorter = SpikeSorter()
# Remove PCA outliers
features, clusters = sorter.spike_sorting_pipeline( data=data, spikes=spikes, window = 2, n_components=3, n_clusters=2, dimensionality_reduction_method = 'tsne', cluster = True)

In [None]:
visualizer.plot_representation(features=features,clusters=clusters)

In [None]:
stats = SpikeStatistics()

# Example parameters
sigma = 15  # ms
size = int(2 * (3 * sigma) + 1)  # Ensure the kernel captures most of the Gaussian
gaussian_kernel = stats.gaussian_kernel(size, sigma)
count_kernel = stats.count_kernel(100)
decoded_train = stats.decode(spikes, gaussian_kernel, time_window=[0,50])

In [None]:
plot_window = np.array(stimulation_times[0])+np.array([-.1,0.2])
print(plot_window)
visualizer.multi_channel_plot(decoded_train, exclude=True, channels=[], spikes=None, stimulation_times=stimulation_times, time_window = list(plot_window))

In [None]:
frequencies = stats.calculate_frequency_after_stimulation(spike_train=spikes, time_window=[0,50], stimulation_times=stimulation_times, window_size=0.2, spike_window=1)
for freq in frequencies:
    print('Channel ' + str(freq))
    print(frequencies[freq]['frequencies_per_stimulation'])
for freq in frequencies:
    print('Average Frequency for Channel ' + str(freq))
    print(frequencies[freq]['average_frequency'])

In [None]:
# Load the YAML file
file_path = '8_gcr.yaml'
electrode_data = load_yaml(file_path)

absolute_peak_times, peak_times = stats.get_peak_times(decoded_train, stimulation_times, window_size=0.2)
visualizer.plot_peak_by_distance(peak_times, electrode_data, stimulation_electrode)

In [None]:
# Plot 2D electrode positions
visualizer.plot_2d_electrodes(electrode_data)

In [None]:
# Plot 2D electrode positions frequency animation
visualizer.show_frequency_after_stim_2D(electrode_data, decoded_train, time_window=plot_window, stimulation_times=stimulation_times[0], stimulation_electrode=stimulation_electrode)

In [None]:
df=pre_and_post_analysis('07102024_pre', '07102024_post')

In [None]:
visualizer.plot_frequencies_over_amplitudes(df)