# Example/tutorial of how Meghan's code/ephys analysis works
## Each Phy folder contains 3 files: 
#### **cluster_group.tsv** which is just a list of each neuron from Phy and whether it was classified as good, mua, or noise
#### **spike_times.npy** which is essentially one long array of spike times, regardless of which neuron fired
#### **spike_clusters.npy** which is essentially one long array that aligns with the spike times, and says which neuron it was that fired
##### The files made to create this notebook are not stored with the notebook, so **don't rerun any of this**, just read it.

In [39]:
import os
import csv
import math

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from scipy.stats import sem, ranksums, fisher_exact, wilcoxon
from statistics import mean, StatisticsError
from sklearn.decomposition import PCA
from mpl_toolkits.mplot3d import Axes3D
from scipy.spatial.distance import euclidean
from itertools import combinations
from sklearn.metrics import roc_auc_score
from sklearn.ensemble import BaggingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC

spike_clusters = np.load('spike_clusters.npy')

spike_times = np.load('spike_times.npy')

In [3]:
spike_clusters

array([85, 41, 55, ..., 14, 26, 87], dtype=int32)

In [4]:
len(spike_clusters)

483418

In [5]:
spike_times

array([[     689],
       [    1754],
       [    1829],
       ...,
       [68293021],
       [68293168],
       [68293331]])

In [6]:
len(spike_times)

483418

In [8]:
spike_clusters[0]

85

In [9]:
spike_times[0]

array([689])

In [11]:
spike_times.shape

(483418, 1)

In [12]:
spike_clusters.shape

(483418,)

In [16]:
spike_times2 = spike_times.flatten()
neuron_spike_dict = {}

# Iterate over each spike time and cluster ID
for cluster_id, spike_time in zip(spike_clusters, spike_times2):
    # If the cluster_id is not yet a key in the dictionary, add it with an empty list
    if cluster_id not in neuron_spike_dict:
        neuron_spike_dict[cluster_id] = []
    # Append the spike time to the list of spike times for this cluster_id
    neuron_spike_dict[cluster_id].append(spike_time)

In [22]:
def get_spiketrain(
    timestamp_array, last_timestamp, timebin=1, sampling_rate=20000
):
    """
    creates a spiketrain of ms time bins
    each array element is the number of spikes recorded per ms

    Args (3 total):
        timestamp_array: numpy array, spike timestamp array
        timebin: int, default=1, timebin (ms) of resulting spiketrain
        sampling_rate: int, default=20000, sampling rate
        in Hz of the ephys recording

    Returns (1):
        spiketrain: numpy array, array elements are number
        of spikes per timebin
    """
    hz_to_timebin = int(sampling_rate * 0.001 * timebin)
    spiketrain = np.histogram(
        timestamp_array, bins=np.arange(0, last_timestamp, hz_to_timebin)
    )[0]
    return spiketrain

In [23]:
last_timestamp = max([max(times) for times in neuron_spike_dict.values() if times])  # Ensure the list is not empty

neuron_spiketrains = {}

for neuron_id, spike_times in neuron_spike_dict.items():
    # Convert list of spike times to a numpy array
    spike_times_array = np.array(spike_times)
    
    # Generate the spiketrain for this neuron
    spiketrain = get_spiketrain(spike_times_array, last_timestamp, timebin=1, sampling_rate=20000)
    
    # Store the spiketrain in the new dictionary
    neuron_spiketrains[neuron_id] = spiketrain

In [28]:
neuron_spiketrains

