In [1]:
import pathlib
import torch
import numpy as np
import retinapy
import retinapy.mea as mea
import retinapy.spikeprediction as sp
import torch.nn.functional as F
import matplotlib as mpl
import plotly
import plotly.graph_objects as go
import plotly.subplots as subplots
import math
import scipy
import torchinfo
import retinapy.vis
import retinapy._logging
from collections import defaultdict

In [2]:
def sample(ds, idx):
    """Don't forget how to get the samples as tensor tuples."""
    sample = torch.utils.data.dataloader.default_collate([ds[bin_idx]])
    return sample

In [3]:
stimulus_pattern_path = "../data/ff_noise.h5"
stimulus_rec_path = "../data/ff_recorded_noise.pickle"
response_path = "../data/ff_spike_response.pickle"
rec_name = "Chicken_17_08_21_Phase_00"
project_root = pathlib.Path("../")
#ckpt_path = "../out/exp/1/2/1/2_clusters/3/MultiClusterDistField-18ds_992in_50out/checkpoint_best_loss.pth"
model_ckpt = project_root / "out/exp/1/2/1/all_vae_meanl1/0/MultiClusterDistField-18ds_992in_100out/checkpoint_epoch-3.pth"
assert pathlib.Path(model_ckpt).resolve().exists()
assert pathlib.Path(model_ckpt).exists()

In [4]:
rec = mea.single_3brain_recording(
    rec_name,
    mea.load_stimulus_pattern(stimulus_pattern_path),
    mea.load_recorded_stimulus(stimulus_rec_path),
    mea.load_response(response_path),
    #include_clusters={21, 138}
    #include_clusters={21,}
)

In [5]:
config = sp.Configuration(
    downsample=18, input_len=992, output_len=100)
t = sp.MultiClusterDistFieldTGroup.create_trainable([rec], config)
retinapy.models.load_model(t.model, model_ckpt)
t.model.cuda()
t.model.eval();

In [6]:
def create_scatters(xs, model_out, dist, showlegend):
    assert model_out.shape == dist.shape
    assert len(model_out.shape) == 1, "Only support single sample (no batch)."
    res = []
    scatter_model = go.Scatter(x=xs, y=model_out, name='Pred', legendgroup='g1', 
                               line_color='gray',
                               showlegend=showlegend)
    res.append(scatter_model)
    scatter_y = go.Scatter(x=xs, y=dist, name='Actual', 
                           line_color='red',
                           legendgroup='g3', showlegend=showlegend)
    res.append(scatter_y)
    return res

ds = t.val_ds
stride = ds.datasets[0].ds.stride
sample_rate = ds.datasets[0].ds.sample_rate
row_len = ds.datasets[0].ds.num_strided_timesteps
cluster_idx =  4 #4 = 21
def to_bin(ms):
    timestep = ms * sample_rate/1000
    b = cluster_idx * row_len + round(timestep / stride)
    #print((ms,cluster_idx), 
    #      (timestep, cluster_idx),
    #      (ds.datasets[0].ds._decode_index(b)))
    actual_cidx = ds[b]['cluster_id']
    return b

def debug(trainable, time_ms):
    ds = trainable.train_ds
    fig = subplots.make_subplots(rows=len(time_ms), cols=1, shared_xaxes=False, 
                                 vertical_spacing=0.01)
    for i,t in enumerate(time_ms):
        showlegend = (i == 0)
        bin_idx = to_bin(t)
        #sample = torch.utils.data.dataloader.default_collate([trainable.train_ds[t]])
        sample = torch.utils.data.dataloader.default_collate([ds[bin_idx]])
        dist = sample['dist']
        m_dist, _ = trainable.forward(sample)
        m_dist = m_dist.cpu().detach().numpy().squeeze()
        start = t
        end = start + 100
        xs = np.arange(start, end)
        scatters = create_scatters(xs, m_dist, 
            trainable.distfield_to_nn_output(dist).numpy().squeeze(), showlegend=showlegend)
        for scatr in scatters:
            fig.add_trace(scatr, row=i+1, col=1)
    fig.update_layout({
        "margin":{"l":0, "r":0, "t":0, "b":0, "pad":0},
        "autosize":False,
        "height":200*len(time_ms),
        "width":1000,
        #"yaxis_range":[0,150],
        "yaxis_fixedrange":False,
    })
    #fig.update_yaxes(range=[-4,5])
    fig.show()
    
#debug(t, [i for i in range(0, 200, 6)])

## Inspect the encoding space

In [7]:
snippet_len = 400
snippet_pad = 10

In [8]:
def calc_kernels(rec):
    snippet_len = 400
    snippet_pad = 10
    snippets, cluster_ids = mea.labeled_spike_snippets(rec,
                                      snippet_len,
                                      snippet_pad,
                                      downsample=config.downsample)
    by_cluster = defaultdict(list)
    for idx in range(len(cluster_ids)):
        by_cluster[cluster_ids[idx]].append(snippets[idx])
        
    kernels = {}
    for c_id, snips in by_cluster.items():
        snips = np.stack(snips)
        k = np.mean(snips, axis=0)
        kernels[c_id] = k
    return kernels
kernels = calc_kernels(rec)

In [9]:
def all_kernel_figs(c_id_to_kernel):
    figs = {}
    for c_id, k in c_id_to_kernel.items():
        t_0 = snippet_len - snippet_pad
        fig = retinapy.vis.kernel(k, t_0=t_0, bin_duration_ms=1000/992)
        fig.update_layout({
            "yaxis": {"visible":False},
            "xaxis": {"title": None},
            "title": None,
            "width": 200,
            "height": 200,
        })
        figs[c_id] = fig
    return figs
k_fig  = all_kernel_figs(kernels)

