## Position using Trodes Tracking from NWB file

**Note: make a copy of this notebook and run the copy to avoid git conflicts in the future**

This is a tutorial on how to process position that was extracted through Trodes Tracking (online or offline) using the Spyglass pipeline used in Loren Frank's lab, UCSF. It will walk through defining your parameters and processing the raw position to extract a centroid and orientation, and inserting the resulting information into the `PositionOutput` table.<br>
-> This tutorial assumes you've completed [tutorial 0](0_intro.ipynb)<br>
**Note 2: Make sure you are running this within the spyglass Conda environment**

In [None]:
import matplotlib.pyplot as plt
import datajoint as dj
import spyglass.position.v1 as sgp
import spyglass.common as sgc
from spyglass.utils.nwb_helper_fn import get_nwb_copy_filename
from spyglass.position import PositionOutput

In [None]:
dj.config["filepath_checksum_size_limit"] = 1 * 1024**3
dj.config.save_global()

This data pipeline takes the 2D position (in video pixels) of the green and red LEDs tracked by Trodes, and computes:
+ position (in cm)
+ orientation (in radians)
+ velocity (in cm/s)
+ speed (in cm/s)

We can then check the quality of the head position and direction by plotting the them on the video along with the oringal green and red LEDs.

This notebook will take you through this process for one dataset.

### 1. Loading the session data

First let us make sure that the session we want to analyze is inserted into the `RawPosition` table

In [None]:
nwb_file_name = "chimi20200216_new.nwb"
nwb_copy_file_name = get_nwb_copy_filename(nwb_file_name)
sgc.RawPosition() & {"nwb_file_name":nwb_copy_file_name}

### 2. Setting the parameters for running the position pipeline

The parameters for the position pipeline are set by the `TrodesPosParams` table. `default` is the name of the standard set of parameters. As usual, if you want to change the parameters, you can insert your own into the table.

The parameters are as follows:

+ `max_separation` is the maxmium acceptable distance (in cm) between the red and green LEDs. When the distance between the LEDs becomes greater than this number, the times are marked as NaNs and inferred by interpolation. This is useful parameter when the inferred red or green LED position tracks a reflection instead of the true red or green LED position. It is set to 9.0 cm by default.
+ `max_speed` is the maximum plausible speed (in cm/s) the animal can move at. Times when the speed is greater than this threshold are marked as NaNs and inferred by interpolation. This can be useful in preventing big, sudden jumps in position. It is set to 300.0 cm/s by default.
+ `position_smoothing_duration` controls how much the red and green LEDs are smoothed before computing the average of their position to get the head position. It is in units of seconds.
+ `speed_smoothing_std_dev` controls how much the head speed is smoothed. It corresponds to the standard deviation of the Gaussian kernel used to smooth the speed and is in units of seconds. It is set to 0.100 seconds by default.
+ `front_led1` is either 1 or 0 indicating True or False. It controls which LED is treated as the front LED and the back LED, which is important for calculating the head direction.
    + 1 indicates that the LED corresponding to `xloc`, `yloc` in the `RawPosition` table as the front LED and the LED corresponding to `xloc2`, `yloc2` as the back LED.
    + 0 indicates that `xloc`, `yloc` are treated as the back LED and `xloc2`, `yloc2` are treated as the front LED.

We can get a list of potential parameters using the method `TrodesPosParams.get_accepted_params()`<br>
And view the default setting for the parameters using `get_default`.

In [None]:
print(f"accepted parameters:\n{sgp.TrodesPosParams.get_accepted_params()}")
print(f"default parameters:\n{sgp.TrodesPosParams.get_default()['params']}")

