# Training Dataset Generation

This notebook uses the GISCUP 2023 Sentinal-2 dataset to generate a Tensorflow Dataset to be consumed by '02_model_training.ipynb'

Bitmasks are first generated using each region and training polygons which are then sliced along with the associated Sentinel-2 RGB data to generate 256x256 training tiles.

In [None]:
from pathlib import Path
from pathlib import Path
import tensorflow as tf
import geopandas as gpd
import rasterio
from rasterio.mask import mask
from rasterio.features import rasterize
import math
from rasterio.windows import Window
import numpy as np
import random

In [None]:
ROOT_DIR = Path("../data/")
DATA_PATH = Path("../data/2023_SIGSPATIAL_Cup_data_files/")

TRAINING_DATA_DIR = ROOT_DIR/"training_data"
if not TRAINING_DATA_DIR.exists():
    TRAINING_DATA_DIR.mkdir(parents=True)

TRAIN_DATA_REGIONS = TRAINING_DATA_DIR/"train_data_regions"
if not TRAIN_DATA_REGIONS.exists():
    TRAIN_DATA_REGIONS.mkdir(parents=True)

INTERIM_DATA_PATH = TRAINING_DATA_DIR/"interim/"
if not INTERIM_DATA_PATH.exists():
    INTERIM_DATA_PATH.mkdir(parents=True)

PARTITIONS_PATH = TRAINING_DATA_DIR/"partitions"
if not PARTITIONS_PATH.exists():
    PARTITIONS_PATH.mkdir(parents=True)

In [None]:
# Load regions and sentinal data
regions = gpd.read_file(DATA_PATH/"lakes_regions.gpkg")
sentinal_files = list(Path(DATA_PATH).glob("*.tif"))
training_lakes_geoms = gpd.read_file(DATA_PATH/"lake_polygons_training.gpkg")

In [None]:
train_regions = ["2019-06-03_2", "2019-06-03_4", "2019-06-03_6",
                 "2019-06-19_1", "2019-06-19_3", "2019-06-19_5",
                 "2019-07-31_2", "2019-07-31_4", "2019-07-31_6",
                 "2019-08-25_1", "2019-08-25_3", "2019-08-25_5"]

test_regions = ["2019-06-03_1", "2019-06-03_3", "2019-06-03_5",
                "2019-06-19_2", "2019-06-19_4", "2019-06-19_6",
                "2019-07-31_1", "2019-07-31_3", "2019-07-31_5",
                 "2019-08-25_2", "2019-08-25_4", "2019-08-25_6" ]

## Generate Bitmasks

In [None]:
for sat_img in DATA_PATH.glob("*.tif"):
        filename = sat_img.name

        raw_satellite = rasterio.open(sat_img)

        for region in range(1, 7):
            region_geom = regions[regions["region_num"] == region].geometry

            lakes = training_lakes_geoms[(training_lakes_geoms["region_num"] == region)& (training_lakes_geoms["image"] == filename)] 
            

            print(f"Processing {filename} Region: {region}")

            if not len(lakes):
                # For each file, some regions do not have training lakes
                continue

            # Extract region from sat image
            region_raw, affine = mask(raw_satellite, shapes=region_geom, crop=True)
            lakes_bitmask = rasterize(
                [(geom, 1) for geom in lakes.geometry],
                out_shape=region_raw.shape[1:],
                transform=affine,
            )

            # Write out region and associated bitmask
            out_dir = INTERIM_DATA_PATH/f"{region}"
            out_dir.mkdir(parents=True, exist_ok=True)

            file_date = sat_img.stem[-13:-3]

            # Write out raw region
            with rasterio.open(
                out_dir / f"{file_date}_raw.tif",
                "w",
                driver=raw_satellite.driver,
                crs=raw_satellite.crs,
                transform=affine,
                width=lakes_bitmask.shape[1],
                height=lakes_bitmask.shape[0],
                count=3,
                dtype=region_raw.dtype,
            ) as out:
                out.write(region_raw)

            # Write out bitmask
            with rasterio.open(
                out_dir / f"{file_date}_bitmask.tif",
                "w",
                driver=raw_satellite.driver,
                crs=raw_satellite.crs,
                transform=affine,
                width=lakes_bitmask.shape[1],
                height=lakes_bitmask.shape[0],
                count=1,
                dtype=region_raw.dtype,
            ) as out:
                out.write_band(1, lakes_bitmask)

