Copyright 2020 DeepMind Technologies Limited.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

<p align="center">
  <h1 align="center">TAPIR: Tracking Any Point with per-frame Initialization and temporal Refinement</h1>
  <p align="center">
    <a href="http://www.carldoersch.com/">Carl Doersch</a>
    ·
    <a href="https://yangyi02.github.io/">Yi Yang</a>
    ·
    <a href="https://scholar.google.com/citations?user=Jvi_XPAAAAAJ">Mel Vecerik</a>
    ·
    <a href="https://scholar.google.com/citations?user=cnbENAEAAAAJ">Dilara Gokay</a>
    ·
    <a href="https://www.robots.ox.ac.uk/~ankush/">Ankush Gupta</a>
    ·
    <a href="http://people.csail.mit.edu/yusuf/">Yusuf Aytar</a>
    ·
    <a href="https://scholar.google.co.uk/citations?user=IUZ-7_cAAAAJ">Joao Carreira</a>
    ·
    <a href="https://www.robots.ox.ac.uk/~az/">Andrew Zisserman</a>
  </p>
  <h3 align="center"><a href="https://arxiv.org/abs/2306.08637">Paper</a> | <a href="https://deepmind-tapir.github.io">Project Page</a> | <a href="https://github.com/deepmind/tapnet">GitHub</a> | <a href="https://github.com/deepmind/tapnet/tree/main#running-tapir-locally">Live Demo</a> </h3>
  <div align="center"></div>
</p>

<p align="center">
  <a href="">
    <img src="https://storage.googleapis.com/dm-tapnet/swaying_gif.gif" alt="Logo" width="50%">
  </a>
</p>



In [None]:
# @title Install Code and Dependencies {form-width: "25%"}
!pip install git+https://github.com/Shrinkhal01/tapnet-testing.git

In [None]:
MODEL_TYPE = 'bootstapir'  # 'tapir' or 'bootstapir'

In [None]:
# @title Download Model {form-width: "25%"}

%mkdir tapnet/checkpoints

if MODEL_TYPE == "tapir":
  !wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/tapir_checkpoint_panning.npy
else:
  !wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/bootstap/bootstapir_checkpoint_v2.npy

%ls tapnet/checkpoints

In [None]:
# @title Imports {form-width: "25%"}
%matplotlib widget
from google.colab import output
import jax
import matplotlib
import matplotlib.pyplot as plt
import mediapy as media
import numpy as np
from tapnet.models import tapir_model
from tapnet.utils import model_utils
from tapnet.utils import transforms
from tapnet.utils import viz_utils

output.enable_custom_widget_manager()

In [None]:
# @title Load Checkpoint {form-width: "25%"}

if MODEL_TYPE == 'tapir':
  checkpoint_path = 'tapnet/checkpoints/tapir_checkpoint_panning.npy'
else:
  checkpoint_path = 'tapnet/checkpoints/bootstapir_checkpoint_v2.npy'
ckpt_state = np.load(checkpoint_path, allow_pickle=True).item()
params, state = ckpt_state['params'], ckpt_state['state']

kwargs = dict(bilinear_interp_with_depthwise_conv=False, pyramid_level=0)
if MODEL_TYPE == 'bootstapir':
  kwargs.update(
      dict(pyramid_level=1, extra_convs=True, softmax_temperature=10.0)
  )
tapir = tapir_model.ParameterizedTAPIR(params, state, tapir_kwargs=kwargs)

In [None]:
# @title Load an Exemplar Video {form-width: "25%"}

%mkdir tapnet/examplar_videos

!wget -P tapnet/examplar_videos http://storage.googleapis.com/dm-tapnet/horsejump-high.mp4

video = media.read_video("tapnet/examplar_videos/horsejump-high.mp4")
media.show_video(video, fps=10)

In [None]:
# @title Utility Functions {form-width: "25%"}


