In [1]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image, ImageDraw
from torchvision.io import read_image
from torchvision.transforms.functional import resize, to_pil_image
from torchvision.utils import draw_segmentation_masks, make_grid
from tqdm import tqdm

data_dir = Path("../data")
images_dir = data_dir / "images"
masks_dir = data_dir / "masks"
classes = pd.read_csv(data_dir / "class_dict.csv")

class_rgb_colors = [tuple(row[1:].tolist()) for _, row in classes.iterrows()]
class_names = classes["name"].tolist()
label_to_name = {idx: name for idx, name in enumerate(class_names)}


def label_to_onehot(mask, num_classes):
    """
    Transforms a label encoded tensor to one hot encoding.
        Parameters:
            mask: Torch tensor of shape (H, W)
            num_classes: Total number of classes:
        Returns:
            Torch tensor of shape (num_classes, H, W).
    """
    dims_p = (2, 0, 1) if mask.ndim == 2 else (0, 3, 1, 2)
    return torch.permute(
        F.one_hot(mask.type(torch.long), num_classes=num_classes).type(torch.bool),
        dims_p,
    )


## Legend
# Define the legend size, background color, and text parameters
legend_width = 140
legend_height = 190
text_color = (0, 0, 0)  # black
# Create a new image for the legend
legend_image = Image.new("RGB", (legend_width, legend_height), (220, 220, 220))
draw = ImageDraw.Draw(legend_image)
# Set the initial position for drawing rectangles and text
x = 10
y = 10
# Draw rectangles and labels for each legend item
for label, color in zip(class_names, class_rgb_colors):
    draw.rectangle([(x, y), (x + 20, y + 20)], fill=color)
    draw.text((x + 30, y), label, fill=text_color)
    y += 30
# Define the position to paste the legend onto the original image
legend_position = (10, 10)

In [2]:
class_per_mask = []
ids = [x.split("_")[0] for x in os.listdir(masks_dir)]
for id in tqdm(ids):
    mask = read_image(str(masks_dir / f"{id}_mask.png"))
    classes_in_image = torch.bincount(mask.view(-1), minlength=7) > 0
    class_per_mask.append(classes_in_image.tolist())

annot = pd.DataFrame(class_per_mask, columns=class_names, index=ids).reset_index(
    names="id"
)
annot.head()

100%|██████████| 803/803 [00:53<00:00, 15.14it/s]


Unnamed: 0,id,urban_land,agriculture_land,rangeland,forest_land,water,barren_land,unknown
0,211316,True,True,True,False,False,True,False
1,967818,False,True,False,True,False,True,False
2,645001,True,True,True,False,True,False,False
3,86805,True,True,True,False,False,True,False
4,312676,True,False,True,False,False,False,False


In [3]:
def viz_label(n_images, class_name, alpha=0.2, downsize_res=None):
    """
    Visualize a sample of masks of the desired class.
        Parameters:
            n_images: Number of images in the sample
            class_name: Class name string, one of [forest, rangeland, barren_land
            water, agriculture_land, urban_land]
            downsize_red: Resolution to downsize the images to
    """
    class_images = annot[class_name]
    ids = annot["id"][class_images]
    sample = np.random.choice(ids, n_images)

    imgs = []
    for img_id in sample:
        sat_img = read_image(str(images_dir / f"{img_id}_sat.jpg"))
        raw_masks = read_image(str(masks_dir / f"{img_id}_mask.png")).squeeze()
        if downsize_res is not None:
            sat_img = resize(sat_img, downsize_res)
            raw_masks = resize(raw_masks, downsize_res)
        masks = label_to_onehot(
            raw_masks,
            7,
        )
        mask_over_image = draw_segmentation_masks(
            sat_img, masks=masks, alpha=alpha, colors=class_rgb_colors
        )
        imgs.extend([sat_img, mask_over_image])

    grid = make_grid(imgs, nrow=2)

    pil_image = to_pil_image(grid)
    pil_image.paste(legend_image, legend_position)
    return pil_image

In [5]:
viz_label(4, "rangeland", alpha=0.2).show()






