In [1]:
import os  # noqa
import sys  # noqa

proj_root = os.path.dirname(os.getcwd())
sys.path.append(proj_root)

OBJ_NAME = "mustard_bottle"
VIDEO_NAME = "mustard0"


video_dir = os.path.join(proj_root, "data", "inputs", VIDEO_NAME)
tracker_result_video = os.path.join(video_dir)
poses_dir = os.path.join(video_dir, "annotated_poses")
video_gt_mask_dir = os.path.join(video_dir, "gt_mask")
video_mask_dir = os.path.join(video_dir, "masks")
video_rgb_dir = os.path.join(video_dir, "rgb")
video_img_dir = os.path.join(video_dir, "img")
video_gt_coords_dir = os.path.join(video_dir, "gt_coords.npy")
video_gt_visibility_dir = os.path.join(video_dir, "gt_visibility.npy")
obj_dir = os.path.join(proj_root, "data", "objects", OBJ_NAME)

In [None]:
import glob

from tqdm import tqdm
from posingpixels.utils.offscreen_renderer import ModelRendererOffscreen
import cv2
import numpy as np
import trimesh
import matplotlib.pyplot as plt
from posingpixels.utils.cotracker import sample_support_grid_points
from posingpixels.utils.geometry import interpolate_poses

from posingpixels.utils.meshes import get_diameter_from_mesh
from posingpixels.alignment import get_safe_query_points
from posingpixels.segmentation import segment
import torch
from cotracker.utils.visualizer import Visualizer
from posingpixels.cotracker import get_offline_cotracker_predictions
from posingpixels.cotracker import get_online_cotracker_predictions
from typing import Optional, Tuple
from posingpixels.segmentation import get_bbox_from_mask, process_image_crop


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class YCBinEOATDataset(torch.utils.data.Dataset):
    def __init__(self, video_dir: str, object_dir: str):
        # Video
        self.video_dir = video_dir
        self.video_rgb_dir = os.path.join(self.video_dir, "rgb")
        self.rgb_video_files = sorted(glob.glob(f"{self.video_dir}/rgb/*.png"))
        self.gt_pose_dir = os.path.join(self.video_dir, "annotated_poses")
        self.gt_pose_files = sorted(glob.glob(f"{self.video_dir}/annotated_poses/*"))
        self.gt_mask_files = sorted(glob.glob(f"{self.video_dir}/gt_mask/*"))

        self.K = np.loadtxt(os.path.join(self.video_dir, "cam_K.txt")).reshape(3, 3)
        self.H, self.W = cv2.imread(self.rgb_video_files[0], cv2.IMREAD_COLOR).shape[:2]

        # Segmentation
        self.videoname_to_sam_prompt = {
            "mustard0": [(124, 292), (135, 304), (156, 336)]
        }
        self.masks_dir = os.path.join(self.video_dir, "masks")
        if not os.path.exists(self.masks_dir) or len(os.listdir(self.masks_dir)) == 0:
            segment(
                self.video_rgb_dir,
                self.masks_dir,
                prompts=self.videoname_to_sam_prompt[self.video_name],
            )
        self.mask_files = sorted(glob.glob(f"{self.masks_dir}/*.png"))

        # Object
        self.object_dir = object_dir
        self.obj_path = os.path.join(self.object_dir, "textured_simple.obj")
        mesh = self.get_mesh()
        self.obj_diameter = get_diameter_from_mesh(mesh)
        self.renderer = ModelRendererOffscreen(self.K, self.H, self.W)

        # Both
        self.videoname_to_object = {
            "bleach0": "021_bleach_cleanser",
            "bleach_hard_00_03_chaitanya": "021_bleach_cleanser",
            "cracker_box_reorient": "003_cracker_box",
            "cracker_box_yalehand0": "003_cracker_box",
            "mustard0": "006_mustard_bottle",
            "mustard_easy_00_02": "006_mustard_bottle",
            "sugar_box1": "004_sugar_box",
            "sugar_box_yalehand0": "004_sugar_box",
            "tomato_soup_can_yalehand0": "005_tomato_soup_can",
        }

    @property
    def video_name(self):
        return os.path.basename(self.video_dir)

    def __len__(self):
        return len(self.rgb_video_files)

    def get_mesh(self) -> trimesh.Trimesh:
        return trimesh.load_mesh(self.obj_path)

    @property
    def image_size(self):
        if self.H is None or self.W is None:
            self.H, self.W = cv2.imread(
                os.listdir(self.video_rgb_dir)[0], cv2.IMREAD_COLOR
            ).shape[:2]
        return self.H, self.W

    def get_gt_poses(self) -> np.ndarray:
        pose = None
        poses = []
        for i in range(len(self)):
            pose_i = self.get_gt_pose(i)
            pose = pose_i if pose_i is not None else pose
            poses.append(pose)
        return np.array(poses)

    def get_gt_pose(self, idx: int) -> Optional[np.ndarray]:
        file = os.path.join(self.gt_pose_dir, f"{idx:07d}.txt")
        if not os.path.exists(file):
            return None
        return np.loadtxt(file).reshape(4, 4)

    def get_rgb(self, idx: int) -> np.ndarray:
        return cv2.cvtColor(
            cv2.imread(self.rgb_video_files[idx], cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB
        )

    def get_mask(self, idx: int) -> np.ndarray:
        return cv2.imread(self.mask_files[idx], cv2.IMREAD_GRAYSCALE)

    def get_gt_mask(self, idx: int) -> np.ndarray:
        return cv2.imread(self.gt_mask_files[idx], cv2.IMREAD_GRAYSCALE)

    def render_mesh_at_pose(
        self, pose: Optional[np.ndarray] = None, idx: Optional[int] = None
    ) -> Tuple[np.ndarray, np.ndarray]:
        assert (pose is None) != (idx is None)
        pose = self.get_gt_pose(idx) if pose is None else pose
        return self.renderer.render(pose, self.get_mesh())

    def get_canonical_pose(self):
        canonical_pose = np.eye(4)
        diameter = self.obj_diameter

        # Translate along z-axis by diameter
        canonical_pose[:3, 3] = np.array([0, 0, diameter])
        # Rotate 90 degrees around x-axis then rotate around y-axis 180 degrees
        canonical_pose[:3, :3] = np.array([[-1, 0, 0], [0, 0, -1], [0, -1, 0]])

        return canonical_pose


class CoMeshTracker:
    def __init__(
        self,
        dataset: YCBinEOATDataset,
        visible_background: bool = False,
        crop: bool = True,
        offline: bool = True,
        offline_limit: int = 500,
        support_grid: Optional[int] = None,
        interpolation_steps: int = 15,
        mask_threshold: float = 0.5,
        device: torch.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu"
        ),
    ):
        assert crop != (
            support_grid is not None
        ), "SUPPORT_GRID must be set if BLACK_BACKGROUND is False (and vice versa)"
        # Model Config
        self.visible_background = visible_background
        self.crop = crop
        self.offline = offline
        self.offline_limit = offline_limit
        self.limit = offline_limit if offline else len(self)
        self.support_grid = support_grid
        self.interpolation_steps = interpolation_steps
        self.mask_threshold = mask_threshold
        self.model_resolution = (384, 512)
        self.device = device
        # Dataset Config
        self.dataset = dataset
        self.K = dataset.K
        self.H, self.W = dataset.image_size
        # Initialization Config
        self.start_pose = dataset.get_gt_pose(0)
        self.base_pose = (
            dataset.get_canonical_pose()
            if self.interpolation_steps > 1
            else self.start_pose
        )
        self.query_poses = [self.base_pose]
        self.init_video_dir = os.path.join(self.dataset.video_dir, "init_video")
        self.cotracker_input_dir = os.path.join(self.dataset.video_dir, "input")

    def __call__(self):
        # Create init video
        self.create_init_video()
        # Prepare img directory (input for CoTracker)
        self.prepare_img_directory()
        # Prepare query points
        self.get_query_points()
        # Run CoTracker
        return self.run_cotracker()

    def __len__(self):
        return len(self.dataset) + self.interpolation_steps

    def get_rgb(self, idx: int) -> np.ndarray:
        if idx < self.interpolation_steps:
            return cv2.imread(
                os.path.join(self.init_video_dir, f"{idx:05d}.jpg"), cv2.IMREAD_COLOR
            )
        return self.dataset.get_rgb(idx - self.interpolation_steps)

    def get_mask(self, idx: int) -> np.ndarray:
        if idx < self.interpolation_steps:
            return (
                cv2.imread(
                    os.path.join(self.init_video_dir, f"{idx:05d}.png"),
                    cv2.IMREAD_GRAYSCALE,
                )
                / 255
            )
        return self.dataset.get_gt_mask(idx - self.interpolation_steps)

    def get_gt_poses(self) -> np.ndarray:
        return np.concatenate([self.interpolation_poses, self.dataset.get_gt_poses()])

    def get_gt_pose(self, idx: int) -> Optional[np.ndarray]:
        if idx < self.interpolation_steps:
            return self.interpolation_poses[idx]
        return self.dataset.get_gt_pose(idx - self.interpolation_steps)

    def prepare_img_directory(self):
        # Clear directory
        if not os.path.exists(self.cotracker_input_dir):
            os.makedirs(self.cotracker_input_dir)
        for f in os.listdir(self.cotracker_input_dir):
            os.remove(os.path.join(self.cotracker_input_dir, f))
        # Prepare images for CoTracker
        self.bboxes, self.scaling = [], []
        for i in tqdm(range(self.limit), desc="Preparing images for CoTracker"):
            rgb = self.get_rgb(i)
            mask = self.get_mask(i) > self.mask_threshold
            if not self.visible_background:
                rgb[mask == 0, :] = 0
            if self.crop:
                bbox = get_bbox_from_mask(mask)
                assert bbox
                rgb, processed_bbox, scaling_factor = process_image_crop(
                    rgb,
                    bbox,
                    padding=10,
                    target_size=self.model_resolution,
                )
                self.bboxes.append(processed_bbox)
                self.scaling.append(scaling_factor)

            cv2.imwrite(
                os.path.join(self.cotracker_input_dir, f"{i:05d}.jpg"),
                cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR),
            )

    def create_init_video(self):
        assert (
            len(self.query_poses) == 1
        ), "Not yet implemented for multiple query poses"
        assert self.base_pose is not None and self.start_pose is not None

        if not self.interpolation_steps:
            return None, None, None
        if not os.path.exists(self.init_video_dir):
            os.makedirs(self.init_video_dir)
        for f in os.listdir(self.init_video_dir):
            os.remove(os.path.join(self.init_video_dir, f))

        self.interpolation_poses = interpolate_poses(
            self.base_pose[:3, :3],
            self.base_pose[:3, 3],
            self.start_pose[:3, :3],
            self.start_pose[:3, 3],
            self.interpolation_steps,
        )

        base_frame = self.dataset.get_rgb(0)
        for i, P_i in enumerate(self.interpolation_poses):
            rgb, depth = self.dataset.render_mesh_at_pose(pose=P_i)
            depth_rgb = depth[:, :, None]
            rgb = base_frame * (depth_rgb <= 0) + rgb * (depth_rgb > 0)

            # Save RGB and Mask
            cv2.imwrite(
                os.path.join(self.init_video_dir, f"{i:05d}.jpg"),
                cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR),
            )
            cv2.imwrite(
                os.path.join(self.init_video_dir, f"{i:05d}.png"),
                (depth > 0).astype(np.uint8) * 255,
            )

    def get_query_points(self):
        assert (
            len(self.query_poses) == 1
        ), "Not yet implemented for multiple query poses"
        assert self.base_pose is not None
        # Get query points
        self.unposed_3d_points, self.query_2d_points = get_safe_query_points(
            R=self.base_pose[:3, :3],
            T=self.base_pose[:3, 3],
            camK=self.K,
            H=self.H,
            W=self.W,
            mesh=self.dataset.get_mesh(),
            min_pixel_distance=10 if not self.interpolation_steps else 25,
            alpha_margin=5 if not self.interpolation_steps else 15,
            depth_margin=2 if not self.interpolation_steps else 6,
        )
        self.object_query_points_num = len(self.unposed_3d_points)
        # Add support grid
        if self.support_grid is not None:
            support_grid_points = sample_support_grid_points(
                self.H,
                self.W,
                self.interpolation_steps,
                self.get_mask(0),
                grid_size=self.support_grid,
            )
            self.query_2d_points = np.concatenate(
                [self.query_2d_points, support_grid_points], axis=0
            )
        # Prepare query points for CoTracker
        self.input_query_points = self.query_2d_points.copy()
        if self.crop:
            self.input_query_points[:, 1] -= self.bboxes[0][0]
            self.input_query_points[:, 2] -= self.bboxes[0][1]
            self.input_query_points[:, 1] *= self.scaling[0][0]
            self.input_query_points[:, 2] *= self.scaling[0][1]

    def prepare_cotracker_initialization(self):
        pass

    def run_cotracker(self):
        if not self.offline:
            return get_online_cotracker_predictions(
                self.cotracker_input_dir,
                grid_size=0,
                queries=self.input_query_points,
            )
        else:
            return get_offline_cotracker_predictions(
                self.cotracker_input_dir,
                grid_size=0,
                queries=self.input_query_points,
                limit=self.offline_limit,
            )