Now we pair the parameters we chose with an interval from our specified NWB file and insert into `TrodesPosSelection`.<br>
We can define the interval we want to use via its `interval_list_name`, which we can see in the `IntervalList` table.<br>
>```python
>sgc.IntervalList & {"nwb_file_name": nwb_copy_file_name
>```
A cool trick doctors don't want you to know: `interval_list_name` = 'pos (epoch# - 1) valid times' (e.g. epoch 3's interval_list_name: 'pos 2 valid times')

Let's choose the interval corresponding to `pos 1 valid times` in `chimi20200216_new_.nwb`.

We first look at the "raw" position data now to see the input into the pipeline. The raw position is in the `RawPosition` table and corresponds to the inferred position of the red and green LEDs from the video (using an algorithm in Trodes). It is in units of pixels. The number of time points corresponds to when the position tracking was turned on and off (and so may not have the same number of time points as the video itself).

We can retrieve the data in the `RawPosition` table for a given interval using a special method called `fetch1_dataframe`. It returns the position of the red and green LEDs as a pandas dataframe where time is the index. The columns of the dataframe are:
+ `xloc`, `yloc` are the x- and y-position of one of the LEDs
+ `xloc2`, `yloc2` are the x- and y-position of the other LEDs.

In [None]:
interval_list_name = "pos 1 valid times"
start_key = {"nwb_file_name": nwb_copy_file_name,
             "interval_list_name": interval_list_name}
raw_position_df = (sgc.RawPosition() & start_key).fetch1_dataframe()
raw_position_df

Let's just quickly plot the two LEDs to get a sense of the inputs to the pipeline:

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.plot(raw_position_df.xloc, raw_position_df.yloc, color="green")
ax.plot(raw_position_df.xloc2, raw_position_df.yloc2, color="red")
ax.set_xlabel("x-position [pixels]", fontsize=18)
ax.set_ylabel("y-position [pixels]", fontsize=18)
ax.set_title("Raw Position", fontsize=28)

Okay, now that we understand what the inputs to the pipeline are, let's associate a set of parameters with a given interval.<br>
To associate parameters with a given interval, we insert them into the `TrodesPosSelection` table.<br>
Here we associate the `default` Trodes position parameters with the interval `pos 1 valid times`:

In [None]:
trodes_pos_params_name = "default"
sgp.TrodesPosSelection.insert1({**start_key, "trodes_pos_params_name": trodes_pos_params_name,}, skip_duplicates=True,)

Now let's check to see if we've inserted correctly:

In [None]:
sgp.TrodesPosSelection()
trodes_key =(sgp.TrodesPosSelection() &  {**start_key, "trodes_pos_params_name": trodes_pos_params_name,}).fetch1("KEY")

### 3. Running the position pipeline and retrieving the results

Now that we have associated the parameters with the interval we want to run, we can finally run the pipeline for that interval.

We run the pipeline using the `populate` method on the `TrodesPosV1` table.

In [None]:
sgp.TrodesPosV1.populate(trodes_key)

Now we can make sure that the entry was inserted into `TrodesPosV1` properly.

In [None]:
TrodesPosV1() & trodes_key

We can see that each NWB file, interval, and parameter set is now associated with a newly created analysis NWB file and object IDs that correspond to our newly computed data. This isn't as useful to work with so we will use another method below to actually retrieve the data for a given interval.

When we populate `TrodesPosV1` the entry is automatically inserted into the merge table for the position pipeline, `PositionOutput`.

In [None]:
PositionOutput() & trodes_key

You may notice that entry has a different set of primary keys than the entry in `TrodesPosV1`. No need to worry as all of those keys are preserved in part table `PositionOutput.TrodesPosV1`

In [None]:
PositionOutput.TrodesPosV1() & trodes_key

In order to retrieve the results of the computation, we use a special method called `fetch1_dataframe` from the `PositionOutput` table that will retrieve the position pipeline results as a pandas DataFrame. Time is set as the index of the dataframe.

This will only work for a single interval so we need to specify the NWB file and the interval.

This dataframe has the following columns:
+ `position_x`, `position_y`: the x,y position (in cm).
+ `orientation`: The direction relative to the bottom left corner (in radians)
+ `velocity_x`, `velocity_y`: the directional change in position over time (in cm/s)
+ `speed`: the magnitude of the change in position over time (in cm/s)

In [None]:
position_info = (PositionOutput() & trodes_key).fetch1_dataframe()
position_info

If you are not familiar with pandas, the time variable is set as the index. It can be accessed using `.index` on the dataframe.

In [None]:
position_info.index

### 4. Examining the results

We should always spot check our results to verify that the pipeline worked correctly.

#### Plots
Let's plot some of the variables first:

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.plot(position_info.position_x, position_info.position_y)
ax.set_xlabel("x-position [cm]", fontsize=18)
ax.set_ylabel("y-position [cm]", fontsize=18)
ax.set_title("Position", fontsize=28)


In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.plot(position_info.velocity_x, position_info.velocity_y)
ax.set_xlabel("x-velocity [cm/s]", fontsize=18)
ax.set_ylabel("y-velocity [cm/s]", fontsize=18)
ax.set_title("Velocity", fontsize=28)


In [None]:
fig, ax = plt.subplots(1, 1, figsize=(25, 3))
ax.plot(position_info.index, position_info.speed)
ax.set_xlabel("Time", fontsize=18)
ax.set_ylabel("Speed [cm/s]", fontsize=18)
ax.set_title("Speed", fontsize=28)
ax.set_xlim((position_info.index.min(), position_info.index.max()))


### Video

These look reasonable but it is better to evaluate these variables by plotting the results on the video where we can see how they correspond. 

The video will appear in the current working directory when it is done.

Bonus points if you made it this far... We can use the `PositionVideo` table to create a video that overlays just the centroid and orientation (regardless of upstream source) on the behavioral video. This table uses the parameter `plot` to determine whether to plot the entry deriving from the DLC arm or from the Trodes arm of the position pipeline. This parameter also accepts 'all', which will plot both (if they exist) in order to compare results.

In [None]:
sgp.PositionVideoSelection().insert1(
    {
        **trodes_key,
        'trodes_position_id': 1,
        'plot': 'Trodes',
        'output_dir': '/home/dgramling/Src/' # Change this to your save location
    }
)

### 4. Upsampling position data

Sometimes you need the position data to be in a different rate than it is sampled in, such as when decoding in 2 ms time bins. You can use the upsampling parameters to get this data:
+ `is_upsampled` controls whether upsampling happens. If it is 1 then there is upsampling, and if it is 0 then upsampling does not happen.
+ `upsampling_sampling_rate` is the rate you want to upsample to. For example position is typically recorded at 33 frames per seconds and you may want to upsample up to 500 frames per second.
+ `upsampling_interpolation_method` is the interpolation method used for upsampling. It is set to linear by default. See the methods available for pandas.DataFrame.interpolate to get a list of the methods.

In [None]:
sgp.TrodesPosParams.insert1(
    {
        "trodes_pos_params_name": "default_decoding",
        "is_upsampled": 1,
        "upsampling_sampling_rate": 500,
    },
    skip_duplicates=True,
)
sgp.TrodesPosParams()

In [None]:
sgp.TrodesPosSelection.insert1({**start_key, "position_info_param_name": "default_decoding",}, skip_duplicates=True,)
sgp.TrodesPosSelection()

In [None]:
TrodesPosV1.populate(start_key)

In [None]:
upsampled_position_info = (PositionOutput() & {**start_key, "position_info_param_name": "default_decoding",}).fetch1_dataframe()
upsampled_position_info

In [None]:
fig, axes = plt.subplots(
    1, 2, figsize=(20, 10), sharex=True, sharey=True, constrained_layout=True
)
axes[0].plot(position_info.position_x, position_info.position_y)
axes[0].set_xlabel("x-position [cm]", fontsize=18)
axes[0].set_ylabel("y-position [cm]", fontsize=18)
axes[0].set_title("Position", fontsize=28)

axes[1].plot(
    upsampled_position_info.position_x, upsampled_position_info.position_y
)
axes[1].set_xlabel("x-position [cm]", fontsize=18)
axes[1].set_ylabel("y-position [cm]", fontsize=18)
axes[1].set_title("Upsampled Position", fontsize=28)


In [None]:
fig, axes = plt.subplots(
    2, 1, figsize=(25, 6), sharex=True, sharey=True, constrained_layout=True
)
axes[0].plot(position_info.index, position_info.speed)
axes[0].set_xlabel("Time", fontsize=18)
axes[0].set_ylabel("Speed [cm/s]", fontsize=18)
axes[0].set_title("Speed", fontsize=28)
axes[0].set_xlim((position_info.index.min(), position_info.index.max()))

axes[1].plot(upsampled_position_info.index, upsampled_position_info.speed)
axes[1].set_xlabel("Time", fontsize=18)
axes[1].set_ylabel("Speed [cm/s]", fontsize=18)
axes[1].set_title("Upsampled Speed", fontsize=28)


In [None]:
fig, axes = plt.subplots(
    1, 2, figsize=(20, 10), sharex=True, sharey=True, constrained_layout=True
)
axes[0].plot(position_info.velocity_x, position_info.velocity_y)
axes[0].set_xlabel("x-velocity [cm/s]", fontsize=18)
axes[0].set_ylabel("y-velocity [cm/s]", fontsize=18)
axes[0].set_title("Velocity", fontsize=28)

axes[1].plot(
    upsampled_position_info.velocity_x, upsampled_position_info.velocity_y
)
axes[1].set_xlabel("x-velocity [cm/s]", fontsize=18)
axes[1].set_ylabel("y-velocity [cm/s]", fontsize=18)
axes[1].set_title("Upsampled Velocity", fontsize=28)
