Copyright 2020 DeepMind Technologies Limited.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

<p align="center">
  <h1 align="center">TAPIR: Tracking Any Point with per-frame Initialization and temporal Refinement</h1>
  <p align="center">
    <a href="http://www.carldoersch.com/">Carl Doersch</a>
    ·
    <a href="https://yangyi02.github.io/">Yi Yang</a>
    ·
    <a href="https://scholar.google.com/citations?user=Jvi_XPAAAAAJ">Mel Vecerik</a>
    ·
    <a href="https://scholar.google.com/citations?user=cnbENAEAAAAJ">Dilara Gokay</a>
    ·
    <a href="https://www.robots.ox.ac.uk/~ankush/">Ankush Gupta</a>
    ·
    <a href="http://people.csail.mit.edu/yusuf/">Yusuf Aytar</a>
    ·
    <a href="https://scholar.google.co.uk/citations?user=IUZ-7_cAAAAJ">Joao Carreira</a>
    ·
    <a href="https://www.robots.ox.ac.uk/~az/">Andrew Zisserman</a>
  </p>
  <h3 align="center"><a href="https://arxiv.org/abs/2306.08637">Paper</a> | <a href="https://deepmind-tapir.github.io">Project Page</a> | <a href="https://github.com/deepmind/tapnet">GitHub</a> | <a href="https://github.com/deepmind/tapnet/tree/main#running-tapir-locally">Live Demo</a> </h3>
  <div align="center"></div>
</p>

<p align="center">
  <a href="">
    <img src="https://storage.googleapis.com/dm-tapnet/swaying_gif.gif" alt="Logo" width="50%">
  </a>
</p>

In [None]:
# @title Install code and dependencies {form-width: "25%"}
!pip install 'tapnet[torch] @ git+https://github.com/google-deepmind/tapnet.git'

In [None]:
# @title Download Model {form-width: "25%"}

%mkdir tapnet/checkpoints

!wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/bootstap/bootstapir_checkpoint_v2.pt

%ls tapnet/checkpoints

In [None]:
# @title Imports {form-width: "25%"}
%matplotlib widget
import haiku as hk
import jax
import matplotlib.pyplot as plt
import mediapy as media
import numpy as np
import tree

import torch
import torch.nn.functional as F

from tapnet.torch import tapir_model
from tapnet.utils import transforms
from tapnet.utils import viz_utils

from google.colab import output
output.enable_custom_widget_manager()

In [None]:
# @title Select Device {form-width: "25%"}

if torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.device('cpu')

In [None]:
# @title Utility Functions {form-width: "25%"}

def preprocess_frames(frames):
  """Preprocess frames to model inputs.

  Args:
    frames: [num_frames, height, width, 3], [0, 255], np.uint8

  Returns:
    frames: [num_frames, height, width, 3], [-1, 1], np.float32
  """
  frames = frames.float()
  frames = frames / 255 * 2 - 1
  return frames


def sample_random_points(frame_max_idx, height, width, num_points): # 随机采样 这些点的坐标顺序是（时间、高度、宽度），
  # 其中时间t从 0 到frame_max_idx + 1随机采样，高度y和宽度x分别在给定的高度和宽度范围内随机采样。
  """Sample random points with (time, height, width) order."""
  y = np.random.randint(0, height, (num_points, 1)) # numpy.random.randint(low, high=None, size=None, dtype='l')
  x = np.random.randint(0, width, (num_points, 1))
  t = np.random.randint(0, frame_max_idx + 1, (num_points, 1))
  points = np.concatenate((t, y, x), axis=-1).astype(np.int32)  # [num_points, 3] 此时采样点的数据均为int
  return points


def postprocess_occlusions(occlusions, expected_dist):
  visibles = (1 - F.sigmoid(occlusions)) * (1 - F.sigmoid(expected_dist)) > 0.5
  return visibles


def inference(frames, query_points, model):
  # Preprocess video to match model inputs format
  frames = preprocess_frames(frames)
  num_frames, height, width = frames.shape[0:3]
  query_points = query_points.float()
  print(query_points[15:20])
  frames, query_points = frames[None], query_points[None]

  # Model inference
  outputs = model(frames, query_points)
  tracks, occlusions, expected_dist = outputs['tracks'][0], outputs['occlusion'][0], outputs['expected_dist'][0]

  print(f"tracks.shape:{outputs['tracks'].shape}\n{outputs['tracks'][0]}\n occlusion.shape:{outputs['occlusion'].shape}{outputs['occlusion'][0]}\n expected_dist.shape:{outputs['expected_dist'].shape}{outputs['expected_dist'][0]}")


  # Binarize occlusions
  visibles = postprocess_occlusions(occlusions, expected_dist)
  return tracks, visibles

