In [None]:
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation
from PIL import Image
import requests


In [None]:
import matplotlib.pyplot as plt
import os
import numpy as np


# load image

In [None]:
image_name = "../../data/ade20k.jpeg"
# image_name = "../../data/cats.jpg"
# image_name = "../../data/penguins.jpg"


In [None]:
def download_image(url):
    response = requests.get(url, stream=True)
    response.raise_for_status()  # Check for HTTP errors
    return Image.open(response.raw)


In [None]:
def save_image(image, image_name):
    os.makedirs(os.path.dirname(image_name), exist_ok=True)
    image.save(image_name)

In [None]:
# read image if the file exists, else read from url.
if os.path.exists(image_name):
    image = Image.open(image_name)
else:
    image_name = "../../data/ade20k.jpeg"
    url = "https://huggingface.co/datasets/shi-labs/oneformer_demo/resolve/main/ade20k.jpeg"
    image = download_image(url)
    save_image(image, image_name)

In [None]:
image

# run segmentation

In [None]:
model_name = "shi-labs/oneformer_coco_swin_large"
task_type = "panoptic"

In [None]:
processor = OneFormerProcessor.from_pretrained(
model_name
)  # Load once here
model = OneFormerForUniversalSegmentation.from_pretrained(
model_name
)

In [None]:
inputs = processor(images=image, task_inputs=[
                    task_type], return_tensors="pt")

In [None]:
model

In [None]:
import torch

In [None]:
with torch.no_grad():
  outputs = model(**inputs)


In [None]:
print(f"outputs.keys()=", outputs.keys())

In [None]:
if task_type == "semantic":
    predicted_map = processor.post_process_semantic_segmentation(
        outputs, target_sizes=[image.size[::-1]])[0]
elif task_type == "panoptic":
    prediction = processor.post_process_panoptic_segmentation(
        outputs, target_sizes=[image.size[::-1]])[0]
    predicted_map = prediction["segmentation"]
    segments_info = prediction["segments_info"]

In [None]:
print(f"predicted_map.shape=", predicted_map.shape)
print(f"predicted_map.unique()=", predicted_map.unique())
if task_type == "panoptic":
  print(f"segments_info=", segments_info)
  print(f"prediction.keys()=", prediction.keys())

In [None]:
if segments_info is not None:
    for segment in segments_info:
        label = model.config.id2label[segment['label_id']]
        print(f"segment id = {segment['id']} : {label}")

In [None]:
os.makedirs("../../result", exist_ok=True)

In [None]:
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")
plt.subplot(1, 2, 2)
plt.imshow(predicted_map)
plt.title(task_type + "Segmentation")
plt.axis("off")
plt.show()
plt.savefig("../../result/oneformer_segm.png")

In [None]:
# draw segmentation image with label name on left top of the segment
# Function to calculate the centroid of a mask
def calculate_centroid(mask):
    if isinstance(mask, torch.Tensor):
        mask = mask.cpu().numpy()
    indices = np.argwhere(mask).astype(float)
    centroid = indices.mean(axis=0)
    return centroid[1], centroid[0]  # Return x, y coordinates

# Draw segmentation image with label name on the centroid of the segment
if task_type == "panoptic":
    plt.figure(figsize=(12, 6))
    plt.imshow(predicted_map)
    for segment in segments_info:
        label = model.config.id2label[segment['label_id']]
        segment_id = segment['id']
        mask = predicted_map == segment_id  # Create a binary mask for the segment
        centroid_x, centroid_y = calculate_centroid(mask)
        plt.text(centroid_x, centroid_y, label, fontsize=12, color='black')
    plt.title("Panoptic Segmentation with Label")
    plt.axis("off")
    plt.show()
    plt.savefig("../../result/oneformer_panoptic.png")