The mask images are in RGB format, we need to transform them to label encoding and save them

In [22]:
import os
import torch
from tqdm import tqdm
from torchvision.io import read_image
from torchvision.transforms.functional import to_pil_image
from pathlib import Path
import pandas as pd


data_dir = Path("../data/")
classes = pd.read_csv(data_dir / "class_dict.csv")
class_colors = [tuple(row[1:].tolist()) for _, row in classes.iterrows()]
out_dir = data_dir / "train_masks"

# specify the directory you want to search for .jpg files
train_dir = data_dir / "train"
mask_files = [f for f in os.listdir(train_dir) if f.endswith('.png')]

def rgb_to_label(mask, class_colors):
    """
    Transforms a mask image from RGB format to label encoding.
        Parameters:
            mask: Torch tensor of shape (3, H, W)
            class_colors: list of tuples of the RGB values for each class
        Returns:
            Torch tensor of shape (H, W) of label enconded classes
    """
    h, w = mask.shape[1:]  # shape expected to be (C, H, W)
    semantic_map = torch.zeros((h, w), dtype=torch.uint8)
    for idx, color in enumerate(class_colors):
        color = torch.tensor(color).view(3, 1, 1)  # rgb value
        class_map = torch.all(torch.eq(mask, color), 0)
        semantic_map[class_map] = idx
    return semantic_map

In [38]:
# transform all the masks and save them in out_dir
for mask_file in tqdm(mask_files):
    img_mask = read_image(os.path.join(train_dir, mask_file))
    lenc_mask = rgb_to_label(img_mask, class_colors)
    to_pil_image(lenc_mask).save(str(out_dir / mask_file))

100%|██████████| 803/803 [12:42<00:00,  1.05it/s]
