<font size=7 face="courier">Organoid Training Source Code

This is the code used to create the diagrams in the notebook, `Organoid_Training_Hypothesis.ipynb`.

In [None]:
print("Loading: organoid_training_source_code.ipynb...")

# Set Up Notebook

import stuff

In [None]:
#!pip install powerlaw
#!pip install scipy

#!pip install smart_open
#!pip install awswrangler
#!pip install deprecated
#!pip install nptyping

#!pip install pygame
#!pip install gymnasium
#!pip install ipywidgets

In [None]:
#!pip install gymnasium
#!pip install pygame

Import Packages

In [1]:
from food_land import FoodLandEnv
import numpy as np
import csv
import sys
import time
import gymnasium as gym
from datetime import datetime
import pygame
import pandas as pd
import matplotlib.pyplot as plt

pygame 2.5.2 (SDL 2.28.2, Python 3.10.0)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
import pytz
import braingeneers

from braingeneers.analysis.analysis import SpikeData, read_phy_files
from braingeneers.analysis import load_spike_data
import braingeneers.data.datasets_electrophysiology as ephys
import scipy.io as sio
import scipy
from scipy.ndimage import gaussian_filter1d
import glob
import pandas as pd
from ipywidgets import interact, interactive, fixed, interact_manual
import random
#import matplotlib.patches as mpatches

In [3]:
Timezone = pytz.timezone("America/Los_Angeles")

# Organoid Training Functions

In [4]:
def compute_smart_action(observation):
    """Current observation is:
    >>> observation = np.array([self.food_signal, self.spike_signal,
                         self.agent_pos[0], self.agent_pos[1],
                         self.agent_dir], dtype=np.float32)

    Food signal acts as a "smell" signal, same with spike signal.
    We want to move towards food and avoid spikes (if spikes are used)
    
    Lets move fast if food signal is small, and turn fast if food signal is large.
    
    Actions are:
        action[0] is within [-1, 1], turning speed (CCW to CW)
        action[1] is within [0, 1], moving speed (stop to forwards)
    """
    food_signal = observation[0]
    spike_signal = observation[1]
    action = np.zeros(2, dtype=np.float32)
    
    
    # Food is close, turn to exploit
    if food_signal > 0.1:
        action[0] = np.random.uniform(-1, 1)
        action[1] = np.random.uniform(.1, .5)
        # print("Close  ", action)

    # Move fast/explore far
    else:
        action[0] = np.random.uniform(-.3, .3)
        action[1] = np.random.uniform(0, 1)
        # print("Far", action)
        
    return action

In [5]:
def action_analysis(env):
    observations = []
    cur_reward = 0
    rewards = []
    action = np.zeros(2, dtype=np.float32)
    n_steps = 10000

    step = 1
    trial_num = 0
    trial_complete = False
    
    observation = env.reset()

    

    while not trial_complete:
        # No action
        # action = np.zeros(2, dtype=np.float32)
        # Random action
        # action = np.random.uniform(-1, 1, 2)
        # Smart action
        action = compute_smart_action(observation)



        observation, reward, done, info = env.step(action)
        observations.append(observation)
        cur_reward += reward

        env.render()

        if done :
            rewards.append(cur_reward)
            # print(f"Trial {trial_num}: Step {step} === Episode finished. Reward: {cur_reward}")
            # print("All rewards:", rewards)

            if step >= n_steps:
                trial_complete = True
                break
            cur_reward = 0
            env.reset()
            trial_num += 1


        step += 1


    observations = np.array(observations).T

    fig, ax = plt.subplots(3, 1, figsize=(15, 15))
    ax[0].plot(rewards)
    ax[0].set_xlabel('Episode')
    ax[0].set_ylabel('Reward')
    ax[0].set_title('Reward per Episode')

    ax[1].plot(observations[2], observations[3])
    ax[1].set_xlabel('x')
    ax[1].set_ylabel('y')
    ax[1].set_title('Agent Position')

    ax[2].plot(observations[0], label='Food Signal')
    ax[2].plot(observations[4], label='Agent Direction')
    ax[2].set_xlabel('Step')
    ax[2].set_ylabel('Signal')
    ax[2].set_title('Food Signal and Agent Direction')
    ax[2].legend()




    plt.show()
    return trial_num

In [6]:
def plot_CC(organoid):
    cc = organoid.get_dummy_causal_connectivity() # get our causal connectivity matrix
    plt.imshow(cc, cmap='Greens')
    plt.colorbar()
    # X and y ticks
    plt.xticks(np.arange(organoid.N), labels=organoid.neuron_labels, rotation=45)
    plt.yticks(np.arange(organoid.N), labels=organoid.neuron_labels)
    plt.title('Causal Connectivity Matrix')

