In [None]:
import os
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.patches import Patch
from PIL import Image
from torchvision import transforms

sys.path.append("../")
from model import BiSeNet
from utils import LABEL_COLORS, label2rgb

%load_ext autoreload
%autoreload 2

In [None]:
# Load data
NR = 28973
img_p = Path(os.environ["HOME"] + "/Data/FacialAttributes/CelebAMask-HQ/CelebA-HQ-img")
mask_p = Path(os.environ["HOME"] + "/Data/FacialAttributes/CelebAMask-HQ/mask")
pil_img = Image.open(img_p / f"{NR}.jpg")
gt_mask = Image.open(mask_p / f"{NR}.png")
plt.imshow(pil_img)
plt.show()

In [None]:
# Set up model
net = BiSeNet(n_classes=19)
net.cuda()
net.load_state_dict(torch.load("../res/cp/79999_iter.pth"))
net.eval()

to_tensor = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)

# Run inference
with torch.no_grad():
    image = pil_img.resize((512, 512), Image.BILINEAR)
    img = to_tensor(image)
    img = torch.unsqueeze(img, 0)
    img = img.cuda()
    out = net(img)[0]
    parsing = out.squeeze(0).cpu().numpy().argmax(0)

print(parsing.shape)

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
axs: list[plt.Axes]

axs[0].imshow(pil_img)
axs[0].set_title("Image")
axs[1].imshow(label2rgb(np.array(gt_mask)))
axs[1].set_title("GT Mask")
axs[2].imshow(label2rgb(parsing))
axs[2].set_title("Pred Mask")
# Create a legend
legend_patches = [
    Patch(color=np.array(rgb), label=label)  # Normalize RGB to [0, 1] for matplotlib
    for rgb, label in zip(list(LABEL_COLORS.values()), list(LABEL_COLORS.keys()))
]
axs[2].legend(
    handles=legend_patches, bbox_to_anchor=(1.05, 1.05), loc="upper left", title="Classes"
)
plt.show()