In [10]:
from __future__ import annotations

import pickle
from pathlib import Path

"/mnt/wd_sn770_oss/imc2024_code/new_sakuramoti/sakuramoti/data/tapvid_rgb_stacking/tapvid_rgb_stacking.pkl"


def load_tapvid_benchmark(dataset_dir: str | Path) -> list[dict[str, any]] | dict[str, any]:
    if not isinstance(dataset_dir, Path):
        dataset_dir = Path(dataset_dir)
    gt_dataset_path = [x for x in dataset_dir.glob("*.pkl")][0]
    with open(str(gt_dataset_path), "rb") as f:
        dataset = pickle.load(f)
    return dataset

In [26]:
import random
import colorsys

import numpy as np


# Generate random colormaps for visualizing different points.
def get_colors(num_colors: int) -> list[tuple[int, int, int]]:
    """Gets colormap for points."""
    colors = []
    for i in np.arange(0.0, 360.0, 360.0 / num_colors):
        hue = i / 360.0
        lightness = (50 + np.random.rand() * 10) / 100.0
        saturation = (90 + np.random.rand() * 10) / 100.0
        color = colorsys.hls_to_rgb(hue, lightness, saturation)
        colors.append((int(color[0] * 255), int(color[1] * 255), int(color[2] * 255)))
    random.shuffle(colors)
    return colors


def paint_point_track(
    frames: np.ndarray,
    point_tracks: np.ndarray,
    visibles: np.ndarray,
    colormap: list[tuple[int, int, int]] | None = None,
) -> np.ndarray:
    """Converts a sequence of points to color code video.

    Args:
      frames: [num_frames, height, width, 3], np.uint8, [0, 255]
      point_tracks: [num_points, num_frames, 2], np.float32, [0, width / height]
      visibles: [num_points, num_frames], bool
      colormap: colormap for points, each point has a different RGB color.

    Returns:
      video: [num_frames, height, width, 3], np.uint8, [0, 255]
    """
    num_points, num_frames = point_tracks.shape[0:2]
    if colormap is None:
        colormap = get_colors(num_colors=num_points)
    height, width = frames.shape[1:3]
    dot_size_as_fraction_of_min_edge = 0.015
    radius = int(round(min(height, width) * dot_size_as_fraction_of_min_edge))
    diam = radius * 2 + 1
    quadratic_y = np.square(np.arange(diam)[:, np.newaxis] - radius - 1)
    quadratic_x = np.square(np.arange(diam)[np.newaxis, :] - radius - 1)
    icon = (quadratic_y + quadratic_x) - (radius**2) / 2.0
    sharpness = 0.15
    icon = np.clip(icon / (radius * 2 * sharpness), 0, 1)
    icon = 1 - icon[:, :, np.newaxis]
    icon1 = np.pad(icon, [(0, 1), (0, 1), (0, 0)])
    icon2 = np.pad(icon, [(1, 0), (0, 1), (0, 0)])
    icon3 = np.pad(icon, [(0, 1), (1, 0), (0, 0)])
    icon4 = np.pad(icon, [(1, 0), (1, 0), (0, 0)])

    video = frames.copy()
    for t in range(num_frames):
        # Pad so that points that extend outside the image frame don't crash us
        image = np.pad(
            video[t],
            [
                (radius + 1, radius + 1),
                (radius + 1, radius + 1),
                (0, 0),
            ],
        )
        for i in range(num_points):
            # The icon is centered at the center of a pixel, but the input coordinates
            # are raster coordinates.  Therefore, to render a point at (1,1) (which
            # lies on the corner between four pixels), we need 1/4 of the icon placed
            # centered on the 0'th row, 0'th column, etc.  We need to subtract
            # 0.5 to make the fractional position come out right.
            x, y = point_tracks[i, t, :] + 0.5
            x = min(max(x, 0.0), width)
            y = min(max(y, 0.0), height)

            if visibles[i, t]:
                x1, y1 = np.floor(x).astype(np.int32), np.floor(y).astype(np.int32)
                x2, y2 = x1 + 1, y1 + 1

                # bilinear interpolation
                patch = (
                    icon1 * (x2 - x) * (y2 - y)
                    + icon2 * (x2 - x) * (y - y1)
                    + icon3 * (x - x1) * (y2 - y)
                    + icon4 * (x - x1) * (y - y1)
                )
                x_ub = x1 + 2 * radius + 2
                y_ub = y1 + 2 * radius + 2
                image[y1:y_ub, x1:x_ub, :] = (1 - patch) * image[y1:y_ub, x1:x_ub, :] + patch * np.array(colormap[i])[
                    np.newaxis, np.newaxis, :
                ]

            # Remove the pad
            video[t] = image[radius + 1 : -radius - 1, radius + 1 : -radius - 1].astype(np.uint8)
    return video

In [40]:
import mediapy as media

dataset = load_tapvid_benchmark("/mnt/wd_sn770_oss/imc2024_code/new_sakuramoti/sakuramoti/data/tapvid_davis")
frames = dataset["cows"]["video"]
occ = dataset["cows"]["occluded"]
points = dataset["cows"]["points"]
if frames.shape[1] > 360:
    frames = media.resize_video(frames, (360, 640))
