In [1]:
import torch
import pynvml

def get_least_used_gpu():
    """
    Finds the GPU with the lowest memory utilization and sets it as the device.

    Returns:
        str: Device to use ("cuda:x" for GPU or "cpu" if no GPU is available).
    """
    if not torch.cuda.is_available():
        print("No GPU available. Using CPU.")
        return "cpu"

    # Initialize NVIDIA Management Library
    pynvml.nvmlInit()
    device_count = pynvml.nvmlDeviceGetCount()

    min_usage = float("inf")
    best_gpu = None

    for i in range(device_count):
        handle = pynvml.nvmlDeviceGetHandleByIndex(i)
        info = pynvml.nvmlDeviceGetMemoryInfo(handle)

        # Calculate memory usage percentage
        used_memory = info.used / info.total * 100  # Memory usage in percentage
        # print(f"GPU {i}: {used_memory:.2f}% memory used.")

        # Track the GPU with the lowest memory usage
        if used_memory < min_usage:
            min_usage = used_memory
            best_gpu = i

    # Cleanup NVML
    pynvml.nvmlShutdown()

    if best_gpu is not None:
        print(f"Using GPU {best_gpu} (lowest memory usage: {min_usage:.2f}%).")
        return f"cuda:{best_gpu}"
    else:
        print("No suitable GPU found. Using CPU.")
        return "cpu"

In [2]:
from pytorch3d.io import load_objs_as_meshes
import os

# Setup
DEVICE = get_least_used_gpu()

# Set paths
DATA_DIR = "./data"
obj_filename = os.path.join(DATA_DIR, "dog_mesh/13466_Canaan_Dog_v1_L3.obj")

# Load obj file
mesh = load_objs_as_meshes([obj_filename], device=DEVICE)

Using GPU 0 (lowest memory usage: 2.49%).


In [3]:
import torch
import numpy as np
from torchvision.ops import box_convert
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from grounding_dino.groundingdino.util.inference import load_model, load_image, predict


SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
GROUNDING_DINO_CONFIG = "grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT = "gdino_checkpoints/groundingdino_swint_ogc.pth"
BOX_THRESHOLD = 0.35
DATA_DIR = "./data"
TEXT_THRESHOLD = 0.25

# build SAM2 image predictor
sam2_checkpoint = SAM2_CHECKPOINT
model_cfg = SAM2_MODEL_CONFIG
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE)
sam2_predictor = SAM2ImagePredictor(sam2_model)

# build grounding dino model
grounding_model = load_model(
    model_config_path=GROUNDING_DINO_CONFIG, 
    model_checkpoint_path=GROUNDING_DINO_CHECKPOINT,
    device=DEVICE
)

def grounded_sam2(img_path: str, text_prompt: str):
    # setup the input image and text prompt for SAM 2 and Grounding DINO
    # VERY important: text queries need to be lowercased + end with a dot
    text = text_prompt

    image_source, image = load_image(img_path)

    sam2_predictor.set_image(image_source)

    # FIXME: figure how does this influence the G-DINO model
    # changed bfloat16 to float16
    with torch.autocast(device_type="cuda", dtype=torch.float16):

        if torch.cuda.get_device_properties(torch.device(DEVICE)).major >= 8:
            # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.backends.cudnn.allow_tf32 = True

        boxes, confidences, labels = predict(
            model=grounding_model,
            image=image,
            caption=text,
            box_threshold=BOX_THRESHOLD,
            text_threshold=TEXT_THRESHOLD,
            device=DEVICE
        )

        # process the box prompt for SAM 2
        h, w, _ = image_source.shape
        boxes = boxes * torch.Tensor([w, h, w, h])
        input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()

        if input_boxes.size == 0:
            print(f"No objects detected in {img_path}. Skipping this image.")
            return None

        masks, scores, logits = sam2_predictor.predict(
            point_coords=None,
            point_labels=None,
            box=input_boxes,
            multimask_output=False,
        )

        """
        Post-process the output of the model to get the masks, scores, and logits for visualization
        """
        # convert the shape to (n, H, W)
        if masks.ndim == 4:
            masks = masks.squeeze(1)


        confidences = confidences.numpy().tolist()
        class_names = labels

        class_ids = np.array(list(range(len(class_names))))

        # labels = [
        #     f"{class_name} {confidence:.2f}"
        #     for class_name, confidence
        #     in zip(class_names, confidences)
        # ]

        labels = [class_name for class_name in class_names]

    return masks, labels   # return sam2 masks and labels

  warn(


final text_encoder_type: bert-base-uncased


In [6]:
img_path = "data/dog_mesh_views/view_00.png"
x = grounded_sam2(img_path, text_prompt="head. tail. legs.")
arr = x[0]
print(np.unique(arr))

[0. 1.]




In [7]:
import os
import torch
import numpy as np
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
    look_at_view_transform,
    PerspectiveCameras,
    RasterizationSettings,
    MeshRasterizer,
    MeshRenderer,
    SoftPhongShader,
    PointLights
)
from pytorch3d.io import load_objs_as_meshes
from torchvision.transforms import ToPILImage


