## Position using DeepLabCut from a Pre-Trained DLC Project

**Note: make a copy of this notebook and run the copy to avoid git conflicts in the future**

This is a tutorial on how to extract position given a pre-trained DeepLabCut (DLC) model using the Spyglass pipeline used in Loren Frank's lab, UCSF. It will walk through adding your DLC model to Spyglass, executing pose estimation on a novel behavioral video, processing the pose estimation output to extract a centroid and orientation, and inserting the resulting information into the `IntervalPositionInfo` table.<br>
-> This tutorial assumes you've completed [tutorial 0](0_intro.ipynb)<br>
**Note 2: Make sure you are running this within the spyglass Conda environment)**

In [None]:
from pathlib import Path, PosixPath, PurePath
import os
import numpy as np
import pandas as pd
import pynwb
import datajoint as dj
import spyglass.common as sgc
import spyglass.position as sgp

#### Here is a schematic showing the tables used in this notebook.<br>
<img src='dlc_existing.png' width="1500" height="400">

### Table of Contents<a id='ToC'></a>
[`DLCProject`](#DLCProject)<br>
[`DLCModel`](#DLCModel)<br>
[`DLCPoseEstimation`](#DLCPoseEstimation)<br>
[`DLCSmoothInterp`](#DLCSmoothInterp)<br>
[`DLCCentroid`](#DLCCentroid)<br>
[`DLCOrientation`](#DLCOrientation)<br>
[`DLCPos`](#DLCPos)<br>
[`DLCPosVideo`](#DLCPosVideo)<br>
[`PosSource`](#PosSource)<br>
[`IntervalPositionInfo`](#IntervalPositionInfo)<br>

#### [DLCProject](#ToC) <a id='DLCProject'></a>

First, we can visualize the contents of the BodyPart table. This table will store standard names of body parts used within DLC models throughout the lab with a concise description.<br>Please do not add to this table unless necessary.

In [None]:
sgp.BodyPart()

To use an existing DLC project we can use the `insert_existing_project` method on the `DLCProject` table.<br>This function will return a dictionary that can be used to query `DLCProject` in the future and expects:<br>
>`project_name`: a short, unique, descriptive name of your project that will be referenced throughout the pipeline<br>`lab_team`: the name of your team from the Spyglass table `LabTeam`<br>`config_path`: string of the path to your existing DLC project's config.yaml<br>`bodyparts`: a list of bodyparts used in your project (optional)<br>`frames_per_video`: number of frames to extract for training from each video (optional)

In [None]:
project_key = sgp.DLCProject.insert_existing_project(
    project_name='test',
    bodyparts=['redLED_C', 'greenLED', 'redLED_L', 'redLED_R', 'tailBase'],
    lab_team='JG_DG',
    config_path='/cumulus/deeplabcut/LH_4LED_model-Daniel_Sharon-2022-07-15/config.yaml',
    frames_per_video=200,
    skip_duplicates=True)

In [None]:
sgp.DLCProject()

In [None]:
project_key = {"project_name": "test", "config_path": "/nimbus/deeplabcut/projects/LH_4LED_model-Daniel_Sharon-2022-07-15/config.yaml"}

#### [DLCModel](#ToC) <a id='DLCModel'></a>

Lets take a look at the `DLCModelInput` table next. This table has `dlc_model_name` and `project_name` as primary keys and `project_path` as a secondary key. 

In [None]:
sgp.DLCModelInput()

Next we can modify the `project_key` to replace `config_path` with `project_path` to fit with the fields in `DLCModelInput`

In [None]:
print(f"current project_key:{project_key}")
project_key['project_path'] = os.path.dirname(project_key['config_path'])
del project_key['config_path']
print(f"updated project_key: {project_key}")

Here we can set a unique name for our model using the `dlc_model_name` variable.<br>We then combine this with the updated `project_key` to insert into `DLCModelInput`.

In [None]:
dlc_model_name = 'LH_4LED_model'
sgp.DLCModelInput().insert1({'dlc_model_name' : dlc_model_name,
                             **project_key},
                              skip_duplicates=True)
sgp.DLCModelInput()

Inserting an entry into `DLCModelInput` will also populate `DLCModelSource`. `DLCModelSource` is a table that is used to switch between models trained using Spyglass and pre-existing projects.

In [None]:
sgp.DLCModelSource() & project_key

Notice the `source` field in the table above. It will only accept "FromImport" or "FromUpstream" as entries. Let's checkout the `FromImport` part table attached to `DLCModelSource` below.

In [None]:
sgp.DLCModelSource.FromImport() & project_key

Next we'll get ready to populate the `DLCModel` table, which holds all the relevant information for both pre-trained models and models trained within Spyglass.<br>First we'll need to determine a set of parameters for our model to select the correct model file.<br>We can visualize a default set below:

In [None]:
sgp.DLCModelParams.get_default()

> Here is the syntax to add your own parameter set:
>```python
dlc_model_params_name = "make_this_yours"
params = {
            "params": {},
            "shuffle": 1,
            "trainingsetindex": 0,
            "model_prefix": "",
        }
sgp.DLCModelParams.insert1({"dlc_model_params_name": dlc_model_params_name, "params": params}, skip_duplicates=True)
```

Now let's fetch the primary keys from `DLCModelSource` to make our lives a bit easier when we insert into `DLCModelSelection`.

In [None]:
temp_model_key = (sgp.DLCModelSource.FromImport() & project_key).fetch1('KEY')

And insert into `DLCModelSelection` to allow for population of `DLCModel`

In [None]:
sgp.DLCModelSelection().insert1({
    **temp_model_key,
    'dlc_model_params_name': 'default'},
    skip_duplicates=True)

Let's populate `DLCModel`!!

In [None]:
model_key = (sgp.DLCModelSelection & temp_model_key).fetch1('KEY')
sgp.DLCModel.populate(model_key)

And of course make sure it populated correctly

In [None]:
sgp.DLCModel() & model_key

#### [DLCPoseEstimation](#ToC) <a id='DLCPoseEstimation'></a>

<div class="alert alert-block alert-warning">
<b>
The following steps should be run on a GPU cluster</b></div>

Alright, now that we brought our trained model into Spyglass we're ready to set-up Pose Estimation on a behavioral video of your choice.<br>For this tutorial, you can choose to use an epoch of your choice, we can also use the one specified below. If you'd like to use your own video, just specify the `nwb_file_name` and `epoch` number and make sure it's in the `VideoFile` table!

In [None]:
nwb_file_name = 'J1620210604_.nwb'
epoch = 14

In [None]:
sgc.VideoFile() & {'nwb_file_name': nwb_file_name,
                  'epoch': epoch}

<div class="alert alert-block alert-info">
    <b>Setting up Pose Estimation</b><br>
<code>gputouse</code> determines which GPU core to use for pose estimation. Run the cell below to determine which core has space and set the <code>gputouse</code> variable accordingly.

In [None]:
! nvidia-smi

<div class="alert alert-block alert-warning">
Set GPU core here</div>

In [None]:
gputouse = ## 0-9

To set up pose estimation, we need to make sure a few things are in order. Using `insert_estimation_task` will take care of these steps for us!<br>Briefly, it will convert out video to be in .mp4 format (DLC struggles with .h264) and determine the directory in which we'll store the pose estimation results.<br>
>**`task_mode`** determines whether or not populating `DLCPoseEstimation` runs a new pose estimation, or loads an existing. Use _'trigger'_ unless you've already run this specific pose estimation.<br>**`video_file_num`** will be 0 in almost all cases.

In [None]:
pose_estimation_key = sgp.DLCPoseEstimationSelection.insert_estimation_task(
    {
        'nwb_file_name': nwb_file_name,
        'epoch': epoch,
        'video_file_num': 0,
        **model_key
    },
    task_mode='load',
    params={'gputouse': gputouse, 'videotype': 'mp4'}
)

And now we populate `DLCPoseEstimation`! This might take a bit...

In [None]:
sgp.DLCPoseEstimation().populate(pose_estimation_key)

#### [DLCSmoothInterp](#ToC) <a id='DLCSmoothInterp'></a>

In [None]:
si_params_name = 'JG_SI_params'
sgp.DLCSmoothInterpParams().insert1({
    'dlc_si_params_name': si_params_name,
    "params":
    {"smoothing_params": {
        "smoothing_duration": 0.05,
        "smooth_method": "moving_avg",
    },
     "interp_params": {
         "likelihood_thresh": 0.95,
     },
     "max_plausible_speed": 300.0,
     "speed_smoothing_std_dev": 0.100,
     "sampling_rate": 50,
    }}, skip_duplicates=True)

In [None]:
si_key = pose_estimation_key.copy()
fields = list(sgp.DLCSmoothInterpSelection.fetch().dtype.fields.keys())
si_key = {key: val for key,val in si_key.items() if key in fields}

In [None]:
si_key

In [None]:
sgp.DLCSmoothInterpSelection.insert1(
    {
        **si_key,
        'bodypart': 'greenLED',
        'dlc_si_params_name': 'JG_SI_params',
    },
    skip_duplicates=True)

In [None]:
bodyparts = (sgp.DLCPoseEstimation.BodyPart & pose_estimation_key).fetch('bodypart')
print(bodyparts)

In [None]:
sgp.DLCSmoothInterpSelection.insert(
    [
        {
            **si_key,
            'bodypart': 'greenLED',
            'dlc_si_params_name': 'JG_SI_params',
        },
        {
            **si_key,
            'bodypart': 'redLED_C',
            'dlc_si_params_name': 'JG_SI_params',
        },
        {
            **si_key,
            'bodypart': 'redLED_L',
            'dlc_si_params_name': 'JG_SI_params',
        },
        {
            **si_key,
            'bodypart': 'redLED_R',
            'dlc_si_params_name': 'JG_SI_params',
        },
    ],
    skip_duplicates=True)

In [None]:
sgp.DLCSmoothInterp().populate(si_key)

In [None]:
(sgp.DLCSmoothInterp() & {**si_key,'bodypart': 'greenLED'}).fetch1_dataframe().plot.scatter(x='x',y='y',s=1)

#### [DLCSmoothInterpCohort](#ToC) <a id='DLCSmoothInterpCohort'></a>

In [None]:
cohort_key = si_key.copy()
if 'bodypart' in cohort_key:
    del cohort_key['bodypart']
if 'dlc_si_params_name' in cohort_key:
    del cohort_key['dlc_si_params_name']
cohort_key['dlc_si_cohort_selection_name'] = '4LEDs'
cohort_key['bodyparts_params_dict'] = {'greenLED': si_params_name,
                                       'redLED_L': si_params_name,
                                       'redLED_C': si_params_name,
                                       'redLED_R': si_params_name,}

In [None]:
sgp.DLCSmoothInterpCohortSelection().insert1(cohort_key, skip_duplicates=True)

In [None]:
sgp.DLCSmoothInterpCohort.populate(cohort_key)

In [None]:
(sgp.DLCSmoothInterpCohort.BodyPart() & {**cohort_key, 'bodypart': 'greenLED'}).fetch1_dataframe()

#### [DLCCentroid](#ToC) <a id='DLCCentroid'></a>

In [None]:
sgp.DLCCentroidParams.get_default()

In [None]:
centroid_params = {
    'centroid_method': 'four_led_centroid',
    'points' : {
        'greenLED': 'greenLED',
        'redLED_L': 'redLED_L',
        'redLED_C': 'redLED_C',
        'redLED_R': 'redLED_R',},
    'speed_smoothing_std_dev': 0.100,
}
centroid_params_name = 'JG_4LED'
sgp.DLCCentroidParams.insert1({'dlc_centroid_params_name': centroid_params_name,
                                'params': centroid_params},
                                skip_duplicates=True)
centroid_key = cohort_key.copy()
fields = list(sgp.DLCCentroidSelection.fetch().dtype.fields.keys())
centroid_key = {key: val for key,val in centroid_key.items() if key in fields}
centroid_key['dlc_centroid_params_name'] = centroid_params_name

In [None]:
sgp.DLCCentroidSelection.insert1(centroid_key, skip_duplicates=True)

In [None]:
sgp.DLCCentroidSelection()

In [None]:
sgp.DLCCentroid.populate(centroid_key)

In [None]:
(sgp.DLCCentroid() & centroid_key).fetch1_dataframe().plot.scatter(
    x='position_x',
    y='position_y',
    c='speed',
    colormap='viridis',
    alpha=0.5,
    s=0.5,
    figsize=(15,15))

#### [DLCOrientation](#ToC) <a id='DLCOrientation'></a>

In [None]:
sgp.DLCOrientationParams.get_default()

In [None]:
fields = list(sgp.DLCOrientationSelection.fetch().dtype.fields.keys())
orient_key = {key: val for key,val in cohort_key.items() if key in fields}
orient_key['dlc_orientation_params_name'] = 'default'

In [None]:
sgp.DLCOrientationSelection().insert1(orient_key, skip_duplicates=True)

In [None]:
sgp.DLCOrientationSelection()

In [None]:
sgp.DLCOrientation().populate(orient_key)

In [None]:
(sgp.DLCOrientation() & orient_key).fetch1_dataframe()

#### [DLCPos](#ToC) <a id='DLCPos'></a>

In [None]:
sgp.DLCPos()

In [None]:
fields = list(sgp.DLCPos.fetch().dtype.fields.keys())
dlc_key = {key: val for key,val in centroid_key.items() if key in fields}
dlc_key['dlc_si_cohort_centroid'] = centroid_key['dlc_si_cohort_selection_name']
dlc_key['dlc_si_cohort_orientation'] = orient_key['dlc_si_cohort_selection_name']
dlc_key['dlc_orientation_params_name'] = orient_key['dlc_orientation_params_name']

In [None]:
dlc_key

In [None]:
sgp.DLCPosSelection().insert1(dlc_key, skip_duplicates=True)

In [None]:
sgp.DLCPos().populate(dlc_key)

In [None]:
(sgp.DLCPos() & dlc_key).fetch1_dataframe()

In [None]:
(sgp.DLCPos() & dlc_key).fetch1('pose_eval_result')

#### [DLCPosVideo](#ToC) <a id='DLCPosVideo'></a>

In [None]:
sgp.DLCPosVideoParams.insert_default()

In [None]:
params = {
    "percent_frames": 0.05,
    "incl_likelihood": True,
}
sgp.DLCPosVideoParams.insert1(
    {"dlc_pos_video_params_name": "five_percent", "params": params},
    skip_duplicates=True)

In [None]:
sgp.DLCPosVideoSelection.insert1(
    {
        **dlc_key,
        "dlc_pos_video_params_name": "five_percent"
    },
    skip_duplicates=True)

In [None]:
sgp.DLCPosVideo().populate(dlc_key)

#### [PosSource](#ToC) <a id='PosSource'></a>

In [None]:
sgp.PosSource()

#### [IntervalPositionInfo](#ToC)<a id='IntervalPositionInfo'></a>

In [None]:
int_pos_info_key = (sgp.PosSource & dlc_key).fetch1('KEY')

In [None]:
int_pos_info_key

In [None]:
sgp.IntervalPositionInfoSelection().insert1(int_pos_info_key, skip_duplicates=True)

In [None]:
sgp.IntervalPositionInfo.populate(int_pos_info_key)

In [None]:
(sgp.IntervalPositionInfo() & int_pos_info_key).fetch1_dataframe()

In [None]:
sgp.PositionVideoSelection().insert1(
    {
        'nwb_file_name': 'J1620210604_.nwb',
        'interval_list_name': 'pos 13 valid times',
        'trodes_position_id': 0,
        'dlc_position_id': 1,
        'plot': 'DLC',
        'output_dir': '/home/dgramling/Src/'
    }
)

In [None]:
sgp.PositionVideo.populate({'plot': 'DLC'})

In [None]:
(sgp.IntervalPositionInfo() & {'nwb_file_name': 'J1620210604_.nwb'}).fetch1('KEY')

### [`Return To Table of Contents`](#ToC)<br>