In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install segment_anything
!pip install gradio
!pip install gradio_image_annotation

Collecting segment_anything
  Downloading segment_anything-1.0-py3-none-any.whl (36 kB)
Installing collected packages: segment_anything
Successfully installed segment_anything-1.0
Collecting gradio
  Downloading gradio-4.37.2-py3-none-any.whl (12.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.3/12.3 MB[0m [31m37.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting aiofiles<24.0,>=22.0 (from gradio)
  Downloading aiofiles-23.2.1-py3-none-any.whl (15 kB)
Collecting fastapi (from gradio)
  Downloading fastapi-0.111.0-py3-none-any.whl (91 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.0/92.0 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting ffmpy (from gradio)
  Downloading ffmpy-0.3.2.tar.gz (5.5 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting gradio-client==1.0.2 (from gradio)
  Downloading gradio_client-1.0.2-py3-none-any.whl (318 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m318.2/31

In [3]:
import gradio as gr
from gradio_image_annotation import image_annotator
import numpy as np
import matplotlib.pyplot as plt
import torch
from segment_anything import sam_model_registry
from skimage import transform
import torch.nn.functional as F
import cv2

def show_mask(mask, ax, random_color=False):
    """
    Display a mask on a matplotlib axis.

    Parameters
    ----------
    mask : numpy.ndarray
        The mask to display.
    ax : matplotlib.axes.Axes
        The axis to display the mask on.
    random_color : bool, optional
        Whether to use a random color for the mask. Default is False.
    """
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([251 / 255, 252 / 255, 30 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_box(box, ax):
    """
    Display a bounding box on a matplotlib axis.

    Parameters
    ----------
    box : list or numpy.ndarray
        The bounding box coordinates in the format [x0, y0, x1, y1].
    ax : matplotlib.axes.Axes
        The axis to display the bounding box on.
    """
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0, 0, 0, 0), lw=2))

@torch.no_grad()
def medsam_inference(medsam_model, img_embed, box_1024, H, W):
    """
    Perform inference using the MedSAM model.

    Parameters
    ----------
    medsam_model : torch.nn.Module
        The MedSAM model.
    img_embed : torch.Tensor
        The image embeddings.
    box_1024 : torch.Tensor
        The bounding box coordinates.
    H : int
        The height of the original image.
    W : int
        The width of the original image.

    Returns
    -------
    numpy.ndarray
        The segmentation mask.
    """
    box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
    if len(box_torch.shape) == 2:
        box_torch = box_torch[:, None, :]  # (B, 1, 4)

    sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
        points=None,
        boxes=box_torch,
        masks=None,
    )
    low_res_logits, _ = medsam_model.mask_decoder(
        image_embeddings=img_embed,  # (B, 256, 64, 64)
        image_pe=medsam_model.prompt_encoder.get_dense_pe(),  # (1, 256, 64, 64)
        sparse_prompt_embeddings=sparse_embeddings,  # (B, 2, 256)
        dense_prompt_embeddings=dense_embeddings,  # (B, 256, 64, 64)
        multimask_output=False,
    )

    low_res_pred = torch.sigmoid(low_res_logits)  # (1, 1, 256, 256)

    low_res_pred = F.interpolate(
        low_res_pred,
        size=(H, W),
        mode="bilinear",
        align_corners=False,
    )  # (1, 1, gt.shape)
    low_res_pred = low_res_pred.squeeze().cpu().numpy()  # (256, 256)
    medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
    return medsam_seg

# Load the model
MedSAM_CKPT_PATH = "/content/drive/MyDrive/medsam_vit_b.pth"
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Load the model architecture
medsam_model = sam_model_registry['vit_b']()
# Load the model weights
medsam_model.load_state_dict(torch.load(MedSAM_CKPT_PATH, map_location=device))
medsam_model = medsam_model.to(device)
medsam_model.eval()

def segment_image(image, box_coordinates):
    """
    Segment an image using the MedSAM model.

    Parameters
    ----------
    image : PIL.Image or numpy.ndarray
        The input image.
    box_coordinates : str
        The bounding box coordinates as a string in the format "x0,y0,x1,y1".

    Returns
    -------
    matplotlib.figure.Figure
        The figure with the input image and segmentation result.
    """
    img_np = np.array(image)
    if len(img_np.shape) == 2:
        img_3c = np.repeat(img_np[:, :, None], 3, axis=-1)
    else:
        img_3c = img_np
    H, W, _ = img_3c.shape

    img_1024 = transform.resize(img_3c, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8)
    img_1024 = (img_1024 - img_1024.min()) / np.clip(
        img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None
    )  # normalize to [0, 1], (H, W, 3)
    img_1024_tensor = torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(device)

    try:
        box_np = np.array([list(map(int, box_coordinates.split(',')))])  # parse the input string to get box coordinates
        if box_np.shape != (1, 4):
            raise ValueError("Box coordinates must be a list of four integers.")
    except Exception as e:
        raise ValueError(f"Invalid box coordinates: {e}")

    box_1024 = box_np / np.array([W, H, W, H]) * 1024
    with torch.no_grad():
        image_embedding = medsam_model.image_encoder(img_1024_tensor)  # (1, 256, 64, 64)

    medsam_seg = medsam_inference(medsam_model, image_embedding, box_1024, H, W)

    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(img_3c)
    show_box(box_np[0], ax[0])
    ax[0].set_title("Input Image and Bounding Box")
    ax[1].imshow(img_3c)
    show_mask(medsam_seg, ax[1])
    show_box(box_np[0], ax[1])
    ax[1].set_title("MedSAM Segmentation")
    plt.tight_layout()

    return fig

xmin = ymin = xmax = ymax = None

def get_boxes_json(annotations):
    """
    Get bounding boxes from annotations.

    Parameters
    ----------
    annotations : dict
        The annotations dictionary.

    Returns
    -------
    list
        The list of bounding boxes.
    """
    return annotations["boxes"]

with gr.Blocks() as interface:
    annotator = image_annotator(
        {"image": "/content/drive/MyDrive/img_demo.png"},
        label_list=["Segment"],
        label_colors=[(0, 255, 0)],
    )
    button_get = gr.Button("Get bounding boxes")
    json_boxes = gr.JSON()
    button_get.click(get_boxes_json, annotator, json_boxes)
    image_input = gr.Image(value="/content/drive/MyDrive/img_demo.png")
    bbox_input = gr.Textbox(label="Bounding Box Coordinates (x0,y0,x1,y1)")
    segment_button = gr.Button("Segment Image")
    output_plot = gr.Plot()
    segment_button.click(fn=segment_image, inputs=[image_input, bbox_input], outputs=output_plot)

# Launch the interface
interface.launch()


Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://3b6921d2fc6591d771.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