In [7]:
def plot_spike_count(organoid):
    organoid.get_fake_motor_spikes()

    # Gather 100
    spikes = []
    for i in range(100):
        spikes.append(organoid.get_fake_motor_spikes())

    spikes = np.array(spikes)
    # plot
    plt.plot(spikes)
    plt.xlabel('Time')
    plt.ylabel('Spike Count')
    plt.title('Motor Neuron Spike Count over Time')
    plt.legend(organoid.neuron_labels)

# Toy interface class

Fake organoid interface class to help with testing, nearly identitical to how an actual organoid would take inputs/give outputs

In [8]:
q = np.random.uniform(-1, 1, (100,3))
# q[:5]

In [9]:
# Fake class to test the method
class OrganoidInterface:
    def __init__(self, game_env, verbose=False, test_file_name='test', n_episodes=10):

        class DummyEnv:
            def __init__(self):
                self.save_file = None
                self.reset = lambda : None
                # self.step = lambda a, b,c : return None, None, None, None
                # step function takes in 2 arguments, action and tag, and returns obs, done
                self.step = lambda action=None, tag=None : (None, None)

        self.game_env = game_env
        # Env is dummy for now
        self.env = DummyEnv() # Dummy maxwell interface environment

        self.save_file = test_file_name
        self.sensory_neurons = np.arange(3,10) # Max 10 sensory neurons to stimulate
        self.sensory_stim_Hz = np.zeros(len(self.sensory_neurons)) # Rate of stimulation

        self.training_neurons = np.arange(5,12) # Max 10 training neurons to stimulate

        self.motor_neurons = np.arange(3)# Indices of motor neurons
        self.motor_spike_count = np.zeros(len(self.motor_neurons)) # Stores the spike count for each motor neuron
        self.motor_spike_rate = np.zeros(len(self.motor_neurons)) # Stores the spike rate for each motor neuron

        self.verbose = verbose


        self.fake_spikes = np.random.uniform(0, 1, (1000,3))
        self.fake_spikes[self.fake_spikes < 0.9] = 0
        self.fake_spikes[self.fake_spikes > 0] = 1
        self.state = 'read'

        self.read_period_ms = 200 # Time between simulation steps
        self.train_period_ms = 200 # Time to give training signal
        self.wait_period_ms = 1000 # Time to wait after training

        self.rewards = []
        self.episode = 0
        self.n_episodes = n_episodes
        self.max_time = 1000

        # Set connectivity matrix
        self.connectivity = None
        self.connectivity = self.get_dummy_causal_connectivity()

        self.neuron_labels = []
        for i in range(self.N):
            cur_label = ''
            if i in self.sensory_neurons:
                cur_label += 'S'
            if i in self.training_neurons:
                cur_label += 'T'
            if i in self.motor_neurons:
                cur_label += 'M'
            cur_label += str(i)
            self.neuron_labels.append(cur_label)
            

        

    def time_elapsed(self):
        return time.perf_counter() - self.start_time

    def run(self):
        done = False
        just_entered = True
        total_reward = 0
        do_training = True

        # Loggers
        # ---- Game log ----
        game_log = open(self.save_file + '_game_log.csv','w')
        # Header for csv
        game_logger = csv.writer(game_log)
        game_logger.writerow(['time','food_signal', 'spike_signal', 'agent_pos_x', 'agent_pos_y',
                                'agent_dir', 'food_got', 'spike_hit', 'reward',
                                              'episode', 'moving_speed', 'turning_speed'])
    
        # ---- Training log ----
        train_log = open(self.save_file + '_train_log.csv','w')
        # Header for csv
        train_logger = csv.writer(train_log)
        train_logger.writerow(['time','pattern','reward', 'episode'])


        self.start_time = time.perf_counter()
       
        # ================== Main Loop ==================
        while not done:
            # ~~~~~~~~~~~~~~~~~~~ Read phase Logic ~~~~~~~~~~~~~~~~~~~
            if self.state == 'read':
                # ================== Enter read phase ==================
                if just_entered:
                    just_entered = False
                    
                # ================== Sensory stimulation logic ==================
                if self.verbose:
                    print("Stimming sensory neurons with", self.sensory_stim_Hz)

                # ================== Spike Read Logic ==================
                self.motor_spike_count = self.get_fake_motor_spikes()

             # ================== Exit read phase ==================
             # This would happen after the read period is up
                if self.verbose:
                    print('Spike count: ', self.motor_spike_count, f'Rate {self.motor_spike_rate[0]:.2f} || {self.motor_spike_rate[1]:.2f}')
                self.state = 'game'
                just_entered = True
                continue
                    
            # ~~~~~~~~~~~~~~~~~~~ Game phase Logic ~~~~~~~~~~~~~~~~~~~
            elif self.state == 'game':
                if just_entered:
                    game_action = self.get_motor_signal()
                    self.game_obs,reward,game_done,inf = self.game_env.step(game_action)
                    
                    total_reward += reward


                    # Converting the game observation to sensory stimulation
                    self.set_sensory_signal(self.game_obs)
                    self.cause_fake_motor_spikes()

                    # Log the game observation
                    game_logger.writerow([time.time(), *self.game_obs, reward, self.episode, game_action[0], game_action[1]])

                    # If game is done, go to train phase, else continue reading
                    if game_done:
                        self.state = 'train'
                        just_entered = True
                        continue
                    else:
                        self.state = 'read'
                        just_entered = True
                        continue

            # ~~~~~~~~~~~~~~~~~~~ Train phase Logic ~~~~~~~~~~~~~~~~~~~
            elif self.state == 'train':
                # ================== Enter train phase ==================
                if just_entered:
                    just_entered = False
                    self.rewards.append(total_reward)
                    do_training = True # Train 100% of the time
                   

                    self.episode += 1
                    if self.episode >= self.n_episodes:
                        done = True
                        return

                    if self.verbose:
                        print('Reward:', total_reward)
                        print('All rewards:', self.rewards)
                        print('-'*20)                        
                    # ================== Set Training pulses ==================
                    if do_training:
                        # Get pattern/updategit 
                        train_pulse, train_freq = self.get_training_signal()
                        # This modifies the connectivity
                        # self.update_connectivity(train_pulse)

                        obs, done = self.env.step(action=train_pulse, tag='train') 
                        # Log the pattern
                        train_logger.writerow([self.time_elapsed(), train_pulse, total_reward, self.episode])
                        continue
                       
                # ================== Training logic ==================
                # There would be logic on stimming the training pulse at the proper frequency
                if self.verbose:
                    print("Stimming training neurons with", train_freq)
                    print("Training pulse:", train_pulse)
                    print("\tfor", self.train_period_ms, "ms")

                
                # ================== Exit train phase ==================
                # This would happen after the train period is up
                self.game_obs = self.game_env.reset()
                self.set_sensory_signal(self.game_obs)
                total_reward = 0
                self.state = 'wait'
                just_entered = True
                continue

                

            # ~~~~~~~~~~~~~~~~~~~ Wait phase Logic ~~~~~~~~~~~~~~~~~~~
            elif self.state == 'wait':
                # There would be logic on waiting for the proper time
                obs, done = self.env.step()
                if self.verbose:
                    print("Waiting for", self.wait_period_ms, "ms")
                if self.verbose:
                    print("Episode done")
                self.state = 'read'
                just_entered = True
                continue


        # Close the loggers
        game_log.close()
        train_log.close()
        

        return True
    
    def get_fake_motor_spikes(self):
        
        # use connectiviy matrix to get the spikes from sensory_stim_Hz
        
        # sensory_stim_Hz is a vector of shape (len(sensory_neurons),)
        # motor spike count is of shape (len(motor_neurons),)
        # connectivity is of shape (N, N), where 
        sub_connectivity = self.connectivity[self.sensory_neurons][:,self.motor_neurons]
        # mat mult plus random
        motor_spikes = np.dot(self.sensory_stim_Hz, sub_connectivity)*2

        # 1/4 power
        motor_spikes = motor_spikes ** 0.4

        # cast to int
        motor_spikes = motor_spikes.astype(int)
        # add some noise
        motor_spikes += np.random.randint(-2, 3, len(self.motor_neurons))
        # clip to 0
        motor_spikes = np.clip(motor_spikes, 0, None)
        self.motor_spike_rate = motor_spikes
        return motor_spikes
    
    def cause_fake_motor_spikes(self):
        self.motor_spike_count = self.get_fake_motor_spikes()
        self.motor_spike_rate = self.motor_spike_count / 1000


    def get_dummy_causal_connectivity(self):
        if self.connectivity is not None:
            return self.connectivity
        n_strong = 10
        # Get unique neurons in sensory and training, since there can be overlap
        unique_neurons = np.unique(np.concatenate([self.sensory_neurons, self.training_neurons]))
        self.N = len(unique_neurons) + len(self.motor_neurons)
        # Create a random connectivity matrix, most are 0-.3, some are >.6
        connectivity = np.random.uniform(0, .1, (self.N, self.N))
        # connectivity = np.zeros((self.N, self.N))
        # pick 5 random indexes and make them >.6
        connectivity[np.random.randint(0, self.N, n_strong), np.random.randint(0, self.N, n_strong)] = np.random.uniform(.6, 1, n_strong)
        self.connectivity = connectivity
        return connectivity



    def set_sensory_signal(self, obs):
        raise NotImplementedError
    
    def get_motor_signal(self):
        raise NotImplementedError
    
    def get_training_signal(self):
        raise NotImplementedError

    def set_sensory_function(self, func):
        '''
        Set the function to use to map the game observation to the sensory neurons
        '''
        self.set_sensory_signal = lambda game_env_obs: func(self, game_env_obs)

    def set_motor_function(self, func):
        '''
        Set the function to use to map the motor neurons to the action
        '''
        self.get_motor_signal = lambda : func(self)

    def set_training_function(self, func):
        '''
        Set the function to use to map the training neurons to the action
        '''
        self.get_training_signal = lambda : func(self)


