## Semantic Correspondence with Synthetic 3D Data

In [None]:
import os
import torch
from pytorch3d.implicitron.tools.config import expand_args_fields
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.dataset.visualize import get_implicitron_sequence_pointcloud
from pytorch3d.implicitron.tools.point_cloud_utils import render_point_cloud_pytorch3d

# Setup dataset root and paths
_CO3DV2_DATASET_ROOT = "/export/group/datasets/co3d"
category = "skateboard"
dataset_root = _CO3DV2_DATASET_ROOT
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
image_size = 256

# Initialize dataset
expand_args_fields(JsonIndexDataset)
dataset = JsonIndexDataset(
    frame_annotations_file=frame_file,
    sequence_annotations_file=sequence_file,
    dataset_root=dataset_root,
    image_height=image_size,
    image_width=image_size,
    load_point_clouds=True,
    box_crop=False,
    mask_images=False,
    load_images=True,
    load_masks=True
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
from typing import Optional, Tuple
from pytorch3d.implicitron.tools.point_cloud_utils import _transform_points, PointsRasterizer, PointsRasterizationSettings

def get_depth_point_cloud_pytorch3d(
    camera,
    point_cloud,
    render_size: Tuple[int, int],
    point_radius: float = 0.03,
    topk: int = 10,
    eps: float = 1e-2,
    bin_size: Optional[int] = None,
    **kwargs,
):
    # move to the camera coordinates; using identity cameras in the renderer
    point_cloud = _transform_points(camera, point_cloud, eps, **kwargs)
    camera_trivial = camera.clone()
    camera_trivial.R[:] = torch.eye(3)
    camera_trivial.T *= 0.0

    bin_size = (
        bin_size
        if bin_size is not None
        else (64 if int(max(render_size)) > 1024 else None)
    )
    rasterizer = PointsRasterizer(
        cameras=camera_trivial,
        raster_settings=PointsRasterizationSettings(
            image_size=render_size,
            radius=point_radius,
            points_per_pixel=topk,
            bin_size=bin_size,
        ),
    )

    fragments = rasterizer(point_cloud, **kwargs)
    return fragments.zbuf.min(dim=-1)[0].unsqueeze(0)

In [None]:

from matplotlib import pyplot as plt
from pytorch3d.renderer.cameras import get_screen_to_ndc_transform

def get_sample(max_frames=16):
    # 1. Sample random sequence from dataset
    random_index = 0 #torch.randint(0, len(dataset.seq_annots.keys()), (1,)).item()
    sequence_name = list(dataset.seq_annots.keys())[random_index]
    point_cloud, sequence_frame_data = get_implicitron_sequence_pointcloud(
        dataset,
        sequence_name=sequence_name,
        mask_points=True,
        max_frames=max_frames,
        num_workers=1,
        load_dataset_point_cloud=True,
    )
    point_cloud = point_cloud.to(device)

    # 2. Sample two random frames from the sequence
    frame_indices = torch.randint(0, len(sequence_frame_data.frame_number), (2,))
    source_idx, target_idx = frame_indices
    source_image = sequence_frame_data.image_rgb[source_idx]
    source_camera = sequence_frame_data.camera.to(device)[source_idx.item()]
    target_image = sequence_frame_data.image_rgb[target_idx]
    target_camera = sequence_frame_data.camera.to(device)[target_idx.item()]

    # 4. Render source and target images
    source_image_render, _, _ = render_point_cloud_pytorch3d(
        source_camera,
        point_cloud,
        render_size=(image_size, image_size),
        point_radius=2e-2,
        topk=10,
        bg_color=1.0,
        bin_size=0,
    )
    source_rendered_image = source_image_render[0].clamp(0.0, 1.0).cpu()

    depth_rendered = get_depth_point_cloud_pytorch3d(
        source_camera,
        point_cloud,
        render_size=(image_size, image_size),
        point_radius=2e-2,
        topk=10,
        bin_size=0,
    )

    target_image_render, _, _ = render_point_cloud_pytorch3d(
        target_camera,
        point_cloud,
        render_size=(image_size, image_size),
        point_radius=2e-2,
        topk=10,
        bg_color=1.0,
        bin_size=0,
    )
    target_rendered_image = target_image_render[0].clamp(0.0, 1.0).cpu()

    # 5. Determine corresponding point on target image using point cloud and camera info
    depth_rendered = depth_rendered[0].squeeze(0)
    nz_indices = (depth_rendered > 0).nonzero()
    depth_values = depth_rendered[nz_indices[:, 0], nz_indices[:, 1]]
    source_points = torch.stack((nz_indices[:, 0], nz_indices[:, 1], depth_values), dim=1) # Y, X, Z in screen space

    source_points_3d = source_points[:, [1, 0, 2]] # X, Y, Z in screen space
    source_points_3d = get_screen_to_ndc_transform(source_camera, image_size=(image_size, image_size), with_xyflip=True).transform_points(source_points_3d) # X, Y, Z in NDC

    source_points_3d = source_camera.unproject_points(source_points_3d, world_coordinates=True, from_ndc=True) # X, Y, Z in world space
    target_points_3d = target_camera.transform_points_screen(source_points_3d, image_size=(image_size, image_size)) # Y, X, Z in screen space
    target_points = target_points_3d[:, [1, 0]].long().clamp(0, image_size - 1)

    source_points = source_camera.transform_points_screen(source_points_3d, image_size=(image_size, image_size)) # Y, X, Z in screen space
    source_points = source_points[:, [1, 0]].long().clamp(0, image_size - 1)

    # rainbow color map
    cmap = plt.get_cmap("rainbow")
    colors = cmap(torch.linspace(0, 1, source_points.shape[0]))
    
    fig, ax = plt.subplots(2, 2, figsize=(8, 8))
    ax[0, 0].imshow(source_image.permute(1, 2, 0))
    #source_color_map = torch.zeros_like(source_image).permute(1, 2, 0)
    #source_color_map[source_points[:, 0], source_points[:, 1]] = torch.tensor(colors, dtype=torch.float)[:, :3]
    #ax[0, 0].imshow(source_color_map, cmap="rainbow", alpha=0.5)
    ax[0, 0].scatter(source_points[:, 1].cpu(), source_points[:, 0].cpu(), c=colors, s=1, alpha=[0.1] * len(source_points))
    ax[0, 0].axis("off")
    ax[0, 1].imshow(target_image.permute(1, 2, 0))
    #target_color_map = torch.zeros_like(target_image).permute(1, 2, 0)
    #target_color_map[target_points[:, 0], target_points[:, 1]] = torch.tensor(colors, dtype=torch.float)[:, :3]
    #ax[0, 1].imshow(target_color_map, cmap="rainbow", alpha=0.5)
    ax[0, 1].scatter(target_points[:, 1].cpu(), target_points[:, 0].cpu(), c=colors, s=1, alpha=[0.1] * len(source_points))
    ax[0, 1].axis("off")
    ax[1, 0].imshow(source_rendered_image.permute(1, 2, 0))
    ax[1, 0].scatter(source_points[:, 1].cpu(), source_points[:, 0].cpu(), c=colors, s=1, alpha=[0.1] * len(source_points))
    #ax[1, 0].imshow(source_color_map, cmap="rainbow", alpha=0.5)
    ax[1, 0].axis("off")
    ax[1, 1].imshow(target_rendered_image.permute(1, 2, 0))
    #ax[1, 1].imshow(target_color_map, cmap="rainbow", alpha=0.5)
    ax[1, 1].scatter(target_points[:, 1].cpu(), target_points[:, 0].cpu(), c=colors, s=1, alpha=[0.1] * len(source_points))
    ax[1, 1].axis("off")

get_sample()