In [None]:
import pathlib
import torch
import numpy as np
import retinapy
import retinapy.mea as mea
import retinapy.spikeprediction as sp
import torch.nn.functional as F
import matplotlib as mpl
import plotly
import plotly.express
import plotly.graph_objects as go
import plotly.subplots as subplots
import math
import scipy
import torchinfo
import retinapy.vis
import retinapy._logging
from collections import defaultdict

In [None]:
def sample(ds, idx):
    """Don't forget how to get the samples as tensor tuples."""
    sample = torch.utils.data.dataloader.default_collate([ds[bin_idx]])
    return sample

In [None]:
stimulus_pattern_path = "../data/ff_noise.h5"
stimulus_rec_path = "../data/ff_recorded_noise.pickle"
response_path = "../data/ff_spike_response.pickle"
rec_name = "Chicken_17_08_21_Phase_00"
project_root = pathlib.Path("../")
#model_ckpt = project_root / "out/exp/1/2/2/vae_1rec_bs128_z2zo10_b10e-4/2/MultiClusterDistField-18ds_992in_100out/"
out_dir = project_root / "out/exp/1/2/2/vae_bs128_z2zo20_b10e-4/0"
out_dir = project_root / "out/exp/1/2/2/vae_bs256_z2d_b25e-4/0"
model_ckpt = out_dir / "MultiClusterDistField-18ds_992in_100out/recovery.pth"
arg_file = out_dir / "args.yaml"
opt = sp.args_from_yaml(arg_file)
assert pathlib.Path(model_ckpt).resolve().exists()
assert pathlib.Path(model_ckpt).exists()

In [None]:
def load_all_recs():
    # Load the data.
    recordings = mea.load_3brain_recordings(
        stimulus_pattern_path,
        stimulus_rec_path,
        response_path,
    )
    ## Filter the recording with different sample rate
    skip_rec_names = {"Chicken_21_08_21_Phase_00"}
    recordings = [r for r in recordings if r.name not in skip_rec_names]
    return recordings

def load_rec():
    rec = mea.single_3brain_recording(
        rec_name,
        mea.load_stimulus_pattern(stimulus_pattern_path),
        mea.load_recorded_stimulus(stimulus_rec_path),
        mea.load_response(response_path),
        #include_clusters={21, 138}
        #include_clusters={21,}
    )
    return [rec]

recs = load_all_recs()
dc_recs = mea.decompress_recordings(recs[0:3], downsample=18, num_workers=23)

In [None]:
retinapy.vis.KernelPlots.generate(dc_recs,
                                  snippet_len=1000,
                                  snippet_pad=100,
                                  out_dir="../resources/kernel_plots",
                                  mini=True,
                                  num_workers=30)

In [None]:
config = sp.Configuration(
    downsample=18, input_len=992, output_len=100)
t = sp.MultiClusterDistFieldTGroup.create_trainable(recs, config, opt=opt)
retinapy.models.load_model(t.model, model_ckpt)
t.model.cuda()
t.model.eval();

In [None]:
def create_scatters(xs, model_out, dist, showlegend):
    assert model_out.shape == dist.shape
    assert len(model_out.shape) == 1, "Only support single sample (no batch)."
    res = []
    scatter_model = go.Scatter(x=xs, y=model_out, name='Pred', legendgroup='g1', 
                               line_color='gray',
                               showlegend=showlegend)
    res.append(scatter_model)
    scatter_y = go.Scatter(x=xs, y=dist, name='Actual', 
                           line_color='red',
                           legendgroup='g3', showlegend=showlegend)
    res.append(scatter_y)
    return res

ds = t.val_ds
stride = ds.datasets[0].ds.stride
sample_rate = ds.datasets[0].ds.sample_rate
row_len = ds.datasets[0].ds.num_strided_timesteps
cluster_idx =  4 #4 = 21
def to_bin(ms):
    timestep = ms * sample_rate/1000
    b = cluster_idx * row_len + round(timestep / stride)
    #print((ms,cluster_idx), 
    #      (timestep, cluster_idx),
    #      (ds.datasets[0].ds._decode_index(b)))
    actual_cidx = ds[b]['cluster_id']
    return b

def debug(trainable, time_ms):
    ds = trainable.train_ds
    fig = subplots.make_subplots(rows=len(time_ms), cols=1, shared_xaxes=False, 
                                 vertical_spacing=0.01)
    for i,t in enumerate(time_ms):
        showlegend = (i == 0)
        bin_idx = to_bin(t)
        #sample = torch.utils.data.dataloader.default_collate([trainable.train_ds[t]])
        sample = torch.utils.data.dataloader.default_collate([ds[bin_idx]])
        dist = sample['dist']
        m_dist, _ = trainable.forward(sample)
        m_dist = m_dist.cpu().detach().numpy().squeeze()
        start = t
        end = start + 100
        xs = np.arange(start, end)
        scatters = create_scatters(xs, m_dist, 
            trainable.distfield_to_nn_output(dist).numpy().squeeze(), showlegend=showlegend)
        for scatr in scatters:
            fig.add_trace(scatr, row=i+1, col=1)
    fig.update_layout({
        "margin":{"l":0, "r":0, "t":0, "b":0, "pad":0},
        "autosize":False,
        "height":200*len(time_ms),
        "width":1000,
        #"yaxis_range":[0,150],
        "yaxis_fixedrange":False,
    })
    #fig.update_yaxes(range=[-4,5])
    fig.show()
    
