### Tapir management

In [11]:
import functools

import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import mediapy as media
import numpy as np
from tqdm import tqdm
import tree
import pandas as pd
import cv2
import os
import csv
import pickle
import torch

from tapnet import tapir_model
from tapnet.utils import transforms
from tapnet.utils import viz_utils

# @title Load Checkpoint {form-width: "25%"}

checkpoint_path = 'tapnet/checkpoints/causal_tapir_checkpoint.npy'
ckpt_state = np.load(checkpoint_path, allow_pickle=True).item()
params, state = ckpt_state['params'], ckpt_state['state']

# @title Build Model {form-width: "25%"}

# Internally, the tapir model has three stages of processing: computing
# image features (get_feature_grids), extracting features for each query point
# (get_query_features), and estimating trajectories given query features and
# the feature grids where we want to track (estimate_trajectories).  For
# tracking online, we need extract query features on the first frame only, and
# then call estimate_trajectories on one frame at a time.

def build_online_model_init(frames, query_points):
  """Initialize query features for the query points."""
  model = tapir_model.TAPIR(use_causal_conv=True, bilinear_interp_with_depthwise_conv=False)

  feature_grids = model.get_feature_grids(frames, is_training=False)
  query_features = model.get_query_features(
      frames,
      is_training=False,
      query_points=query_points,
      feature_grids=feature_grids,
  )
  return query_features


def build_online_model_predict(frames, query_features, causal_context):
  """Compute point tracks and occlusions given frames and query points."""
  model = tapir_model.TAPIR(use_causal_conv=True, bilinear_interp_with_depthwise_conv=False)
  feature_grids = model.get_feature_grids(frames, is_training=False)
  trajectories = model.estimate_trajectories(
      frames.shape[-3:-1],
      is_training=False,
      feature_grids=feature_grids,
      query_features=query_features,
      query_points_in_video=None,
      query_chunk_size=64,
      causal_context=causal_context,
      get_causal_context=True,
  )
  causal_context = trajectories['causal_context']
  del trajectories['causal_context']
  return {k: v[-1] for k, v in trajectories.items()}, causal_context


online_init = hk.transform_with_state(build_online_model_init)
online_init_apply = jax.jit(online_init.apply)

online_predict = hk.transform_with_state(build_online_model_predict)
online_predict_apply = jax.jit(online_predict.apply)

rng = jax.random.PRNGKey(42)
online_init_apply = functools.partial(
    online_init_apply, params=params, state=state, rng=rng
)
online_predict_apply = functools.partial(
    online_predict_apply, params=params, state=state, rng=rng
)

# @title Utility Functions {form-width: "25%"}

def preprocess_frames(frames):
  """Preprocess frames to model inputs.

  Args:
    frames: [num_frames, height, width, 3], [0, 255], np.uint8

  Returns:
    frames: [num_frames, height, width, 3], [-1, 1], np.float32
  """
  frames = frames.astype(np.float32)
  frames = frames / 255 * 2 - 1
  return frames


def postprocess_occlusions(occlusions, expected_dist):
  """Postprocess occlusions to boolean visible flag.

  Args:
    occlusions: [num_points, num_frames], [-inf, inf], np.float32

  Returns:
    visibles: [num_points, num_frames], bool
  """
  pred_occ = jax.nn.sigmoid(occlusions)
  pred_occ = 1 - (1 - pred_occ) * (1 - jax.nn.sigmoid(expected_dist))
  visibles = pred_occ < 0.5  # threshold
  return visibles


def sample_random_points(frame_max_idx, height, width, num_points):
  """Sample random points with (time, height, width) order."""
  y = np.random.randint(0, height, (num_points, 1))
  x = np.random.randint(0, width, (num_points, 1))
  t = np.random.randint(0, frame_max_idx + 1, (num_points, 1))
  points = np.concatenate((t, y, x), axis=-1).astype(np.int32)  # [num_points, 3]
  return points


