In [None]:
import sys

import matplotlib.pyplot as plt
import numpy as np
import requests
import torch
from matplotlib.patches import Patch
from PIL import Image
from torch import nn
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor

sys.path.append("../")
from utils import LABEL_COLORS_DINU, label2rgb

%load_ext autoreload
%autoreload 2

In [None]:
# convenience expression for automatically determining device
if torch.cuda.is_available():  # Device for NVIDIA or AMD GPUs
    device = "cuda"
elif torch.backends.mps.is_available():  # Device for Apple Silicon (Metal Performance Shaders)
    device = "mps"
else:
    device = "cpu"
print(device)

In [None]:
# load models
image_processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing")
model.to(device)

# expects a PIL.Image or torch.Tensor
# url = "https://images.unsplash.com/photo-1539571696357-5a69c17a67c6"
# image = Image.open(requests.get(url, stream=True).raw)

image = Image.open("../data/025_08.jpg")
print(image)

In [None]:
# run inference on image
inputs = image_processor(images=image, return_tensors="pt").to(device)
outputs = model(**inputs)
logits = outputs.logits  # shape (batch_size, num_labels, ~height/4, ~width/4)

# resize output to match input image dimensions
upsampled_logits = nn.functional.interpolate(
    logits,
    size=image.size[::-1],
    mode="bilinear",
    align_corners=False,  # H x W
)

# get label masks
labels = upsampled_logits.argmax(dim=1)[0]

# move to CPU to visualize in matplotlib
labels_viz = labels.cpu().numpy()
print(labels_viz.shape)

In [None]:
# Create a legend
legend_patches = [
    Patch(color=np.array(rgb), label=label)  # Normalize RGB to [0, 1] for matplotlib
    for rgb, label in zip(list(LABEL_COLORS_DINU.values()), list(LABEL_COLORS_DINU.keys()))
]

fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs: list[plt.Axes]

axs[0].imshow(image)
axs[0].set_title("Image")
axs[1].imshow(label2rgb(labels_viz, author="dinu"))
axs[1].set_title("Pred Mask")
axs[1].legend(handles=legend_patches, bbox_to_anchor=(1.05, 1.05), loc="upper left", title="Classes")
plt.show()