# Generating templates and recordings with non-rigid drifts

This notebook shows how to generate templates and recordings from scratch using MEArec.

The last part of the notebook showcases some plotting routines available through the package.

In [None]:
import MEArec as mr
import MEAutility as mu
import yaml
from pprint import pprint
import matplotlib.pylab as plt
import numpy as np
from pathlib import Path

%matplotlib notebook

## Load default configuration files

First, let's load the default configuration of MEArec

In [None]:
default_info, mearec_home = mr.get_default_config()
pprint(default_info)

## Generating and saving templates

In [None]:
# define cell_models folder
cell_folder = default_info['cell_models_folder']
template_params = mr.get_default_templates_params()
pprint(template_params)

Now let's change a few parameters and generate templates. We need to generate templates with drift. In order to use the advanced drift features, we have to make sure that all templates drift roughly in the same direction and by the same distance.

In [None]:
template_params["drifting"] = True
template_params["drift_steps"] = 30
# this ensures that all cells drift on the same z trajectory, with a small xy variation
template_params["drift_xlim"] = [-5, 5]
template_params["drift_ylim"] = [-5, 5]
template_params["drift_zlim"] = [100, 100]
template_params["max_drift"] = 200

# let's generate 50 templates per cell models (total 650 templates)
template_params["n"] = 20
template_params['probe'] = 'Neuropixels-32'

In [None]:
templates_path = Path("data/test_drift_templates.h5")

In [None]:
if not templates_path.is_file():
    # the templates are not saved, but the intracellular simulations are saved in 'templates_folder'
    tempgen = mr.gen_templates(cell_models_folder=cell_folder, params=template_params, n_jobs=13, verbose=1)
    # this will take a few minutes...
else:
    tempgen = mr.load_templates(templates_path, return_h5_objects=False)

The `tempgen` variable is a `TemplateGenerator` object. It contains the `templates`, `locations`, `rotations`, and `celltypes` of the generated templates.

In [None]:
print('Templates shape', tempgen.templates.shape)

In [None]:
# plot locations
probe = mu.return_mea(info=tempgen.info["electrodes"])

ax_probe = mu.plot_probe(probe)
for loc in tempgen.locations[::5]:
    ax_probe.plot([loc[0, 1], loc[-1, 1]], [loc[0, 2], loc[-1, 2]], alpha=0.7)

We can now save the `TemplateGenerator` object in h5 format.

In [None]:
if not templates_path.is_file():
    # save templates in h5 format
    mr.save_template_generator(tempgen, filename=templates_path)

## Generating and saving recordings

Once the templates have been generated, we can use them to generate recordings. Let's fisrt load and take a look at the default parameters:

In [None]:
recordings_params = mr.get_default_recordings_params()
pprint(recordings_params)

Similarly to the templates generation, we can change th eparameters that we pass to the `gen_recordings` function.
In this case we will keep the default parameters.

In [None]:
# 10 min
recordings_params["spiketrains"]["duration"] = 600

# 100 Excitatory, 20 inhibitory (the main difference is morphology and avg firing rates)
recordings_params["spiketrains"]["n_exc"] = 8
recordings_params["spiketrains"]["n_inh"] = 2

# set template selection params

recordings_params["templates"]["min_amp"] = 30
recordings_params["templates"]["min_dist"] = 20 # um 

# other settings
recordings_params["recordings"]["filter"] = True

# noise level and model
recordings_params["recordings"]["noise_level"] = 10
recordings_params["recordings"]["noise_mode"] = "distance-correlated"

# set chunk duration (IMPORTANT for RAM usage and parallelization)
recordings_params["recordings"]["chunk_duration"] = 10

# drifting option
recordings_params["recordings"]["drifting"] = True
recordings_params["recordings"]["slow_drift_velocity"] = 30
recordings_params["recordings"]["slow_drift_amplitude"] = 30
recordings_params["recordings"]["t_start_drift"] = 100
recordings_params["recordings"]["t_end_drift"] = 500

recordings_params["recordings"]["drift_mode_probe"] = 'rigid'


recordings_params["recordings"]["filter"] = False


# (optional) set seeds for reproducibility 
# (e.g. if you want to maintain underlying activity, but change e.g. noise level)
recordings_params['seeds']['spiketrains'] = None
recordings_params['seeds']['templates'] = None
recordings_params['seeds']['convolution'] = None
recordings_params['seeds']['noise'] = None



In [None]:
recgen = mr.gen_recordings(templates='data/test_drift_templates.h5', 
                           params=recordings_params, verbose=True, 
                           n_jobs=1)

In [None]:
recordings_params["recordings"]["drift_mode_probe"] = 'non-rigid'
recordings_params["recordings"]["non_rigid_gradient_mode"] = 'linear'

In [None]:
recgen_non_rigid = mr.gen_recordings(templates='data/test_drift_templates.h5', 
                                     params=recordings_params, verbose=True,
                                     n_jobs=10)

In [None]:
# TODO fix this

In [None]:
# save recordings in h5 format
mr.save_recording_generator(recgen, filename='data/test_drift_recordings_rigid_middle2.h5')
# mr.save_recording_generator(recgen_non_rigid, filename='data/test_drift_recordings_nonrigid_middle1.h5')

In [None]:
recgen_loaded = mr.load_recordings('data/test_drift_recordings_rigid_middle1.h5')

In [None]:
debug

In [None]:
def plot_drifts(recgen):
    fig, ax = plt.subplots()
    assert recgen.drift_dict is not None, "No drift info is available"
    
    drift_dict = recgen.drift_dict
    drift_vectors = drift_dict["drift_vectors_idxs"]
    drift_times = drift_dict["drift_times"]
    
    if drift_vectors.ndim == 1:
        # rigid
        rigid = True
    else:
        # non-rigid
        rigid = False
        
    locations = recgen.template_locations
    for ist, st in enumerate(recgen.spiketrains):
        loc = locations[ist]
        drifting = False
        if "drifting" in st.annotations:
            if st.annotations["drifting"]:
                drifting = True
        if drifting:
            if rigid:
                loc_drift = loc[drift_vectors, 2]
            else:
                loc_drift = loc[drift_vectors[:, ist], 2]
        else:
            n_steps = loc.shape[0]
            loc_drift = [loc[n_steps // 2, 2]] * len(drift_times)
        ax.plot(drift_times, loc_drift, label=f"Unit {ist}")
    ax.legend()
    

In [None]:
plot_drifts(recgen)

plot_drifts(recgen_non_rigid)