In this notebook, we will explain how to use XMem, SAM, DINOv2 with our wrappers. The wrappers are designed to make it easy to perform downstream robotics tasks

In [None]:
import sys
import os

import h5py
import cv2
import os
import h5py
import json
import numpy as np
import matplotlib.pyplot as plt

project_repo_folder = "."
# For XMem
sys.path.append(f"{project_repo_folder}/third_party/XMem")
sys.path.append(f"{project_repo_folder}/third_party/XMem/model")
sys.path.append(f"{project_repo_folder}/third_party/XMem/util")
sys.path.append(f"{project_repo_folder}/third_party/XMem/inference")
sys.path.append(f"{project_repo_folder}/")
from PIL import Image
from pathlib import Path
from groot_imitation.groot_algo import GROOT_ROOT_PATH
from groot_imitation.groot_algo.xmem_tracker import XMemTracker
from groot_imitation.groot_algo.misc_utils import get_annotation_path, get_first_frame_annotation, overlay_xmem_mask_on_image, depth_to_rgb, resize_image_to_same_shape, plotly_draw_seg_image, rotate_camera_pose
from groot_imitation.groot_algo.misc_utils import overlay_xmem_mask_on_image, add_palette_on_mask, VideoWriter, get_transformed_depth_img
from groot_imitation.groot_algo.o3d_modules import O3DPointCloud, convert_convention


from IPython.display import Video


### Example of XMem VOS
In this part, you will learn how to use:
1. xmem_tracker wrapper that makes it easy to process video streams in robotics domains
2. how to render videos with the VOS masks.  

In [None]:
annotation_folder = f"{project_repo_folder}/datasets/annotations/example_demo"
demo_file_name = f"{project_repo_folder}/datasets/example_demo.hdf5"

first_frame, first_frame_annotation = get_first_frame_annotation(annotation_folder)

# ************************ Most important part *******************************
xmem_tracker = XMemTracker(xmem_checkpoint=f'{project_repo_folder}/third_party/xmem_checkpoints/XMem.pth', device='cuda:0')
xmem_tracker.clear_memory()
# **************************************************************************

resized_images = []

with h5py.File(demo_file_name, "r") as f:
    images = f["data/demo_0/obs"]["agentview_rgb"][:]

for image in images:
    image = cv2.resize(image, (first_frame_annotation.shape[1], first_frame_annotation.shape[0]), interpolation=cv2.INTER_AREA)
    resized_images.append(image)

masks = xmem_tracker.track_video(resized_images, first_frame_annotation)


mask_file = os.path.join(annotation_folder, "video_masks.hdf5")

with h5py.File(mask_file, "w") as f:
    f.create_group("data")
    f["data"].create_dataset("agentview_masks", data=np.stack(masks, axis=0))

with VideoWriter(video_path=annotation_folder, video_name="mask_only_video.mp4", fps=20, save_video=True) as video_writer:
    for mask, image in zip(masks, resized_images):
        new_mask_img = add_palette_on_mask(mask).convert("RGB")
        video_writer.append_image(np.array(new_mask_img))

with VideoWriter(video_path=annotation_folder, video_name="overlay_video.mp4", fps=20, save_video=True) as video_writer:
    for mask, image in zip(masks, resized_images):
        new_mask_img = overlay_xmem_mask_on_image(image, mask, use_white_bg=True)
        video_writer.append_image(np.array(new_mask_img))

# Video(os.path.join(annotation_folder, "overlay_video.mp4"), embed=True, width=500, height=500)

### Example of RGB-D reconstruction
In this part, you will learn how to use:
1. read in RGB-D image from an example dataset
2. reconstruct the point clouds of the images
3. load in the segmentation mask, reconstruct object-centric 3d point clouds

In [None]:
example_demo_file = f"{project_repo_folder}/datasets/example_demo.hdf5"