def inference(frames, query_points):
  """Inference on one video.

  Args:
    frames: [num_frames, height, width, 3], [0, 255], np.uint8
    query_points: [num_points, 3], [0, num_frames/height/width], [t, y, x]

  Returns:
    tracks: [num_points, 3], [-1, 1], [t, y, x]
    visibles: [num_points, num_frames], bool
  """
  # Preprocess video to match model inputs format
  frames = model_utils.preprocess_frames(frames)
  query_points = query_points.astype(np.float32)
  frames, query_points = frames[None], query_points[None]  # Add batch dimension

  outputs = tapir(
      video=frames,
      query_points=query_points,
      is_training=False,
      query_chunk_size=32,
  )
  tracks, occlusions, expected_dist = (
      outputs['tracks'],
      outputs['occlusion'],
      outputs['expected_dist'],
  )

  # Binarize occlusions
  visibles = model_utils.postprocess_occlusions(occlusions, expected_dist)
  return tracks[0], visibles[0]


inference = jax.jit(inference)


def sample_random_points(frame_max_idx, height, width, num_points):
  """Sample random points with (time, height, width) order."""
  y = np.random.randint(0, height, (num_points, 1))
  x = np.random.randint(0, width, (num_points, 1))
  t = np.random.randint(0, frame_max_idx + 1, (num_points, 1))
  points = np.concatenate((t, y, x), axis=-1).astype(
      np.int32
  )  # [num_points, 3]
  return points

In [None]:
# @title Predict Sparse Point Tracks {form-width: "25%"}

resize_height = 256  # @param {type: "integer"}
resize_width = 256  # @param {type: "integer"}
num_points = 20  # @param {type: "integer"}

frames = media.resize_video(video, (resize_height, resize_width))
query_points = sample_random_points(
    0, frames.shape[1], frames.shape[2], num_points
)
tracks, visibles = inference(frames, query_points)
tracks = np.array(tracks)
visibles = np.array(visibles)

# Visualize sparse point tracks
height, width = video.shape[1:3]
tracks = transforms.convert_grid_coordinates(
    tracks, (resize_width, resize_height), (width, height)
)
video_viz = viz_utils.paint_point_track(video, tracks, visibles)
media.show_video(video_viz, fps=10)

In [None]:
# @title Download TAPVid-DAVSIS Dataset {form-width: "25%"}
!wget https://storage.googleapis.com/dm-tapnet/tapvid_davis.zip
!unzip tapvid_davis.zip

In [None]:
# @title DAVIS First Eval on 256x256 Resolution {form-width: "25%"}
%%time

from tapnet import evaluation_datasets

davis_dataset = evaluation_datasets.create_davis_dataset(
    'tapvid_davis/tapvid_davis.pkl', query_mode='first'
)

summed_scalars = None
for sample_idx, sample in enumerate(davis_dataset):
  sample = sample['davis']
  frames = np.round((sample['video'][0] + 1) / 2 * 255).astype(np.uint8)
  query_points = sample['query_points'][0]

  tracks, visibles = inference(frames, query_points)
  tracks = np.array(tracks)
  visibles = np.array(visibles)
  occluded = ~visibles

  query_points = sample['query_points'][0]

  scalars = evaluation_datasets.compute_tapvid_metrics(
      query_points[None],
      sample['occluded'],
      sample['target_points'],
      occluded[None],
      tracks[None],
      query_mode='first',
  )
  scalars = jax.tree.map(lambda x: np.array(np.sum(x, axis=0)), scalars)
  print(sample_idx, scalars)

  if summed_scalars is None:
    summed_scalars = scalars
  else:
    summed_scalars = jax.tree.map(np.add, summed_scalars, scalars)

  num_samples = sample_idx + 1
  mean_scalars = jax.tree.map(lambda x: x / num_samples, summed_scalars)
  print(mean_scalars)

In [None]:
# @title Efficient Chunked Point Track Prediction {form-width: "25%"}

resize_height = 256  # @param {type: "integer"}
resize_width = 256  # @param {type: "integer"}
num_points = 4096  # @param {type: "integer"}