In [None]:
# @title Load an Exemplar Video {form-width: "25%"}

%mkdir tapnet/examplar_videos

!wget -P tapnet/examplar_videos https://storage.googleapis.com/dm-tapnet/horsejump-high.mp4

video = media.read_video('tapnet/examplar_videos/horsejump-high.mp4')#[frames,height,weight,3]
height, width = video.shape[1:3]
media.show_video(video, fps=10)

In [None]:
# @title Define Model {form-width: "25%"}

model = tapir_model.TAPIR(pyramid_level=1)
model.load_state_dict(torch.load('tapnet/checkpoints/bootstapir_checkpoint_v2.pt'))
model = model.to(device)

In [None]:
# @title Set to Inference Mode to Save Memory {form-width: "25%"}

model = model.eval()
torch.set_grad_enabled(False)

In [None]:
# @title Predict Sparse Point Tracks {form-width: "25%"}

resize_height = 256  # @param {type: "integer"}
resize_width = 256  # @param {type: "integer"}
num_points = 50  # @param {type: "integer"}

frames = media.resize_video(video, (resize_height, resize_width))
query_points = sample_random_points(0, frames.shape[1], frames.shape[2], num_points)
frames = torch.tensor(frames).to(device)
query_points = torch.tensor(query_points).to(device)

tracks, visibles = inference(frames, query_points, model)

tracks = tracks.cpu().detach().numpy()
visibles = visibles.cpu().detach().numpy()
# Visualize sparse point tracks
tracks = transforms.convert_grid_coordinates(tracks, (resize_width, resize_height), (width, height)) # 转回原始resize前的尺寸空间
video_viz = viz_utils.paint_point_track(video, tracks, visibles)# 在视频上绘制跟踪点
media.show_video(video_viz, fps=10)

In [None]:
# @title Select Any Points at Any Frame {form-width: "25%"}

select_frame = 7  #@param {type:"slider", min:0, max:49, step:1}

# Generate a colormap with 20 points, no need to change unless select more than 20 points
colormap = viz_utils.get_colors(20)

fig, ax = plt.subplots(figsize=(10, 5))
ax.imshow(video[select_frame])
ax.axis('off')
ax.set_title('You can select more than 1 points. After select enough points, run the next cell.')

select_points = []

 # Event handler for mouse clicks
def on_click(event):
  if event.button == 1 and event.inaxes == ax:  # Left mouse button clicked
    x, y = int(np.round(event.xdata)), int(np.round(event.ydata))

    select_points.append(np.array([x, y]))

    color = colormap[len(select_points) - 1]
    color = tuple(np.array(color) / 255.0)
    ax.plot(x, y, 'o', color=color, markersize=5)
    plt.draw()

fig.canvas.mpl_connect('button_press_event', on_click)
plt.show()

In [None]:
# @title Predict Point Tracks for the Selected Points {form-width: "25%"}

resize_height = 256  # @param {type: "integer"}
resize_width = 256  # @param {type: "integer"}

def convert_select_points_to_query_points(frame, points):
  """Convert select points to query points.

  Args:
    points: [num_points, 2], in [x, y]
  Returns:
    query_points: [num_points, 3], in [t, y, x]
  """
  points = np.stack(points)
  query_points = np.zeros(shape=(points.shape[0], 3), dtype=np.float32)
  query_points[:, 0] = frame
  query_points[:, 1] = points[:, 1]
  query_points[:, 2] = points[:, 0]
  return query_points

frames = media.resize_video(video, (resize_height, resize_width))
query_points = convert_select_points_to_query_points(select_frame, select_points)
query_points = transforms.convert_grid_coordinates(
    query_points, (1, height, width), (1, resize_height, resize_width), coordinate_format='tyx')

frames = torch.tensor(frames).to(device)
query_points = torch.tensor(query_points).to(device)
tracks, visibles = inference(frames, query_points, model)

tracks = tracks.detach().cpu().numpy()
visibles = visibles.detach().cpu().numpy()

# Visualize sparse point tracks
tracks = transforms.convert_grid_coordinates(tracks, (resize_width, resize_height), (width, height))
video_viz = viz_utils.paint_point_track(video, tracks, visibles, colormap)
media.show_video(video_viz, fps=10)

That's it!