In [24]:
import os, json
import pandas as pd
import numpy as np

from kilosort.io import load_probe

In [12]:
source = '/home/sobolev/nevermind/AG_Pecka/data/processed/013608/013608_ppcSIT_2025-01-09_09-09-55/kilosort'

labels_path = os.path.join(source, 'cluster_KSLabel.tsv')

In [13]:
labels = pd.read_csv(labels_path, sep='\t', header=0)
labels

Unnamed: 0,cluster_id,KSLabel
0,0,mua
1,1,good
2,2,good
3,3,mua
4,4,good
...,...,...
356,356,mua
357,357,mua
358,358,mua
359,359,mua


In [14]:
labels[labels['KSLabel'] == 'good']

Unnamed: 0,cluster_id,KSLabel
1,1,good
2,2,good
4,4,good
5,5,good
6,6,good
...,...,...
346,346,good
352,352,good
353,353,good
354,354,good


In [28]:
ks_path = '/home/sobolev/nevermind/AG_Pecka/data/processed/013608/013608_ppcSIT_2025-01-09_09-09-55/kilosort'

In [71]:
# load probe configuration
with open(os.path.join(ks_path, 'probe.json'), 'r') as json_file:  
    probe = json.load(json_file)

channels = probe['chanMap']
ch_shank_map = np.array(probe['kcoords'], dtype=np.int16) + 1
shanks = np.unique(ch_shank_map)

# load kilosorted spike times / clusters / templates / positions / labels
s_times   = np.load(os.path.join(ks_path, 'spike_times.npy'))  # all spike times of all clusters (1D array)
s_clust   = np.load(os.path.join(ks_path, 'spike_clusters.npy'))  # IDs of clusters for each spike
templates = np.load(os.path.join(ks_path, 'templates.npy'))  # cluster (unit), timepoints, channel
ch_pos    = np.load(os.path.join(ks_path, 'channel_positions.npy'))  # cluster (unit), timepoints, channel
ks_labels = pd.read_csv(os.path.join(ks_path, 'cluster_KSLabel.tsv'), sep='\t', header=0)  # cluster, good / mua

template_maxchans = np.abs(templates).max(axis=1).argmax(axis=1)  # channel with highest AP amplitude for each unit
clu_ch_mapping    = ch_shank_map[template_maxchans]
clu_pos_mapping   = ch_pos[template_maxchans]

# 'good' units
good_idxs = np.where(ks_labels['KSLabel'] == 'good')[0]

all_units = {}
all_pos = {}
for shank in shanks:
    clu_idxs = np.where(clu_ch_mapping == shank)[0]
    sel_clusters = np.intersect1d(good_idxs, clu_idxs)
    
    spiketrains = {}
    for clu_id in sel_clusters:
        spiketrains[clu_id] = s_times[np.where(s_clust == clu_id)[0]]
        
    all_units[shank] = spiketrains
    all_pos[shank] = clu_pos_mapping[sel_clusters]


In [62]:
all_units[4].keys()

dict_keys([66, 81, 101, 128, 129, 130, 131, 140, 141, 142, 152, 186, 187, 236, 237, 238, 239, 254, 255, 257, 266, 267, 268, 269, 287, 289, 296, 298, 345, 346])

In [75]:
all_pos[1]

array([[   8., 2760.],
       [   8., 2760.],
       [   8., 2760.],
       [   8., 2760.],
       [   8., 2760.],
       [   8., 2760.],
       [   8., 2805.],
       [   8., 2775.],
       [   8., 2805.],
       [   8., 2775.],
       [   8., 2790.],
       [   8., 2820.],
       [   8., 2850.],
       [   8., 2850.],
       [   8., 2835.],
       [   8., 3210.],
       [   8., 3225.],
       [   8., 3540.],
       [   8., 4185.]], dtype=float32)

In [77]:
with open(os.path.join('/tmp', 'probe.json'), 'w') as f:
    f.write(json.dumps(probe, indent=2))