def construct_initial_causal_state(num_points, num_resolutions):
  value_shapes = {
      "tapir/~/pips_mlp_mixer/block_1_causal_1": (1, num_points, 2, 512),
      "tapir/~/pips_mlp_mixer/block_1_causal_2": (1, num_points, 2, 2048),
      "tapir/~/pips_mlp_mixer/block_2_causal_1": (1, num_points, 2, 512),
      "tapir/~/pips_mlp_mixer/block_2_causal_2": (1, num_points, 2, 2048),
      "tapir/~/pips_mlp_mixer/block_3_causal_1": (1, num_points, 2, 512),
      "tapir/~/pips_mlp_mixer/block_3_causal_2": (1, num_points, 2, 2048),
      "tapir/~/pips_mlp_mixer/block_4_causal_1": (1, num_points, 2, 512),
      "tapir/~/pips_mlp_mixer/block_4_causal_2": (1, num_points, 2, 2048),
      "tapir/~/pips_mlp_mixer/block_5_causal_1": (1, num_points, 2, 512),
      "tapir/~/pips_mlp_mixer/block_5_causal_2": (1, num_points, 2, 2048),
      "tapir/~/pips_mlp_mixer/block_6_causal_1": (1, num_points, 2, 512),
      "tapir/~/pips_mlp_mixer/block_6_causal_2": (1, num_points, 2, 2048),
      "tapir/~/pips_mlp_mixer/block_7_causal_1": (1, num_points, 2, 512),
      "tapir/~/pips_mlp_mixer/block_7_causal_2": (1, num_points, 2, 2048),
      "tapir/~/pips_mlp_mixer/block_8_causal_1": (1, num_points, 2, 512),
      "tapir/~/pips_mlp_mixer/block_8_causal_2": (1, num_points, 2, 2048),
      "tapir/~/pips_mlp_mixer/block_9_causal_1": (1, num_points, 2, 512),
      "tapir/~/pips_mlp_mixer/block_9_causal_2": (1, num_points, 2, 2048),
      "tapir/~/pips_mlp_mixer/block_10_causal_1": (1, num_points, 2, 512),
      "tapir/~/pips_mlp_mixer/block_10_causal_2": (1, num_points, 2, 2048),
      "tapir/~/pips_mlp_mixer/block_11_causal_1": (1, num_points, 2, 512),
      "tapir/~/pips_mlp_mixer/block_11_causal_2": (1, num_points, 2, 2048),
      "tapir/~/pips_mlp_mixer/block_causal_1": (1, num_points, 2, 512),
      "tapir/~/pips_mlp_mixer/block_causal_2": (1, num_points, 2, 2048),
  }
  fake_ret = {
      k: jnp.zeros(v, dtype=jnp.float32) for k, v in value_shapes.items()
  }
  return [fake_ret] * num_resolutions * 4

In [12]:
pkl_dir = '../../../tapvid_davis/tapvid_davis/tapvid_davis.pkl'
with open(pkl_dir, 'rb') as f:
    loaded_data = pickle.load(f)

video_base_path = "../../Train_data/TAPvid/"

for title, data in loaded_data.items():
    if title == "goat":
        num_points, num_frames, _ = data['points'].shape
        height, width = 256, 256
        video_file = title +".mp4"
            
        full_video_path = os.path.join(video_base_path, video_file)

        # Create initial queries with frame number 0 for each point
        initial_queries = np.zeros((num_points, 2))
        for i in range(num_points):
            # initial_queries[i, 0] = 0  # Frame number
            initial_queries[i, 0] = data['points'][i, 0, 0] * width  # Adjust x-coordinate
            initial_queries[i, 1] = data['points'][i, 0, 1] * height  # Adjust y-coordinate

        select_points = initial_queries.copy()
        # Read video
        video = media.read_video(full_video_path)

        # Get points
        
        resize_height = 256  # @param {type: "integer"}
        resize_width = 256  # @param {type: "integer"}

        frames = media.resize_video(video, (resize_height, resize_width))
        query_points = sample_random_points(0, frames.shape[1], frames.shape[2], num_points)

        query_features, _ = online_init_apply(frames=preprocess_frames(frames[None, None, 0]), query_points=query_points[None])
        causal_state = construct_initial_causal_state(query_points.shape[0], len(query_features.resolutions) - 1)

        # Predict point tracks frame by frame
        predictions = []
        for i in range(frames.shape[0]):
            (prediction, causal_state), _ = online_predict_apply(
                frames=preprocess_frames(frames[None, None, i]),
                query_features=query_features,
                causal_context=causal_state,
            )
            predictions.append(prediction)

        tracks = np.concatenate([x['tracks'][0] for x in predictions], axis=1)
        occlusions = np.concatenate([x['occlusion'][0] for x in predictions], axis=1)
        expected_dist = np.concatenate([x['expected_dist'][0] for x in predictions], axis=1)

        visibles = postprocess_occlusions(occlusions, expected_dist)

        # Visualize sparse point tracks
        tracks = transforms.convert_grid_coordinates(tracks, (resize_width, resize_height), (width, height))
        video_viz = viz_utils.paint_point_track(video, tracks, visibles)
        # media.show_video(video_viz, fps=10)


        # FINAL PART
        colormap = viz_utils.get_colors(30)

        # @title Predict Point Tracks for the Selected Points {form-width: "25%"}

        resize_height = 256  # @param {type: "integer"}
        resize_width = 256  # @param {type: "integer"}

        def convert_select_points_to_query_points(frame, points):
            """Convert select points to query points.

            Args:
                points: [num_points, 2], [t, y, x]
            Returns:
                query_points: [num_points, 3], [t, y, x]
            """
            points = np.stack(points)
            query_points = np.zeros(shape=(points.shape[0], 3), dtype=np.float32)
            query_points[:, 0] = frame
            query_points[:, 1] = points[:, 1]
            query_points[:, 2] = points[:, 0]
            return query_points

        frames = media.resize_video(video, (resize_height, resize_width))
        query_points = convert_select_points_to_query_points(0, select_points)
        query_points = transforms.convert_grid_coordinates(
            query_points, (1, height, width), (1, resize_height, resize_width), coordinate_format='tyx')

        query_features, _ = online_init_apply(frames=preprocess_frames(frames[None, None, 0]), query_points=query_points[None])
        causal_state = construct_initial_causal_state(query_points.shape[0], len(query_features.resolutions) - 1)

        # Predict point tracks frame by frame
        predictions = []
        for i in tqdm(range(frames.shape[0])):
            (prediction, causal_state), _ = online_predict_apply(
                frames=preprocess_frames(frames[None, None, i]),
                query_features=query_features,
                causal_context=causal_state,
            )
            predictions.append(prediction)

        tracks = np.concatenate([x['tracks'][0] for x in predictions], axis=1)
        occlusions = np.concatenate([x['occlusion'][0] for x in predictions], axis=1)
        expected_dist = np.concatenate([x['expected_dist'][0] for x in predictions], axis=1)

        visibles = postprocess_occlusions(occlusions, expected_dist)

        # Visualize sparse point tracks
        tracks = transforms.convert_grid_coordinates(tracks, (resize_width, resize_height), (width, height))
        video_viz = viz_utils.paint_point_track(video, tracks, visibles, colormap)
        media.show_video(video_viz, fps=10)
    else:
        pass


    