frames = media.resize_video(video, (resize_height, resize_width))
frames = model_utils.preprocess_frames(frames[None])
feature_grids = tapir.get_feature_grids(frames, is_training=False)
query_points = sample_random_points(
    0, frames.shape[2], frames.shape[3], num_points
)
chunk_size = 32


def chunk_inference(query_points):
  query_points = query_points.astype(np.float32)[None]

  outputs = tapir(
      video=frames,
      query_points=query_points,
      is_training=False,
      query_chunk_size=chunk_size,
      feature_grids=feature_grids,
  )
  tracks, occlusions, expected_dist = (
      outputs["tracks"],
      outputs["occlusion"],
      outputs["expected_dist"],
  )

  # Binarize occlusions
  visibles = model_utils.postprocess_occlusions(occlusions, expected_dist)
  return tracks[0], visibles[0]


chunk_inference = jax.jit(chunk_inference)

all_tracks = []
all_visibles = []
for chunk in range(0, query_points.shape[0], chunk_size):
  tracks, visibles = chunk_inference(query_points[chunk : chunk + chunk_size])
  all_tracks.append(np.array(tracks))
  all_visibles.append(np.array(visibles))

tracks = np.concatenate(all_tracks, axis=0)
visibles = np.concatenate(all_visibles, axis=0)

# Visualize sparse point tracks
height, width = video.shape[1:3]
tracks = transforms.convert_grid_coordinates(
    tracks, (resize_width, resize_height), (width, height)
)
video_viz = viz_utils.paint_point_track(video, tracks, visibles)
media.show_video(video_viz, fps=10)

In [None]:
# @title Select Any Points at Any Frame {form-width: "25%"}

select_frame = 0  # @param {type:"slider", min:0, max:49, step:1}

# Generate a colormap with 20 points, no need to change unless select more than 20 points
colormap = viz_utils.get_colors(20)

fig, ax = plt.subplots(figsize=(10, 5))
ax.imshow(video[select_frame])
ax.axis('off')
ax.set_title(
    'You can select more than 1 point. After selecting enough points, run the'
    ' next cell.'
)

select_points = []


# Event handler for mouse clicks
def on_click(event):
  if event.button == 1 and event.inaxes == ax:  # Left mouse button clicked
    x, y = int(np.round(event.xdata)), int(np.round(event.ydata))

    select_points.append(np.array([x, y]))

    color = colormap[len(select_points) - 1]
    color = tuple(np.array(color) / 255.0)
    ax.plot(x, y, 'o', color=color, markersize=5)
    plt.draw()


fig.canvas.mpl_connect('button_press_event', on_click)
plt.show()

In [None]:
# @title Predict Point Tracks for the Selected Points {form-width: "25%"}

resize_height = 256  # @param {type: "integer"}
resize_width = 256  # @param {type: "integer"}


def convert_select_points_to_query_points(frame, points):
  """Convert select points to query points.

  Args:
    points: [num_points, 2], in [x, y]

  Returns:
    query_points: [num_points, 3], in [t, y, x]
  """
  points = np.stack(points)
  query_points = np.zeros(shape=(points.shape[0], 3), dtype=np.float32)
  query_points[:, 0] = frame
  query_points[:, 1] = points[:, 1]
  query_points[:, 2] = points[:, 0]
  return query_points


frames = media.resize_video(video, (resize_height, resize_width))
query_points = convert_select_points_to_query_points(
    select_frame, select_points
)
height, width = video.shape[1:3]
query_points = transforms.convert_grid_coordinates(
    query_points,
    (1, height, width),
    (1, resize_height, resize_width),
    coordinate_format='tyx',
)
tracks, visibles = inference(frames, query_points)

# Visualize sparse point tracks
tracks = transforms.convert_grid_coordinates(
    tracks, (resize_width, resize_height), (width, height)
)
video_viz = viz_utils.paint_point_track(video, tracks, visibles, colormap)
media.show_video(video_viz, fps=10)

That's it!