In [2]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image, ImageDraw
from torchvision.io import read_image
from torchvision.transforms.functional import resize, to_pil_image
from torchvision.utils import draw_segmentation_masks, make_grid
from tqdm import tqdm

data_dir = Path("../data")
images_dir = data_dir / "images"
masks_dir = data_dir / "masks"
num_classes = 7

int2str = {
    0: "urban_land",
    1: "agriculture_land",
    2: "rangeland",
    3: "forest_land",
    4: "water",
    5: "barren_land",
    6: "unknown",
}
int2rgb = {
    0: (0, 255, 255),
    1: (255, 255, 0),
    2: (255, 0, 255),
    3: (0, 255, 0),
    4: (0, 0, 255),
    5: (255, 255, 255),
    6: (0, 0, 0),
}


def label_to_onehot(mask, num_classes):
    """
    Transforms a label encoded tensor to one hot encoding.
        Parameters:
            mask: Torch tensor of shape (H, W)
            num_classes: Total number of classes:
        Returns:
            Torch tensor of shape (num_classes, H, W).
    """
    dims_p = (2, 0, 1) if mask.ndim == 2 else (0, 3, 1, 2)
    return torch.permute(F.one_hot(mask.long(), num_classes=num_classes).bool(), dims_p)


## Legend
# Define the legend size, background color, and text parameters
legend_width = 140
legend_height = 190
text_color = (0, 0, 0)  # black
# Create a new image for the legend
legend_image = Image.new("RGB", (legend_width, legend_height), (220, 220, 220))
draw = ImageDraw.Draw(legend_image)
# Set the initial position for drawing rectangles and text
x = 10
y = 10
# Draw rectangles and labels for each legend item
for idx in range(num_classes):
    name = int2str[idx]
    color = int2rgb[idx]
    draw.rectangle([(x, y), (x + 20, y + 20)], fill=color)
    draw.text((x + 30, y), name, fill=text_color)
    y += 30
# Define the position to paste the legend onto the original image
legend_position = (10, 10)


In [3]:
import sys

sys.path.insert(0, "..")
from dataset import LandcoverDataset
from torchvision import transforms
import torch

data_dir = Path("../data")
images_dir = data_dir / "images"
masks_dir = data_dir / "masks"


ds = LandcoverDataset(transform=transforms.ToTensor(), target_transform=transforms.PILToTensor())


In [4]:
train_ids, val_ids, test_ids = torch.utils.data.random_split(
    ds.image_ids, [454, 207, 142], generator=torch.Generator().manual_seed(42)
)


In [27]:
def get_annot(ids):
    class_per_mask = []
    for id in tqdm(ids):
        mask = read_image(str(masks_dir / f"{id}_mask.png"))
        classes_in_image = torch.bincount(mask.view(-1), minlength=7) > 0
        class_per_mask.append(classes_in_image.tolist())

    annot = pd.DataFrame(class_per_mask, columns=int2str.values(), index=list(ids)).reset_index(names=["id"])
    return annot


In [28]:
train_annot = get_annot(train_ids)
val_annot = get_annot(val_ids)
test_annot = get_annot(test_ids)

100%|██████████| 454/454 [00:17<00:00, 26.09it/s]
100%|██████████| 207/207 [00:07<00:00, 26.24it/s]
100%|██████████| 142/142 [00:05<00:00, 25.09it/s]


In [30]:
for annot, split_name in zip([train_annot, val_annot, test_annot], ["train", "val", "test"]):
    print("Split:", split_name)
    print("Number of images with each class:\n")
    counts = annot.drop(columns=["id"]).sum(0)
    props = counts / counts.sum()
    print(round(props, 3), "\n")

Split: train
Number of images with each class:

urban_land          0.209
agriculture_land    0.234
rangeland           0.169
forest_land         0.061
water               0.151
barren_land         0.129
unknown             0.047
dtype: float64 

Split: val
Number of images with each class:

urban_land          0.192
agriculture_land    0.218
rangeland           0.172
forest_land         0.073
water               0.153
barren_land         0.137
unknown             0.056
dtype: float64 

Split: test
Number of images with each class:

urban_land          0.215
agriculture_land    0.230
rangeland           0.146
forest_land         0.047
water               0.151
barren_land         0.157
unknown             0.053
dtype: float64 



In [31]:
def viz_label(n_images, class_name, annot, alpha=0.2, downsize_res=None):
    """
    Visualize a sample of masks of the desired class.
        Parameters:
            n_images: Number of images in the sample
            class_name: Class name string, one of [forest, rangeland, barren_land
            water, agriculture_land, urban_land]
            downsize_red: Resolution to downsize the images to
    """
    class_images = annot[class_name]
    ids = annot["id"][class_images]
    sample = np.random.choice(ids, n_images)

    imgs = []
    for img_id in sample:
        sat_img = read_image(str(images_dir / f"{img_id}_sat.jpg"))
        raw_masks = read_image(str(masks_dir / f"{img_id}_mask.png")).squeeze()
        if downsize_res is not None:
            sat_img = resize(sat_img, downsize_res)
            raw_masks = resize(raw_masks, downsize_res)
        masks = label_to_onehot(raw_masks, 7)
        mask_over_image = draw_segmentation_masks(sat_img, masks=masks, alpha=alpha, colors=list(int2rgb.values()))
        imgs.extend([sat_img, mask_over_image])

    grid = make_grid(imgs, nrow=2)

    pil_image = to_pil_image(grid)
    pil_image.paste(legend_image, legend_position)
    return pil_image


In [32]:
viz_label(4, "rangeland", annot=train_annot, alpha=0.2).show()