In [None]:
#organoid.connectivity.shape

# Playback plots

In [10]:
def game_plots(glog):
    # Plot the agent pos, a different color for each episode
    fig, ax = plt.subplots(2, 1, figsize=(10,9))
    for episode in glog['episode'].unique():
        ep_log = glog[glog['episode'] == episode]
        ax[0].plot(ep_log['agent_pos_x'], ep_log['agent_pos_y'], label=f'Episode {episode}')
    ax[0].set_xlabel('x')
    ax[0].set_ylabel('y')
    ax[0].set_title('Agent Position')
    # ax[0].legend()

    # Plot the reward per episode

    reward_for_plots = np.array(glog['reward'], dtype=object)
    #ax[1].plot(game_log['reward'])
    ax[1].plot(reward_for_plots)
    ax[1].set_xlabel('Episode')
    ax[1].set_ylabel('Reward')
    ax[1].set_title('Reward per Episode')

    plt.show()

In [11]:
def speed_plots(glog):
    # Plot turning and forward speed
    fig, ax = plt.subplots(2, 1, figsize=(10,9))

    turning_speed_for_plots = np.array(glog['turning_speed'], dtype=object)
    moving_speed_for_plots = np.array(glog['moving_speed'], dtype=object)

    #ax[0].plot(game_log['turning_speed'], label='Turn Speed')
    ax[0].plot(turning_speed_for_plots, label='Turn Speed')
    ax[0].set_xlabel('Step')
    ax[0].set_ylabel('Turn Speed')
    ax[0].set_title('Turn Speed')
    ax[0].legend()

    #ax[1].plot(game_log['moving_speed'], label='Forward Speed')
    ax[1].plot(moving_speed_for_plots, label='Forward Speed')
    ax[1].set_xlabel('Step')
    ax[1].set_ylabel('Forward Speed')
    ax[1].set_title('Forward Speed')
    ax[1].legend()

    plt.show()

