In [None]:
import sys
import os
sys.path.append("./Grounded_Segment_Anything/recognize-anything")
sys.path.append("./Grounded_Segment_Anything/GroundingDINO")
sys.path.append("./Grounded_Segment_Anything/segment_anything")

import numpy as np
import torch
import cv2
import matplotlib.pyplot as plt
# import pykinect_azure as pykinect
import open3d as o3d

# Grounding DINO
import Grounded_Segment_Anything.GroundingDINO.groundingdino.datasets.transforms as T
from Grounded_Segment_Anything.GroundingDINO.groundingdino.models import build_model
from Grounded_Segment_Anything.GroundingDINO.groundingdino.util.slconfig import SLConfig
from Grounded_Segment_Anything.GroundingDINO.groundingdino.util.utils import (
    clean_state_dict, 
    get_phrases_from_posmap
)

# Segment Anything
from Grounded_Segment_Anything.segment_anything.segment_anything import (
    sam_model_registry,
    SamPredictor
)

from collections import defaultdict
from PIL import Image

from pointcloud import PointCloud
from projections import PointProjector
from aggregator import PointCloudAggregator

GROUNDING_DINO_CONFIG = "Grounded_Segment_Anything/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT = "Grounded_Segment_Anything/groundingdino_swint_ogc.pth"
SAM_CHECKPOINT = "Grounded_Segment_Anything/sam_vit_h_4b8939.pth"
BOX_THRESHOLD = 0.3
TEXT_THRESHOLD = 0.25
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BERT_BASE_UNCASED_PATH = None

# device_config = pykinect.default_configuration
# device_config.color_format = pykinect.K4A_IMAGE_FORMAT_COLOR_BGRA32
# device_config.color_resolution = pykinect.K4A_COLOR_RESOLUTION_720P
# device_config.depth_mode = pykinect.K4A_DEPTH_MODE_NFOV_2X2BINNED

# pykinect.initialize_libraries()

In [None]:
def get_rotation_translation(filepath) -> dict[np.ndarray]: ## filename -> transform
    world_transforms = dict()
    
    skip_header = True
    with open(filepath, 'r') as file:
        for line in file:
            if skip_header:
                skip_header = False
                continue

            data = line.split()[1:] # ignore timestamp
            filename = data.pop() + ".png"
            data = [float(p) for p in data]
            rigid_transform = quaternion_to_rigid_transform(*data)
            
            ## first do E, then try E inverse
            E = np.eye(4)
            E[:3] = rigid_transform
            world_transforms[filename] = E

    return world_transforms

def quaternion_to_rigid_transform(x, y, z, qx, qy, qz, qw) -> np.ndarray:
    # Normalize the quaternion
    norm = np.sqrt(qx**2 + qy**2 + qz**2 + qw**2)
    qx, qy, qz, qw = qx / norm, qy / norm, qz / norm, qw / norm

    # Compute the rotation matrix
    E_inv = np.zeros((3, 4))
    R = np.array([
        [1 - 2*(qy**2 + qz**2), 2*(qx*qy - qz*qw),     2*(qx*qz + qy*qw)],
        [2*(qx*qy + qz*qw),     1 - 2*(qx**2 + qz**2), 2*(qy*qz - qx*qw)],
        [2*(qx*qz - qy*qw),     2*(qy*qz + qx*qw),     1 - 2*(qx**2 + qy**2)]
    ])

    E_inv[:3, :3] = R
    E_inv[:, 3] = np.array([x, y, z])
    return E_inv

