<a href="https://colab.research.google.com/github/PsorTheDoctor/computer-vision/blob/master/object_tracking/tapir.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#TAPIR

In [None]:
!git clone https://github.com/deepmind/tapnet.git
!pip install -r tapnet/requirements_inference.txt

In [None]:
%mkdir tapnet/checkpoints
!wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/tapir_checkpoint.npy
%ls tapnet/checkpoints

In [3]:
import haiku as hk
import jax
import mediapy as media
import numpy as np
import tree

from tapnet import tapir_model
from tapnet.utils import transforms
from tapnet.utils import viz_utils

##Checkpoints

In [4]:
checkpoint_path = 'tapnet/checkpoints/tapir_checkpoint.npy'
ckpt_state = np.load(checkpoint_path, allow_pickle=True).item()
params, state = ckpt_state['params'], ckpt_state['state']

##Model

In [5]:
def build_model(frames, query_points):
  """Compute point tracks and occlusions given frames and query points."""
  model = tapir_model.TAPIR()
  outputs = model(
      video=frames,
      is_training=False,
      query_points=query_points,
      query_chunk_size=64,
  )
  return outputs

model = hk.transform_with_state(build_model)
model_apply = jax.jit(model.apply)

##Utils

In [6]:
def preprocess_frames(frames):
  frames = frames.astype(np.float32)
  frames = frames / 255 * 2 - 1
  return frames

def postprocess_occlusions(occlusions, expected_dist):
  # visibles = occlusions < 0
  visibles = (1 - jax.nn.sigmoid(occlusions)) * (1 - jax.nn.sigmoid(expected_dist)) > 0.5
  return visibles

def inference(frames, query_points):
  # Preprocess video to match model inputs format
  frames = preprocess_frames(frames)
  num_frames, height, width = frames.shape[0:3]
  query_points = query_points.astype(np.float32)
  frames, query_points = frames[None], query_points[None]  # Add batch dimension

  # Model inference
  rng = jax.random.PRNGKey(42)
  outputs, _ = model_apply(params, state, rng, frames, query_points)
  outputs = tree.map_structure(lambda x: np.array(x[0]), outputs)
  tracks, occlusions, expected_dist = outputs['tracks'], outputs['occlusion'], outputs['expected_dist']

  # Binarize occlusions
  visibles = postprocess_occlusions(occlusions, expected_dist)
  return tracks, visibles

def sample_random_points(frame_max_idx, height, width, num_points):
  """Sample random points with (time, height, width) order."""
  y = np.random.randint(0, height, (num_points, 1))
  x = np.random.randint(0, width, (num_points, 1))
  t = np.random.randint(0, frame_max_idx + 1, (num_points, 1))
  points = np.concatenate((t, y, x), axis=-1).astype(np.int32)  # [num_points, 3]
  return points

##Example video

In [None]:
%mkdir tapnet/examplar_videos
!wget -P tapnet/examplar_videos https://storage.googleapis.com/dm-tapnet/horsejump-high.mp4

In [9]:
video = media.read_video('tapnet/examplar_videos/horsejump-high.mp4')
media.show_video(video, fps=10)

0
This browser does not support the video tag.


##Predict sparse points tracks

In [11]:
resize_height = 256
resize_width = 256
num_points = 100

height, width = video.shape[1:3]
frames = media.resize_video(video, (resize_height, resize_width))
query_points = sample_random_points(0, frames.shape[1], frames.shape[2], num_points)
tracks, visibles = inference(frames, query_points)

# Visualize sparse point tracks
tracks = transforms.convert_grid_coordinates(tracks, (resize_width, resize_height), (width, height))
video_viz = viz_utils.paint_point_track(video, tracks, visibles)
media.show_video(video_viz, fps=10)

0
This browser does not support the video tag.
