<a href="https://colab.research.google.com/github/talmo/sleap-mit-tutorial/blob/master/SLEAP_Interactive_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup environment

First, make sure to set the Colab runtime to use a GPU by going to the menu **Runtime** -> **Change runtime type** -> Under **Hardware accelerator**, select GPU and restart the session (click on **Connect** or **RAM/Disk** on the top right and select **Connect to a hosted runtime**).

Install SLEAP from PyPI:

In [7]:
%tensorflow_version 2.x
!pip install sleap==1.0.6

import sleap



Download a set of trained models and a test clip that was not in the training data:

In [8]:
!curl -L --output models.zip https://www.dropbox.com/s/vi13zysbzpq9c19/models.zip?dl=1
!unzip -o models.zip

!curl -L --output test_clip.15s.mp4 https://www.dropbox.com/s/nvmr4jnhmdoiwdk/test_clip.15s.mp4?dl=1

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 54.8M  100 54.8M    0     0  9533k      0  0:00:05  0:00:05 --:--:-- 12.7M
Archive:  models.zip
  inflating: models/baseline_model.centroids/best_model.h5  
  inflating: models/baseline_model.centroids/initial_config.json  
  inflating: models/baseline_model.centroids/training_config.json  
  inflating: models/baseline_model.topdown/best_model.h5  
  inflating: models/baseline_model.topdown/initial_config.json  
  inflating: models/baseline_model.topdown/training_config.json  
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0  

# Run inference

First, we'll load the models that we just downloaded and create the predictors that will run the actual tracking:

In [9]:
from sleap.nn.inference import TopdownPredictor, Tracker

predictor = TopdownPredictor.from_trained_models(
    centroid_model_path="models/baseline_model.centroids",
    confmap_model_path="models/baseline_model.topdown",
)
predictor.tracker = Tracker.make_tracker_by_name(tracker="simple")

Next, we'll set up the input data using a `sleap.Video` reader. This class allows you to access a video as if it were a numpy array:

In [10]:
video = sleap.Video.from_filename("test_clip.15s.mp4")
video.shape

(375, 1024, 1024, 1)

Finally, we'll run inference and gather the results into a `sleap.Labels` object. This is like a project file that constains all the predictions and metadata associated with the video:

In [11]:
labels_pr = predictor.predict(video, make_labels=True)

INFO:sleap.nn.inference:Finished 750 examples in 31.65 seconds (inference + postprocessing)
INFO:sleap.nn.inference:examples/s = 23.69462671254522
INFO:sleap.nn.inference:Predicted 375 labeled frames in 32.997 secs [11.4 FPS]


# Visualization

Now that we have the results, we can visualize them frame-by-frame:

In [12]:
%matplotlib inline
from ipywidgets import interactive
import matplotlib.pyplot as plt

def plot_frame(i=0):
  lf = labels_pr[i]
  sleap.nn.viz.plot_img(lf.image, scale=0.6)
  sleap.nn.viz.plot_instances(lf, color_by_track=True, tracks=labels_pr.tracks);

interactive_plot = interactive(plot_frame, i=(0, len(labels_pr) - 1, 1))
output = interactive_plot.children[-1]
output.layout.height = "600px"
interactive_plot

interactive(children=(IntSlider(value=0, description='i', max=374), Output(layout=Layout(height='600px'))), _d…

Visualizing in Colab is not ideal, and in the typical workflow you would inspect the result of tracking locally.

Let's download the results so that we can load it in the SLEAP GUI:

In [13]:
# Save the predictions.
sleap.Labels.save_file(labels_pr, "predictions.slp")

# Download.
from google.colab import files
files.download("predictions.slp")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>