In [1]:
import os
import json
import numpy as np
import sys
import os

import cv2

root_path = '/home/bb8/workspace/gr1/GR1_vision'

sys.path.insert(0, os.path.join(root_path, '../retarget_vision'))
sys.path.insert(0, root_path)

In [2]:
sam_checkpoint_path = os.path.join(root_path, '../retarget_vision/third_party/sam_checkpoints/sam_vit_h_4b8939.pth' )

In [3]:
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
from torchvision.ops import box_convert

# Grounding DINO
import groundingdino.datasets.transforms as T
from groundingdino.models import build_model
from groundingdino.util import box_ops
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
from groundingdino.util.inference import annotate, load_image, predict

import supervision as sv

# segment anything
from segment_anything import build_sam, SamPredictor 
import cv2
import numpy as np
import matplotlib.pyplot as plt

# diffusers
import PIL
import requests
import torch
from io import BytesIO
from diffusers import StableDiffusionInpaintPipeline

from huggingface_hub import hf_hub_download

import groundingdino.datasets.transforms as T

from orion.utils.misc_utils import get_palette

class GroundedSamWrapper:
    def load_model_hf(self, repo_id, filename, ckpt_config_filename, device='cuda'):
        cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)

        args = SLConfig.fromfile(cache_config_file) 
        model = build_model(args)
        args.device = device

        cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
        checkpoint = torch.load(cache_file, map_location='cuda')
        log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
        print("Model loaded from {} \n => {}".format(cache_file, log))
        _ = model.eval()
        return model   
    def __init__(self):
        # Use this command for evaluate the Grounding DINO model
        # Or you can download the model by yourself
        ckpt_repo_id = "ShilongLiu/GroundingDINO"
        ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
        ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"

        self.groundingdino_model = self.load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename)

        device = 'cuda'
        sam_checkpoint = sam_checkpoint_path
        sam = build_sam(checkpoint=sam_checkpoint)
        sam.to(device=device)
        self.sam_predictor = SamPredictor(sam)

        from diffusers import StableDiffusionInpaintPipeline

        pipe = StableDiffusionInpaintPipeline.from_pretrained(
            "stabilityai/stable-diffusion-2-inpainting",
            torch_dtype=torch.float16,
        )

        pipe = pipe.to("cuda")

        import io

    def transform(self, image):
        transform = T.Compose(
            [
                T.RandomResize([800], max_size=1333),
                T.ToTensor(),
                T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )
        image_source = Image.fromarray(image).convert("RGB")
        image_transformed, _ = transform(image_source, None)
        return image, image_transformed

    def combine_masks(self, masks):
        final_mask = np.zeros((masks.shape[2], masks.shape[3]), dtype=np.uint8)
        masks = masks.cpu().detach().numpy()
        final_mask = final_mask + masks[0][0]

        for i in range(1, masks.shape[0]):
            final_mask += (masks[i][0] * (i + 1)).astype(np.uint8)
        
        #for visualizing the mask
        #final_mask = final_mask * (255 // np.amax(final_mask))

        return final_mask

    def segment(self, image_np, prompts, box_threshold=0.3, text_threshold=0.25, filter_threshold=200):
        image_source, image = self.transform(image_np)

        prompt_text = ""
        for prompt in prompts:
            prompt_text += (prompt + ".")

        boxes, logits, phrases = predict(
            model=self.groundingdino_model, 
            image=image, 
            caption=prompt_text, 
            box_threshold=box_threshold, 
            text_threshold=text_threshold
        )

        if boxes.shape[0] == 0:
            print("no boxes found!")
            return np.array([])

           
        annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
        annotated_frame = annotated_frame[...,::-1] # BGR to RGB

        img_src = Image.fromarray(image_source)
        # img_src.save("img_src.png")

        img_annotated = Image.fromarray(annotated_frame)
        # img_annotated.save("annotated.png")

        # set image
        self.sam_predictor.set_image(image_source)

        # box: normalized box xywh -> unnormalized xyxy
        H, W, _ = image_source.shape
        boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])

        device = "cuda"
        transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(boxes_xyxy, image_source.shape[:2]).to(device)
        masks, _, _ = self.sam_predictor.predict_torch(
                    point_coords = None,
                    point_labels = None,
                    boxes = transformed_boxes,
                    multimask_output = False,
                )

        intermediate_final_mask = self.combine_masks(masks)

        filter_indices = []
        for i in range(1, intermediate_final_mask.max() + 1):
            if np.sum(intermediate_final_mask == i) < filter_threshold:
                filter_indices.append(i)
        
        final_mask = np.zeros_like(intermediate_final_mask)
        count = 0
        for i in range(1, intermediate_final_mask.max() + 1):
            if i not in filter_indices:
                final_mask[intermediate_final_mask == i] = count + 1
                count += 1

        mask_image_pil = Image.fromarray(final_mask) # .convert("RGBA")
        mask_image_pil.putpalette(get_palette())
        # mask_image_pil.save("final_mask.png")
        return mask_image_pil
    