scale_factor = np.array(frames.shape[2:0:-1])[np.newaxis, np.newaxis, :]
painted_frames = paint_point_track(frames, points * scale_factor, ~occ)

In [41]:
media.show_video(painted_frames, fps=10)

0
This browser does not support the video tag.


In [39]:
dataset = load_tapvid_benchmark("/mnt/wd_sn770_oss/imc2024_code/new_sakuramoti/sakuramoti/data/tapvid_davis")

In [21]:
print(dataset["goat"]["points"].shape)

print(dataset["goat"]["video"].shape)

print(dataset["goat"]["occluded"].shape)

(5, 90, 2)
(90, 480, 854, 3)
(5, 90)


In [24]:
dataset["goat"]["points"]

array([[[0.515288  , 0.34306568],
        [0.5278646 , 0.34861112],
        [0.53828126, 0.3587963 ],
        [0.54765624, 0.37453705],
        [0.5570313 , 0.40046296],
        [0.5643229 , 0.41527778],
        [0.56953126, 0.4097222 ],
        [0.5721354 , 0.40694445],
        [0.57890624, 0.4125    ],
        [0.5893229 , 0.41435185],
        [0.59296876, 0.4162037 ],
        [0.5955211 , 0.4140875 ],
        [0.6059896 , 0.41527778],
        [0.62369794, 0.41990742],
        [0.6351563 , 0.4300926 ],
        [0.6382812 , 0.43657407],
        [0.6408854 , 0.43842593],
        [0.6460937 , 0.43472221],
        [0.6539062 , 0.43472221],
        [0.66536456, 0.4337963 ],
        [0.6773437 , 0.43935186],
        [0.6872396 , 0.4412037 ],
        [0.6924479 , 0.41805556],
        [0.7002604 , 0.38935184],
        [0.70963544, 0.3726852 ],
        [0.71223956, 0.3587963 ],
        [0.7049479 , 0.35324073],
        [0.6981771 , 0.35416666],
        [0.7007812 , 0.34675926],
        [0.703

In [None]:
dataset["goat"]["video"]

In [23]:
# @title Load an Exemplar Video {form-width: "25%"}
import mediapy as media

media.show_video(dataset["goat"]["video"], fps=10)

0
This browser does not support the video tag.


In [25]:
!wget -P tapnet/examplar_videos https://storage.googleapis.com/dm-tapnet/horsejump-high.mp4

orig_frames = media.read_video("tapnet/examplar_videos/horsejump-high.mp4")
height, width = orig_frames.shape[1:3]
media.show_video(orig_frames, fps=10)

--2024-08-18 19:54:50--  https://storage.googleapis.com/dm-tapnet/horsejump-high.mp4
storage.googleapis.com (storage.googleapis.com) をDNSに問いあわせています... 172.217.175.123, 172.217.174.123, 142.250.207.27, ...
storage.googleapis.com (storage.googleapis.com)|172.217.175.123|:443 に接続しています... 接続しました。
HTTP による接続要求を送信しました、応答を待っています... 200 OK
長さ: 467706 (457K) [video/mp4]
‘tapnet/examplar_videos/horsejump-high.mp4’ に保存中


2024-08-18 19:54:50 (2.82 MB/s) - ‘tapnet/examplar_videos/horsejump-high.mp4’ へ保存完了 [467706/467706]



0
This browser does not support the video tag.


In [None]:
"""Visualize frames of a random video of the given dataset."""

import io
from collections.abc import Sequence

import numpy as np
import mediapy as media
from PIL import Image
from absl import app, flags, logging
from tapnet.utils import viz_utils

FLAGS = flags.FLAGS

flags.DEFINE_string("input_path", None, "Path to the pickle file.", required=True)
flags.DEFINE_string("output_path", None, "Path to the output mp4 video.", required=True)


def main(argv: Sequence[str]) -> None:
    del argv

    logging.info("Loading data from %s. This takes time.", FLAGS.input_path)
    with open(FLAGS.input_path, "rb") as f:
        data = pickle.load(f)
        if isinstance(data, dict):
            data = list(data.values())

    idx = random.randint(0, len(data) - 1)
    video = data[idx]

    frames = video["video"]

    if isinstance(frames[0], bytes):
        # Tapnet is stored and JPEG bytes rather than `np.ndarray`s.
        def decode(frame):
            byteio = io.BytesIO(frame)
            img = Image.open(byteio)
            return np.array(img)

        frames = np.array([decode(frame) for frame in frames])

    if frames.shape[1] > 360:
        frames = media.resize_video(frames, (360, 640))

    scale_factor = np.array(frames.shape[2:0:-1])[np.newaxis, np.newaxis, :]
    painted_frames = viz_utils.paint_point_track(
        frames,
        video["points"] * scale_factor,
        ~video["occluded"],
    )

    media.write_video(FLAGS.output_path, painted_frames, fps=25)
    logging.info("Examplar point visualization saved to %s", FLAGS.output_path)


if __name__ == "__main__":
    app.run(main)