In [1]:
import collections
import copy
import functools
import glob
import json
import os
import os.path
import shutil
import warnings

import geopandas as gpd
import numpy as np
import rasterio.mask

In [2]:
# Set up a temporary subset of the complete dataset to accommodate the currently annotated images.
os.makedirs("../dataset/temp/imgs", exist_ok=True)
os.makedirs("../dataset/temp/msks", exist_ok=True)

In [3]:
# Find the segmentation masks in the Roboflow output folder.
src_maskpaths = glob.glob("C:/Users/Dimit/Downloads/RoofSense.v2-batch-2-unchecked-.png-mask-semantic/train/*.png")
src_maskpaths.sort()

# Build the corresponding destination paths.
dst_maskpaths = []
for filepath in src_maskpaths:
    dst_maskpaths.append(f"../dataset/temp/msks/{os.path.basename(filepath[:filepath.index('_png')])}.tif")

In [4]:
# Group the masks by the corresponding tile ID.
def get_tile_id(filepath: str) -> str:
    return os.path.basename(filepath[:filepath.index("_")])


maskgroups: dict[str, dict[list[str]]] = collections.defaultdict(functools.partial(collections.defaultdict, list))
for src, dst in zip(src_maskpaths, dst_maskpaths):
    tile_id = get_tile_id(src)

    maskgroups[tile_id]["src"].append(src)
    maskgroups[tile_id]["dst"].append(dst)

In [5]:
# Copy the corresponding images to the subset.
for dst in dst_maskpaths:
    imagename = os.path.basename(dst)
    shutil.copy2(f"../dataset/imgs/{imagename}",
                 dst=f"../dataset/temp/imgs/{imagename}")

In [6]:
# Keep track of the class frequencies and the total number of valid pixels.
clsfreqs = {i: 0 for i in range(14)}
validpx = 0

img_src: rasterio.io.DatasetReader
msk_src: rasterio.io.DatasetReader
msk_dst: rasterio.io.DatasetWriter

for tile_id, filepaths in maskgroups.items():
    surfs = gpd.read_file(f"../temp/{tile_id}.surf.gpkg").dissolve()
    for src, dst in zip(filepaths["src"], filepaths["dst"]):
        imagename = os.path.basename(dst)
        with (warnings.catch_warnings()):
            warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)

            # Georeference the masks.
            with rasterio.open(f"../dataset/temp/imgs/{imagename}") as img_src, rasterio.open(src) as msk_src:
                msk_data = msk_src.read()

                msk_profile = copy.deepcopy(img_src.profile)
                msk_profile.update(count=1, dtype=np.uint8, nodata=0)

                with rasterio.open(dst, mode="w+", **msk_profile) as msk_dst:
                    msk_dst.write(msk_data)

                    # Remask the backgound.
                    msk_data, _ = rasterio.mask.mask(msk_dst, shapes=surfs["geometry"])

                    # Replace invalid pixels with background.
                    msk_data[msk_data == 6] = 0

                    msk_dst.write(msk_data)

                # Update the class frequencies and valid pixel count.
                for cls, freq in zip(*np.unique(msk_data[msk_data != 0], return_counts=True)):
                    clsfreqs[cls] += freq

                validpx += np.count_nonzero(msk_data)

In [7]:
src_names = {
    0: "__ignore__",
    1: "Asphalt Shingles",
    2: "Bituminous Coating / Membranes",
    3: "Ceramic Tiles",
    4: "Concrete",
    5: "Gravel",
    6: "Invalid",
    7: "Light-permitting Opening",
    8: "Metal",
    9: "Non-bituminous Coating / Membranes",
    10: "Other",
    11: "Solar Panel Installation",
    12: "Superstructure",
    13: "Vegetation"}

src_colors = {
    0: [0, 0, 0, 255],
    1: [1, 25, 89, 255],
    2: [250, 204, 250, 255],
    3: [130, 130, 49, 255],
    4: [33, 95, 96, 255],
    5: [241, 156, 107, 255],
    6: [255, 255, 255, 255],
    7: [77, 114, 76, 255],
    8: [17, 67, 96, 255],
    9: [253, 179, 179, 255],
    10: [192, 144, 53, 255],
    11: [22, 82, 98, 255],
    12: [252, 191, 213, 255],
    13: [251, 167, 144, 255]

}

# Discard the annotation classes which do not appear in the destination masks.
invalid_classes = []
for cls, freq in clsfreqs.items():
    if cls != 0 and freq == 0:
        invalid_classes.append(cls)

for cls in invalid_classes:
    clsfreqs.pop(cls)

# Remap the remaining classes to a continuous range.
cls_mapping = {new: old for new, old in zip(range(len(clsfreqs)), clsfreqs.keys())}

for tile_id, filepaths in maskgroups.items():
    for src, dst in zip(filepaths["src"], filepaths["dst"]):
        with rasterio.open(dst, mode="r+", **msk_profile) as src:
            data = src.read()
            for new, old in cls_mapping.items():
                data[data == old] = new
            src.write(data)

# Save the corresponding name and color mappings.
dst_names: dict[int, str] = {}
dst_colors: dict[int, list[int]] = {}
for new, old in cls_mapping.items():
    dst_names[new] = src_names[old]
    dst_colors[new] = src_colors[old]

with open("../dataset/temp/classes.json", mode="w") as f:
    json.dump(dst_names, f)

with open("../dataset/temp/colors.json", mode="w") as f:
    json.dump(dst_colors, f)

In [8]:
# Compute and save the class weights.
clswghts = (np.array(list(clsfreqs.values()), dtype=np.float32) / validpx) ** -1

# Snap the total number of background to zero to force them to be ignored.
# NOTE: Setting the null class weights to a small positive value instead of actually zero ensures a valid loss value when encountering image patches containing only said classes.
clswghts[0] = 1e-5
clswghts[1:]/=clswghts[1:].sum()

np.save("../dataset/temp/weights", clswghts)

  clswghts = (np.array(list(clsfreqs.values()), dtype=np.float32) / validpx) ** -1


In [10]:
np.set_printoptions(suppress=True)

np.load("../dataset/temp/weights.npy")

array([0.00001   , 0.02830823, 0.06861622, 0.02898338, 0.40045056,
       0.06691286, 0.04502055, 0.09106125, 0.27064695])