<p align="center">
  <h1 align="center">TAPIR: Tracking Any Point with per-frame Initialization and temporal Refinement</h1>
  <p align="center">
    <a href="http://www.carldoersch.com/">Carl Doersch</a>
    ·
    <a href="https://yangyi02.github.io/">Yi Yang</a>
    ·
    <a href="https://scholar.google.com/citations?user=Jvi_XPAAAAAJ">Mel Vecerik</a>
    ·
    <a href="https://scholar.google.com/citations?user=cnbENAEAAAAJ">Dilara Gokay</a>
    ·
    <a href="https://www.robots.ox.ac.uk/~ankush/">Ankush Gupta</a>
    ·
    <a href="http://people.csail.mit.edu/yusuf/">Yusuf Aytar</a>
    ·
    <a href="https://scholar.google.co.uk/citations?user=IUZ-7_cAAAAJ">Joao Carreira</a>
    ·
    <a href="https://www.robots.ox.ac.uk/~az/">Andrew Zisserman</a>
  </p>
  <h3 align="center"><a href="https://arxiv.org/abs/2306.08637">Paper</a> | <a href="https://deepmind-tapir.github.io">Project Page</a> | <a href="https://github.com/deepmind/tapnet">GitHub</a> | <a href="https://github.com/deepmind/tapnet/tree/main#running-tapir-locally">Live Demo</a> </h3>
  <div align="center"></div>
</p>

<p align="center">
  <img src="https://storage.googleapis.com/dm-tapnet/horsejump_rainbow.gif" width="70%"/><br/><br/>
</p>
<p>
  This visualization uses TAPIR to show how an object moves through space, even if the camera is tracking the object.  It begins by tracking points densely on a grid.  Then it estimates the camera motion as a homography (i.e., assuming either planar background or camera that rotates but does not move).  Any points that move according to that homography are removed.  Then we generate a &ldquo;rainbow&rdquo; visualization, where the tracked points leave &ldquo;tails&rdquo; that follow the camera motion, so it looks like the earlier positions of points are frozen in space.  This visualization was inspired by a similar one from <a href="https://omnimotion.github.io/">OmniMotion</a>, although that one assumes ground-truth segmentations are available and models the camera as only 2D translation.
</p>
<p>
  Note that we consider this algorithm &ldquo;semi-automatic&rdquo; because you may need some tuning for pleasing results on arbitrary videos.  Tracking failures on the background may show up as foreground objects.  Results are sensitive to the outlier thresholds used in RANSAC and segmentation, and you may wish to discard short tracks.  You can sample in a different way (e.g. sampling points from multiple frames) and everything will work, but the <font face="Courier">plot_tracks_tails</font> function uses the input order of the points to choose colors, so you will have to sort the points appropriately.
</p>



In [None]:
# @title Download Code {form-width: "25%"}
!git clone https://github.com/deepmind/tapnet.git

In [None]:
# @title Install Dependencies {form-width: "25%"}
!pip install -r tapnet/requirements_inference.txt

In [None]:
MODEL_TYPE = 'bootstapir' # 'tapir' or 'bootstapir'

In [None]:
# @title Download Model {form-width: "25%"}

%mkdir tapnet/checkpoints

if MODEL_TYPE == 'tapir':
  !wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/tapir_checkpoint_panning.npy
else:
  !wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/bootstap/bootstapir_checkpoint_v2.npy

%ls tapnet/checkpoints


In [None]:
# @title Imports {form-width: "25%"}

import jax
import jax.numpy as jnp
import haiku as hk
import mediapy as media
import numpy as np
import tree


In [None]:
from tapnet import tapir_model
from tapnet.utils import transforms
from tapnet.utils import viz_utils
from tapnet.utils import model_utils

# @title Load Checkpoint {form-width: "25%"}
if MODEL_TYPE == 'tapir':
  checkpoint_path = 'tapnet/checkpoints/tapir_checkpoint_panning.npy'
else:
  checkpoint_path = 'tapnet/checkpoints/bootstapir_checkpoint_v2.npy'
ckpt_state = np.load(checkpoint_path, allow_pickle=True).item()
params, state = ckpt_state['params'], ckpt_state['state']

kwargs = dict(bilinear_interp_with_depthwise_conv=False, pyramid_level=0)
if MODEL_TYPE == 'bootstapir':
  kwargs.update(dict(
    pyramid_level=1,
    extra_convs=True,
    softmax_temperature=10.0
  ))
tapir = tapir_model.ParameterizedTAPIR(params, state, tapir_kwargs=kwargs)

## Load and Build Model

In [None]:
# @title Utilities for model inference {form-width: "25%"}

