In [1]:
import os
import gradio as gr
import numpy as np
import torch
import cv2
from segment_anything import sam_model_registry, SamPredictor

In [2]:
print("Model load ho raha hai...")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_TYPE = "vit_h"
CHECKPOINT = "sam_vit_h_4b8939.pth"

sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT).to(DEVICE)
predictor = SamPredictor(sam)
print("Model load ho gaya hai.")

Model load ho raha hai...
Model load ho gaya hai.


In [3]:
selected_image = None
combined_mask = None
original_filename = None

In [5]:
def apply_mask_to_image(image_rgb, mask):
    """Mask ko image par draw karta hai."""
    # Mask ko 3-channel banayein (Color ke liye)
    color = np.array([30, 144, 255]) # Blue color
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    
    # Original image ko wapas BGR karein display ke liye
    output_image = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
    
    # Mask ko thora transparent banayein
    # Isko uint8 mein convert karna zaroori hai
    output_image = cv2.addWeighted(output_image, 0.7, mask_image.astype(np.uint8), 0.3, 0)
    
    return output_image

def store_image_and_init(image):
    """Jab user image upload karta hai, toh yeh function run hota hai."""
    global selected_image, combined_mask
    # Image ko set_image ke liye RGB format mein store karein
    original_filename = getattr(image, 'orig_name', 'Original filename not available (older Gradio version or not a FileData object)')
    selected_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    predictor.set_image(selected_image)
    combined_mask = np.zeros(selected_image.shape[:2], dtype=bool)
    print("Image predictor mein set ho gayi hai.")
    # Image ko mask ke saath dikhane ke liye return karein
    return image

def segment_and_add_mask(evt: gr.SelectData):
    """Jab user image par click karta hai, toh yeh function run hota hai."""
    global selected_image, combined_mask
    if selected_image is None:
        return image # Agar koi image nahi hai toh wapas

    # Click ke coordinates (X, Y)
    input_point = np.array([[evt.index[0], evt.index[1]]])
    input_label = np.array([1]) # 1 = foreground point

    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=False,
    )

    new_mask = masks[0]
    combined_mask = np.logical_or(combined_mask, new_mask)
    
    output_image = apply_mask_to_image(selected_image, combined_mask)
    return output_image

def clear_all_masks():
    """Clear button dabane par yeh function run hota hai."""
    global combined_mask, selected_image
    
    if selected_image is None:
        return None # Koi image nahi hai
        
    # Mask ko dobara 0 se reset karein
    combined_mask = np.zeros(selected_image.shape[:2], dtype=bool)
    
    print("Masks cleared.")
    
    # Sirf original image (BGR) wapas bhej dein
    return cv2.cvtColor(selected_image, cv2.COLOR_RGB2BGR)

def save_image_to_png():
    """Save button dabane par yeh function run hota hai."""
    global selected_image, combined_mask, original_filename
    
    if combined_mask is None or original_filename is None:
        print("Error: Pehle image upload karein aur mask banayein.")
        return None

    # Final image (BGR) banayein
    output_image_bgr = apply_mask_to_image(selected_image, combined_mask)
    
    # Naya filename banayein
    base_name = os.path.splitext(original_filename)[0] # Naam (bina .jpg/.png)
    save_path = f"{base_name}_masked.png" # Naya naam
    
    # Image ko PNG format mein save karein
    try:
        cv2.imwrite(save_path, output_image_bgr)
        print(f"Image saved to: {save_path}")
        return save_path # File path return karein taake user download kar sake
    except Exception as e:
        print(f"Error saving image: {e}")
        return None

In [None]:
with gr.Blocks() as demo:
    gr.Markdown("# ðŸ¤– SAM Interactive UI (Multiple Masks)")
    gr.Markdown("Image upload karein, phir segment karne ke liye objects par click karein. 'Clear Masks' se sab reset ho jayega.")
    
    with gr.Row():
        # Input Image
        input_img = gr.Image(label="Image Upload Karein", type="numpy")
        # Output Image (mask ke saath)
        output_img = gr.Image(label=" ", type="numpy", show_label=False)

    with gr.Row():
        clear_btn = gr.Button("Clear Masks")
        save_btn = gr.Button("Save Image (PNG)")

    download_file = gr.File(label="Download Masked Image")

    # --- Actions ---
    # Jab image upload ho:
    input_img.upload(
        store_image_and_init,
        [input_img],
        [output_img] # Result ko output mein bhi dikhao
    )
    
    # Jab output image par 'click' ho:
    output_img.select(
        segment_and_add_mask,
        [], # Koi input nahi chahiye, click data 'evt' se ayega
        [output_img]
    )
    
    # Jab clear button par click ho:
    clear_btn.click(
        clear_all_masks,
        [],
        [output_img]
    )

    save_btn.click(
        save_image_to_png,
        [],
        [download_file]
    )

# --- UI ko Launch Karein ---
demo.launch(debug=True, share=True)

Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://da4b78cc201183b901.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


Image predictor mein set ho gayi hai.
Masks cleared.
Error: Pehle image upload karein aur mask banayein.
Masks cleared.
