In [1]:
import sys
import os
sys.path.append("./Grounded_Segment_Anything/recognize-anything")
sys.path.append("./Grounded_Segment_Anything/GroundingDINO")
sys.path.append("./Grounded_Segment_Anything/segment_anything")

import numpy as np
import torch
import cv2
import matplotlib.pyplot as plt
# import pykinect_azure as pykinect
import open3d as o3d

# Grounding DINO
import Grounded_Segment_Anything.GroundingDINO.groundingdino.datasets.transforms as T
from Grounded_Segment_Anything.GroundingDINO.groundingdino.models import build_model
from Grounded_Segment_Anything.GroundingDINO.groundingdino.util.slconfig import SLConfig
from Grounded_Segment_Anything.GroundingDINO.groundingdino.util.utils import (
    clean_state_dict, 
    get_phrases_from_posmap
)

# Segment Anything
from Grounded_Segment_Anything.segment_anything.segment_anything import (
    sam_model_registry,
    SamPredictor
)

from collections import defaultdict
from PIL import Image
from tqdm import tqdm

from pointcloud import PointCloud
from projections import PointProjector
from aggregator import PointCloudAggregator

GROUNDING_DINO_CONFIG = "Grounded_Segment_Anything/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT = "Grounded_Segment_Anything/groundingdino_swint_ogc.pth"
SAM_CHECKPOINT = "Grounded_Segment_Anything/sam_vit_h_4b8939.pth"
BOX_THRESHOLD = 0.3
TEXT_THRESHOLD = 0.25
DEVICE = "cuda"
BERT_BASE_UNCASED_PATH = None

# device_config = pykinect.default_configuration
# device_config.color_format = pykinect.K4A_IMAGE_FORMAT_COLOR_BGRA32
# device_config.color_resolution = pykinect.K4A_COLOR_RESOLUTION_720P
# device_config.depth_mode = pykinect.K4A_DEPTH_MODE_NFOV_2X2BINNED

# pykinect.initialize_libraries()

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
def get_rotation_translation(filepath) -> dict[np.ndarray]: ## filename -> transform
    world_transforms = dict()
    
    skip_header = True
    with open(filepath, 'r') as file:
        for line in file:
            if skip_header:
                skip_header = False
                continue

            data = line.split()[1:] # ignore timestamp
            filename = data.pop() + ".png"
            data = [float(p) for p in data]
            rigid_transform = quaternion_to_rigid_transform(*data)
            
            ## first do E, then try E inverse
            E = np.eye(4)
            E[:3] = rigid_transform
            world_transforms[filename] = E

    return world_transforms

def quaternion_to_rigid_transform(x, y, z, qx, qy, qz, qw) -> np.ndarray:
    # Normalize the quaternion
    norm = np.sqrt(qx**2 + qy**2 + qz**2 + qw**2)
    qx, qy, qz, qw = qx / norm, qy / norm, qz / norm, qw / norm

    # Compute the rotation matrix
    E_inv = np.zeros((3, 4))
    R = np.array([
        [1 - 2*(qy**2 + qz**2), 2*(qx*qy - qz*qw),     2*(qx*qz + qy*qw)],
        [2*(qx*qy + qz*qw),     1 - 2*(qx**2 + qz**2), 2*(qy*qz - qx*qw)],
        [2*(qx*qz - qy*qw),     2*(qy*qz + qx*qw),     1 - 2*(qx**2 + qy**2)]
    ])

    E_inv[:3, :3] = R
    E_inv[:, 3] = np.array([x, y, z])
    return E_inv