# Load Data

## <font color="blue">Data Loader

<font color="red"><b>Warning:</b> This code gets very messy. I do not reccomend reading it. You do not need to know it for the HW.

In [12]:
import io
import zipfile
from typing import List, Tuple

def read_phy_files(path: str, fs=20000.0):
    """
    :param path: a s3 or local path to a zip of phy files.
    :return: SpikeData class with a list of spike time lists and neuron_data.
            neuron_data = {0: neuron_dict, 1: config_dict}
            neuron_dict = {"new_cluster_id": {"channel": c, "position": (x, y),
                            "amplitudes": [a0, a1, an], "template": [t0, t1, tn],
                            "neighbor_channels": [c0, c1, cn],
                            "neighbor_positions": [(x0, y0), (x1, y1), (xn,yn)],
                            "neighbor_templates": [[t00, t01, t0n], [tn0, tn1, tnn]}}
            config_dict = {chn: pos}
    """
    assert path[-3:] == 'zip', 'Only zip files supported!'
    import braingeneers.utils.smart_open_braingeneers as smart_open
    with smart_open.open(path, 'rb') as f0:
        f = io.BytesIO(f0.read())

        with zipfile.ZipFile(f, 'r') as f_zip:
            assert 'params.py' in f_zip.namelist(), "Wrong spike sorting output."
            with io.TextIOWrapper(f_zip.open('params.py'), encoding='utf-8') as params:
                for line in params:
                    if "sample_rate" in line:
                        fs = float(line.split()[-1])
            clusters = np.load(f_zip.open('spike_clusters.npy')).squeeze()
            templates = np.load(f_zip.open('templates.npy'))  # (cluster_id, samples, channel_id)
            channels = np.load(f_zip.open('channel_map.npy')).squeeze()
            templates_w = np.load(f_zip.open('templates.npy'))
            wmi = np.load(f_zip.open('whitening_mat_inv.npy'))
            spike_templates = np.load(f_zip.open('spike_templates.npy')).squeeze()
            spike_times = np.load(f_zip.open('spike_times.npy')).squeeze() / fs * 1e3  # in ms
            positions = np.load(f_zip.open('channel_positions.npy'))
            amplitudes = np.load(f_zip.open("amplitudes.npy")).squeeze()
            if 'cluster_info.tsv' in f_zip.namelist():
                cluster_info = pd.read_csv(f_zip.open('cluster_info.tsv'), sep='\t')
                cluster_id = np.array(cluster_info['cluster_id'])
                # select clusters using curation label, remove units labeled as "noise"
                # find the best channel by amplitude
                labeled_clusters = cluster_id[cluster_info['group'] != "noise"]
            else:
                labeled_clusters = np.unique(clusters)

    df = pd.DataFrame({"clusters": clusters, "spikeTimes": spike_times, "amplitudes": amplitudes})
    cluster_agg = df.groupby("clusters").agg({"spikeTimes": lambda x: list(x),
                                              "amplitudes": lambda x: list(x)})
    cluster_agg = cluster_agg[cluster_agg.index.isin(labeled_clusters)]

    cls_temp = dict(zip(clusters, spike_templates))
    neuron_dict = dict.fromkeys(np.arange(len(labeled_clusters)), None)

    # un-whitten the templates before finding the best channel
    templates = np.dot(templates_w, wmi)

    neuron_attributes = []
    for i in range(len(labeled_clusters)):
        c = labeled_clusters[i]
        temp = templates[cls_temp[c]]
        amp = np.max(temp, axis=0) - np.min(temp, axis=0)
        sorted_idx = [ind for _, ind in sorted(zip(amp, np.arange(len(amp))))]
        nbgh_chan_idx = sorted_idx[::-1][:12]
        nbgh_temps = temp.transpose()[nbgh_chan_idx]
        best_chan_temp = nbgh_temps[0]
        nbgh_channels = channels[nbgh_chan_idx]
        nbgh_postions = [tuple(positions[idx]) for idx in nbgh_chan_idx]
        best_channel = nbgh_channels[0]
        best_position = nbgh_postions[0]
        # neighbor_templates = dict(zip(nbgh_postions, nbgh_temps))
        cls_amp = cluster_agg["amplitudes"][c]
        neuron_dict[i] = {"cluster_id": c, "channel": best_channel, "position": best_position,
                          "amplitudes": cls_amp, "template": best_chan_temp,
                          "neighbor_channels": nbgh_channels, "neighbor_positions": nbgh_postions,
                          "neighbor_templates": nbgh_temps}
        neuron_attributes.append(
            NeuronAttributes(
                cluster_id=c,
                channel=best_channel,
                position=best_position,
                amplitudes=cluster_agg["amplitudes"][c],
                template=best_chan_temp,
                templates=templates[cls_temp[c]].T,
                label=cluster_info['group'][cluster_info['cluster_id'] == c].values[0],
                neighbor_channels=channels[nbgh_chan_idx],
                neighbor_positions=[tuple(positions[idx]) for idx in nbgh_chan_idx],
                neighbor_templates=[templates[cls_temp[c]].T[n] for n in nbgh_chan_idx]
            )
        )

    config_dict = dict(zip(channels, positions))
    neuron_data = {0: neuron_dict}
    metadata = {0: config_dict}
    spikedata = SpikeData(list(cluster_agg["spikeTimes"]), neuron_data=neuron_data, metadata=metadata, neuron_attributes=neuron_attributes)
    return spikedata

