In [None]:
import retinapy.mea as mea
import numpy as np
import pandas as pd
import json
from collections import defaultdict
import plotly.express as px
import plotly.graph_objects as go
import plotly.subplots
from sklearn.decomposition import PCA
import pathlib

In [None]:
# Load recording
project_root = pathlib.Path("../")
web_dir = project_root / "snippet_viewer"
data_dir = project_root / "data/ff_noise_recordings"
out_dir = web_dir / 'resources' / 'snippets'
if not out_dir.exists():
    out_dir.mkdir()
recs = mea.load_3brain_recordings(data_dir)

In [None]:
for rec in recs:
    snippets, cluster_ids = mea.labeled_spike_snippets(rec,
                                          snippet_len=90,
                                          snippet_pad=10,
                                          downsample=180)
    for snips, cluster_id in zip(snippets, cluster_ids):
        path = out_dir / rec.name / f'{cluster_id}.json'
        if not path.parent.exists():
            path.parent.mkdir(parents=False)
        with open(path, 'w') as f: 
            json.dump([arr.tolist() for arr in snips], f)

In [None]:
by_cluster = defaultdict(list)
num_13 = 0
for idx in range(len(cluster_ids)):
    by_cluster[cluster_ids[idx]].append(snippets[idx])
    if cluster_ids[idx] == 13:
        num_13 += 1

spike_limit = 19*60*15 # 19 per second
print(spike_limit)
print(len(by_cluster.keys()))
for c_id in set([cluster_ids[i] for i in range(len(cluster_ids))]):
    if len(by_cluster[c_id]) > spike_limit:
        by_cluster.pop(c_id)
print(f'Remaining clusters: {len(by_cluster.keys())}.')
for k,v in by_cluster.items():
    print(len(v), end=', ')

In [None]:
def sort_snippets(snippets_):
    pca = PCA(n_components=1)
    flattened_snippets = np.array([s.reshape(-1) for s in snippets_])
    proj = pca.fit_transform(flattened_snippets)
    comp_snip_pairs = list(zip(proj, snippets_))
    comp_snip_pairs.sort(key = lambda t: t[0])
    sorted_snippets = [p[1] for p in comp_snip_pairs]
    return sorted_snippets

do_sort = False
if do_sort:
    for k,v in by_cluster.items():
        by_cluster[k] = sort_snippets(v)

In [None]:
len(json.dumps([arr.tolist() for arr in by_cluster[20][0:1000]]))/2**20

In [None]:
for cluster_id, snippets in by_cluster.items():
    path = out_dir / f'{cluster_id}.json'
    if not path.parent.exists():
        path.parent.mkdir(parents=False)
    with open(path, 'w') as f: 
        max_idx = min(len(by_cluster[cluster_id]), 5000)
        json.dump([arr.tolist() for arr in by_cluster[cluster_id][0:max_idx]], f)
    

In [None]:
colormap = pd.DataFrame({
    'names':['Red', 'Green', 'UV', 'Blue', 'Stim'],
    'display_hex':['#ff0a0a', '#0aff0a', '#0a0aff', '#303030', '#0a0a0a']})
def kernel_plot(kernel):
    fig = go.Figure()
    xs = np.arange(kernel.shape[0])
    # Shift the x-axis to have zero in the middle.
    for c in range(1,3):
        fig.add_trace(go.Scatter(x=xs, 
                                 y=kernel[:,c], 
                                 line_color=colormap.loc[c]['display_hex'], 
                                 mode='lines'))
    fig.update_layout(autosize=False,
                      height=300,
                      margin=dict(l=1, r=1, b=1, t=25, pad=1),
                      yaxis_fixedrange=True,
                      showlegend=False,
                      title='Kernel',
                      title_x=0.5,
                      title_pad=dict(l=1, r=1, b=10, t=1),
                      xaxis={'title':'time (ms), with spike at 0'},
                      yaxis={'title':'summed responses'} )
    return fig

In [None]:
path = out_dir / 'cluster_ids.json'
with open(path, 'w') as f:
    as_int_array = list(map(lambda x : int(x), by_cluster.keys()))
    json.dump(as_int_array, f)

In [None]:
len(by_cluster[20])

In [None]:
5272/(60*15)