In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Image Segmentation on Vertex AI


<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/vision/getting-started/image_segmentation.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Google Colaboratory logo"><br> Open in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fvision%2Fgetting-started%2Fimage_segmentation.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Open in Colab Enterprise
    </a>
  </td>    
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/vision/getting-started/image_segmentation.ipynb">
      <img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo"><br> Open in Workbench
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/generative-ai/blob/main/vision/getting-started/image_segmentation.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
</table>

<div style="clear: both;"></div>

<b>Share to:</b>

<a href="https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/vision/getting-started/image_segmentation.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg" alt="LinkedIn logo">
</a>

<a href="https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/vision/getting-started/image_segmentation.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg" alt="Bluesky logo">
</a>

<a href="https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/vision/getting-started/image_segmentation.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/53/X_logo_2023_original.svg" alt="X logo">
</a>

<a href="https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/vision/getting-started/image_segmentation.ipynb" target="_blank">
  <img width="20px" src="https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png" alt="Reddit logo">
</a>

<a href="https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/vision/getting-started/image_segmentation.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg" alt="Facebook logo">
</a>            

| | |
|-|-|
|Author | [Jorj Ismailyan](https://github.com/jismailyan-google) |

## Overview

Vertex Image Segmentation brings Google's state of the art segmentation models to developers as a scalable and reliable service.

With Image Segmentation, developers can choose from five different modes to segment images and build AI products, including with a **text prompt** and **interactive** mode.

Learn more about [Image Segmentation on Vertex AI](https://docs.google.com/document/d/1y5H_m29zGM3Xt6ba2lMw_di6bpbvtQagpU-xY30Kx78/preview?hgd=1&resourcekey=0-_-4WVkfl0oS3nfBwIEhWWQ&tab=t.0).


### Objectives

In this notebook, you will be exploring the features of Vertex Image Segmentation using the Vertex AI Python SDK. You will

- Segment the foreground or background of an object
  - Create a product image by removing the background
  - Change the background color of an image
- Control the generated mask by configuring dilation
- Use an open-vocabulary text prompt to perform:
  - Object detection
  - Instance segmentation
- Draw a scribble to guide segmentation
  - Perform point-to-mask segmentation

### Access

Request access for your Google Cloud project by visiting the [Image Segmentation model card](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/image-segmentation-001) in Vertex AI Model Garden.

### Costs

- This notebook uses billable components of Google Cloud:
  - Vertex AI

- Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing) and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.

## Getting Started

### Install Vertex AI SDK for Python (Jupyter only)

In [None]:
%pip install --upgrade --user google-cloud-aiplatform

### Restart runtime (Jupyter only)
To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel.

The restart might take a minute or longer. After it's restarted, continue to the next step.

In [None]:
import IPython

app = IPython.Application.instance()
app.kernel.do_shutdown(True)

### Authenticate your notebook environment (Colab only)

If you are running this notebook on Google Colab, run the following cell to authenticate your environment. This step is not required if you are using [Vertex AI Workbench](https://cloud.google.com/vertex-ai-workbench).

In [None]:
import sys

# Additional authentication is required for Google Colab
if "google.colab" in sys.modules:
    # Authenticate user to Google Cloud
    from google.colab import auth

    auth.authenticate_user()

### Set Google Cloud project information and initialize Vertex AI SDK

To get started using Vertex AI, you must have an existing Google Cloud project and enable the Vertex AI API.

In [None]:
PROJECT_ID = "[your project id]"  # @param {type:"string"}
LOCATION = "us-central1"  # @param ["asia-northeast1", "asia-northeast3", "asia-southeast1", "europe-west1", "europe-west2", "europe-west3", "europe-west4", "europe-west9", "northamerica-northeast1", "us-central1", "us-east4", "us-west1", "us-west4"]

import vertexai

vertexai.init(project=PROJECT_ID, location=LOCATION)

print(f"Vertex AI SDK client initialized on project {PROJECT_ID} in {LOCATION}.")

In [None]:
# @title Import libraries and define helper functions
# @markdown Run this cell before proceeding to import libraries and define
# @markdown utility functions.
import imghdr
import io
import random

from PIL import Image, ImageDraw
from google.colab import files
import matplotlib.pyplot as plt
import numpy as np
from vertexai.preview.vision_models import (
    ImageSegmentationModel,
    ImageSegmentationResponse,
)
from vertexai.preview.vision_models import Image as Vertex_Image

# Resizes image bytes while preserving the aspect ratio.


def get_resized_bytes(image_bytes, max_size):
    image = Image.open(io.BytesIO(image_bytes))
    buffered = io.BytesIO()
    image_type = imghdr.what(None, h=image_bytes)
    image.thumbnail((max_size, max_size))
    image.save(buffered, image_type)
    return buffered.getvalue()


# Displays PIL images in a row.
def display_row(images: list, figsize: tuple[int, int] = (12, 12)):
    fig, axes = plt.subplots(1, len(images), figsize=figsize, squeeze=False)

    for i, ax in enumerate(axes.ravel()):
        ax.imshow(images[i].convert("RGBA"))
        ax.axis("off")

    plt.show()


# Extracts masks from the response and overlays them onto the base image.
def overlay_masks(
    input_image: Image.Image, response: ImageSegmentationResponse
) -> Image.Image:
    # Make the original image grayscale to keep overlayed masks visible.
    overlayed_image = input_image.copy().convert("L").convert("RGB")

    for mask in response.masks:
        mask_pil = mask._pil_image
        # Gives the mask a distinct color and transparent background.
        color = (
            random.randint(0, 255),
            random.randint(0, 255),
            random.randint(0, 255),
            128,
        )
        colored_mask = Image.new("RGBA", mask_pil.size, color)
        colored_mask = Image.composite(
            colored_mask, Image.new("RGBA", mask_pil.size), mask_pil
        )

        overlayed_image.paste(colored_mask, (0, 0), colored_mask)

    return overlayed_image


# Calculates the bounding box coordinates of the masked area in a mask image.


def get_bounding_box(mask: Image.Image) -> tuple | None:
    mask_array = np.array(mask.convert("L")) > 0
    rows, cols = np.nonzero(mask_array)
    if rows.size == 0:
        return None

    x1 = np.min(cols)
    y1 = np.min(rows)
    x2 = np.max(cols)
    y2 = np.max(rows)

    return (x1, y1, x2 + 1, y2 + 1)


# Converts a segmentation response to labeled bounding boxes.
def get_labeled_boxes(response: ImageSegmentationResponse) -> list:
    labeled_boxes = []
    for mask in response.masks:
        bounding_box = get_bounding_box(mask._pil_image)
        if bounding_box:
            entity_label = mask.labels[0]
            score = round(float(entity_label.score), 3)
            labeled_box = (entity_label.label, score, bounding_box)
            labeled_boxes.append(labeled_box)

    return labeled_boxes


# Draws bounding boxes on a base image around each generated mask.
def draw_bounding_boxes(
    base_image: Image.Image, response: ImageSegmentationResponse
) -> Image.Image:
    bbox_image = base_image.copy()
    labeled_boxes = get_labeled_boxes(response)
    color = "green"
    draw = ImageDraw.Draw(bbox_image)
    for box in labeled_boxes:
        bounding_box = box[2]
        draw.rectangle(bounding_box, outline=color, width=2)

        text_label = f"{box[0]}: {box[1]}"
        text_width = (len(text_label) * 5) + 3  # Add 2 for padding
        text_height = 12
        label_x = bounding_box[0]
        label_y = bounding_box[1]  # Position label above the box

        # Draw a filled rectangle as the background for the label
        draw.rectangle(
            (label_x, label_y, label_x + text_width, label_y + text_height),
            fill=color,
        )
        draw.text((label_x + 2, label_y), text_label, fill="white")

    return bbox_image


IMAGE_SEGMENTATION_MODEL = "image-segmentation-001"
segmentation_model = ImageSegmentationModel.from_pretrained(IMAGE_SEGMENTATION_MODEL)

## Select an image to segment

Run this cell to enable and select the `Choose files` button.
You can then select an image file from your local device to upload.
Large images are resized to a maximum dimension of 640 pixels for faster processing.

In [None]:
images = files.upload()
raw_image_bytes = list(images.values())[0]
resized_bytes = get_resized_bytes(raw_image_bytes, 640)

BASE_IMAGE = Vertex_Image(resized_bytes)
BASE_IMAGE_PIL = BASE_IMAGE._pil_image
w, h = BASE_IMAGE_PIL.size

display_row([BASE_IMAGE_PIL], (6, 6))
print(f"Base image width x height: {w} x {h}")

## Segment images using different modes

You can generate image masks with different Image Segmentation features by setting the `mode` field to one of the available options:
* **Foreground**: Generate a mask of the segmented foreground of the image.
* **Background**: Generate a mask of the segmented background of the image.
* **Semantic**: Select the items in an image to segment from a set of 194 classes.
* **Prompt**: Use an open-vocabulary text prompt to guide the image segmentation.
* **Interactive**: Draw a rough mask to guide the model segmentation.

### Foreground segmentation request

This section will explores how to edit images using different `edit_mode` and `mask_mode` parameter options.

In [None]:
gcs_uri = None  # gs:// path to the input image
mode = "foreground"  # Segmentation mode [foreground,background,semantic,prompt,interactive]
prompt = None  # Prompt to guide segmentation for `semantic` and `prompt` modes
scribble = None  # Input scribble for `interactive` segment mode
mask_dilation = (
    None  # Optional mask dilation for thin objects. Numeric value between 0 and 1.
)
max_predictions = (
    None  # Optional maximum predictions limit for prompt mode. Unlimited by default.
)
confidence_threshold = (
    None  # Optional confidence limit for prompt/background/foreground modes.
)

response = segmentation_model.segment_image(
    BASE_IMAGE,
    prompt,
    scribble,
    mode,
    mask_dilation,
    max_predictions,
    confidence_threshold,
)
MASK_PIL = response.masks[0]._pil_image
display_row([BASE_IMAGE_PIL, MASK_PIL])

#### Background removal
Use the foreground segmentation mask you created above to make the image background transparent.

In [None]:
# Creates an empty transparent background.
transparent_background = Image.new("RGBA", BASE_IMAGE_PIL.size, (128, 128, 128, 255))

# Uses the mask to cut and paste the foreground object in the original image
# onto the transparent background.
transparent_background.paste(BASE_IMAGE_PIL, mask=MASK_PIL)
transparent_background.putalpha(MASK_PIL)

display_row([transparent_background], (6, 6))

#### Change background color

In [None]:
# RGBA color light blue
color = (141, 224, 254, 255)
gray_background = Image.new("RGBA", BASE_IMAGE_PIL.size, color)
gray_background.paste(BASE_IMAGE_PIL, mask=MASK_PIL)

display_row([gray_background], (6, 6))

### Background segment mode

Generate background masks.

In [None]:
response = segmentation_model.segment_image(
    BASE_IMAGE, mode="background", mask_dilation=None
)
MASK_PIL = response.masks[0]._pil_image
display_row([BASE_IMAGE_PIL, MASK_PIL])

### Semantic segment mode

Specify the objects to segment from the set of 194 classes. The full set is available in the Appendix section at the end of this tutorial. You can specify multiple classes by delimiting with commas, e.g. `prompt="cat, dog"`

The semantic segmenter will return a single prediction containing the generated mask. If the classes in the prompt are detected, they are masked in white pixels and the background will be black. If the requested classes are not detected in the image, the whole mask will be black.

In [None]:
response = segmentation_model.segment_image(
    BASE_IMAGE, prompt="motorcycle, bus", mode="semantic", mask_dilation=None
)
MASK_PIL = response.masks[0]._pil_image
display_row([BASE_IMAGE_PIL, MASK_PIL])

### Prompt instance segmentation mode

You can use Prompt mode to perform detection and segmentation on many instances of your suggested objects. The response can generate multiple masks, along with one or more associated labels for each mask. Each label also contains an confidence score. Only objects matching labels specified in the request prompt are detected and segmented. The prompt is completely open-vocabulary, it is not limited to any class set.

**Recommended**:
* Use the confidence_threshold and max_predictions parameters to filter and limit results
* You can request multiple items be detected by separating them with commas. Hundreds of classes can be set in a single prompt.

In [None]:
prompt = "green watermelon, cantaloupe, price tag"
threshold = 0.1

response = segmentation_model.segment_image(
    BASE_IMAGE,
    prompt=prompt,
    mode="prompt",
    mask_dilation=None,
    max_predictions=None,
    confidence_threshold=threshold,
)

count = str(len(response.masks))
print(f"Detected {count} objects at threshold {threshold}.")

bbox_image = draw_bounding_boxes(BASE_IMAGE_PIL, response)
overlayed_image = overlay_masks(BASE_IMAGE_PIL, response)
display_row([BASE_IMAGE_PIL, bbox_image, overlayed_image], figsize=(25, 25))

## Appendix

### Semantic segmentation classes

| Class ID | Class ID | Class ID | Class ID |
| --- | --- | --- | --- |
|   backpack  |   broccoli  |   road  |   mountain_hill   |
|   umbrella  |   carrot  |   snow  |   rock    |
|   bag |   hot_dog |   sidewalk_pavement |   frisbee   |
|   tie |   pizza |   runway  |   skis    |
|   suitcase  |   donut |   terrain |   snowboard   |
|   case  |   cake  |   book  |   sports_ball   |
|   bird  |   fruit_other |   box |   kite    |
|   cat |   food_other  |   clock |   baseball_bat    |
|   dog |   chair_other |   vase  |   baseball_glove    |
|   horse |   armchair  |   scissors  |   skateboard    |
|   sheep |   swivel_chair  |   plaything_other |   surfboard   |
|   cow |   stool |   teddy_bear  |   tennis_racket   |
|   elephant  |   seat  |   hair_dryer  |   net   |
|   bear  |   couch |   toothbrush  |   base    |
|   zebra |   trash_can |   painting  |   sculpture   |
|   giraffe |   potted_plant  |   poster  |   column    |
|   animal_other  |   nightstand  |   bulletin_board  |   fountain    |
|   microwave |   bed |   bottle  |   awning    |
|   radiator  |   table |   cup |   apparel   |
|   oven  |   pool_table  |   wine_glass  |   banner    |
|   toaster |   barrel  |   knife |   flag    |
|   storage_tank  |   desk  |   fork  |   blanket   |
|   conveyor_belt |   ottoman |   spoon |   curtain_other   |
|   sink  |   wardrobe  |   bowl  |   shower_curtain    |
|   refrigerator  |   crib  |   tray  |   pillow    |
|   washer_dryer  |   basket  |   range_hood  |   towel   |
|   fan |   chest_of_drawers  |   plate |   rug_floormat    |
|   dishwasher  |   bookshelf |   person  |   vegetation    |
|   toilet  |   counter_other |   rider_other |   bicycle   |
|   bathtub |   bathroom_counter  |   bicyclist |   car   |
|   shower  |   kitchen_island  |   motorcyclist  |   autorickshaw    |
|   tunnel  |   door  |   paper |   motorcycle    |
|   bridge  |   light_other |   streetlight |   airplane    |
|   pier_wharf  |   lamp  |   road_barrier  |   bus   |
|   tent  |   sconce  |   mailbox |   train   |
|   building  |   chandelier  |   cctv_camera |   truck   |
|   ceiling |   mirror  |   junction_box  |   trailer   |
|   laptop  |   whiteboard  |   traffic_sign  |   boat_ship   |
|   keyboard  |   shelf |   traffic_light |   slow_wheeled_object   |
|   mouse |   stairs  |   fire_hydrant  |   river_lake    |
|   remote  |   escalator |   parking_meter |   sea   |
|   cell phone  |   cabinet |   bench |   water_other   |
|   television  |   fireplace |   bike_rack |   swimming_pool   |
|   floor |   stove |   billboard |   waterfall   |
|   stage |   arcade_machine  |   sky |   wall    |
|   banana  |   gravel  |   pole  |   window    |
|   apple |   platform  |   fence |   window_blind    |
|   sandwich  |   playingfield  |   railing_banister  |       |
|   orange  |   railroad  |   guard_rail  |       |
