In [None]:
import numpy as np
from PIL import Image

from semantic_grasping_datagen.eval.generate_taskgrasp_eval import TaskGraspScanLibrary, img_to_pc

from scorer import load_scorer
from utils import backproject

In [None]:
SCANS_DIR = "../data/taskgrasp/scans"

In [None]:
dataset = TaskGraspScanLibrary(SCANS_DIR)


In [None]:
grasp_scorer = load_scorer("01JQW5E267HTPPQFPJ2V23W8S3", ckpt=12500, map_location="cuda")


In [None]:
from transformers import pipeline, DepthProImageProcessorFast, DepthProForDepthEstimation
import torch

if "depth_anything" not in globals():
    depth_anything = pipeline(task="depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")

if "depth_pro_model" not in globals():
    depth_pro_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf")
    depth_pro_model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").cuda()

def complete_depth(rgb: Image.Image, depth: np.ndarray) -> np.ndarray:
    depth_relative = np.array(depth_anything(rgb)["depth"])
    mask = depth > 0
    masked_metric_depth = depth[mask].astype(np.float32)
    masked_relative_depth = depth_relative[mask].astype(np.float32)

    # Find relation between relative and metric depth
    A = np.vstack([masked_relative_depth, np.ones(len(masked_relative_depth))]).T
    scale, shift = np.linalg.lstsq(A, masked_metric_depth, rcond=None)[0]

    depth_pred = scale * depth_relative + shift
    return depth_pred

def complete_depth_dp(rgb: Image.Image, depth: np.ndarray) -> np.ndarray:
    inputs = depth_pro_processor(rgb, return_tensors="pt").to("cuda")
    with torch.no_grad():
        outputs = depth_pro_model(**inputs)
    processed_outputs = depth_pro_processor.post_process_depth_estimation(outputs, target_sizes=[depth.shape])
    depth_pred = processed_outputs[0]["predicted_depth"].cpu().numpy()

    depth_pred_masked = depth_pred[depth > 0]
    depth_masked = depth[depth > 0]

    A = np.vstack([depth_pred_masked, np.ones(len(depth_pred_masked))]).T
    scale, shift = np.linalg.lstsq(A, depth_masked, rcond=None)[0]

    depth_pred = scale * depth_pred + shift
    print(scale, shift)
    return depth_pred


In [None]:
data = dataset.get("236_mug", 2)

rgb = data["rgb"]
depth = data["depth"]
pc = data["fused_pc"]
grasps = data["registered_grasps"]
cam_K = data["cam_params"]

rgb.show()

FAR_CLIP = 0.8

depth[depth > FAR_CLIP] = 0
completed_depth = complete_depth_dp(rgb, depth)

xyz = backproject(cam_K, completed_depth)

import matplotlib
import matplotlib.pyplot as plt
cmap = matplotlib.colormaps.get_cmap("viridis")
cmap.set_bad(color="black")
vmin = np.minimum(np.nanmin(depth), np.nanmin(completed_depth))
vmax = np.maximum(np.nanmax(depth), np.nanmax(completed_depth))

plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
depth_viz = depth.copy()
depth_viz[depth == 0] = np.nan
print(np.nanmin(depth_viz), np.nanmax(depth_viz))
plt.imshow(depth_viz, cmap=cmap, vmin=vmin, vmax=vmax)
plt.axis("off")
plt.title("Original Depth")

plt.subplot(1, 3, 2)
plt.imshow(completed_depth, cmap=cmap, vmin=vmin, vmax=vmax)
plt.axis("off") 
plt.title("Completed Depth")

plt.subplot(1, 3, 3)
cmap2 = matplotlib.colormaps.get_cmap("coolwarm")
cmap2.set_bad(color="black")
plt.imshow(completed_depth - depth_viz, cmap=cmap2)
plt.colorbar()
plt.title("Difference (m)")
plt.axis("off")

text = "The grasp is on the mug. The grasp is on the handle of the mug. It is oriented vertically, with the fingers grasping the sides of the body from above."
# text = "The grasp is on the pan. The grasp is on the handle of the pan. It is oriented vertically, with the fingers grasping the sides of the handle."
# text = "The grasp is on the pan. The grasp is on the rim of the pan. It is oriented vertically, with the fingers grasping the inside and outside of the pan."

trf = np.eye(4)
trf[[1,2]] = -trf[[1,2]]
trf_grasps = trf[None] @ grasps
pred = grasp_scorer.score_grasps(np.eye(4), rgb, xyz, trf_grasps, text)
print(pred)


In [None]:
import trimesh
from acronym_tools import create_gripper_marker

scene = trimesh.Scene()
scene_pc = img_to_pc(np.asarray(rgb), depth, cam_K, depth < FAR_CLIP)
# scene_pc[:, 3:] = [[255, 0, 0]]
completed_pc = img_to_pc(np.asarray(rgb), completed_depth, cam_K, completed_depth < FAR_CLIP)
scene_pc_obj = trimesh.PointCloud(scene_pc[:,:3], scene_pc[:,3:].astype(np.uint8))
completed_pc_obj = trimesh.PointCloud(completed_pc[:,:3], completed_pc[:,3:].astype(np.uint8))
scene.add_geometry(scene_pc_obj)
# scene.add_geometry(completed_pc_obj)

for p, grasp in zip(pred, grasps):
    val = np.interp(p, [0.6, 0.875], [0, 255]).round().astype(np.uint8)
    gripper: trimesh.Trimesh = create_gripper_marker([255-val, val, 0])
    gripper.apply_transform(grasp)
    scene.add_geometry(gripper)

# best_idx = np.argmax(pred)
# for i, grasp in enumerate(grasps):
#     color = [0,255,0] if i == best_idx else [255,0,0]
#     gripper: trimesh.Trimesh = create_gripper_marker(color)
#     gripper.apply_transform(grasp)
#     scene.add_geometry(gripper)

axes = trimesh.creation.axis(origin_size=0.025)
# scene.add_geometry(axes)

scene.lights

scene.show(line_settings={'point_size':100}, height=720)