{85: array([0, 0, 0, ..., 0, 0, 0]),
 41: array([0, 0, 0, ..., 0, 0, 0]),
 55: array([0, 0, 0, ..., 0, 0, 0]),
 3: array([0, 0, 0, ..., 0, 0, 0]),
 58: array([0, 0, 0, ..., 0, 0, 0]),
 17: array([0, 0, 0, ..., 0, 0, 0]),
 1: array([0, 0, 0, ..., 0, 0, 0]),
 46: array([0, 0, 0, ..., 0, 0, 0]),
 24: array([0, 0, 0, ..., 0, 0, 0]),
 68: array([0, 0, 0, ..., 0, 0, 0]),
 65: array([0, 0, 0, ..., 0, 0, 0]),
 47: array([0, 0, 0, ..., 0, 0, 0]),
 22: array([0, 0, 0, ..., 0, 0, 0]),
 9: array([0, 0, 0, ..., 0, 0, 0]),
 34: array([0, 0, 0, ..., 0, 0, 0]),
 98: array([0, 0, 0, ..., 0, 0, 0]),
 94: array([0, 0, 0, ..., 0, 0, 0]),
 38: array([0, 0, 0, ..., 0, 0, 0]),
 66: array([0, 0, 0, ..., 0, 0, 0]),
 67: array([0, 0, 0, ..., 0, 0, 0]),
 2: array([0, 0, 0, ..., 0, 0, 0]),
 14: array([0, 0, 0, ..., 0, 0, 0]),
 33: array([0, 0, 0, ..., 0, 0, 0]),
 35: array([0, 0, 0, ..., 0, 0, 0]),
 37: array([0, 0, 0, ..., 0, 0, 0]),
 23: array([0, 0, 0, ..., 0, 0, 0]),
 13: array([0, 0, 0, ..., 0, 0, 0]),
 82: 

In [30]:
len(neuron_spiketrains[85])

3414666

In [31]:
len(neuron_spiketrains[41])

3414666

In [35]:
def get_firing_rate(spiketrain, smoothing_window, timebin):
    """
    calculates firing rate (spikes/second)

    Args (3 total, 1 required):
        spiketrain: numpy array, in timebin (ms) bins
        smoothing_window: int, default=250, smoothing average window (ms)
            min smoothing_window = 1
        timebin: int, default = 1, timebin (ms) of spiketrain

    Return (1):
        firing_rate: numpy array of firing rates in timebin sized windows

    """
    weights = np.ones(smoothing_window) / smoothing_window * 1000 / timebin
    firing_rate = np.convolve(spiketrain, weights, mode="same")

    return firing_rate

# Dictionary to store firing rates for each neuron
firing_rates = {}

# Iterate over each neuron in the neuron_spiketrains dictionary
for neuron_id, spiketrain in neuron_spiketrains.items():
    # Calculate the firing rate for each neuron
    firing_rate = get_firing_rate(spiketrain, 250, 1)
    # Store the result in the firing_rates dictionary
    firing_rates[neuron_id] = firing_rate


In [36]:
firing_rates

{85: array([4., 4., 4., ..., 0., 0., 0.]),
 41: array([4., 4., 4., ..., 0., 0., 0.]),
 55: array([4., 4., 4., ..., 4., 4., 4.]),
 3: array([0., 0., 0., ..., 4., 4., 4.]),
 58: array([0., 0., 0., ..., 0., 0., 0.]),
 17: array([0., 0., 0., ..., 0., 0., 0.]),
 1: array([0., 0., 0., ..., 0., 0., 0.]),
 46: array([0., 0., 0., ..., 0., 0., 0.]),
 24: array([0., 0., 0., ..., 0., 0., 0.]),
 68: array([0., 0., 0., ..., 0., 0., 0.]),
 65: array([0., 0., 0., ..., 0., 0., 0.]),
 47: array([0., 0., 0., ..., 0., 0., 0.]),
 22: array([0., 0., 0., ..., 8., 8., 8.]),
 9: array([0., 0., 0., ..., 4., 4., 4.]),
 34: array([0., 0., 0., ..., 0., 0., 0.]),
 98: array([0., 0., 0., ..., 4., 4., 4.]),
 94: array([0., 0., 0., ..., 0., 0., 0.]),
 38: array([0., 0., 0., ..., 0., 0., 0.]),
 66: array([0., 0., 0., ..., 0., 0., 0.]),
 67: array([0., 0., 0., ..., 0., 0., 0.]),
 2: array([0., 0., 0., ..., 0., 0., 0.]),
 14: array([ 0.,  0.,  0., ..., 12., 12., 12.]),
 33: array([0., 0., 0., ..., 4., 4., 4.]),
 35: arra

In [37]:
class EphysRecording:
    """
    A class for an ephys recording after being spike sorted and manually
    curated using phy. Ephys recording must have a phy folder.

    Attributes:
        path: str, relative path to the phy folder
            formatted as: './folder/folder/phy'
        subject: str, subject id who was being recorded
        sampling_rate: int, sampling rate of the ephys device
            in Hz, standard in the PC lab is 20,000Hz
        timestamps_var: numpy array, all spike timestamps
            of good and mua units (no noise unit-generated spikes)
        unit_array: numpy array, unit ids associated with each
            spike in the timestamps_var
        labels_dict: dict, keys are unit ids (str) and
            values are labels (str)
        unit_timestamps: dict, keys are unit ids (int), and
            values are numpy arrays of timestamps for all spikes
            from "good" units only
        spiketrain: np.array, spiketrain of number of spikes
            in a specified timebin
        unit_spiketrains: dict, spiketrains for each unit
            keys: str, unit ids
            values: np.array, number of spikes per specified timebin
        unit_firing_rates: dict, firing rates per unit
            keys: str, unit ids
            values: np.arrays, firing rate of unit in a specified timebin
                    calculated with a specified smoothing window

    Methods: (all called in __init__)
        get_unit_labels: creates labels_dict
        get_spike_specs: creates timestamps_var and unit_array
        get_unit_timestamps: creates unit_timestamps dictionary
    """

    def __init__(self, path, sampling_rate=20000):
        """
        constructs all necessary attributes for the EphysRecording object
        including creating labels_dict, timestamps_var, and a unit_timstamps
        dictionary

        Arguments (2 total):
            path: str, relative path to the phy folder
                formatted as: './folder/folder/phy'
            sampling_rate: int, default=20000; sampling rate of
                the ephys device in Hz
        Returns:
            None
        """
        self.path = path
        self.sampling_rate = sampling_rate
        self.zscored_events = {}
        self.wilcox_dfs = {}
        self.get_unit_labels()
        self.get_spike_specs()
        self.get_unit_timestamps()

    def get_unit_labels(self):
        """
        assigns self.labels_dicts as a dictionary
        with unit id (str) as key and label as values (str)
        labels: 'good', 'mua', 'noise'

        Arguments:
            None

        Returns:
            None
        """
        labels = "cluster_group.tsv"
        with open(os.path.join(self.path, labels), "r") as f:
            reader = csv.DictReader(f, delimiter="\t")
            self.labels_dict = {
                row["cluster_id"]: row["group"] for row in reader
            }

    def get_spike_specs(self):
        """
        imports spike_time and spike_unit from phy folder
        deletes spikes from units labeled noise in unit and timestamp array
        and assigns self.timstamps_var (numpy array)
        as the remaining timestamps and assigns self.unit_array
        (numpy array) as the unit ids associated with each spike

        Args:
            None

        Returns:
            None
        """
        timestamps = "spike_times.npy"
        unit = "spike_clusters.npy"
        timestamps_var = np.load(os.path.join(self.path, timestamps))
      
        print(type(timestamps_var))
        unit_array = np.load(os.path.join(self.path, unit))
        print(type(unit_array))
        spikes_to_delete = []
        unsorted_clusters = {}
        for spike in range(len(timestamps_var)):
            try:
                if self.labels_dict[unit_array[spike].astype(str)] == "noise":
                    spikes_to_delete.append(spike)
            except KeyError:
                spikes_to_delete.append(spike)
                if unit_array[spike] in unsorted_clusters.keys():
                    total_spikes = unsorted_clusters[unit_array[spike]]
                    total_spikes = total_spikes + 1
                    unsorted_clusters[unit_array[spike]] = total_spikes
                else:
                    unsorted_clusters[unit_array[spike]] = 1
        for unit, no_spike in unsorted_clusters.items():
            print(
                f"Unit {unit} is unsorted & has {no_spike} spikes"
            )
            print(
                f"Unit {unit} will be deleted"
            )
        self.timestamps_var = np.delete(timestamps_var, spikes_to_delete)
        self.unit_array = np.delete(unit_array, spikes_to_delete)

    def get_unit_timestamps(self):
        """
        creates a dictionary of units to spike timestamps
        keys are unit ids (int) and values are spike timestamps for
        that unit (numpy arrays)and assigns dictionary to self.unit_timestamps

        Args:
            None

        Return:
            None
        """

        unit_timestamps = {}
        for spike in range(len(self.timestamps_var)):
            if self.unit_array[spike] in unit_timestamps.keys():
                timestamp_list = unit_timestamps[self.unit_array[spike]]
                timestamp_list = np.append(
                    timestamp_list, self.timestamps_var[spike]
                )
                unit_timestamps[self.unit_array[spike]] = timestamp_list
            else:
                unit_timestamps[self.unit_array[spike]] = self.timestamps_var[
                    spike
                ]

        self.unit_timestamps = unit_timestamps

In [42]:
yourObject = EphysRecording('./')

<class 'numpy.ndarray'>
<class 'numpy.ndarray'>


In [43]:
unit_labels = yourObject.labels_dict
print(unit_labels)

{'1': 'noise', '2': 'good', '3': 'noise', '5': 'good', '9': 'unsorted', '11': 'noise', '12': 'noise', '13': 'good', '14': 'noise', '17': 'good', '18': 'good', '19': 'good', '20': 'noise', '22': 'mua', '23': 'good', '24': 'good', '26': 'good', '29': 'mua', '30': 'mua', '31': 'mua', '32': 'noise', '33': 'good', '34': 'mua', '35': 'mua', '36': 'good', '37': 'noise', '38': 'mua', '39': 'good', '40': 'noise', '41': 'noise', '42': 'noise', '44': 'mua', '45': 'noise', '46': 'noise', '47': 'noise', '48': 'noise', '53': 'mua', '54': 'noise', '55': 'good', '56': 'mua', '57': 'good', '58': 'mua', '59': 'good', '65': 'good', '66': 'noise', '67': 'noise', '68': 'good', '81': 'good', '82': 'mua', '85': 'good', '87': 'noise', '88': 'good', '94': 'good', '95': 'noise', '98': 'good', '99': 'noise'}


In [44]:
spike_timestamps = yourObject.timestamps_var
unit_ids = yourObject.unit_array


In [45]:
spike_timestamps

array([     689,     1829,     2935, ..., 68292947, 68292976, 68293168])

In [46]:
unit_ids

array([85, 55, 58, ..., 19, 35, 26], dtype=int32)

In [47]:
unit_specific_timestamps = yourObject.unit_timestamps

In [48]:
unit_specific_timestamps

{85: array([     689,     3542,     6757, ..., 67999904, 68154946, 68261070]),
 55: array([    1829,     5260,     6148, ..., 68286888, 68289021, 68291330]),
 58: array([    2935,     3071,    13522, ..., 68275795, 68286245, 68288966]),
 17: array([    3496,     4288,     5675, ..., 68276326, 68276771, 68286188]),
 24: array([    3820,     4215,    24136, ..., 68255449, 68255762, 68255898]),
 68: array([    4036,    68706,    74700, ..., 68159155, 68177080, 68183699]),
 65: array([    4328,     6304,    23963, ..., 68244852, 68245601, 68255391]),
 22: array([    4451,    13081,    20294, ..., 68290504, 68291540, 68292754]),
 9: array([    4767,    89292,   116038, ..., 68254564, 68284836, 68291234]),
 34: array([    4988,     5489,     6344, ..., 68279667, 68283153, 68285848]),
 98: array([    5337,    29951,    80915, ..., 68187930, 68196317, 68292174]),
 94: array([    5368,    13507,    25909, ..., 68217509, 68279774, 68282156]),
 38: array([    5418,     5648,    23411, ..., 682601

In [49]:
yourObject.path

'./'