def prepare_image(image: np.ndarray):
    transform = T.Compose([
        T.RandomResize([800], max_size=1333),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    image_pil = Image.fromarray(image)
    image_tensor, _ = transform(image_pil, None)  # 3, h, w
    return image_pil, image_tensor

def load_model(model_config_path, model_checkpoint_path, bert_base_uncased_path, device):
    args = SLConfig.fromfile(model_config_path)
    args.device = device
    args.bert_base_uncased_path = bert_base_uncased_path

    model = build_model(args)
    checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
    load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
    print(load_res)

    model.eval()
    return model

def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
    caption = caption.lower().strip()
    if not caption.endswith("."):
        caption += "."

    model = model.to(device)
    image = image.to(device)

    with torch.no_grad():
        outputs = model(image[None], captions=[caption])
        logits = outputs["pred_logits"].cpu().sigmoid()[0]
        boxes = outputs["pred_boxes"].cpu()[0]

    filt_mask = logits.max(dim=1)[0] > box_threshold
    logits_filt = logits[filt_mask]
    boxes_filt = boxes[filt_mask]

    tokenized = model.tokenizer(caption)
    pred_phrases = [
        get_phrases_from_posmap(logit > text_threshold, tokenized, model.tokenizer) +
        (f"({str(logit.max().item())[:4]})" if with_logits else "")
        for logit, _ in zip(logits_filt, boxes_filt)
    ]

    return boxes_filt, pred_phrases

gd_model = load_model(GROUNDING_DINO_CONFIG, GROUNDING_DINO_CHECKPOINT, BERT_BASE_UNCASED_PATH, device=DEVICE)
sam_model = SamPredictor(sam_model_registry["vit_h"](checkpoint=SAM_CHECKPOINT).to(DEVICE))

In [None]:
## must have /rgb, /depth directories and /poses.txt file
DIRECTORY = "/home/bwilab/Semantic-Mapping-BWI/red_chair"
OBJECT_LABELS = "couch."

pointclouds = defaultdict(list) # Object label --> list[PointCloud]
projector = PointProjector()
aggregator = PointCloudAggregator(eps=0.05) ## higher eps == more merging, lower eps == more detail (or noise)

all_pointclouds = []

world_transforms = get_rotation_translation(f"{DIRECTORY}/poses.txt")
file_basenames = [os.path.basename(file) for file in os.listdir(f"{DIRECTORY}/rgb")]
file_basenames.sort(key=lambda x: int(x[:-4]))
rgb_images = [f"{DIRECTORY}/rgb/{file}" for file in file_basenames]
depth_images = [f"{DIRECTORY}/depth/{file}" for file in file_basenames]
assert(len(rgb_images) == len(depth_images) == len(world_transforms))

rgb_images = rgb_images[:15]
depth_images = depth_images[:15]

for i, (rgb_path, depth_path) in enumerate(zip(rgb_images, depth_images)):
    print(f"Processing frame {i+1}/{len(rgb_images)}")
    scene = defaultdict(list) # Object label --> list[masks]
    transform = world_transforms[os.path.basename(rgb_path)]

    ## Make sure frames match
    assert os.path.basename(rgb_path) == os.path.basename(depth_path)

    with Image.open(rgb_path) as color_image, Image.open(depth_path) as depth_image:
        ## Make sure images are same dims
        color_image, depth_image = np.array(color_image), np.array(depth_image)
        resized_color_image = cv2.resize(color_image, depth_image.shape[::-1])

    ## Feed through DINO
    image_pil, image_tensor = prepare_image(resized_color_image)
    boxes, pred_phrases = get_grounding_output(
        gd_model, image_tensor, OBJECT_LABELS, BOX_THRESHOLD, TEXT_THRESHOLD, device=DEVICE
    )

    ## Prepare SAM
    if torch.numel(boxes) == 0: # nothing found in frame
        continue
    
    sam_model.set_image(resized_color_image)
    W, H = image_pil.size
    for i in range(boxes.size(0)):
        boxes[i] *= torch.Tensor([W, H, W, H])
        boxes[i][:2] -= boxes[i][2:] / 2
        boxes[i][2:] += boxes[i][:2]

    ## SAM outputs
    transformed_boxes = sam_model.transform.apply_boxes_torch(boxes, resized_color_image.shape[:2]).to(DEVICE)
    masks, _, _ = sam_model.predict_torch(
        point_coords=None, point_labels=None, boxes=transformed_boxes.to(DEVICE), multimask_output=False
    )

    ## Associate outputs
        ## How to use
        # color_pixels = color_image * mask[:, :, None]
        # depth_pixels = depth_image * mask
    for mask, box, label in zip(masks, boxes, pred_phrases):
        mask = mask[0].cpu().numpy()
        x0, y0 = box[0], box[1]
        label = label[:label.index('(')] # remove confidence
        scene[label].append(mask)

    ## Generating pointclouds
    for label in scene:
        for mask in scene[label]:
            masked_depth_image = depth_image * mask
            masked_color_image = resized_color_image * mask[:, :, None]
            # plt.imshow(masked_color_image)
            # plt.show()
            # plt.imshow(masked_depth_image)
            # plt.show()

            pcl = projector.get_pointcloud(masked_depth_image, masked_color_image, stride=3)
            pcl.label = label
            # pcl.clean() # todo this method removes color
            if pcl.is_empty():
                continue

            pcl.transform(transform)
            target = aggregator.nearest_pointcloud(pcl)
            aggregator.aggregate_pointcloud(pcl, target)
            
            # aggregator._unmerged_pointclouds[label] += [pcl]
            # {'chair': [[pcl1, pcl2], [pcl3]]}
            # {'chair': [pcl1_done, pcl2_done]}
            
            ## TESTING 11/22
            # aggregator.add_unmerged_pointcloud(pcl)
            all_pointclouds.append(pcl)
                        
## TESTING 11/22: merge all instances
# aggregator.gather_pointclouds()
# aggregator.aggregate_all()
# projector.visualize(aggregator.main)


count = 0
all_pointclouds = all_pointclouds[:-3]

while len(all_pointclouds) > 1:
    print(len(all_pointclouds))
    new_pointclouds = []
    eps = 0.0001 * (10 ** count)
    for i in range(1, len(all_pointclouds), 2):
        print(f"i = {i}")
        if i >= len(all_pointclouds):
            break
        target = all_pointclouds[i] # 1
        source = all_pointclouds[i - 1] # 0 

        # transform = world_transforms[os.path.basename(color_images[i])]
        # all_pointclouds[i].transform(transform)
        
        source_transform = world_transforms[os.path.basename(rgb_images[i - 1])] if count == 0 else np.eye(4)
 
        target_transform = world_transforms[os.path.basename(rgb_images[i])] if count == 0 else np.eye(4)
        
        eps += i / 1000
        
        target.transform(target_transform)

        reg_p2p = o3d.pipelines.registration.registration_icp(
            source._pcl, target._pcl, eps, source_transform, o3d.pipelines.registration.TransformationEstimationPointToPoint()
        )

        source.transform(reg_p2p.transformation)
        new_pointclouds.append(target + source)

    all_pointclouds = new_pointclouds
    count += 1

    projector.visualize(all_pointclouds)
print(len(all_pointclouds))
projector.visualize(all_pointclouds)

# allpoints = []

# for label in pointclouds:
#     for pcl in pointclouds[label]:
#         allpoints.append(pcl)

# projector.visualize(allpoints)

## Show pointclouds
# projector.visualize(pointclouds['chair'][0])


In [None]:
len(rgb_images)
# print('objects', aggregator._scene.keys())
# projector.visualize(aggregator._scene['couch'])

# res = PointCloud()
# res.label = "chair"
# for pcl in allpoints:
#     res += pcl
# rgb_images.sort(key=lambda x: int(os.path.basename(x)[:-4]))
# print(rgb_images)

# projector.visualize(aggregator._unmerged_pointclouds['couch'])
# aggregator.gather_pointclouds()
# aggregator.aggregate_all()
# projector.visualize(aggregator._scene['couch'])