In [1]:
import cv2
import os
import sys
from PIL import Image
import numpy as np

In [17]:
def extract_foreground_with_grabcut(depth_map):
        # Ensure the depth map is a NumPy array and normalize it for visualization
        depth_map = cv2.normalize(depth_map, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

                # Convert the single-channel depth map to a 3-channel image
        depth_map_3ch = cv2.cvtColor(depth_map, cv2.COLOR_GRAY2BGR)
        
        # Initialize the mask
        mask = np.ones(depth_map_3ch.shape[:2], np.uint8) * 2  # Initialize all pixels as probable background
        
        # Assume the center region is more likely to be foreground
        height, width = depth_map_3ch.shape[:2]
        rect = (width // 4, height // 4, width // 2, height // 2)
        mask[rect[1]:rect[1]+rect[3], rect[0]:rect[0]+rect[2]] = 3  # Mark probable foreground
        
        # Initialize background and foreground models
        bgd_model = np.zeros((1, 65), np.float64)
        fgd_model = np.zeros((1, 65), np.float64)
        
        # Apply the GrabCut algorithm
        cv2.grabCut(depth_map_3ch, mask, rect, bgd_model, fgd_model, 5, cv2.GC_INIT_WITH_RECT)

        # Convert mask to binary: 1 (foreground), 0 (background)
        mask_foreground = np.where((mask == 2) | (mask == 0), 0, 1).astype('uint8')

        # Extract the foreground object
        result = depth_map * mask_foreground

        return result, mask_foreground

In [18]:
depth_map = cv2.imread('../assets/h8_depth.jpeg', cv2.IMREAD_GRAYSCALE)

scale_factor = 0.5  # Adjust this value as needed
new_size = (int(depth_map.shape[1] * scale_factor), int(depth_map.shape[0] * scale_factor))
depth_map = cv2.resize(depth_map, new_size)

# Ensure the image was loaded correctly
if depth_map is None:
        print("Failed to load depth map image.")
        exit()

# Extract the foreground object
foreground, mask = extract_foreground_with_grabcut(depth_map)

# Display the results
cv2.imwrite("original.png", depth_map)
cv2.imwrite("Foreground.png", foreground)
cv2.imwrite("mask.png", mask * 255)  # Scale mask for visualization

True