class NeuronAttributes:
    cluster_id: int
    channel: np.ndarray
    position: Tuple[float, float]
    amplitudes: List[float]
    template: np.ndarray
    templates: np.ndarray
    label: str

    # These lists are the same length and correspond to each other
    neighbor_channels: np.ndarray
    neighbor_positions: List[Tuple[float, float]]
    neighbor_templates: List[np.ndarray]

    def __init__(self, *args, **kwargs):
        self.cluster_id = kwargs.pop("cluster_id")
        self.channel = kwargs.pop("channel")
        self.position = kwargs.pop("position")
        self.amplitudes = kwargs.pop("amplitudes")
        self.template = kwargs.pop("template")
        self.templates = kwargs.pop("templates")
        self.label = kwargs.pop("label")
        self.neighbor_channels = kwargs.pop("neighbor_channels")
        self.neighbor_positions = kwargs.pop("neighbor_positions")
        self.neighbor_templates = kwargs.pop("neighbor_templates")
        for key, value in kwargs.items():
            setattr(self, key, value)

    def add_attribute(self, key, value):
        setattr(self, key, value)

    def list_attributes(self):
        return [attr for attr in dir(self) if not attr.startswith('__') and not callable(getattr(self, attr))]

## Load Data

In [13]:
### Current method for loading data (doesn't work)
braingeneers.set_default_endpoint("/home/jovyan/data/ephys/2023-08-28-e-Math_Mind_Class/HW3-Experiment")
#sd = load_spike_data(uuid="2023-08-28-e-Math_Mind_Class",
#                     full_path="/home/jovyan/data/ephys/2023-08-28-e-Math_Mind_Class/HW3-Experiment/ephys/2022-04-28-e-/derived/kilosort2/Trace_20220428_15_52_47_chip11350_curated.zip")
metadata = ephys.load_metadata('2022-04-28-e-')

