# Demo the features in `vis.py` 

# 1. Setup
Imports, fuctions, config, loading etc.

### 1.1 Imports

In [None]:
import pathlib
import torch
import numpy as np
import retinapy
import retinapy.mea as mea
import retinapy.spikeprediction as sp
import plotly
import plotly.graph_objects as go
import plotly.subplots as subplots
import retinapy.vis as vis

### 1.2 Helper functions

In [None]:
def get_sample(ds, ids):
    """Don't forget how to get the samples as tensor tuples."""
    if type(ids) == int:
        ids = [ids]
    sample = torch.utils.data.dataloader.default_collate(
        [ds[i] for i in ids])
    return sample

### 1.3 Config 

In [None]:
project_root = pathlib.Path("../")

### 1.4 Load stimulus & response

In [None]:
stimulus_pattern_path = project_root / "data/ff_noise.h5"
stimulus_rec_path = project_root / "data/ff_recorded_noise.pickle"
response_path = project_root / "data/ff_spike_response.pickle"
rec_name = "Chicken_17_08_21_Phase_00"

rec = mea.single_3brain_recording(
    rec_name,
    mea.load_stimulus_pattern(stimulus_pattern_path),
    mea.load_recorded_stimulus(stimulus_rec_path),
    mea.load_response(response_path),
)

### 1.5 Load model

In [None]:
model_ckpt = project_root / "out/exp/1/2/2/catvae_z=2d_l1_insert/3/MultiClusterDistField-18ds_992in_100out/recovery.pth"
assert pathlib.Path(model_ckpt).resolve().exists()

In [None]:
torch.set_grad_enabled(False)
config = sp.Configuration(downsample=18, input_len=992, output_len=100)
t = sp.MultiClusterDistFieldTGroup.create_trainable([rec], config)
retinapy.models.load_model(t.model, model_ckpt)
t.model.cuda();
t.model.eval();

## 2 Test 
Test out some functions in `vis.py`

### 2.1 View a stimulus

In [None]:
def view_stimulus():
    sample = get_sample(t.val_ds, ids=0)
    fig = vis.stimulus_fig(sample['snippet'][0, 0:4], 
                           start_ms=0,
                           bin_duration_ms=t.sample_period_ms)
    fig.show()
view_stimulus()

### 2.2 View distance fields, actual vs. predicted

In [None]:
def view_distfield():
    batch_sample = get_sample(t.val_ds, ids=[0, 1, 2, 3])
    dist_pred = t.forward(batch_sample)[0].cpu().numpy()
    dist_actual = t.distfield_to_nn_output(batch_sample['dist']).cpu().numpy()
    fig = vis.distfield_fig(dist_actual, dist_pred,
                          start_ms=0,
                          bin_duration_ms = t.sample_period_ms,
                          stride_bins=t.val_ds.stride,
                          log_space=True)
    fig.show()
view_distfield()

### 2.3 View of a distfield model input and output

In [None]:
def view_model_in_out():
    sample = t.val_ds[2]
    in_sample = torch.utils.data.dataloader.default_collate([sample])
    
    stimulus = sample['snippet'][0 : mea.NUM_STIMULUS_LEDS]
    spikes = sample["snippet"][-1]
    target_dist = t.distfield_to_nn_output(in_sample['dist']).cpu().numpy()[0]
    model_out, loss = t.forward(in_sample)
    model_out = model_out[0].cpu().numpy()
    fig = vis.distfield_model_in_out(stimulus, spikes, target_dist, model_out, start_ms=0, bin_duration_ms=1.0)
    fig.show()
view_model_in_out()

### 2.4 Inspect the 2D latent space

In [None]:
def inspect_latent_space():
    