debug(t, [i for i in range(0, 200, 6)])

## Inspect the encoding space

In [None]:
snippet_len = 400
snippet_pad = 10

In [None]:
def spike_counts(trainable):
    figs = {}
    spike_counts = {}
    for r_idx, ds in enumerate(trainable.val_ds.datasets):
        ds_rec = ds.recording
        for c_idx, c_id in enumerate(ds_rec.cluster_ids):
            num_spikes = np.count_nonzero(ds_rec.spikes[:,c_idx])
            spike_counts[(r_idx, c_idx)] = num_spikes
    return spike_counts
            
spike_counts = spike_counts(t)

In [None]:
import dash
import jupyter_dash
rc_to_sd = {}
sd_list = []
def latent_fig(trainable):
    with torch.no_grad():
        zs = []
        colors = []
        for r_idx, ds in enumerate(trainable.val_ds.datasets):
            for c_idx, c_id in enumerate(ds.recording.cluster_ids):
                _, z, z_logvar = trainable.model.encode_vae(torch.Tensor([r_idx]).cuda(), torch.Tensor([c_idx]).cuda())
                sd = torch.sqrt(torch.exp(z_logvar)).cpu().numpy().squeeze()
                rc_to_sd[(r_idx, c_idx)] = sd
                sd_list.append(sd)
                zs.append(z.cpu().numpy().squeeze())
                colors.append(spike_counts[(r_idx, c_idx)])
        zs = np.array(zs)
        colors = np.array(colors)
    fig = go.Figure()
    xs = zs[:,0]
    y = zs[:,1]
    scatter = go.Scatter(x=xs, y=y, 
                         marker={"color":np.log(colors), "colorscale": "Blues"},#plotly.express.colors.sequential.Blues},
                         mode='markers',
                        )
    fig.add_trace(scatter)
    fig.update_layout(
        xaxis={"range": [-3,3]},
        yaxis={"range": [-3,3]},
        height=1000,
        width=800)
    fig.update_traces(
        hoverinfo="none",
        hovertemplate=None)
    return fig

fig = latent_fig(t)
#fig.add_shape(type="circle",
#    xref="x", yref="y",
#    fillcolor="PaleTurquoise",
#    x0=0, y0=0, x1=1, y1=1,
#    line_color="LightSeaGreen",
#)    
#fig = retinapy.vis.latent_fig(t)
fig

In [None]:
kernel_plots = retinapy.vis.KernelPlots('../resources/kernel_plots')
def _display_hover(hoverData):
    #print("hover called")
    rec = t.train_ds.datasets[0].recording
    pt = hoverData["points"][0]
    pt_num = pt["pointNumber"]
    #r_id,c_idx = list(rc_to_sd.items())[pt_num]
    bbox = pt["bbox"]
    #  TODO: children needs to be img, maybe b64 encoded.
    children = [dash.dcc.Graph(figure=k_fig_list[pt_num]),]
    too_big = 10
    w = min(too_big, sd_list[pt_num][0] * 3)
    h = min(too_big, sd_list[pt_num][1] * 3)
    x0 = pt["x"] - w/2
    x1 = pt["x"] + w/2
    y0 = pt["y"] - h/2
    y1 = pt["y"] + h/2
    #print(f"{w:.3f},{h:.3f}, {x0:.1f}, {x1:.1f}, {y0:.1f}, {y1:.1f}")
    #fig.update_layout(shapes=[{"type":"circle", "x0":x0, "x1":x1, "y0":y0, "y1":y1,
    #            "opacity":0.4, "line_color":"black", "fillcolor":"orange"}],
    #                 uirevision=True)
    return True, bbox, children#, fig

In [None]:
app = jupyter_dash.JupyterDash(__name__)
app.layout = dash.html.Div(
    className="container",
    children=[
        dash.dcc.Graph(id="graph", figure=fig, clear_on_unhover=True),
        dash.dcc.Tooltip(id="graph-tooltip", direction="bottom"),
    ])

@app.callback(
    dash.Output("graph-tooltip", "show"),
    dash.Output("graph-tooltip", "bbox"),
    dash.Output("graph-tooltip", "children"),
    #dash.Output("graph", "figure"),
    dash.Input("graph", "hoverData"),
)
def display_hover(hoverData):
    if hoverData is None:
        return False, dash.no_update, dash.no_update#, dash.no_update
    return _display_hover(hoverData)


In [None]:
#del app.config._read_only["requests_pathname_prefix"]
#app.kernel.do_shutdown(True)
app.run_server(mode="jupyterlab", debug=True, 
               dev_tools_ui=True, host="0.0.0.0",  dev_tools_hot_reload=True,  port=8050)