In [1]:
import retinapy.mea as mea
import numpy as np
import pandas as pd
import json
from collections import defaultdict
import plotly.express as px
import plotly.graph_objects as go
import plotly.subplots
from sklearn.decomposition import PCA

In [2]:
np.show_config()

blas_info:
    libraries = ['cblas', 'blas', 'cblas', 'blas']
    library_dirs = ['/home/app/mambaforge/lib']
    include_dirs = ['/home/app/mambaforge/include']
    language = c
    define_macros = [('HAVE_CBLAS', None)]
blas_opt_info:
    define_macros = [('NO_ATLAS_INFO', 1), ('HAVE_CBLAS', None)]
    libraries = ['cblas', 'blas', 'cblas', 'blas']
    library_dirs = ['/home/app/mambaforge/lib']
    include_dirs = ['/home/app/mambaforge/include']
    language = c
lapack_info:
    libraries = ['lapack', 'blas', 'lapack', 'blas']
    library_dirs = ['/home/app/mambaforge/lib']
    language = f77
lapack_opt_info:
    libraries = ['lapack', 'blas', 'lapack', 'blas', 'cblas', 'blas', 'cblas', 'blas']
    library_dirs = ['/home/app/mambaforge/lib']
    language = c
    define_macros = [('NO_ATLAS_INFO', 1), ('HAVE_CBLAS', None)]
    include_dirs = ['/home/app/mambaforge/include']
Supported SIMD extensions in this NumPy install:
    baseline = SSE,SSE2,SSE3
    found = SSSE3,SSE41,POPCNT,SS

In [3]:
# Load recording
rec_name = 'Chicken_17_08_21_Phase_00'
rec_name = 'Chicken_04_08_21_Phase_01'
rec = mea.single_3brain_recording(
    rec_name,
    mea.load_stimulus_pattern('../data/ff_noise.h5'),
    mea.load_recorded_stimulus('../data/ff_recorded_noise.pickle'),
    mea.load_response('../data/ff_spike_response.pickle'))
print(rec)

Recording: Chicken_04_08_21_Phase_01, sensor sample rate: 17852.76785 Hz, num samples: 16070787, duration: 900.2 seconds, stimulus pattern shape: (24000, 4),num clusters: 187.


In [5]:
snippets, cluster_ids = mea.labeled_spike_snippets(rec,
                                      snippet_len=120,
                                      snippet_pad=20,
                                      downsample=180)

In [6]:
by_cluster = defaultdict(list)
num_13 = 0
for idx in range(len(cluster_ids)):
    by_cluster[cluster_ids[idx]].append(snippets[idx])
    if cluster_ids[idx] == 13:
        num_13 += 1

spike_limit = 19*60*15 # 19 per second
print(spike_limit)
print(len(by_cluster.keys()))
for c_id in set([cluster_ids[i] for i in range(len(cluster_ids))]):
    if len(by_cluster[c_id]) > spike_limit:
        by_cluster.pop(c_id)
print(f'Remaining clusters: {len(by_cluster.keys())}.')
for k,v in by_cluster.items():
    print(len(v), end=', ')

17100
187
Remaining clusters: 187.
5065, 4916, 4615, 3712, 4498, 3645, 4390, 5068, 2914, 3195, 3814, 4815, 4818, 3019, 2760, 5396, 2317, 2446, 3889, 2278, 2257, 2995, 5375, 2109, 2893, 2427, 2882, 2492, 1579, 1884, 1864, 2420, 1692, 2325, 1462, 1939, 2794, 2924, 2150, 1751, 1481, 1182, 1131, 988, 1901, 1655, 2256, 1347, 1287, 1311, 4043, 1431, 2364, 805, 2429, 1638, 1258, 866, 1170, 1512, 1631, 1113, 1390, 829, 1235, 1300, 2363, 1613, 1663, 939, 1385, 713, 784, 616, 1255, 1972, 786, 1149, 717, 491, 557, 877, 660, 605, 906, 1532, 868, 811, 361, 319, 735, 556, 884, 271, 422, 639, 507, 610, 384, 643, 793, 398, 494, 1128, 725, 329, 257, 927, 451, 748, 402, 689, 570, 439, 309, 206, 626, 242, 339, 287, 329, 633, 159, 597, 201, 187, 214, 211, 146, 506, 303, 438, 216, 179, 379, 451, 93, 142, 199, 236, 286, 275, 211, 114, 144, 196, 170, 56, 114, 92, 88, 121, 58, 103, 23, 87, 438, 216, 42, 25, 14, 61, 64, 168, 23, 137, 33, 26, 200, 49, 61, 99, 44, 53, 36, 137, 28, 62, 31, 19, 52, 44, 41, 48, 17,

In [7]:
def sort_snippets(snippets_):
    pca = PCA(n_components=1)
    flattened_snippets = np.array([s.reshape(-1) for s in snippets_])
    proj = pca.fit_transform(flattened_snippets)
    comp_snip_pairs = list(zip(proj, snippets_))
    comp_snip_pairs.sort(key = lambda t: t[0])
    sorted_snippets = [p[1] for p in comp_snip_pairs]
    return sorted_snippets

for k,v in by_cluster.items():
    by_cluster[k] = sort_snippets(v)

In [8]:
len(json.dumps([arr.tolist() for arr in by_cluster[20][0:1000]]))/2**20

1.9073486328125e-06

In [9]:
for cluster_id, snippets in by_cluster.items():
    path = f'../web/public/data_pca/{cluster_id}.json'
    with open(path, 'w') as f: 
        max_idx = min(len(by_cluster[cluster_id]), 5000)
        json.dump([arr.tolist() for arr in by_cluster[cluster_id][0:max_idx]], f)
    

In [None]:
colormap = pd.DataFrame({
    'names':['Red', 'Green', 'UV', 'Blue', 'Stim'],
    'display_hex':['#ff0a0a', '#0aff0a', '#0a0aff', '#303030', '#0a0a0a']})
def kernel_plot(kernel):
    fig = go.Figure()
    xs = np.arange(kernel.shape[0])
    # Shift the x-axis to have zero in the middle.
    for c in range(1,3):
        fig.add_trace(go.Scatter(x=xs, 
                                 y=kernel[:,c], 
                                 line_color=colormap.loc[c]['display_hex'], 
                                 mode='lines'))
    fig.update_layout(autosize=False,
                      height=300,
                      margin=dict(l=1, r=1, b=1, t=25, pad=1),
                      yaxis_fixedrange=True,
                      showlegend=False,
                      title='Kernel',
                      title_x=0.5,
                      title_pad=dict(l=1, r=1, b=10, t=1),
                      xaxis={'title':'time (ms), with spike at 0'},
                      yaxis={'title':'summed responses'} )
    return fig

In [21]:
path = f'../web/public/data_pca/cluster_ids.json'
with open(path, 'w') as f:
    as_int_array = list(map(lambda x : int(x), by_cluster.keys()))
    json.dump(as_int_array, f)

In [None]:
len(by_cluster[20])

In [None]:
5272/(60*15)