print("Reading the example demo file...")
idx = 0
with h5py.File(example_demo_file) as f:
    # convert_convention fucntion is to make sure that the arrays are saved as contiguous arrays, which is important to make rendering proper.
    rgb_image = convert_convention(f["data/demo_0/obs"]["agentview_rgb"][idx])
    depth_image = convert_convention(f["data/demo_0/obs"]["agentview_depth"][idx])
    mask_image = f["data/demo_0/obs"]["agentview_masks"][idx]
    camera_extrinsics = f["data/demo_0/obs"]["agentview_extrinsics"][idx]
    camera_intrinsics = json.loads(f["data"].attrs["camera_intrinsics"])["agentview"]
    print(f["data"].attrs.keys())
    print(f["data/demo_0/obs"].keys())
mask_image = resize_image_to_same_shape(mask_image, rgb_image)

Visualize what they look like:

In [None]:
depth_image_in_rgb = depth_to_rgb(depth_image, colormap="jet") # You can visualize the depth with either jet, magma, viridis color map
# display the two images by simply concatenating them
plt.imshow(np.concatenate([rgb_image, depth_image_in_rgb], axis=1))


Visualize their 3d point clouds

In [None]:
# render 3d point cloud
import plotly
import plotly.graph_objs as go

depth_pc = O3DPointCloud()

depth_pc.create_from_depth(depth_image, camera_intrinsics)

depth_pc.transform(camera_extrinsics)
point_cloud = depth_pc.get_points()
x_vals = point_cloud[:, 0]
y_vals = point_cloud[:, 1]
z_vals = point_cloud[:, 2]

# Create the scatter3d plot
scatter = go.Scatter3d(
    x=x_vals,
    y=y_vals,
    z=z_vals,
    mode='markers',
    marker=dict(size=3, color=z_vals, colorscale='Viridis', opacity=0.8)
)

# Set the layout for the plot
layout = go.Layout(
    margin=dict(l=0, r=0, b=0, t=0)
)

# Combine the scatter plot and layout to create a figure
fig = go.Figure(data=[scatter], layout=layout)

# Show the figure
fig.show()

In [None]:
rgbd_pc = O3DPointCloud()
rgbd_pc.create_from_rgbd(rgb_image, depth_image, camera_intrinsics)
rgbd_pc.transform(camera_extrinsics)

point_cloud = rgbd_pc.get_points()
colors_rgb = rgbd_pc.get_colors()
# Convert RGB colors to a format recognizable by Plotly
color_str = ['rgb('+str(r)+','+str(g)+','+str(b)+')' for r,g,b in colors_rgb]

# Extract x, y, and z columns from the point cloud
x_vals = point_cloud[:, 0]
y_vals = point_cloud[:, 1]
z_vals = point_cloud[:, 2]

# Create the scatter3d plot
rgbd_scatter = go.Scatter3d(
    x=x_vals,
    y=y_vals,
    z=z_vals,
    mode='markers',
    marker=dict(size=3, color=color_str, opacity=0.8)
)

# Set the layout for the plot
layout = go.Layout(
    margin=dict(l=0, r=0, b=0, t=0)
)

fig = go.Figure(data=[rgbd_scatter], layout=layout)

# Show the figure
fig.show()

In [None]:
from plotly.subplots import make_subplots

object_centric_scatters = []
for mask_idx in range(1, mask_image.max()):
    # crop the image and depth
    masked_depth_image = depth_image.copy()
    binary_mask = np.where(mask_image==mask_idx, 1, 0)
    masked_depth_image[binary_mask == 0] = -1

    object_pcd = O3DPointCloud()
    object_pcd.create_from_rgbd(rgb_image, masked_depth_image, camera_intrinsics)
    object_pcd.transform(camera_extrinsics)
    # object_pcd.preprocess()

    point_cloud = object_pcd.get_points()
    colors_rgb = object_pcd.get_colors()

    x_vals = point_cloud[:, 0]
    y_vals = point_cloud[:, 1]
    z_vals = point_cloud[:, 2]

    # Convert RGB colors to a format recognizable by Plotly
    color_str = ['rgb('+str(r)+','+str(g)+','+str(b)+')' for r,g,b in colors_rgb]
    scatter = go.Scatter3d(
        x=x_vals,
        y=y_vals,
        z=z_vals,
        mode='markers',
        marker=dict(size=3, color=color_str, opacity=0.8)
    )
    object_centric_scatters.append(scatter)


