In [1]:
import os, sys
import numpy as np
import matplotlib.pyplot as plt

unit_matcher_path = os.getcwd()
prototype_path = os.path.abspath(os.path.join(unit_matcher_path, os.pardir))
project_path = os.path.abspath(os.path.join(prototype_path, os.pardir))
lab_path = os.path.abspath(os.path.join(project_path, os.pardir))
sys.path.append(project_path)
os.chdir(project_path)
print(project_path)

from _prototypes.unit_matcher.main import format_cut, run_unit_matcher, map_unit_matches_first_session, map_unit_matches_sequential_session
from _prototypes.unit_matcher.read_axona import read_sequential_sessions, temp_read_cut
from _prototypes.unit_matcher.session import compare_sessions
from _prototypes.unit_matcher.write_axona import format_new_cut_file_name
from x_io.rw.axona.batch_read import make_study
from _prototypes.unit_matcher.waveform import time_index, derivative, derivative2, morphological_points

c:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit


In [2]:
""" ONLY EDIT THE SETTINGS IN THIS CELL & RESTART BEFORE RUNNING """

# If a setting is not used for your analysis (e.g. smoothing_factor), just pass in an arbitrary value or pass in 'None'
STUDY_SETTINGS = {
    'ppm': 511,
    'smoothing_factor': None,
    'useMatchedCut': False, # Set to False if you want to use runUnitMatcher, set to True after to load in matched.cut file
}


# Switch devices to True/False based on what is being used (to be extended for more devices in future)
device_settings = {'axona_led_traacker': True, 'implant': True}
# Make sure implant metadata is correct, change if not, AT THE MINIMUM leave implant_type: tetrode
implant_settings = {'implant_type': 'tetrode', 'implant_geometry': 'square', 'wire_length': 25, 'wire_length_units': 'um', 'implant_units': 'uV'}

# WE ASSUME DEVICE AND IMPLANT SETTINGS ARE CONSISTENCE ACROSS SESSIONS, IF THIS IS NOT THE CASE PLEASE LET ME KNOW

# Set channel count + add device/implant settings
SESSION_SETTINGS = {
    'channel_count': 4, # default is 4, can change to other but code will check how many tetrode files are present regardless
    'devices': device_settings,
    'implant': implant_settings,
}

STUDY_SETTINGS['session'] = SESSION_SETTINGS

settings_dict = STUDY_SETTINGS

In [None]:
data_dir = lab_path + r'\neuroscikit_test_data\single_sequential'

# animal metadata, insert in settings for now, later can read from a file or similar
# CHANGE TO READ FROM FILE
animal = {'animal_id': 'id', 'species': 'mouse', 'sex': 'F', 'age': 1, 'weight': 1, 'genotype': 'type', 'animal_notes': 'notes'}

# later will change to set bool based on whether file is present or not,currently based on user input
# CHANGE TO DEFAULT = TRUE. FORCE FALSE if data not present
devices = {'axona_led_tracker': True, 'implant': True}

# same as animal metadata
# CHANGE TO READ FROM FILE
implant = {'implant_id': 'id', 'implant_type': 'tetrode', 'implant_geometry': 'square', 'wire_length': 25, 'wire_length_units': 'um', 'implant_units': 'uV'}

# aggregate settings for session
# MAKE FUNCTION TO READ METADATA ABOVE + fill session dict
session_settings = {'channel_count': 4, 'animal': animal, 'devices': devices, 'implant': implant}

# add in cross-session/global settings
# MAKE FUNCTION TO ADD SESSION DICTS TO GLOBAL SETTINGS DICT
settings_dict_unmatched = {'ppm': 511, 'session': [session_settings,session_settings], 'smoothing_factor': 5, 'useMatchedCut': False} # --> compute matched cut labels + write file

# make one where matched unit cut file is used
settings_dict_matched = {'ppm': 511, 'session': [session_settings,session_settings], 'smoothing_factor': 5, 'useMatchedCut': True} # --> use matched cut to load data