dataset = YCBinEOATDataset(video_dir, obj_dir)
tracker = CoMeshTracker(dataset, offline_limit=500, interpolation_steps=15)



In [None]:
pred_tracks, pred_visibility, pred_confidence = tracker()

Preparing images for CoTracker:  79%|███████▉  | 395/500 [00:09<00:02, 38.33it/s]

In [None]:
from posingpixels.datasets import load_video_images


video = load_video_images(
    video_img_dir, limit=tracker.offline_limit if tracker.offline else None
)

In [None]:
def visualize_results(
    video,
    pred_tracks,
    pred_visibility,
    pred_confidence,
    save_dir,
    num_of_main_queries=None,
    filename="video",
):
    if num_of_main_queries is None:
        num_of_main_queries = pred_tracks.shape[2]
    vis = Visualizer(save_dir=save_dir, pad_value=0, linewidth=3)
    vis.visualize(
        video,
        pred_tracks[:, :, :num_of_main_queries, :],
        (pred_visibility * pred_confidence > 0.6)[:, :, :num_of_main_queries],
        filename=filename,
    )


visualize_results(
    video,
    pred_tracks,
    pred_visibility,
    pred_confidence,
    tracker_result_video,
    num_of_main_queries=tracker.object_query_points_num,
)

Video saved to /home/joao/Documents/repositories/GSPose/data/inputs/mustard0/video.mp4