Get 2D layout of all electrodes

In [14]:
connectoid = read_phy_files("/home/jovyan/data/ephys/2023-08-28-e-Math_Mind_Class/HW3-Experiment/ephys/2022-04-28-e-/derived/kilosort2/Trace_20220428_15_52_47_chip11350_curated.zip")

In [15]:
electrode_mapping = pd.DataFrame( metadata['ephys_experiments']['Trace_20220428_15_52_47_chip11350']['mapping'],
                                  columns=['channel','electrode','x','y'] )

# <font color="blue"> Pretty Raster Plotter

Detailed raster plot function, allows for specified zooms and window frames

In [16]:
def plot_raster(sd, title="Spike Raster", l1=-10, l2=False, xsize=10, ysize=6, analize=False):
    """
    Plots a configuable raster plot of the spike data.
        sd : spike data object from braingeneers
        title : Title of the plot
        l1 : start time in seconds
        l2 : end time in seconds
        xsize : width of the plot
        ysize : height of the plot
        analize : If True, will plot the population rate as well
    """

    if l2==False:
        l2 = sd.length / 1000 + 10
    
    idces, times = sd.idces_times()
    
    if analize == True:
        # Get population rate for everything
        pop_rate = sd.binned(bin_size=1)  # in ms
        # Lets smooth this to make it neater
        sigma = 5
        pop_rate_smooth = gaussian_filter1d(pop_rate.astype(float), sigma=sigma)
        t = np.linspace(0, sd.length, pop_rate.shape[0]) / 1000

        # Determine the stop_time if it's not provided
        if l2 is None:
            l2 = t[-1]

        # Filter times and idces within the specified start and stop times
        mask = (times >= l1 * 1000) & (times <= l2 * 1000)
        times = times[mask]
        idces = idces[mask]

    fig, ax = plt.subplots(figsize=(xsize, ysize))
    fig.suptitle(title)
    ax.scatter(times/1000,idces,marker='|',s=1)
    
    if analize == True:
        ax2 = ax.twinx()
        ax2.plot(t, pop_rate_smooth, c='r')
        ax2.set_ylabel('Firing Rate')
        
    ax.set_xlabel("Time(s)")
    ax.set_ylabel('Unit #')
    plt.xlim(l1, l2)
    plt.show()

In [None]:
# plot_raster(connectoid, title="example raster")

# <font color="blue">electrode_layout</font> 

plots the placement of all the recording electrodes

In [17]:
def electrodeLayout():
    x = electrode_mapping.x.values
    y = electrode_mapping.y.values

    plt.scatter(x,y,s=2)
    plt.xlabel('um')
    plt.ylabel('um')
    plt.title("electrode layout")
    plt.show()

In [None]:
#electrodeLayout()

# <font color="blue">neuronLayout</font> 

Plots the putative neural units found from spike sorting

