In [1]:
import os

import numpy as np
import cv2 as cv
import torch
import matplotlib.pyplot as plt
from PIL import Image

import open3d as o3d
from trimesh import Trimesh

from moge.model.v2 import MoGeModel
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

from superprimitive_fusion.utils import (get_integer_segments,
                   plot_region_numbers,
                   triangulate_segments,
                   smooth_mask,
                   crop_by_SP,
                   fill_ring_holes,
                   trimesh_to_o3d)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
device0 = torch.device("cuda:0")
device1 = torch.device("cuda:1")

torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

In [3]:
# Load SAM2
sam2_checkpoint = "../models/SAM2/sam2.1_hiera_tiny.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"

sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device0, apply_postprocessing=False)

mask_generator = SAM2AutomaticMaskGenerator(sam2)

# Load MoGe
moge_model = MoGeModel.from_pretrained("../models/MoGe/moge-2-vitl-normal.pt").to(device0)

In [14]:
image_path = '../data/mustard360/images/'
image_names = os.listdir(image_path)
image_names.sort(key=lambda name: int(name.split('_')[1].split('.')[0]))
frame_numbers = [int(name.split('_')[1].split('.')[0]) for name in image_names]

In [8]:
meshes = []
idx = [2, 2, 5, 5, 2, 3, 4, 4, 10, 4, 7, 4, 21, 1, 2, 1, 24, 4, 3, 23, 3, 16, 2, 2, 2, 2, 2, 4, 2, 2, 2, 3, 2]
for SPID,image_name in zip(idx,image_names):
    # Get input image and format for SAM and MoGe
    input_image = cv.cvtColor(cv.imread(image_path + image_name), cv.COLOR_BGR2RGB)                       
    input_image_moge = torch.tensor(input_image / 255, dtype=torch.float32, device=device0).permute(2, 0, 1)
    input_image_sam  = np.array(input_image)

    # Run inference for each model
    sam_masks = mask_generator.generate(input_image_sam)

    moge_output = moge_model.infer(input_image_moge)

    # Collapse each binary mask into a single integer mask
    segmentation_masks = np.array([mask['segmentation'] for mask in sam_masks])
    int_seg = get_integer_segments(segmentation_masks)

    # Get array of 3D points from MoGe output
    points = moge_output['points'].cpu().numpy()

    # Crop the masks, image and points around the desired primitive
    int_seg_cropped, img_cropped, points_cropped = crop_by_SP(SPID, int_seg, input_image, points, border=5, make_binary=True)

    # Fill holes and smooth the segmentation mask
    filled_int_seg = fill_ring_holes(int_seg_cropped, radius=2)
    int_seg_smoothed = smooth_mask(filled_int_seg, radius_erode=5, radius_dilate=None)

    # Process colours to add to mesh
    colours_cropped = np.clip(img_cropped, 0, 255).reshape((-1,3)).astype(np.uint8, casting='unsafe')
    alpha_channel = 255*np.ones((np.prod(img_cropped.shape[:2]),1), dtype=np.uint8)
    colours_cropped = np.hstack([colours_cropped, alpha_channel])

    # Triangulate the pointcloud
    verts_cropped = points_cropped.reshape((-1,3))
    tris = triangulate_segments(verts_cropped, int_seg_smoothed)
    all_tris = [tri for trise in tris for tri in trise]

    # Make a mesh
    mesh = Trimesh(vertices=verts_cropped, faces=all_tris, vertex_colors=colours_cropped)
    meshes.append(mesh)

In [17]:
for frame_number,mesh in zip(frame_numbers, meshes):
    pass
    # mesh.export(f'../data/mustard360/super-primitives/frame_{frame_number}.ply')
    # mesh_o3d = trimesh_to_o3d(mesh)
    # o3d.visualization.draw_geometries([mesh_o3d])

#### Superprimitive ID of the mustard bottle in each frame
idx = [2, 2, 5, 5, 2, 3, 4, 4, 10, 4, 7, 4, 21, 1, 2, 1, 24, 4, 3, 23, 3, 16, 2, 2, 2, 2, 2, 4, 2, 2, 2, 3, 2]

for frame numbers

frame_numbers = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 300, 310, 320]