# Image Segmentation

In [None]:
# Libraries required

# !pip install transformers
# !pip install gradio
# !pip install timm
# !pip install torchvision

In [None]:
from transformers.utils import logging
logging.set_verbosity_error()

## Mask Generation with SAM

The Segment Anything Model (SAM) model was released by Meta AI.

In [None]:
from transformers import pipeline

In [None]:
sam_pipe = pipeline("mask-generation",
    "./models/Zigeng/SlimSAM-uniform-77")

In [None]:
from PIL import Image

In [None]:
raw_image = Image.open('meta_llamas.jpg')
raw_image.resize((720, 375))

The higher the value of 'points_per_batch', the more efficient pipeline inference will be

In [None]:
output = sam_pipe(raw_image, points_per_batch=32)

In [None]:
from helper import show_pipe_masks_on_image

In [None]:
show_pipe_masks_on_image(raw_image, output)

## Faster Inference: Infer an Image and a Single Point

In [None]:
from transformers import SamModel, SamProcessor

In [None]:
model = SamModel.from_pretrained(
    "./models/Zigeng/SlimSAM-uniform-77")

processor = SamProcessor.from_pretrained(
    "./models/Zigeng/SlimSAM-uniform-77")

In [None]:
raw_image.resize((720, 375))

Segment the blue shirt Andrew is wearing.
Give any single 2D point that would be in that region (blue shirt).

In [None]:
input_points = [[[1600, 700]]]

In [None]:
inputs = processor(raw_image,
                 input_points=input_points,
                 return_tensors="pt")
# we want to return the pyTorch tensors so that's why use pt

In [None]:
import torch

In [None]:
with torch.no_grad():
    outputs = model(**inputs)

# inputs is a dictionary that's why we have to pass it as **

In [None]:
predicted_masks = processor.image_processor.post_process_masks(
    outputs.pred_masks,
    inputs["original_sizes"],
    inputs["reshaped_input_sizes"]
)

Length of predicted_masks corresponds to the number of images that are used in the input.

In [None]:
len(predicted_masks)

Inspecting the size of the first ([0]) predicted mask

In [None]:
predicted_mask = predicted_masks[0]
predicted_mask.shape

torch.Size([1, 3, 1500, 2880])

In [None]:
outputs.iou_scores

tensor([[[0.9583, 0.9551, 0.9580]]])

In [None]:
from helper import show_mask_on_image

In [None]:
for i in range(3):
    show_mask_on_image(raw_image, predicted_mask[:, i])

# Depth Estimation with DPT

In [None]:
depth_estimator = pipeline(task="depth-estimation",
                        model="./models/Intel/dpt-hybrid-midas")

In [None]:
raw_image = Image.open('gradio_tamagochi_vienna.png')
raw_image.resize((806, 621))

In [None]:
output = depth_estimator(raw_image)

In [None]:
output

Post-process the output image to resize it to the size of the original image.

In [None]:
output["predicted_depth"].shape

torch.Size([1, 384, 384])

In [None]:
output["predicted_depth"].unsqueeze(1).shape

torch.Size([1, 1, 384, 384])

In [None]:
prediction = torch.nn.functional.interpolate(
    output["predicted_depth"].unsqueeze(1),
    size=raw_image.size[::-1],
    mode="bicubic",
    align_corners=False,
)

In [None]:
prediction.shape

torch.Size([1, 1, 1242, 1612])

In [None]:
raw_image.size[::-1],

((1242, 1612),)

In [None]:
prediction

Normalize the predicted tensors (between 0 and 255) so that they can be displayed.

In [None]:
import numpy as np

In [None]:
output = prediction.squeeze().numpy()
formatted = (output * 255 / np.max(output)).astype("uint8")
depth = Image.fromarray(formatted)

In [None]:
depth

## Making Gradio Interface

In [None]:
import os
import gradio as gr
from transformers import pipeline

In [None]:
def launch(input_image):
    out = depth_estimator(input_image)

    # resize the prediction
    prediction = torch.nn.functional.interpolate(
        out["predicted_depth"].unsqueeze(1),
        size=input_image.size[::-1],
        mode="bicubic",
        align_corners=False,
    )

    # normalize the prediction
    output = prediction.squeeze().numpy()
    formatted = (output * 255 / np.max(output)).astype("uint8")
    depth = Image.fromarray(formatted)
    return depth

In [None]:
iface = gr.Interface(launch, 
                     inputs=gr.Image(type='pil'), 
                     outputs=gr.Image(type='pil'))

In [None]:
iface.launch(share=True, server_port=int(os.environ['PORT1']))

In [None]:
iface.close()