def prepare_image(image: np.ndarray):
    transform = T.Compose([
        T.RandomResize([800], max_size=1333),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    image_pil = Image.fromarray(image)
    image_tensor, _ = transform(image_pil, None)  # 3, h, w
    return image_pil, image_tensor

def load_model(model_config_path, model_checkpoint_path, bert_base_uncased_path, device):
    args = SLConfig.fromfile(model_config_path)
    args.device = device
    args.bert_base_uncased_path = bert_base_uncased_path

    model = build_model(args)
    checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
    load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
    print(load_res)

    model.eval()
    return model

def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
    caption = caption.lower().strip()
    if not caption.endswith("."):
        caption += "."

    model = model.to(device)
    image = image.to(device)

    with torch.no_grad():
        outputs = model(image[None], captions=[caption])
        logits = outputs["pred_logits"].cpu().sigmoid()[0]
        boxes = outputs["pred_boxes"].cpu()[0]

    filt_mask = logits.max(dim=1)[0] > box_threshold
    logits_filt = logits[filt_mask]
    boxes_filt = boxes[filt_mask]

    tokenized = model.tokenizer(caption)
    pred_phrases = [
        get_phrases_from_posmap(logit > text_threshold, tokenized, model.tokenizer) +
        (f"({str(logit.max().item())[:4]})" if with_logits else "")
        for logit, _ in zip(logits_filt, boxes_filt)
    ]

    return boxes_filt, pred_phrases

gd_model = load_model(GROUNDING_DINO_CONFIG, GROUNDING_DINO_CHECKPOINT, BERT_BASE_UNCASED_PATH, device=DEVICE)
sam_model = SamPredictor(sam_model_registry["vit_h"](checkpoint=SAM_CHECKPOINT).to(DEVICE))

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


final text_encoder_type: bert-base-uncased


  checkpoint = torch.load(model_checkpoint_path, map_location="cpu")


_IncompatibleKeys(missing_keys=[], unexpected_keys=['label_enc.weight', 'bert.embeddings.position_ids'])


  state_dict = torch.load(f)


In [None]:
# to-do: 
# 1) better merging algorithm in aggregator.py
# 2) paint the pointcloud with provided method
# 3) project the pointcloud onto the thingy

## must have /rgb, /depth directories and /poses.txt file
DIRECTORY = "/home/bwilab/Documents/RTAB-Map/red_chair"
OBJECT_LABELS = "couch."

all_masks = defaultdict(list) # Object label --> list[masks]
pointclouds = defaultdict(list) # Object label --> list[PointCloud]
projector = PointProjector()
aggregator = PointCloudAggregator(eps=2.0) ## higher eps == more merging, lower eps == more detail (or noise)

world_transforms = get_rotation_translation(f"{DIRECTORY}/poses-id.txt")
rgb_images = [f"{DIRECTORY}/rgb/{file}" for file in os.listdir(f"{DIRECTORY}/rgb")] ## os.listdir is arbitrary
rgb_images.sort(key=lambda x: int(os.path.basename(x)[:-4]))

depth_images = [f"{DIRECTORY}/depth/{file}" for file in os.listdir(f"{DIRECTORY}/depth")]
depth_images.sort(key=lambda x: int(os.path.basename(x)[:-4]))
assert(len(rgb_images) == len(depth_images) == len(world_transforms))

rgb_images = rgb_images[:50]
depth_images = depth_images[:50]

for rgb_path, depth_path in tqdm(zip(rgb_images, depth_images), leave=False):
    scene = defaultdict(list) # Object label --> list[masks]
    transform = world_transforms[os.path.basename(rgb_path)]

    assert os.path.basename(rgb_path) == os.path.basename(depth_path)

    with Image.open(rgb_path) as color_image, Image.open(depth_path) as depth_image:
        ## Make sure images are same dims
        color_image, depth_image = np.array(color_image), np.array(depth_image)
        resized_color_image = cv2.resize(color_image, depth_image.shape[::-1])

    ## Feed through DINO
    image_pil, image_tensor = prepare_image(resized_color_image)
    boxes, pred_phrases = get_grounding_output(
        gd_model, image_tensor, OBJECT_LABELS, BOX_THRESHOLD, TEXT_THRESHOLD, device=DEVICE
    )

    ## Prepare SAM
    if torch.numel(boxes) == 0: # nothing found in frame
        continue
    
    sam_model.set_image(resized_color_image)
    W, H = image_pil.size
    for i in range(boxes.size(0)):
        boxes[i] *= torch.Tensor([W, H, W, H])
        boxes[i][:2] -= boxes[i][2:] / 2
        boxes[i][2:] += boxes[i][:2]

    ## SAM outputs
    transformed_boxes = sam_model.transform.apply_boxes_torch(boxes, resized_color_image.shape[:2]).to(DEVICE)
    masks, _, _ = sam_model.predict_torch(
        point_coords=None, point_labels=None, boxes=transformed_boxes.to(DEVICE), multimask_output=False
    )

    ## Associate outputs
        ## How to use
        # color_pixels = color_image * mask[:, :, None]
        # depth_pixels = depth_image * mask
    for mask, box, label in zip(masks, boxes, pred_phrases):
        mask = mask[0].cpu().numpy()
        x0, y0 = box[0], box[1]
        label = label[:label.index('(')] # remove confidence
        scene[label].append(mask)
        all_masks[label].append(mask)

    ## Generating pointclouds
    for label in scene:
        for mask in scene[label]:
            masked_depth_image = depth_image * mask
            masked_color_image = resized_color_image * mask[:, :, None]
            # plt.imshow(masked_color_image)
            # plt.show()
            # plt.imshow(masked_depth_image)
            # plt.show()

            pcl = projector.get_pointcloud(masked_depth_image, masked_color_image, stride=3)
            if pcl.is_empty():
                continue

            pcl.label = label
            # aggregator._unmerged_pointclouds[label] += [pcl]
            # {'chair': [[pcl1, pcl2], [pcl3]]}
            # {'chair': [pcl1_done, pcl2_done]}
            
            ## TESTING 11/22
            # aggregator.add_unmerged_pointcloud(pcl)
            
            #### if inverse extrinsic works use this
            #### and make a more robust icp algorithm (which pointcloud to merge onto?)
            target = aggregator.nearest_pointcloud(pcl, transform)
            target = aggregator.aggregate_pointcloud(pcl, target, transform, verbose=False)
            # target.clean()
            
            pointclouds[label].append(pcl)
            # projector.visualize(pcl)

## TESTING 11/22: merge all instances
# aggregator.gather_pointclouds()
# aggregator.aggregate_all()
projector.visualize(aggregator.main)

# allpoints = []

# for label in pointclouds:
#     for pcl in pointclouds[label]:
#         allpoints.append(pcl)

# projector.visualize(allpoints)

## Show masked images
# pcl_mask = scene['box'][0]
# masked_color = resized_color_image * pcl_mask[:, :, None]
# masked_depths = depth_image * pcl_mask
# fig, axes = plt.subplots(1, 3, figsize=(10,5))
# axes[0].axis('off'); axes[1].axis('off'); axes[2].axis('off')
# axes[0].imshow(masked_color); axes[1].imshow(masked_depths); axes[2].imshow(resized_color_image)
# plt.show()

## Show pointclouds
# projector.visualize(pointclouds['chair'][0])


0it [00:00, ?it/s]

1it [00:01,  1.24s/it]

2.2637760665113136e-05 3.724806104251828e-05


3it [00:02,  1.07it/s]

0.0002098660761104951 3.724806104251828e-05
swapped the pointclouds


7it [00:06,  1.34it/s]

0.0001649648477238602 9.970353169507759e-05
swapped the pointclouds


17it [00:12,  1.49it/s]

0.0001282240589524004 1.5225534825047905e-05
swapped the pointclouds


19it [00:14,  1.25it/s]

0.00010015335158024644 0.0
swapped the pointclouds


20it [00:15,  1.13it/s]

0.00011998280436912186 5.057382383019412e-07
swapped the pointclouds


24it [00:18,  1.12it/s]

0.0030290675656010355 3.42577678982847e-05
swapped the pointclouds
1.9497563040734045e-05 0.0008126510710826996


26it [00:19,  1.28it/s]

0.00016270010273645495 0.0008126510710826996


27it [00:21,  1.15it/s]

3.298754401536634e-05 3.8759304498684465e-05


28it [00:22,  1.09it/s]

0.00011040364101399879 0.0008126510710826996


30it [00:22,  1.41it/s]

7.545266508852438e-05 0.0008126510710826996


31it [00:23,  1.39it/s]

0.0014450484477943255 0.0008126510710826996
swapped the pointclouds


32it [00:24,  1.37it/s]

0.0 3.8759304498684465e-05


35it [00:26,  1.48it/s]

0.0001271979864221163 0.0003391435824565908


39it [00:29,  1.34it/s]

0.00010102136950248828 2.3050508834580266e-05
swapped the pointclouds


44it [00:33,  1.28it/s]

0.00033640921394500036 0.00010041161152241175
swapped the pointclouds


45it [00:34,  1.29it/s]

8.264670058622297e-05 0.00015566652197113032


46it [00:35,  1.23it/s]

4.124779884776593e-07 3.065029903663102e-05
5.6412753664846536e-05 0.00015566652197113032


47it [00:36,  1.20it/s]

0.0007771603172920297 0.00015566652197113032
swapped the pointclouds
0.13993850107682357 0.0
swapped the pointclouds


49it [00:37,  1.36it/s]

0.0002126761705503517 0.0003391435824565908


                       

0.0002169891702494571 0.0003391435824565908




In [None]:
# print('objects', aggregator._scene.keys())
# projector.visualize(aggregator._scene['couch'])

# res = PointCloud()
# res.label = "chair"
# for pcl in allpoints:
#     res += pcl
rgb_images.sort(key=lambda x: int(os.path.basename(x)[:-4]))
print(rgb_images)

# projector.visualize(aggregator._unmerged_pointclouds['couch'])
# aggregator.gather_pointclouds()
# aggregator.aggregate_all()
# projector.visualize(aggregator._scene['couch'])

['/home/bwilab/Documents/RTAB-Map/red_chair/rgb/8.png', '/home/bwilab/Documents/RTAB-Map/red_chair/rgb/12.png', '/home/bwilab/Documents/RTAB-Map/red_chair/rgb/13.png', '/home/bwilab/Documents/RTAB-Map/red_chair/rgb/18.png', '/home/bwilab/Documents/RTAB-Map/red_chair/rgb/19.png', '/home/bwilab/Documents/RTAB-Map/red_chair/rgb/20.png', '/home/bwilab/Documents/RTAB-Map/red_chair/rgb/24.png', '/home/bwilab/Documents/RTAB-Map/red_chair/rgb/26.png', '/home/bwilab/Documents/RTAB-Map/red_chair/rgb/30.png', '/home/bwilab/Documents/RTAB-Map/red_chair/rgb/34.png', '/home/bwilab/Documents/RTAB-Map/red_chair/rgb/40.png', '/home/bwilab/Documents/RTAB-Map/red_chair/rgb/41.png', '/home/bwilab/Documents/RTAB-Map/red_chair/rgb/43.png', '/home/bwilab/Documents/RTAB-Map/red_chair/rgb/45.png', '/home/bwilab/Documents/RTAB-Map/red_chair/rgb/46.png', '/home/bwilab/Documents/RTAB-Map/red_chair/rgb/50.png', '/home/bwilab/Documents/RTAB-Map/red_chair/rgb/55.png', '/home/bwilab/Documents/RTAB-Map/red_chair/rgb/5