In [18]:
def neuronLayout(sd):
    x = electrode_mapping.x.values
    y = electrode_mapping.y.values

    plt.scatter(x,y,s=2)

    neuron_x = []
    neuron_y = []
    neuron_amp = []
    for neuron in sd.neuron_attributes:
    #     print("x,y:",neuron['position'])
        neuron_x.append(neuron.position[0])
        neuron_y.append(neuron.position[1])
        neuron_amp.append(np.mean(neuron.amplitudes))

    plt.scatter(neuron_x,neuron_y,s=neuron_amp,c='r')
    plt.xlabel('um')
    plt.ylabel('um')
    plt.title("electrode layout")
    plt.show()

In [None]:
#neuronLayout(connectoid)

# <font color="blue">sttcLayout

Show which neurons are connected via spike time tiling

In [19]:
def sttcLayout( sd, threshold ):
    # Lets bring back our graph, and overlay these as connectivity lines

    sttc = sd.spike_time_tilings()
    # electrodes
    x = electrode_mapping.x.values
    y = electrode_mapping.y.values
    neuron_x = []
    neuron_y = []
    neuron_amp = []
    for neuron in sd.neuron_attributes:
        #     print("x,y:",neuron['position'])
        neuron_x.append(neuron.position[0])
        neuron_y.append(neuron.position[1])
        neuron_amp.append(np.mean(neuron.amplitudes))
    
    plt.figure(figsize=(15,6))
    plt.subplot(1, 2, 1)
    plt.scatter(x,y,s=2)
    # Neurons
    plt.scatter(neuron_x,neuron_y,s=neuron_amp,c='r')

    # Now lines for each neuron if above threshhold
    for i in range(sttc.shape[0]):
        for j in range(sttc.shape[1]):

            # Only need to do upper triangle since sttc' = sttc
            if i<=j: continue

            if sttc[i,j] < threshold : continue

            #Position of neuron i
            ix,iy = sd.neuron_attributes[i].position
            jx,jy = sd.neuron_attributes[j].position

            # Plot line between the points, linewidth is the sttc
            plt.plot([ix,jx],[iy,jy], linewidth=sttc[i,j],c='k')

    plt.xlabel('um')
    plt.ylabel('um')
    plt.title("electrode layout")
    
    

    plt.subplot(1, 2, 2)
    sttc[sttc<threshold ] = 0
    plt.imshow(sttc,vmin=0,vmax=1)
    plt.colorbar()
    plt.title("Spike Time Tiling")
    plt.show()

In [None]:
#sttcLayout(connectoid, .8)

# <font color="blue">Latency Plots

Displays two plots. The left plot shows which two points the latencies are being calculated between. The right plot is a histogram of all the latencies between the points.


In [21]:
def latencyPlots(sd, neuron1, neuron2):
    
    # Create plot of neural units with selected neurons for latency calculation in green
    plt.figure(figsize=(15,6))
    plt.subplot(1, 2, 1)
    
    x = electrode_mapping.x.values
    y = electrode_mapping.y.values
    plt.scatter(x,y,s=2)

    neuron_x = []
    neuron_y = []
    neuron_amp = []
    for neuron in sd.neuron_attributes:
    #     print("x,y:",neuron['position'])
        neuron_x.append(neuron.position[0])
        neuron_y.append(neuron.position[1])
        neuron_amp.append(np.mean(neuron.amplitudes))

    plt.scatter(neuron_x,neuron_y,s=neuron_amp,c='r')
    plt.scatter(neuron_x,neuron_y,s=neuron_amp,c='r')

    plt.scatter([neuron_x[neuron1]],[neuron_y[neuron1]],s=70,c='g')
    plt.scatter([neuron_x[neuron2]],[neuron_y[neuron2]],s=70,c='g')

    plt.xlabel('um')
    plt.ylabel('um')
    plt.title("electrode layout")
    #plt.show()
    
    # Plot Histogram of latencies
    plt.subplot(1, 2, 2)
    lates_n1 = sd.latencies_to_index(neuron1)
    lates_raw = lates_n1[neuron2]
    lates =  list(filter(lambda x: x < 10 and x > -10, lates_raw ))
    plt.hist(lates,bins=8)
    #plt.xlim([-25,25])
    print("Mean Latency: ", np.mean(lates) )
    print("Number of Latencies: ", len(lates))
    plt.show()

# <font color="blue">Stim Pulses

## <font color="red">Legacy

Prebuilt function used to plot the square wave stimulation example shown in the notebook :