# Set the layout for the plot
layout = go.Layout(
    margin=dict(l=0, r=0, b=0, t=0)
)

fig = go.Figure(data=object_centric_scatters, layout=layout)

# Show the figure
fig.show()

#### Camera perspective augmentation

In [None]:
new_camera_extrinsics = rotate_camera_pose(camera_extrinsics, angle=-60, point=[0.6, 0, 0])

new_depth_img, z_max = get_transformed_depth_img(
    point_cloud=rgbd_pc.get_points(),
    camera_intrinsics=np.array(camera_intrinsics),
    new_camera_extrinsics=new_camera_extrinsics,
    camera_width=224,
    camera_height=224,
)

plt.imshow(depth_to_rgb(new_depth_img, colormap="jet"))

### Example of Segment Correspondence Model
In this part, you will learn how to use:
1. get a segmentation of the image using SAM
2. load a reference image, and compute DINOv2 feature using our DINOv2 wrapper
3. interactively visualize DINOv2 cost volume (may need to use jupyter-dash)
4. get the correspondence

In [None]:
from groot_imitation.groot_algo.sam_operator import SAMOperator
from groot_imitation.groot_algo.dino_features import DinoV2ImageProcessor, compute_affinity, rescale_feature_map, generate_video_from_affinity

dinov2 = DinoV2ImageProcessor()
sam_operator = SAMOperator()
sam_operator.print_config()
sam_operator.init()

In [None]:
import torch
from functools import partial
# autocast_dtype = torch.half
# autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype)
# with autocast_ctx():
#     with torch.no_grad():
mask_result_dict = sam_operator.segment_image(rgb_image)

In [None]:
# visaualize the mask using plotly
import plotly
import plotly.express as px
overall_mask = (mask_result_dict["overall_mask"] * 255).astype(np.uint8)
mask_ids = mask_result_dict["merged_mask"]
# draw the mask image in plotly canvas
fig = px.imshow(overall_mask)

fig.data[0].customdata = mask_ids
# fig.data[0].hovertemplate = '<b>Mask ID:</b> %{customdata}'
fig.data[0].hovertemplate = 'x: %{x}<br>y: %{y}<br>Mask ID: %{customdata}'


fig.update_layout(
    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
    showlegend=False,
    width=300,   # you can adjust this as needed
    height=300,   # you can adjust this as needed
    margin=dict(l=0, r=0, b=0, t=0)
)

fig.show()


In [None]:


new_instance_image = np.array(Image.open(os.path.join(GROOT_ROOT_PATH, "../", "datasets", "example_new_object.jpg")))
new_instance_image = resize_image_to_same_shape(new_instance_image, rgb_image)

img_list = []
feature_list = []
for img in [rgb_image, new_instance_image]:
    img_list.append(img)
    feature_list.append(dinov2.process_image(img))

saved_video_file = generate_video_from_affinity(
    img_list[0], 
    img_list[1], 
    feature_list[0], 
    feature_list[1],
    h=32,
    w=32,
    patch_size=14,
    )
# display the video
# Video(saved_video_file, embed=True, width=500, height=500)


### SCM Model example

In [None]:
from groot_imitation.segmentation_correspondence_model.scm import SegmentationCorrespondenceModel

scm_module = SegmentationCorrespondenceModel(dinov2=dinov2, sam_operator=sam_operator)

new_annotation_mask = scm_module(new_instance_image, rgb_image, mask_image)
new_annotation_mask = resize_image_to_same_shape(new_annotation_mask, new_instance_image)
print(new_instance_image.shape, new_annotation_mask.shape)
new_instance_overlay_image = overlay_xmem_mask_on_image(new_instance_image, new_annotation_mask, use_white_bg=True)

plotly_draw_seg_image(new_instance_overlay_image, new_annotation_mask)
