

# Import

In [None]:

from torchinfo import summary
import sys
import geopandas as gpd
from tqdm import tqdm
from torch.utils.data import DataLoader, ConcatDataset
from torch.utils.data import Dataset
import rasterio as rio
from typing import Any, List
from rasterio.features import rasterize
from rasterio.plot import reshape_as_image
from rasterio.windows import from_bounds, transform as w_transform
from torchvision.utils import make_grid, draw_segmentation_masks
import numpy as np
from pathlib import Path
import os
import kornia.augmentation as K
from typing import Any
from lightning.pytorch.utilities.types import STEP_OUTPUT
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import segmentation_models_pytorch as smp
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch import LightningModule
from torchvision.utils import make_grid, draw_segmentation_masks
from dotenv import load_dotenv
load_dotenv()
import geopandas as gpd
import os
from pathlib import Path
from math import floor
from itertools import product
from shapely.geometry import box

# Tiles computing

In [None]:
#  Finding the image CRS to compute appropriately the grid geometry using GeoPandas.
train_val_path = 'Imagery/Train_Val/Train_04_Apr.tif'
source_db = rio.open(train_val_path)
image_crs = source_db.crs
image_crs

In [16]:
#  Computing the virtual tiles geometry
overlap = 0.5
tile_width = 256
tile_height = 256

x_indices = range(0, source_db.width, floor(tile_width * (1 - overlap)))
y_indices = range(0, source_db.height, floor(tile_height * (1 - overlap)))
grid_coordinates = product(y_indices, x_indices) # Not a typo, the raster starts from the top-left coordinate. 

def tile_box_for_coordinates(grid_x, grid_y):
    tile_min = source_db.xy(grid_x, grid_y)
    tile_max = source_db.xy(grid_x + tile_width, grid_y + tile_height)
    tile_box = box(tile_min[0], tile_min[1], tile_max[0], tile_max[1])
    return tile_box
grid_boxes = [tile_box_for_coordinates(grid_xy[0], grid_xy[1]) for grid_xy in grid_coordinates]

raster_tiles_df = gpd.GeoDataFrame(geometry=grid_boxes, crs=image_crs)
raster_tiles_df["height"] = tile_height
raster_tiles_df["width"] = tile_width
# generate a visual interactive map
raster_tiles_df.explore(style_kwds=dict(fill=False))

In [None]:
# load train and validation sets
train_set_df = gpd.read_file('Area Train_Val_test/Train_Area.geojson')
val_set_df = gpd.read_file('Area Train_Val_test/Validation_Area.geojson')
# generate a visual interactive map
val_area_map = val_set_df.explore(style_kwds=dict(fill=False, color='red'))

# The `unary_union` method is used bacause the used DataFrame methods expect 
# Series of the same length but work with broadcast semantics.
#
# We consider val tiles those that have more than 50% overlap with the val 
# area polygons.  
val_overlap_query = (raster_tiles_df.intersection(val_set_df.geometry.unary_union).area / raster_tiles_df.geometry.area) > 0.5
val_set_tiles = raster_tiles_df[val_overlap_query]

# We consider train tiles those that have more than 50% overlap with the train
# area polygons and do not intersect with the val area.
train_overlap_query = (raster_tiles_df.intersection(train_set_df.geometry.unary_union).area / raster_tiles_df.geometry.area) > 0.5
train_set_tiles = raster_tiles_df[train_overlap_query & ~(raster_tiles_df.contains(val_set_tiles.geometry.unary_union))]

print(f"Got {len(train_set_tiles)} training tiles and {len(val_set_tiles)} val tiles")
# generate a visual interactive map
val_set_tiles.explore(m=val_area_map, style_kwds=dict(fill=False, color='blue'))
train_set_tiles.explore(m=val_area_map, style_kwds=dict(fill=False, color='green'))

#  Saving the tiles to file.
val_set_tiles.to_file("valTiles_CLS_UTM.geojson", driver='GeoJSON', index=False)
train_set_tiles.to_file("TrainTiles_CLS_UTM.geojson", driver='GeoJSON', index=False)

In [None]:

class SingleRasterPalaeochannelDataset(Dataset):
    def __init__(self, 
                 tiles_df: gpd.GeoDataFrame, 
                 source_path: Path, 
                 features_df: gpd.GeoDataFrame,
                 aoi_df: gpd.GeoDataFrame):
        self.tiles_df = tiles_df
        self.source_path = source_path
        self.features_df = features_df
        self.aoi_df = aoi_df
    
    def __len__(self) -> int:
        return len(self.tiles_df)
    
    def __getitem__(self, index: int) -> Any:
        # Get the vectorial tile from the tiles GeoDataFrame.
        tile = self.tiles_df.loc[index]
        
        with rio.open(self.source_path) as source_db:
            # Get the raster window from the source dataset.
            tile_window = from_bounds(*tile.geometry.bounds, transform=source_db.transform)
            tile_raster: np.ndarray = source_db.read(window=tile_window)
            tile_raster[np.isnan(tile_raster)] = 0
            if tile_raster.dtype == np.uint16:
                tile_raster = tile_raster.astype(np.int32)
            
            # Compute the window's Affine transform for features and aoi rasterization
            window_transform = w_transform(tile_window, source_db.transform)
            mask_shape = (tile_raster.shape[1], tile_raster.shape[2])
            mask = rasterize(self.features_df.geometry, out_shape=mask_shape, transform=window_transform)
            aoi_mask = rasterize(self.aoi_df.geometry, out_shape=mask_shape, transform=window_transform)
            
            # Mask out data and features falling outside of aoi.
            tile_raster = tile_raster * aoi_mask
            mask = mask * aoi_mask
            
            item = dict(
                image = reshape_as_image(tile_raster),
                mask = mask,
                # add geometry and geografical informations for plotting.
                tile_geometry = tile.geometry.wkt,
                tile_crs = str(self.tiles_df.crs)
            )
            return item