In [22]:
def plotPulse():
    t,signal =  getTimeSignal(60)
    signal = squareWave( signal, time_s=1, phase_us=200, amp_mV=150 )

    plt.plot(t,signal)
    plt.xlim([.998,1.002])
    plt.xlabel('seconds')
    plt.ylabel('Voltage (mv)')
    plt.show()

## <font color="blue"> General Stim Functions

In [23]:
def plotStimPattern(signal,t):
    plt.plot(t,signal.T)
    plt.legend(['n0','n1','n2'])
    plt.xlabel('seconds')
    plt.ylabel('Voltage (mv)')
    plt.title("Stim Pattern over 1 Second")
    plt.show()

In [24]:
def plotIndividualPatterns(signal,t):
    names=["n0","n1","n2"]
    color=["b","#F97306","g"]
    plt.figure(figsize=(16,4))
    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.plot(t,signal[i], color=color[i])
        plt.legend([f"n{i}"])
        plt.ylabel('Voltage (mv)')
        plt.xlabel('seconds')
        plt.show

A helper function used in `create_stim_pulse_sequence`

In [25]:
def insert_square_wave_advanced(sig,time,phase_length=20, amplitude=1):
    '''Deflects positively at *time* for 20 samples (1ms) then deflects negatively for 20 samples, 
    then returns to normal'''
    sig[time:time+phase_length] = amplitude
    sig[time+phase_length:time+phase_length*2] = -amplitude
    return sig

Main function used to create stim pulse sequences :

In [26]:
def create_stim_pulse_sequence(stim_list, freq_Hz=None, time_arr = None,max_time_s = 1):
    """
    Creates a stim pulse sequence, parallels the real code that will run on the Maxwell
    
    Params:
    stim_list - list of tuples indicating the commands to run
            ------------------------------------------------
            For 'stim' command:
            ('stim', [neuron inds], mv, us per phase)

            For 'delay'
            ('delay', frames_delay)
            
            For 'next'
            ('next', None)
            This command acts as a placeholder to move to the next timepoint in the time_arr or the next
            period triggered by the freq_Hz
            -------------------------------------------------
    freq_Hz - frequency to call the top stim_list in
            *Note* this takes priority over time_arr
    time_arr - array of time values that will be when the stimulations occur in order
    
    max_time_s - time in seconds to stimulate
    
    
    Returns:
    sig - np.array -- shape=(n_neurons, timesteps) of what the signal will look like
    t - np.array   -- shape=(timesteps) of time in seconds
    """
    
    # Since this is fake, only parallels the code on the device,
    # We have to make our own time
    # And simulate what the stimulation will look like
    
    # Conversion paramerters
    fs_ms = 20 # Good for converting frames to ms
    fs_us = .2
    
    # stim list
    seq = ('next',None)
    stim_list.append(seq)
    stim_list=stim_list*freq_Hz
    
    # Setup
    fs = 20000
    n_neurons = 3
    t = np.arange(0,fs*max_time_s)/fs
    sig = np.zeros(shape=(n_neurons,t.shape[0]))
    
    
    # This would be generated in *real time*
    if freq_Hz is not None:
        # Until the time is right to stimulation the sequence
        for time in t[::fs//freq_Hz]:
            time_frames = int(time*fs)
            
            
            if len(stim_list) == 0:
                return sig
            
            #Build the sequence
            command = None
            
            while (command != 'next'):
                command, *params = stim_list.pop(0) # Get first thing off list
                if command == 'stim':
                    neurons, amplitude, phase_length = params
                    phase_length = int(phase_length*fs_us)
                    
                    # Change signal for each neuron
                    for n in neurons:
                        sig[n,:] = insert_square_wave_advanced(sig[n,:],time_frames,phase_length, amplitude=amplitude)
                    time_frames += phase_length*2
                    
                if command == 'delay':
                    time_frames += fs_ms*params[0]
                    
                #double checking here
                if command == 'next':
                    break 
                    
        return sig,t

Example

In [None]:
# stim_pattern = []
# stim_pattern.append(('stim',[0],150,200))
# stim_pattern.append(('delay',20))
# stim_pattern.append(('stim',[1],200,20))
# #stim_pattern.append(('delay',250))
# #stim_pattern.append(('stim',[0,1,2],150,20))
# signal,t = create_stim_pulse_sequence(stim_pattern, freq_Hz=1)

# <font color="red">Bookend

In [27]:
now = datetime.now(Timezone)
printNow = now.strftime("%Y/%m/%d %H:%M:%S")

print(f"Done at: {printNow}")

Done at: 2024/05/17 18:23:11
