In [None]:
!pip install opencv-python Pillow torch transformers segment-anything


In [None]:
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

In [None]:
import argparse
import cv2
import torch
import numpy as np
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
from PIL import Image

# Task 1: Segment Object Based on Text Prompt
def segment_object(image_path, class_prompt, output_path):
    # Load the SAM model
    sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
    sam.to(device="cuda" if torch.cuda.is_available() else "cpu")

    # Load the image
    image = cv2.imread(image_path)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Use SAM to generate mask for the prompt
    mask_generator = SamAutomaticMaskGenerator(sam)
    masks = mask_generator.generate(image_rgb)

    # Here, we assume the first mask corresponds to the object of interest.
    mask = masks[0]['segmentation']

    # Create a red mask and apply it where the object is detected
    red_mask = np.zeros_like(image)
    red_mask[mask] = [0, 0, 255]  # Red color in BGR format

    # Blend the red mask with the original image
    result_image = cv2.addWeighted(image, 0.7, red_mask, 0.3, 0)

    # Save the result
    cv2.imwrite(output_path, result_image)
    print(f"Segmented image saved to {output_path}")

def move_object(image_path, class_prompt, shift_x, shift_y, output_path):
    # Load the SAM model
    sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
    sam.to(device="cuda" if torch.cuda.is_available() else "cpu")

    # Load the image
    image = cv2.imread(image_path)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Use SAM to generate mask for the prompt
    mask_generator = SamAutomaticMaskGenerator(sam)
    masks = mask_generator.generate(image_rgb)

    # Get the mask of the object
    mask = masks[0]['segmentation'].astype(np.uint8)  # Ensure mask is uint8

    # Ensure mask is binary (0 or 255) for further operations
    mask[mask > 0] = 255

    # Extract the object using the binary mask
    object_region = cv2.bitwise_and(image, image, mask=mask)

    # Shift the object by x and y
    rows, cols, _ = image.shape
    translation_matrix = np.float32([[1, 0, shift_x], [0, 1, shift_y]])
    shifted_object = cv2.warpAffine(object_region, translation_matrix, (cols, rows))

    # Create an inverse mask to remove the object from the original location
    inverse_mask = cv2.bitwise_not(mask)

    # Remove the object from the original image by applying the inverse mask
    background = cv2.bitwise_and(image, image, mask=inverse_mask)

    # Add the shifted object back to the image
    result_image = cv2.add(background, shifted_object)

    # Save the result
    cv2.imwrite(output_path, result_image)
    print(f"Shifted object image saved to {output_path}")


if __name__ == "__main__":
    task = "segment"  # or "move"
    image_path = "./bagpack.jpg"  # Path to your uploaded image
    class_prompt = "bag"  # Define the class prompt
    output_path = "./segmented_output.png"  # Output path for the processed image
    shift_x = 80  # Shift in x-direction for 'move' task
    shift_y = 0  # Shift in y-direction for 'move' task

    # Run the segment or move task
    if task == "segment":
        segment_object(image_path, class_prompt, output_path)
    elif task == "move":
        move_object(image_path, class_prompt, shift_x, shift_y, output_path)