In [3]:
data_dir = lab_path + r'\neuroscikit_test_data\single_sequential'
settings_dict_matched = settings_dict


In [None]:
# Run unit matching on non-matched study, will save new matched cut file (around 3-4 mins runtime per 2 sesssions compared)
unmatched_study = run_unit_matcher([data_dir], settings_dict_unmatched)

In [4]:
# For now, set up new study with matched unit cut file. In the future will have a converter to copy overwrite existing study or copy it with edits rather than loading in a new one (=slow)  
matched_study = make_study([data_dir], settings_dict_matched)

[['c:\\Users\\aaoun\\OneDrive - cumc.columbia.edu\\Desktop\\HussainiLab\\neuroscikit_test_data\\single_sequential/1-13_20210621-34-50x50cm-1500um-Test1.pos', 'c:\\Users\\aaoun\\OneDrive - cumc.columbia.edu\\Desktop\\HussainiLab\\neuroscikit_test_data\\single_sequential/1-13_20210621-34-50x50cm-1500um-Test1.3', 'c:\\Users\\aaoun\\OneDrive - cumc.columbia.edu\\Desktop\\HussainiLab\\neuroscikit_test_data\\single_sequential/1-13_20210621-34-50x50cm-1500um-Test1_3.cut', 'c:\\Users\\aaoun\\OneDrive - cumc.columbia.edu\\Desktop\\HussainiLab\\neuroscikit_test_data\\single_sequential/1-13_20210621-34-50x50cm-1500um-Test1_3_matched.cut'], ['c:\\Users\\aaoun\\OneDrive - cumc.columbia.edu\\Desktop\\HussainiLab\\neuroscikit_test_data\\single_sequential/1-13_20210621-34-50x50cm-1500um-Test2.pos', 'c:\\Users\\aaoun\\OneDrive - cumc.columbia.edu\\Desktop\\HussainiLab\\neuroscikit_test_data\\single_sequential/1-13_20210621-34-50x50cm-1500um-Test2.3', 'c:\\Users\\aaoun\\OneDrive - cumc.columbia.edu\\Des

KeyError: 'animal'

In [None]:
# call make animals to organze sesssions by animal
matched_study.make_animals()

In [None]:
session1 = matched_study.animals[0].sessions['session_1']
ensemble1 = matched_study.animals[0].ensembles['session_1']
session2 = matched_study.animals[0].sessions['session_2']
ensemble2 = matched_study.animals[0].ensembles['session_2']

unmatched_ensembles1 = unmatched_study.animals[0].ensembles['session_1']
unmatched_ensembles2 = unmatched_study.animals[0].ensembles['session_2']

In [None]:
# Check they have same number of matched cells (can remove this)
assert ensemble1.get_label_ids() == ensemble1.get_label_ids()

In [None]:
# To view morphological points
def plot_waveform_points(waveform, time_step):
    t = time_index(waveform, time_step)
    d_waveform = derivative(waveform, time_step)
    d2_waveform = derivative2(waveform, time_step)
    p1, p2, p3, p4, p5, p6 = morphological_points(t, waveform, d_waveform, d2_waveform, time_step)
    fig, ax1 = plt.subplots()
    color = 'tab:red'
    ax1.set_xlabel('time (mS)')
    ax1.set_ylabel('mV', color=color)
    ax1.plot(t, waveform, color=color)
    ax1.plot(p1.t, p1.v, 'o', color=color, label='p1')
    ax1.plot(p3.t, p3.v, 'o', color=color, label='p3')
    ax1.plot(p5.t, p5.v, 'o', color=color, label='p5')
    ax1.tick_params(axis='y', labelcolor=color)
    ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis
    color = 'tab:orange'
    ax2.set_ylabel('mV/mS', color=color)  # we already handled the x-label with ax1
    ax2.plot(t, d_waveform, color=color, linestyle='--')
    ax2.plot(p2.t, p2.dv, 'o', color=color, label='p2')
    ax2.plot(p4.t, p4.dv, 'o', color=color, label='p4')
    ax2.plot(p6.t, p6.dv, 'o', color=color, label='p6')
    ax2.tick_params(axis='y', labelcolor=color)
    fig.tight_layout()  # otherwise the right y-label is slightly clipped
    plt.legend()
    plt.show()

In [None]:
# example plot of morphological points
plot_waveform_points(ensemble1.cells[5].signal[-1,3], .02)

In [None]:
""" Plot session 1 (top) & session 2 (bottom) MATCHED units """

pair_count = len(ensemble1.get_label_ids())

for i in range(pair_count):
# for i in range(3):

    fig = plt.figure(figsize=(18,6))

    axes = []

    jsd1 = round(unmatched_ensembles1.cells[i].cluster.stats_dict['JSD'], 2)
    jsd2 = round(unmatched_ensembles2.cells[i].cluster.stats_dict['JSD'], 2)

    assert jsd1 == jsd2

    waveforms1 = ensemble1.cells[i].signal
    waveforms2 = ensemble2.cells[i].signal

    avg_waveforms1 = np.mean(waveforms1, axis=0)
    avg_waveforms2 = np.mean(waveforms2, axis=0)

    assert waveforms1.shape[1] == avg_waveforms1.shape[0]

    for j in range(avg_waveforms1.shape[0]):
        ax1 = plt.subplot(2,avg_waveforms1.shape[0],j+1)
        ax2 = plt.subplot(2,avg_waveforms1.shape[0],j+1 + 4)

        ax1.plot(waveforms1[:,int(j)].T, color='gray', lw=0.5, alpha=0.5)
        ax2.plot(waveforms2[:,int(j)].T, color='gray', lw=0.5, alpha=0.5)

        ax1.plot(avg_waveforms1[int(j)], color='k', lw=2)
        ax2.plot(avg_waveforms2[int(j)], color='k', lw=2)

        ax1.set_title('Channel ' + str(int(j+1)))
        ax2.set_title('Channel ' + str(int(j+1)))

        axes.append(ax1)
        axes.append(ax2)

    for ax in axes:
        ax.set_xlabel('Bin Number')
        ax.set_ylabel('Waveform')

    fig.suptitle('Session 1 (top) & 2 (bottom) - Unit ' + str(i+1) + ': JSD = ' + str(jsd1))

    fig.tight_layout()
    plt.show()


In [None]:
""" Plot session 1 (left) & session 2 (right) MATCHED units """

pair_count = len(ensemble1.get_label_ids())

for i in range(pair_count):
# for i in range(2):

    fig = plt.figure(figsize=(6,12))

    axes = []

    waveforms1 = ensemble1.cells[i].signal
    waveforms2 = ensemble2.cells[i].signal

    avg_waveforms1 = np.mean(waveforms1, axis=0)
    avg_waveforms2 = np.mean(waveforms2, axis=0)

    assert waveforms1.shape[1] == avg_waveforms1.shape[0]

    for j in range(0,avg_waveforms1.shape[0]*2,2):
        ax1 = plt.subplot(avg_waveforms1.shape[0],2,j+1)
        ax2 = plt.subplot(avg_waveforms1.shape[0],2,j+2)

        ax1.plot(waveforms1[:,int(j/2)].T, color='gray', lw=0.5, alpha=0.5)
        ax2.plot(waveforms2[:,int(j/2)].T, color='gray', lw=0.5, alpha=0.5)

        ax1.plot(avg_waveforms1[int(j/2)], color='k', lw=2)
        ax2.plot(avg_waveforms2[int(j/2)], color='k', lw=2)

        ax1.set_title('Channel ' + str(int(j/2+1)))
        ax2.set_title('Channel ' + str(int(j/2+1)))

        axes.append(ax1)
        axes.append(ax2)

    for ax in axes:
        ax.set_xlabel('Bin Number')
        ax.set_ylabel('Waveform')

    fig.suptitle('Session 1 (left) & 2 (right) - Unit ' + str(i+1))

    fig.tight_layout()
    plt.show()
