## Semantic Correspondence with Synthetic 3D Data

In [None]:
import os
import torch
from pytorch3d.implicitron.tools.config import expand_args_fields
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.dataset.visualize import get_implicitron_sequence_pointcloud
from pytorch3d.implicitron.tools.point_cloud_utils import render_point_cloud_pytorch3d

# Setup dataset root and paths
_CO3DV2_DATASET_ROOT = "/export/group/datasets/co3d"
category = "skateboard"
dataset_root = _CO3DV2_DATASET_ROOT
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
image_size = 256

# Initialize dataset
expand_args_fields(JsonIndexDataset)
dataset = JsonIndexDataset(
    frame_annotations_file=frame_file,
    sequence_annotations_file=sequence_file,
    dataset_root=dataset_root,
    image_height=image_size,
    image_width=image_size,
    load_point_clouds=True,
    box_crop=False,
    mask_images=False,
    load_images=True,
    load_masks=True
)

# Load a point cloud and camera info for visualization
sequence_name = list(dataset.seq_annots.keys())[0]  # Use the first sequence
point_cloud, sequence_frame_data = get_implicitron_sequence_pointcloud(
    dataset,
    sequence_name=sequence_name,
    mask_points=True,
    max_frames=16,
    num_workers=1,
    load_dataset_point_cloud=True,
)

# Render the point cloud (simplified for illustration)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
point_cloud = point_cloud.to(device)
camera = sequence_frame_data.camera.to(device)[0]  # Use the first camera view for simplicity

image_render, _, _ = render_point_cloud_pytorch3d(
    camera,
    point_cloud,
    render_size=(image_size, image_size),
    point_radius=2e-2,
    topk=10,
    bg_color=0.0,
    bin_size=0,
)

image_render = image_render.clamp(0.0, 1.0).cpu()

In [None]:
from matplotlib import pyplot as plt
fig, ax = plt.subplots(4, 4, figsize=(16, 16))
for i in range(16):
    ax[i // 4, i % 4].imshow(sequence_frame_data.image_rgb[i].permute(1, 2, 0))
    ax[i // 4, i % 4].axis("off")

In [None]:
from matplotlib import pyplot as plt
fig, ax = plt.subplots(4, 4, figsize=(16, 16))
for i in range(16):
    ax[i // 4, i % 4].imshow(sequence_frame_data.depth_mask[i].permute(1, 2, 0))
    ax[i // 4, i % 4].axis("off")

In [None]:
from matplotlib import pyplot as plt
plt.imshow(image_render[0].permute(1, 2, 0))
plt.show()

In [None]:
# Zhang et al.:
# 1) double flip: flipped source image and flipped target image
# 2) single flip: flipped source image and original target image
# 3) self flip: source image and flipped source image.
# For setting 2 and 3, keypoint annotations are correspondingly flipped

# Take an image from the sequence