def overlay_xmem_mask_on_image(rgb_img, mask, use_white_bg=False, rgb_alpha=0.7):
    """

    Args:
        rgb_img (np.ndarray):rgb images
        mask (np.ndarray)): binary mask
        use_white_bg (bool, optional): Use white backgrounds to visualize overlap. Note that we assume mask ids 0 as the backgrounds. Otherwise the visualization might be screws up. . Defaults to False.

    Returns:
        np.ndarray: overlay image of rgb_img and mask
    """
    colored_mask = Image.fromarray(mask)
    colored_mask.putpalette(get_palette())
    colored_mask = np.array(colored_mask.convert("RGB"))
    if use_white_bg:
        colored_mask[mask == 0] = [255, 255, 255]
    overlay_img = cv2.addWeighted(rgb_img, rgb_alpha, colored_mask, 1-rgb_alpha, 0)

    return overlay_img

  warn(
    PyTorch 2.0.0+cu118 with CUDA 1108 (you have 2.3.0+cu121)
    Python  3.9.16 (you have 3.9.19)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


In [4]:
wrapper = GroundedSamWrapper()



final text_encoder_type: bert-base-uncased




Model loaded from /home/bb8/.cache/huggingface/hub/models--ShilongLiu--GroundingDINO/snapshots/a94c9b567a2a374598f05c584e96798a170c56fb/groundingdino_swinb_cogcoor.pth 
 => _IncompatibleKeys(missing_keys=[], unexpected_keys=['label_enc.weight', 'bert.embeddings.position_ids'])


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

In [5]:
from deoxys_vision.utils.camera_utils import assert_camera_ref_convention, get_camera_info
from deoxys_vision.networking.camera_redis_interface import CameraRedisSubInterface
from deoxys_vision.utils.o3d_utils import scene_pcd_fn, O3DPointCloud, estimate_rotation
from deoxys_vision.utils.plotly_utils import plotly_draw_3d_pcd

from retarget.retargeter import SMPLGR1Retargeter
from retarget.utils.configs import load_config
from robot.gr1 import GR1URDFModel

from utils.grasp_offset_utils import obtain_object_target_in_head_frame

assert_camera_ref_convention('rs_0')
camera_info = get_camera_info('rs_0')
cr_interface = CameraRedisSubInterface(redis_host="localhost", camera_info=camera_info, use_depth=True)
cr_interface.start()

gr1 = GR1URDFModel()

intrinsics = np.array([
            [909.83630371,   0.        , 651.97015381],
            [  0.        , 909.12280273, 376.37097168],
            [  0.        ,   0.        ,   1.        ],
        ])

extrinsics = np.array([
    [ 2.22044605e-16,  2.07353665e-16,  1.00000000e+00 , 7.74200000e-02],
    [-1.00000000e+00 , 6.93889390e-18 , 2.22044605e-16 , 3.25000000e-02],
    [ 6.93889390e-18 ,-1.00000000e+00 , 2.31531374e-16 ,-2.13700000e-02],
    [ 0.00000000e+00 , 0.00000000e+00 , 0.00000000e+00 , 1.00000000e+00]
])

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [15]:
with open(os.path.join(root_path, 'tmp_results/status.json'), 'r') as f:
    status = json.load(f)

object_names = [status['object_name']]
# object_names = ['plate']
print("object_names: ", object_names)

# get image
imgs = cr_interface.get_img()
img_info = cr_interface.get_img_info()
rgb_image = cv2.cvtColor(imgs['color'], cv2.COLOR_BGR2RGB)
depth_image = imgs['depth']

# get mask
final_mask_image = wrapper.segment(rgb_image, object_names)
binary_mask = np.array(final_mask_image) > 0
depth_image = depth_image * binary_mask

rgbd_pc = O3DPointCloud(max_points=50000)
rgbd_pc.create_from_rgbd(rgb_image, depth_image, intrinsics, depth_trunc=1.0)
rgbd_pc.transform(extrinsics)
rgbd_pc.preprocess()

pcd_points, pcd_colors = rgbd_pc.get_points(), rgbd_pc.get_colors()

rgbd_cp = O3DPointCloud(max_points=50000)
rgbd_cp.create_from_rgbd(rgb_image, depth_image, intrinsics, depth_trunc=1.0)
rgbd_cp.transform(extrinsics)
rgbd_cp.preprocess()
object_center = obtain_object_target_in_head_frame(rgbd_cp, status['direction'], np.array(status['T_world_head']))

print("T_world_head: ", status['T_world_head'])

with open(os.path.join(root_path, 'tmp_results/status.json'), 'w') as f:
    status['object_target_in_head'] = object_center.tolist()
    status['ready'] = 1
    json.dump(status, f)

target_pos_in_head = object_center + status['wrist_offset_in_head']
plotly_draw_3d_pcd(pcd_points, 
                   pcd_colors,
                   addition_points=np.array([
                       [object_center],
                       [target_pos_in_head]
                   ]),
                   marker_size=10)

object_names:  ['blue plate']



annotate is deprecated: `BoxAnnotator` is deprecated and will be removed in `supervision-0.22.0`. Use `BoundingBoxAnnotator` and `LabelAnnotator` instead



object_center_in_world= [ 0.37381331 -0.0013841  -0.5208413 ]
max_world_height= -0.5060636251559821 min_world_height= -0.5300893331376609 diff= 0.0240257079816788
T_world_head:  [[0.7658634352324868, 0.001545957953009239, 0.6430014063654155, 0.0], [0.004175205052142207, 0.9999640717685971, -0.00737718338801283, 0.0], [-0.6429897092774246, 0.008334577732274967, 0.765829464422307, 0.0], [0.0, 0.0, 0.0, 1.0]]


In [7]:
pcd_points, pcd_colors = rgbd_cp.get_points(), rgbd_cp.get_colors()
plotly_draw_3d_pcd(pcd_points, 
                   pcd_colors,
                   marker_size=10)

In [8]:
# get transformation between world frame and head frame, using chest frame and camera frame as intermediates
T_head_cam = extrinsics.copy()
T_head_chest = gr1.get_joint_pose_relative_to_head('base', np.array(status['q']))
R_head_chest = T_head_chest.copy()
R_head_chest[:3, 3] = 0
T_chest_cam = np.linalg.inv(R_head_chest) @ T_head_cam

rgbd_ori = O3DPointCloud(max_points=50000)
rgbd_ori.create_from_rgbd(rgb_image, depth_image, intrinsics, depth_trunc=1.0)
rgbd_ori.transform(T_chest_cam)
plane_model = rgbd_ori.plane_estimation()["plane_model"]
R_world_chest = estimate_rotation(plane_model, z_up=True)

rgbd_ori.transform(R_world_chest)
pcd_points, pcd_colors = rgbd_ori.get_points(), rgbd_ori.get_colors()
plotly_draw_3d_pcd(pcd_points, 
                   pcd_colors,
                   marker_size=10)

KeyError: 'q'