In [None]:
# Generate ground truth for the coords location (& visibility)
from typing import Tuple

from posingpixels.utils.cotracker import get_ground_truths


N = len(tracker) if not tracker.offline else tracker.offline_limit
# _, depths = tracker.dataset.renderer.render_batch(tracker.get_gt_poses(), tracker.dataset.get_mesh())

gt_coords = np.zeros((N, tracker.object_query_points_num, 2))
gt_visibility = np.zeros((N, tracker.object_query_points_num))

poses = tracker.get_gt_poses()
for i in tqdm(range(N), desc="Processing frames"):
    mask = tracker.get_mask(i)
    pose = poses[i]
    depth = tracker.dataset.render_mesh_at_pose(pose=pose)[1]
    gt_coords[i], gt_visibility[i] = get_ground_truths(
        pose, tracker.K, tracker.unposed_3d_points, mask, depth
    )
np.save(video_gt_coords_dir, gt_coords)
np.save(video_gt_visibility_dir, gt_visibility)

Processing frames: 100%|██████████| 30/30 [00:20<00:00,  1.50it/s]


In [None]:
from posingpixels.utils.cotracker import scale_by_crop


torch_bbox = torch.tensor(tracker.bboxes).to(device)[:N]
torch_scaling = torch.tensor(tracker.scaling).to(device)[:N]
gt_coords_torch = scale_by_crop(
    torch.tensor(gt_coords).to(device)[:N], torch_bbox, torch_scaling
)[None]
gt_visibility_torch = torch.tensor(gt_visibility).to(device)[None]
visualize_results(
    video,
    gt_coords_torch,
    gt_visibility_torch,
    torch.ones_like(gt_visibility_torch).to(device),
    tracker_result_video,
    num_of_main_queries=tracker.object_query_points_num,
    filename="gt_video",
)

Video saved to /home/joao/Documents/repositories/GSPose/data/inputs/mustard0/gt_video.mp4


In [None]:
# See distribution of confidence values over time
conf_np = pred_confidence.detach().cpu().numpy()[0]
print(conf_np.shape)
plt.plot(conf_np.mean(axis=1))
plt.axvline(x=tracker.interpolation_steps - 1, color="r", linestyle="--")
plt.title("Mean confidence over time")
plt.show()
plt.plot(conf_np.mean(axis=1)[:20])
# Plot vertical line at tracker.interpolation_steps
plt.axvline(x=tracker.interpolation_steps - 1, color="r", linestyle="--")
plt.title("Mean confidence over time")
plt.show()
# Plot distribution of confidence values at frame tracker.interpolation_steps - 1
plt.hist(conf_np[tracker.interpolation_steps - 1])
plt.title(
    f"Distribution of confidence values at frame {tracker.interpolation_steps - 1}"
)
plt.show()
# Plot distribution of confidence values at frame tracker.interpolation_steps + 1
plt.hist(conf_np[tracker.interpolation_steps + 1])
plt.title(
    f"Distribution of confidence values at frame {tracker.interpolation_steps + 1}"
)
plt.show()