def calculate_matrix_Xi(
    obj_file_path: str,
    batch_size: int,
    text_prompt: str,
    elevs: tuple,
    azims: tuple,
    save_dir: str,
    device: str = DEVICE
):
    """
    Calculate the face-region matrix X_i for a given mesh using grounded_sam2 masks.

    Args:
        obj_file_path (str): Path to the OBJ file representing the 3D mesh.
        batch_size (int): Number of images to render in a batch.
        text_prompt (str): Input text for grounded_sam2.
        elevs (tuple): Elevation range (start, end) in degrees.
        azims (tuple): Azimuth range (start, end) in degrees.
        device (str): Device to run the computation on ("cuda:x" or "cpu").

    Returns:
        np.ndarray: Matrix X_i of shape (num_faces, num_regions).
    """
    device = torch.device(device)

    # Load the mesh from the OBJ file
    mesh = load_objs_as_meshes([obj_file_path], device=device)
    num_faces = mesh.faces_packed().shape[0]

    # Parse semantic regions from text_prompt (labels end with '.')
    labels = text_prompt.strip().split('.')[:-1]
    num_regions = len(labels)
    labels_dict = {}
    for idx in range(len(labels)):
        labels_dict[labels[idx]] = idx

    # Initialize the face-region matrix Xi with zeros
    Xi = torch.zeros((num_faces, num_regions), device=device)

    # Generate rasterization and rendering settings
    raster_settings = RasterizationSettings(
        image_size=512,  # Customize based on desired output resolution
        blur_radius=0.0,
        faces_per_pixel=1,  # Nearest face only
        max_faces_per_bin=30000
    )

    # batched meshes
    meshes = mesh.extend(batch_size)

    # Create batches of elevation and azimuth angles
    elev_angles = torch.linspace(elevs[0], elevs[1], batch_size)
    azim_angles = torch.linspace(azims[0], azims[1], batch_size)
    # elev_grid, azim_grid = torch.meshgrid(elev_angles, azim_angles, indexing="ij")
    # elev_grid, azim_grid = elev_grid.flatten(), azim_grid.flatten()
    R, T = look_at_view_transform(60, elev=elev_angles, azim=azim_angles)

    # Create batched cameras
    cameras = PerspectiveCameras(
        device=device,
        R=R,
        T=T
    )

    # lights
    lights = PointLights(device=device, location=[[0.0, 0.0, -70.0]])

    # Initialize the rasterizer and shade
    rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
    shader = SoftPhongShader(device=device, cameras=cameras, lights=lights)


    # Render the images if needed (optional, just to save them)
    renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
    images = renderer(meshes, cameras=cameras, lights=lights)
    
    # Rasterize the mesh to get the fragments
    fragments = rasterizer(meshes_world=meshes)

    # Now access pix_to_face from fragments
    pix_to_face = fragments.pix_to_face[..., 0]  # Shape: (B, H, W)

    os.makedirs(save_dir, exist_ok=True)

    for batch_idx in range(pix_to_face.shape[0]):
        # Extract the RGB channels (H, W, 3)
        rgb_image = images[batch_idx, ..., :3]  # Take only RGB channels

        # Convert to uint8 (0-255 range) if needed
        rgb_image = (rgb_image.clamp(0, 1) * 255).byte()

        # Convert to PIL image and save
        image_pil = ToPILImage()(rgb_image.permute(2, 0, 1).cpu())  # (C, H, W) for ToPILImage
        image_path = f"{save_dir}/view_{batch_idx:02d}.png"
        image_pil.save(image_path)
        
        # Use grounded_sam2 to get masks and labels for each rendered image
        sam2_masks, sam2_labels = grounded_sam2(image_path, text_prompt)

        # Iterate through all masks in sam2_masks
        for mask_idx, label in enumerate(sam2_labels):
            # Map the label to its corresponding region index using labels_dict
            if label not in labels_dict:
                continue  # Skip unknown labels

            region_idx = labels_dict[label]  # Get the column index for this region

            # Convert the corresponding 2D mask to a PyTorch tensor
            region_mask = torch.tensor(sam2_masks[mask_idx] == 1, device=device)  # Foreground is 1

            # Extract valid face indices using the region mask
            valid_face_indices = pix_to_face[batch_idx][region_mask]
            valid_face_indices = valid_face_indices[valid_face_indices >= 0]  # Ignore background (-1)
            valid_face_indices %= num_faces

            if valid_face_indices.numel() > 0:  # If there are valid faces
                # Count occurrences of each face
                face_counts = torch.bincount(valid_face_indices, minlength=num_faces)

                # Update the face-region matrix Xi
                Xi[:, region_idx] += face_counts

    return Xi.cpu().numpy()

In [14]:
text_prompt = "head. tail. legs. body."
obj_file_path = "data/dog_mesh/13466_Canaan_Dog_v1_L3.obj"
save_dir = "./data/dog_mesh_views"
batch_size = 30

Xi = calculate_matrix_Xi(obj_file_path=obj_file_path,
                    batch_size=batch_size,
                    text_prompt=text_prompt,
                    elevs=(30, 40),
                    azims=(-180, 180),
                    save_dir=save_dir,
                    device=DEVICE)
print(np.unique(Xi))



[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17.
 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29.]