# Loader

In [None]:
# creating Dataloaders
train_tiles_df = gpd.read_file('TrainTiles_CLS_UTM.geojson')
train_aoi_df = gpd.read_file('Area Train_Val_test/Train_Area.geojson')
val_tiles_df = gpd.read_file('valTiles_CLS_UTM.geojson')
val_aoi_df = gpd.read_file('Area Train_Val_test/Validation_Area.geojson')

april_tif_path = 'Imagery/Train_Val/Train_04_Apr.tif'
march_tif_path = 'Imagery/Train_Val/Train_03_March.tif'
march_april_features_df = gpd.read_file('annotations/L1_Train_March-April.geojson')

aug_tif_path = 'Imagery/Train_Val/Train_08_Aug.tif'
jul_tif_path = 'Imagery/Train_Val/Train_07_Jul.tif'
jul_aug_features_df = gpd.read_file('annotations/L2_Train_July-August.geojson')

nov_tif_path = 'Imagery/Train_Val/Train_11_Nov.tif'
jan_tif_path = 'Imagery/Train_Val/Train_12_Jan23.tif'
nov_jan_features_df = gpd.read_file('annotations/L3_Train_November-December.geojson')

# Loader for TR1 setting

In [None]:
# creating dataloader for TR1
TR1_april_train_dataset = SingleRasterPalaeochannelDataset(train_tiles_df, april_tif_path, march_april_features_df, train_aoi_df)
TR1_august_train_dataset = SingleRasterPalaeochannelDataset(train_tiles_df, aug_tif_path, jul_aug_features_df, train_aoi_df)
TR1_nov_train_dataset = SingleRasterPalaeochannelDataset(train_tiles_df, nov_tif_path, nov_jan_features_df, train_aoi_df)

TR1_april_val_dataset = SingleRasterPalaeochannelDataset(val_tiles_df, april_tif_path, march_april_features_df, val_aoi_df)
TR1_august_val_dataset = SingleRasterPalaeochannelDataset(val_tiles_df, aug_tif_path, jul_aug_features_df, val_aoi_df)
TR1_nov_val_dataset = SingleRasterPalaeochannelDataset(val_tiles_df, nov_tif_path, nov_jan_features_df, val_aoi_df)

full_train_dataset = ConcatDataset([TR1_april_train_dataset, TR1_august_train_dataset, TR1_nov_train_dataset])
full_val_dataset = ConcatDataset([TR1_april_val_dataset, TR1_august_val_dataset, TR1_nov_val_dataset])
print(f"Datasets build! {len(full_train_dataset)} training tiles, {len(full_val_dataset)} validation tiles.")


# Loader for TR2 setting

In [None]:

# TR2_april_train_dataset = SingleRasterPalaeochannelDataset(train_tiles_df, april_tif_path, march_april_features_df, train_aoi_df)
# TR2_march_train_dataset = SingleRasterPalaeochannelDataset(train_tiles_df, march_tif_path, march_april_features_df, train_aoi_df)
# TR2_august_train_dataset = SingleRasterPalaeochannelDataset(train_tiles_df, aug_tif_path, jul_aug_features_df, train_aoi_df)
# TR2_july_train_dataset = SingleRasterPalaeochannelDataset(train_tiles_df, jul_tif_path, jul_aug_features_df, train_aoi_df)
# TR2_nov_train_dataset = SingleRasterPalaeochannelDataset(train_tiles_df, nov_tif_path, nov_jan_features_df, train_aoi_df)
# TR2_jan_train_dataset = SingleRasterPalaeochannelDataset(train_tiles_df, jan_tif_path, nov_jan_features_df, train_aoi_df)

# TR2_april_val_dataset = SingleRasterPalaeochannelDataset(val_tiles_df, april_tif_path, march_april_features_df, val_aoi_df)
# TR2_march_val_dataset = SingleRasterPalaeochannelDataset(val_tiles_df, march_tif_path, march_april_features_df, val_aoi_df)
# TR2_august_val_dataset = SingleRasterPalaeochannelDataset(val_tiles_df, aug_tif_path, jul_aug_features_df, val_aoi_df)
# TR2_july_val_dataset = SingleRasterPalaeochannelDataset(val_tiles_df, jul_tif_path, jul_aug_features_df, val_aoi_df)
# TR2_nov_val_dataset = SingleRasterPalaeochannelDataset(val_tiles_df, nov_tif_path, nov_jan_features_df, val_aoi_df)
# TR2_jan_val_dataset = SingleRasterPalaeochannelDataset(val_tiles_df, jan_tif_path, nov_jan_features_df, val_aoi_df)

# full_train_dataset = ConcatDataset([TR2_april_train_dataset, TR2_march_train_dataset, TR2_august_train_dataset, TR2_july_train_dataset, TR2_nov_train_dataset, TR2_jan_train_dataset])
# full_val_dataset = ConcatDataset([TR2_april_val_dataset, TR2_march_val_dataset, TR2_august_val_dataset, TR2_july_val_dataset, TR2_nov_val_dataset, TR2_jan_val_dataset])
# print(f"Datasets build! {len(full_train_dataset)} training tiles, {len(full_val_dataset)} validation tiles.")