def sample_grid_points(frame_idx, height, width, stride=1):
  """Sample grid points with (time height, width) order."""
  points = np.mgrid[stride//2:height:stride, stride//2:width:stride]
  points = points.transpose(1, 2, 0)
  out_height, out_width = points.shape[0:2]
  frame_idx = np.ones((out_height, out_width, 1)) * frame_idx
  points = np.concatenate((frame_idx, points), axis=-1).astype(np.int32)
  points = points.reshape(-1, 3)  # [out_height*out_width, 3]
  return points

In [None]:
# @title Load an Exemplar Video {form-width: "25%"}

%mkdir tapnet/examplar_videos

!wget -P tapnet/examplar_videos https://storage.googleapis.com/dm-tapnet/horsejump-high.mp4

orig_frames = media.read_video('tapnet/examplar_videos/horsejump-high.mp4')
height, width = orig_frames.shape[1:3]
media.show_video(orig_frames, fps=10)

In [None]:
# @title Inference function {form-width: "25%"}

resize_height = 256  # @param {type: "integer"}
resize_width = 256  # @param {type: "integer"}
stride = 8  # @param {type: "integer"}
query_frame = 0  # @param {type: "integer"}

frames = media.resize_video(orig_frames, (resize_height, resize_width))
frames = model_utils.preprocess_frames(frames[None])
feature_grids = tapir.get_feature_grids(frames, is_training=False)
chunk_size = 64
height, width = orig_frames.shape[1:3]

all_points = []


def chunk_inference(query_points):
  query_points = query_points.astype(np.float32)[None]

  outputs = tapir(
      video=frames,
      query_points=query_points,
      is_training=False,
      query_chunk_size=chunk_size,
      feature_grids=feature_grids,
  )
  tracks, occlusions, expected_dist = outputs["tracks"], outputs["occlusion"], outputs["expected_dist"]

  # Binarize occlusions
  visibles = model_utils.postprocess_occlusions(occlusions, expected_dist)
  return tracks[0], visibles[0]


chunk_inference = jax.jit(chunk_inference)

In [None]:
# @title Predict semi-dense point tracks {form-width: "25%"}
%%time


query_points = sample_grid_points(query_frame, resize_height, resize_width, stride)

tracks = []
visibles = []
for i in range(0,query_points.shape[0],chunk_size):
  query_points_chunk = query_points[i:i+chunk_size]
  num_extra = chunk_size - query_points_chunk.shape[0]
  if num_extra > 0:
    query_points_chunk = np.concatenate([query_points_chunk, np.zeros([num_extra, 3])], axis=0)
  tracks2, visibles2 = chunk_inference(query_points_chunk)
  if num_extra > 0:
    tracks2 = tracks2[:-num_extra]
    visibles2 = visibles2[:-num_extra]
  tracks.append(tracks2)
  visibles.append(visibles2)
tracks=jnp.concatenate(tracks, axis=0)
visibles=jnp.concatenate(visibles, axis=0)

tracks = transforms.convert_grid_coordinates(tracks, (resize_width, resize_height), (width, height))

# We show the point tracks without rainbows so you can see the input.
video = viz_utils.plot_tracks_v2(orig_frames, tracks, 1.0 - visibles)
media.show_video(video, fps=10)


In [None]:
# The inlier point threshold for ransac, specified in normalized coordinates
# (points are rescaled to the range [0, 1] for optimization).
ransac_inlier_threshold = 0.07  # @param {type: "number"}
# What fraction of points need to be inliers for RANSAC to consider a trajectory
# to be trustworthy for estimating the homography.
ransac_track_inlier_frac = 0.95  # @param {type: "number"}
# After initial RANSAC, how many refinement passes to adjust the homographies
# based on tracks that have been deemed trustworthy.
num_refinement_passes = 2  # @param {type: "number"}
# After homographies are estimated, consider points to be outliers if they are
# further than this threshold.
foreground_inlier_threshold = 0.07  # @param {type: "number"}
# After homographies are estimated, consider tracks to be part of the foreground
# if less than this fraction of its points are inliers.
foreground_frac = 0.6  # @param {type: "number"}


occluded = 1.0 - visibles
homogs, err, canonical = viz_utils.get_homographies_wrt_frame(
    tracks,
    occluded,
    [width, height],
    thresh=ransac_inlier_threshold,
    outlier_point_threshold=ransac_track_inlier_frac,
    num_refinement_passes=num_refinement_passes,
)

inliers = (err < np.square(foreground_inlier_threshold)) * visibles
inlier_ct = np.sum(inliers, axis=-1)
ratio = inlier_ct / np.maximum(1.0, np.sum(visibles, axis=1))
is_fg = ratio <= foreground_frac
video = viz_utils.plot_tracks_tails(
    orig_frames,
    tracks[is_fg],
    occluded[is_fg],
    homogs
)
media.show_video(video, fps=12)