In [21]:
import dash
import jupyter_dash
c_to_sd = {}
def latent_fig(trainable, cluster_idxs=None):
    rec = trainable.train_ds.datasets[0].recording
    if cluster_idxs is None:
        cluster_idxs = np.arange(len(rec.cluster_ids))
        
    ids = [rec.cluster_ids[i] for i in cluster_idxs]
    with torch.no_grad():
        zs = []
        for cidx in cluster_idxs:
            _, z, z_logvar = trainable.model.encode(torch.Tensor([0]).cuda(), torch.Tensor([cidx]).cuda())
            c_to_sd[rec.cluster_ids[cidx]] = torch.sqrt(torch.exp(z_logvar)).cpu().numpy().squeeze()
            zs.append(z.cpu().numpy().squeeze())
        zs = np.array(zs)
    fig = go.Figure()
    xs = zs[:,0]
    y = zs[:,1]
    scatter = go.Scatter(x=xs, y=y, text=[str(i) for i in ids], 
                         textposition="bottom center",
                         mode='markers+text',
                        )
    fig.add_trace(scatter)
    fig.update_layout(
        xaxis={"range": [-3,3]},
        yaxis={"range": [-3,3]},
        height=1000,
        width=1000)
    fig.update_traces(
        hoverinfo="none",
        hovertemplate=None)
    return fig

fig = latent_fig(t)
#fig.add_shape(type="circle",
#    xref="x", yref="y",
#    fillcolor="PaleTurquoise",
#    x0=0, y0=0, x1=1, y1=1,
#    line_color="LightSeaGreen",
#)    

def _display_hover(hoverData):
    #print("hover called")
    rec = t.train_ds.datasets[0].recording
    pt = hoverData["points"][0]
    pt_num = pt["pointNumber"]
    c_id = rec.cluster_ids[pt_num]
    bbox = pt["bbox"]
    children = [dash.dcc.Graph(figure=k_fig[c_id]),]
    too_big = 10
    w = min(too_big, c_to_sd[c_id][0] * 3)
    h = min(too_big, c_to_sd[c_id][1] * 3)
    x0 = pt["x"] - w/2
    x1 = pt["x"] + w/2
    y0 = pt["y"] - h/2
    y1 = pt["y"] + h/2
    #print(f"{w:.3f},{h:.3f}, {x0:.1f}, {x1:.1f}, {y0:.1f}, {y1:.1f}")
    fig.update_layout(shapes=[{"type":"circle", "x0":x0, "x1":x1, "y0":y0, "y1":y1,
                "opacity":0.4, "line_color":"black", "fillcolor":"orange"}],
                     uirevision=True)
    return True, bbox, children, fig

In [12]:
app = jupyter_dash.JupyterDash(__name__)
app.layout = dash.html.Div(
    className="container",
    children=[
        dash.dcc.Graph(id="graph", figure=fig, clear_on_unhover=True),
        dash.dcc.Tooltip(id="graph-tooltip", direction="bottom"),
    ])

@app.callback(
    dash.Output("graph-tooltip", "show"),
    dash.Output("graph-tooltip", "bbox"),
    dash.Output("graph-tooltip", "children"),
    dash.Output("graph", "figure"),
    dash.Input("graph", "hoverData"),
)
def display_hover(hoverData):
    if hoverData is None:
        return False, dash.no_update, dash.no_update, dash.no_update
    return _display_hover(hoverData)


In [13]:
del app.config._read_only["requests_pathname_prefix"]
#app.kernel.do_shutdown(True)
app.run_server(mode="jupyterlab", debug=True, 
               dev_tools_ui=True, host="0.0.0.0",  dev_tools_hot_reload=True,  port=8050)

0.373,0.206, 1.2, 1.5, 2.3, 2.5
0.204,0.207, 0.2, 0.4, 1.6, 1.8
0.269,0.168, -0.2, 0.0, 1.4, 1.6
0.252,0.164, -0.3, -0.1, 1.4, 1.5
0.254,0.180, -0.7, -0.4, 1.9, 2.1
0.316,0.224, -0.6, -0.3, 1.9, 2.2
0.254,0.180, -0.7, -0.4, 1.9, 2.1
0.099,0.247, -1.1, -1.0, 2.1, 2.3
0.092,11.173, -2.5, -2.4, -3.8, 7.4
0.112,0.154, -2.4, -2.3, 1.9, 2.0
0.098,0.403, -1.4, -1.3, 1.2, 1.6
0.527,0.172, -0.7, -0.1, 0.4, 0.6
0.204,0.207, 0.2, 0.4, 1.6, 1.8
0.881,0.466, -0.4, 0.5, 2.7, 3.2
0.733,0.424, -0.4, 0.3, 2.8, 3.2
0.425,0.394, 2.1, 2.5, 2.1, 2.5
0.529,0.227, -0.6, -0.1, -0.8, -0.6
0.529,0.227, -0.6, -0.1, -0.8, -0.6
0.197,0.358, -1.5, -1.3, -0.2, 0.1
0.173,0.505, -1.5, -1.4, -0.4, 0.1
0.115,0.602, -2.0, -1.9, 0.1, 0.7
0.527,0.172, -0.7, -0.1, 0.4, 0.6
0.527,0.172, -0.7, -0.1, 0.4, 0.6
0.196,0.321, -0.9, -0.7, 0.4, 0.7
0.197,0.358, -1.5, -1.3, -0.2, 0.1
0.977,0.119, 0.6, 1.6, -0.4, -0.3
0.408,0.080, 1.0, 1.4, -0.3, -0.2
0.402,0.058, 1.3, 1.7, -0.0, 0.0
0.302,0.064, 1.5, 1.8, 0.0, 0.1
0.402,0.058, 1.3, 1