100%|██████████| 90/90 [01:10<00:00,  1.27it/s]


0
This browser does not support the video tag.


### Get GT data

In [13]:
from tapnet.evaluation_datasets_copy import create_davis_dataset
from tapnet.evaluation_datasets_copy import compute_tapvid_metrics


dav = create_davis_dataset('../../../tapvid_davis/tapvid_davis/tapvid_davis.pkl')
first = next(dav)

qp = first["davis"]["query_points"]
gt_occluded = first["davis"]["occluded"]
gt_tracks = first["davis"]["target_points"]

print(f"query points: ", qp.shape)
print(f"gt occluded: ", gt_occluded.shape)
print(f"gt track: ", gt_tracks.shape)

query points:  (1, 90, 3)
gt occluded:  (1, 90, 90)
gt track:  (1, 90, 90, 2)


### Predictions

In [14]:
pred_occluded = visibles
pred_tracks = tracks

print(f"pred occluded: ", pred_occluded.shape)
print(f"pred_tracks: ", pred_tracks.shape)

pred occluded:  (5, 90)
pred_tracks:  (5, 90, 2)


### When I load the pickle file directly

In [15]:
import pickle

pkl_dir = '../../../tapvid_davis/tapvid_davis/tapvid_davis.pkl'
with open(pkl_dir, 'rb') as f:
    loaded_data = pickle.load(f)

loaded_tracks = loaded_data["goat"]["points"]
loaded_occluded = loaded_data["goat"]["occluded"]
print("pickle occluded: ", loaded_occluded.shape)
print("pickle tracks: ", loaded_tracks.shape)

pickle occluded:  (5, 90)
pickle tracks:  (5, 90, 2)


### Running compute davis 
On the generate_davis the shape difference produce an error
When running on the pickle file, this error occurs
Accidently printed some two rows for trouble shooting

In [16]:
compute_tapvid_metrics(qp, loaded_occluded, loaded_tracks, pred_occluded, pred_tracks, query_mode="first")


q frames 2 eval  [[0 1]
 [0 0]]
q frame  [[ 0  0  0  0  0  5  5  5  5  5 10 10 10 10 10 15 15 15 15 15 20 20 20 20
  20 25 25 25 25 25 30 30 30 30 30 35 35 35 35 35 40 40 40 40 40 45 45 45
  45 45 50 50 50 50 50 55 55 55 55 55 60 60 60 60 60 65 65 65 65 65 70 70
  70 70 70 75 75 75 75 75 80 80 80 80 80 85 85 85 85 85]]


IndexError: index 5 is out of bounds for axis 0 with size 2