In [None]:
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation
from PIL import Image
import requests
import matplotlib.pyplot as plt
import os


# load image

In [None]:
image_name = "../../data/ade20k.jpeg"
# read image if the file exists, else read from url.
if os.path.exists(image_name):
  image = Image.open(image_name)
else:
  url = "https://huggingface.co/datasets/shi-labs/oneformer_demo/resolve/main/ade20k.jpeg"
  response = requests.get(url, stream=True)
  response.raise_for_status()  # Check for HTTP errors
  image = Image.open(response.raw)

In [None]:
image

# run segmentation

In [None]:
model_name = "shi-labs/oneformer_coco_swin_large"
task_type = "semantic"

In [None]:
processor = OneFormerProcessor.from_pretrained(
model_name
)  # Load once here
model = OneFormerForUniversalSegmentation.from_pretrained(
model_name
)

In [None]:
inputs = processor(images=image, task_inputs=[
                    "semantic"], return_tensors="pt")


In [None]:
model

In [None]:
import torch

In [None]:
with torch.no_grad():
  outputs = model(**inputs)


In [None]:
predicted_map = processor.post_process_semantic_segmentation(
    outputs, target_sizes=[image.size[::-1]]
)[0]

In [None]:
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")
plt.subplot(1, 2, 2)
plt.imshow(predicted_map)
plt.title("Segmentation")
plt.axis("off")
plt.show()
plt.savefig("../../result/oneformer_segm.png")