In [1]:
from scripts.neuron_signal_generator import generate_signal
from scripts.noise_and_filtering import generate_electrode_signal
from scripts.simulator import simulate_recording
from scripts.reduce import dimensional_reduction
from scripts.clustering import get_clusters
from scripts.triangulate_neurons import triangulate_neurons

from classes.sorter import SpikeSorter
import numpy as np
import matplotlib.pyplot as plt
import pprint

# from scripts.spike_extraction import get_waveform_data

In [2]:
# generate signal
neuron_signal = generate_signal()

generate_electrode_signal(
    signal=neuron_signal,
    decay_type='square',
    decay_rate=2,
    noise_type='none',
    noise_std=0.5,
    filter_type='bandpass',
    low=500,
    high=3000
)

placements = [0,0,0,0,0,0,0,0,0,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,2,0,0,0,0,0,0,0,0,0,0,0,0,0,2,0,0]
neuronParams = {"neuron_type": "standard", "lambda": 14, "v_rest": -70, "v_thres": -10, "t_ref": 0.02, "fix_random_seed": True}
processingParams = {"decay_type": "square", "decay_rate": 2, "noise_type": "gaussian", "noise_std": 0.5, "filter_type": "none", "low": 500, "high": 3000}

time, filtered_signals = simulate_recording(placements, neuronParams, processingParams)

In [3]:
def get_waveform_data(signals, multiplier, waveform_duration):

    # intialise the spike sorter
    sorter = SpikeSorter(threshold_factor=multiplier, sample_rate=25000, waveform_duration=waveform_duration)

    # get the spikes for each recording
    spikes = sorter.get_spikes(signals)

    # convert the spikes into a numpy array
    spike_indices = np.array(spikes)

    # neuron spikes is used to determine the true labels of the spike data, so we'll just give an empty array
    neuron_spikes = []

    #### something not working here

    # get the locations of all the spikes detected by all channels, and true spike posititons
    merged_spike_indices, true_labels = sorter.merge_spike_indices(spike_indices, neuron_spikes, tolerance=30)

    # extract the waveform data from each of the identified spikes across all electrodes
    waveforms, waveform_info = sorter.get_all_waveforms(
        signals, 
        merged_spike_indices, 
        recenter=True,
        # visualise=True,
        labels=true_labels 
    )

    return waveforms, waveform_info

In [14]:
data = {
    "signals": filtered_signals,
    "extractionParams": {
        "thresholdMultiplier": 4,
        "waveformDuration": 0.3
    }
}

signals = data["signals"]
extraction_data = data["extractionParams"]
multiplier, waveform_duration = extraction_data["thresholdMultiplier"], extraction_data["waveformDuration"]

waveforms, waveform_info = get_waveform_data(signals, multiplier, waveform_duration)

waveforms = waveforms.tolist()
waveform_info = waveform_info.tolist()

  spike_indices = np.array(spikes)


In [15]:
model = "pca"
n_components = 3

reduced_data = dimensional_reduction(model=model, n_components=n_components, waveforms=waveforms)

In [16]:
cluster_type, k_type, k = 'gmm', 'manual', 2

labels = get_clusters(cluster_type, k_type, k, reduced_data)

In [20]:
construction_dict = triangulate_neurons(signals=signals, placements=placements, labels=labels, waveforms=waveforms, decay_type=processingParams["decay_type"])

ModeResult(mode=array([-16.9009408]), count=array([1]))
-16.900940796131067
ModeResult(mode=array([-16.00244474]), count=array([1]))
-16.002444737629787
ModeResult(mode=array([-17.15931598]), count=array([1]))
-17.159315983836535
ModeResult(mode=array([-20.9938797]), count=array([1]))
-20.993879702330172


  mode_result = mode(electrode_dict[i]["signal"])


In [18]:
construction_dict

{'true_neuron_positions': [(3, 3), (5, 5)],
 'all_electrode_positions': [(1, 0), (2, 8), (7, 1), (8, 6)],
 'predicted_neuron_positions': [(3.295159042993193, 3.201920376263164),
  (4.55243793705298, 4.435590712802648)],
 0: {'true_neuron_position': (3, 3),
  'predicted_neuron_position': (3.295159042993193, 3.201920376263164),
  'circles': [((1, 0), 2.790323401526936), ((1, 0), 3.4618741951023213)],
  'used_electrodes': [0, 2, 1],
  'intersecting_lines': [(-6.0, 22.972874634222322),
   (-0.125, 3.613815256637312)]},
 1: {'true_neuron_position': (5, 5),
  'predicted_neuron_position': (4.55243793705298, 4.435590712802648),
  'circles': [((8, 6), 2.7759348940839996), ((8, 6), 2.2101520633313303)],
  'used_electrodes': [3, 1, 2],
  'intersecting_lines': [(3.0, -9.221723098356293),
   (-0.2, 5.3460783002132475)]}}

In [22]:
print("waveforms", type(waveforms), len(waveforms), len(waveforms[0]), len(waveforms[0][0]), type(waveforms[0][0][0]))
print("signals", type(filtered_signals), len(filtered_signals), len(filtered_signals[0]), type(filtered_signals[0][0]))
print("labels", type(labels), len(labels), type(labels[0]))
print("placements", type(placements), len(placements), type(placements[0]))

waveforms <class 'list'> 4 16 7 <class 'float'>
signals <class 'list'> 4 25000 <class 'float'>
labels <class 'list'> 16 <class 'int'>
placements <class 'list'> 81 <class 'int'>