## Partition Segments into 256x256 images

In [None]:
WINDOW_SIZE = 256

def read_window_segment(window, raster):
    """Read the segment and affine transform from the raster for the given window"""
    raster_raw = raster.read(window=window)
    raster_trans = raster.window_transform(window)
    return raster_raw, raster_trans

for path in INTERIM_DATA_PATH.glob("*/*raw.tif"):
        region = str(path).split("/")[4]
        date = path.stem[:10]
        bitmask_path = str(path).replace("raw", "bitmask")

        # Load raster files
        satellite = rasterio.open(path)
        bitmask = rasterio.open(bitmask_path)
        
        num_cols = math.ceil(satellite.width / WINDOW_SIZE)
        num_rows = math.ceil(satellite.height / WINDOW_SIZE)

        window_indices = [
            (col, row) for col in range(num_cols) for row in range(num_rows)
        ]

        for col_indx, row_indx in window_indices:
            window = Window(
                col_indx * WINDOW_SIZE, row_indx * WINDOW_SIZE, WINDOW_SIZE, WINDOW_SIZE
            )

            sat_window, sat_window_trans = read_window_segment(window, satellite)
            mask_window, mask_trans = read_window_segment(window, bitmask)
        
            out_dir = PARTITIONS_PATH/f"{WINDOW_SIZE}/{region}"
            out_dir.mkdir(parents=True, exist_ok=True)

            index_date_name = f"{col_indx}_{row_indx}_{date}"

            # Write out satellite segment
            with rasterio.open(
                out_dir / (index_date_name + "_sat.tif"),
                "w",
                driver=satellite.driver,
                crs=satellite.crs,
                transform=sat_window_trans,
                width=sat_window.shape[1],
                height=sat_window.shape[2],
                count=3,
                dtype=sat_window.dtype,
            ) as out:
                out.write(sat_window)

            # Write out bitmask segment
            with rasterio.open(
                out_dir / (index_date_name + "_bitmask.tif"),
                "w",
                driver=bitmask.driver,
                crs=bitmask.crs,
                transform=mask_trans,
                width=mask_window.shape[1],
                height=mask_window.shape[2],
                count=1,
                dtype=mask_window.dtype,
            ) as out:
                out.write(mask_window)

# Create Tensorflow Training Dataset

In [None]:
def load_images():
    path = PARTITIONS_PATH/f"{WINDOW_SIZE}/"

    for p in path.glob("*/*_sat.tif"):
        bitmask_path = str(p).replace("sat", "bitmask")

        with rasterio.open(p, "r") as sat_in:
            raw_sat = sat_in.read()
        
        with rasterio.open(bitmask_path, "r") as bitmask_in:
            bitmask = bitmask_in.read()

        if (raw_sat.shape != (3,WINDOW_SIZE,WINDOW_SIZE) or bitmask.shape != (1,WINDOW_SIZE,WINDOW_SIZE)):
            continue
        
        # If a bitmask has zero pixels (no lake), then there's a 2% chance it will be included
        # in the dataset
        if np.sum(bitmask) == 0:
            if random.randint(0, 100) > 50:
                continue

        # Reshape to channel last
        raw_sat = np.moveaxis(raw_sat, 0, 2)
        bitmask = np.moveaxis(bitmask, 0, 2)
        raw_sat = raw_sat / 255
        
        yield tf.convert_to_tensor(raw_sat), tf.convert_to_tensor(bitmask)

image_loader = tf.data.Dataset.from_generator(load_images, output_signature=(tf.TensorSpec(shape=(WINDOW_SIZE, WINDOW_SIZE, 3), dtype=tf.float32, name="RGB"), tf.TensorSpec(shape=(WINDOW_SIZE, WINDOW_SIZE, 1), dtype=tf.uint8, name="Bitmask")))
tf.data.Dataset.save(image_loader, str(TRAINING_DATA_DIR/f"RGB256"))