In [1]:
import dataset

In [2]:
import dataset._collate

In [3]:
import time
import os
from typing import Union, Optional, Tuple
import logging
import math
import warnings
from pathlib import Path

from itertools import product
from copy import deepcopy
import numpy as np
import pandas as pd
import geopandas as gpd
import osmnx as ox
from shapely.geometry import box
from shapely.ops import transform, unary_union
import pyproj
from pyproj import Transformer
import rasterio.features
from rasterio import windows
from rasterio.windows import Window
from tqdm import tqdm

import torch
from torch.utils.data import Subset, Dataset, DataLoader
import lightning.pytorch as pl
from skimage.morphology import disk, dilation
import fiona
import dataset
from dataset._collate import default_collate_with_shapely_support
from dataset._image_table import build_images_table
from dataset._data_utils import is_same_crs
from dataset._osm import get_osm_geometries
from dataset._data_utils import get_complete_areas, write_area_split, assign_target_polygons_to_areas
from dataset._image_extractor import extract_images
from utils.utils import set_abs_path

class BaseDataset(Dataset):
    """
    RemoteSensingDataset is designed to handle multiple sources of images at once.
    These sources may have different resolutions, channels or CRS.
    And the dataset would return a sample image accordingly. However, this flexibility has some limitations.
    For example, the expected size in different resolution maybe a couple of bit off,
    (assuming that they perfectly overlap).
    More likely, a sample may lie across multiple tiles in some sources,
    for that matter mapping each point on earth is still a very challenging task and it has its limitations.
    While we find work around these limitations, it may not be possible always.

    Note, that all the shape of patch, aoi and stride are given specific to a reference image source
    (assuming there are multiple crs), and rest are of the shapes are sizes are calculated proportionally.
    """
    def __init__(
        self,
        dataset_name: str = 'train',
        labeled_areas: gpd.GeoDataFrame=None,
        images_df: gpd.GeoDataFrame=None,
        assigned_target_features: gpd.GeoDataFrame=None,
        assigned_nontrees_features: gpd.GeoDataFrame=None,
        reference_source:str=None,
        patch_size: Union[int, Tuple]=256,
        allow_partial_patches: bool=False,
        sequential_stride: Union[int, Tuple] = (256, 256),
        dataset_transform=None,
        processed_dir: str = None,
        save_samples: bool = False,
        save_patch_df: bool = False,
        save_labeled_areas_df: bool = False,
        extract_images_for_areas: bool=False,
        project_crs: str = None,
        merge_mode: str = "keep_first",
        extracted_image_col:str='extracted_images',
        get_patch_custom: callable=None,
    ) -> None:
        super(Dataset, self).__init__()
        assert patch_size is not None
        if type(patch_size) == int:
            patch_size = (patch_size, patch_size)
            
        assert sequential_stride is not None
        if type(sequential_stride) == int:
            sequential_stride = (sequential_stride, sequential_stride)
            
        self.dataset_name = dataset_name
        self.images_df = images_df
        self.assigned_target_features = assigned_target_features
        self.assigned_nontrees_features = assigned_nontrees_features
        self.dataset_transform = dataset_transform
        self.extract_images_for_areas = extract_images_for_areas
        self.reference_source = reference_source
        self.processed_dir = processed_dir
        self.save_samples = save_samples
        self.project_crs = project_crs
        self.merge_mode = merge_mode
        self.patch_size = patch_size
        self.get_patch_custom = get_patch_custom

        if labeled_areas.empty:
            return
            
        if self.processed_dir:
            logging.info(f"Processed_dir path: {self.processed_dir}")
            self.processed_dir = Path(self.processed_dir)
            # create processed sample dir
            self.processed_dir.mkdir(exist_ok=True)
            patch_df_path = self.processed_dir / f"{dataset_name}_patch_df.pt"
            labeled_areas_df_path = self.processed_dir / f"{dataset_name}_labeled_areas_df.pt"


        if (self.save_samples or save_patch_df or save_labeled_areas_df) and not self.processed_dir:
            raise Exception("No 'processed_dir' provided, cannot save samples or patch_df or save_labeled_areas_df.")
        elif not (self.save_samples or save_patch_df or save_labeled_areas_df) and self.processed_dir:
            warnings.warn("You provided 'processed_dir' provided but neither samples nor patch_df are saved.")
        elif self.processed_dir:
            (self.processed_dir / 'patches').mkdir(exist_ok=True)
            logging.info(f"Saving samples:{self.save_samples}, patch_df: {save_patch_df} to {self.processed_dir}")




        # For backward compatibility; it can be removed once it has been run for all training scripts.
        if self.processed_dir and save_labeled_areas_df and save_patch_df and patch_df_path.exists():
            extract_labeled_areas_from_patch_dict(patch_df_path, labeled_areas_df_path)

        if self.processed_dir and save_labeled_areas_df and labeled_areas_df_path.exists():
            logging.info(f"Reading labeled areas from {labeled_areas_df_path}")
            lb_dict = torch.load(labeled_areas_df_path)
            self.labeled_areas = lb_dict["labeled_areas"]
            self.assigned_images = lb_dict["assigned_images"]
            logging.info(f"Done reading labeled areas")
        else:
            logging.info(f"Creating labeled areas and assigned images for dataset {self.dataset_name}")
            self.labeled_areas, self.assigned_images = self.build_labeled_areas_table(labeled_areas.copy(), extract_images_for_areas, extracted_image_col)
            if self.processed_dir and save_labeled_areas_df:
                lb_dict = {
                    "labeled_areas": self.labeled_areas,
                    "assigned_images": self.assigned_images,
                }
                torch.save(lb_dict, labeled_areas_df_path)

        if self.processed_dir and save_patch_df and patch_df_path.exists():
            patch_dict = torch.load(patch_df_path)
            self.patch_df = patch_dict["patch_df"]
            self.length = len(self.patch_df)
        else:
            self.patch_df, self.length = self.build_patch_table_sequentially(sequential_stride=sequential_stride, allow_partial_patches=allow_partial_patches)
            if self.processed_dir and save_patch_df:
                patch_dict = {"patch_df": self.patch_df}
                torch.save(patch_dict, patch_df_path)

        
    
    def __len__(self):
        return self.length
    

    def __getitem__(self, idx):
        if self.processed_dir and self.save_samples:
            path = self.processed_dir / 'patches' / f"{self.dataset_name}_{idx}.pt"
            if path.exists():
                try:
                    sample = torch.load(path)
                except Exception as e:
                    logging.info(e)
                    sample = self.get_patch(idx)
                    torch.save(sample, path)
            else:
                sample = self.get_patch(idx)
                torch.save(sample, path)
        else:
            sample = self.get_patch(idx)
        if self.dataset_transform:
            sample = self.dataset_transform(sample)
        return sample

    
    def get_patch(self, idx: int):
        # Should returns a dict per source
        patch = self.patch_df.loc[idx]
        pt = patch.to_dict()
        pt.update(self.get_patch_raster_images(patch))
        pt['patch_id'] = idx

        if self.get_patch_custom:
            pt = self.get_patch_custom(pt=pt, dataset_name=self.dataset_name,labeled_areas=self.labeled_areas)
        
        return pt
    
    def get_patch_raster_images(self, patch):
        pt = {}
        overlapping_images = self.get_patch_overlapping_images(patch)
        # Handle different source by defining window using bounds
        # While it's nice to have overlapping images in the patch sample, it creates problems in dataloader collate.
        # pt['overlapping_images'] = overlapping_images.index.tolist()
        for _, img in overlapping_images.iterrows():
            src_name = img["src"]
            # transform to reference image crs
            patch_geom = patch.geometry
            wn_geom = geometry_to_crs(patch_geom, self.project_crs, img["ori_crs"])
            # careful here, the windows is from the bounds not the geometry of the patch
            wn = rasterio.windows.from_bounds(*wn_geom.bounds, transform=img['ori_transform'])
            # This is required due to a rounding error while creating the window and case when the window is less than 1 pixel
            wn = Window(np.floor(wn.col_off), np.floor(wn.row_off), max(1, round(wn.width)), max(1, round(wn.height)))

            img_arr = self.read_window_from_image(img_path=img["path"], wn=wn, boundless=True)
            # need to reproject if different reference crs was used
            img_arr = img_arr.astype(np.float32)
            nan_mask = np.isnan(img_arr) | (img_arr == img["nodatavals"])  # | (img_arr == 0)
            img_arr[nan_mask] = np.nan
            if src_name in pt:
                if self.merge_mode == "keep_first":
                    nan_mask = np.isnan(pt[src_name])
                    pt[src_name] = np.where(nan_mask, img_arr, pt[src_name])
                if self.merge_mode == 'keep_last':
                    pt[src_name] = np.where(nan_mask, pt[src_name], img_arr)
            else:
                pt[src_name] = img_arr
        for src_name in overlapping_images["src"].unique():
            pt[src_name] = torch.tensor(pt[src_name], dtype=torch.float32)
        return pt
    
    def get_patch_overlapping_images(self, patch):
        ov_im = self.assigned_images.query("area_id == @patch.area_id")
        # For multiple sources, creates patches on the reference source and then select images that overlap the patch
        overlapping_images = self.images_df.loc[ov_im.reset_index().image_id]
        overlapping_images = overlapping_images[overlapping_images.intersects(patch.geometry)]
        return overlapping_images

    def build_labeled_areas_table(self, labeled_areas: gpd.GeoDataFrame, extract_images_for_areas=False, extracted_image_col = "extracted_images"):
        """
        assign area id to each image
        Sample table:
        area_id, geometry, area_size_pixel ,polygons, patch_start_index, patch_end_index, overlapping_images
        1,  [[1,2], [300,400]], (5000,5000), [[[1,2], [5,6], [64,3]], [[83,45], [34, 45], [97, 66]]], 0, 120, [img1.tif, img2.tif]
        """

        assigned_images = gpd.sjoin(self.images_df[["src", "geometry"]],
                                    labeled_areas[["geometry"]],
                                    predicate="intersects",
                                    how="inner").rename(columns={"index_right": "area_id"})
        # filter out not found values
        missing_areas = labeled_areas.query("index not in @assigned_images.area_id")
        labeled_areas = labeled_areas.query("index in @assigned_images.area_id")
        if len(missing_areas) > 0:
            logging.info(f"{len(missing_areas)} areas with no image source, idx: {missing_areas.index}")
            logging.info(f"ignoring them for now")
        complete_areas = assigned_images.query("src == @self.reference_source").area_id
        missing_areas = labeled_areas.query("index not in @complete_areas")
        labeled_areas = labeled_areas.query("index in @complete_areas")
        assigned_images = assigned_images.query("area_id in @complete_areas")
        if len(missing_areas) > 0:
            logging.info(f"{len(missing_areas)} areas with no input image source, idx: {missing_areas.index}")
            logging.info(f"ignoring them for now")
        assigned_images.set_index(['area_id'], append=True, inplace=True)
        assert assigned_images.index.names == ["image_id", "area_id"]
        # assigned_images.set_index(["area_id", "image_id"], inplace=True)
        # todo this is ugly but faster than apply on the outer loop
        area_info = []
        for area_id, group_df in assigned_images.groupby("area_id"):
            area_geom = labeled_areas.loc[area_id]['geometry']

            def get_overlapping_area(row):
                # These transform represent the overlapping transform, shape and geometry so renamed it to ov from area
                return pd.Series([area_geom.intersection(row['geometry'])], index=["ov_geometry"])

            overlapping_area = group_df.apply(get_overlapping_area, axis=1)
            area_info.append((overlapping_area))

        area_info = pd.concat(area_info, axis=0)
        # add to lookup table
        assigned_images = assigned_images.join(area_info)
        # Geometry and transform refer to the geometry of the original image instead of its overlap with the rectangle.
        # So we delete these columns and replace them with the common area-image transform and geometry
        assigned_images.set_geometry('ov_geometry', inplace=True)
        assigned_images.drop(labels=['geometry'], axis="columns", inplace=True)
        assigned_images.rename_geometry('geometry', inplace=True)

        if extract_images_for_areas and self.processed_dir is not None:
            ep = self.processed_dir / 'extracted_images'
            if ep.exists() and ep.iterdir():
                raise Exception(
                    f"Extracted images already exist in {ep}!!" + f"\nPlease remove the folder to re-extract." +
                    f"\nOtherwise change image source to the extracted location and set image extract to False")
            else:
                ep.mkdir(exist_ok=True)
                logging.info(f"Extracting relevant image parts to {ep}")
                assigned_images = extract_images(areas=labeled_areas,
                                                 images_df=self.images_df,
                                                 base_dir=ep,
                                                 extracted_image_col=extracted_image_col,
                                                 assigned_images=assigned_images,
                                                 prefix=self.dataset_name)
                assert extracted_image_col in assigned_images.columns
        else:
            logging.info("Not extracting images.")

        return labeled_areas, assigned_images


    def build_patch_table_sequentially(self, patch_start_index=0, sequential_stride=1, allow_partial_patches=False):
        patches = []
        height, width = self.patch_size
        images_grouped_by_area = self.assigned_images.query(f"src == '{self.reference_source}'").groupby('area_id')
        for area_idx, area_images in tqdm(images_grouped_by_area, desc="Iterating over areas"):
            area = self.labeled_areas.loc[area_idx]  # Returns a series
            img_indices = area_images.reset_index()["image_id"].values
            imgs = self.images_df.loc[img_indices]  #TODO:assign self.images_df

            ori_crs, ori_transform, union_geometry = get_imgs_reference(imgs)
            # TODO: Verify!
            # We still want to work with the overlapping geometry instead of the pure image geometry
            union_geometry = union_geometry.intersection(area.geometry)

            # Now we divide this geometry into smaller patches
            # Use the point approach; start with the top left point,
            # get positions in image space
            col_start, col_end, row_start, row_end = self.get_rows_cols_from_geom(ori_crs, ori_transform,
                                                                                  union_geometry)

            if row_start > row_end or col_start > col_end:
                logging.info(f"row_start > row_end or col_start > col_end {row_start} {row_end} {col_start} {col_end}")

            stride_row, stride_col = sequential_stride
            if allow_partial_patches:
                col_itr_start = col_start
                row_itr_start = row_start
                col_itr_end = col_end
                row_itr_end = row_end
            else:
                # use center
                col_itr_start = col_start + ((col_end - col_start) % height) // 2
                row_itr_start = row_start + ((row_end - row_start) % width) // 2
                col_itr_end = col_end - height
                row_itr_end = row_end - width

            for col, row in product(range(col_itr_start, col_itr_end, stride_col),
                                    range(row_itr_start, row_itr_end, stride_row)):
                # Window is defined as offset in pure image space
                wn = Window(col, row, width, height)
                wx, wy = rasterio.transform.xy(ori_transform, [row, row + width], [col, col + height])
                # Window box in original CRS
                wn_geometry = box(wx[0], wy[1], wx[1], wy[0])  # left, bottom, right, top
                transform_in_ori_crs = windows.transform(wn, ori_transform)

                patch = {
                    "ori_geometry": wn_geometry,
                    "geometry": wn_geometry,  # !!This is a placeholder which is transformed to project_crs in the end
                    "shape": (height, width),
                    "wn_ori": (col, row, width, height),
                    "ori_crs": ori_crs,
                    "ori_transform": transform_in_ori_crs,
                    "area_id": area_idx,
                }

                patches.append(patch)

        pdf = gpd.GeoDataFrame.from_dict(patches).set_crs(self.project_crs)
        pdf = get_df_in_single_crs(pdf, self.project_crs)

        # REMOVE ALL-0 MASKED PATCHES
        def features_count(rw, atf=self.assigned_target_features,nontf=self.assigned_nontrees_features):
            trfs = len(get_patch_features(atf, rw.geometry, area_id=rw.area_id))
            nontf = len(get_patch_features(nontf, rw.geometry, area_id=rw.area_id))
            return trfs + nontf

        # if not (self.allow_empty_patches or self.target_features is None):  # 2. Case of prediction
        pdf['n_features'] = pdf.apply(features_count, axis=1) 
        pdf = pdf.query("n_features > 0") # to remove all-0 masked patches, there're no target features neither non-target features

        pdf['patch_id'] = range(patch_start_index, len(pdf) + patch_start_index)

        # verify that patches overlap with labeled_area they come from and adhere to their shape
        pdfs = []
        for area_idx in self.labeled_areas.index:
            pdfs.append(
                pdf.query("area_id == @area_idx").sjoin(self.labeled_areas.query("index == @area_idx"),
                                                        how="inner",
                                                        predicate="intersects",
                                                        lsuffix="")[pdf.columns])
        pdf = pd.concat(pdfs)
        pdf.drop_duplicates("patch_id", inplace=True)

        pdf['patch_id'] = range(patch_start_index, len(pdf) + patch_start_index)

        pdf.set_index("patch_id", inplace=True)

        if self.processed_dir is not None:
            p = self.processed_dir / 'qgis' / self.dataset_name
            logging.info(f"Writing labeled areas, image df and assigned images to {p}")
            p.mkdir(exist_ok=True, parents=True)
            with fiona.Env(OSR_WKT_FORMAT="WKT2_2018"):
                pdf[['geometry', 'area_id']].to_file(p / "patch_grid.gpkg", driver="GPKG")
                self.labeled_areas[['geometry']].to_file(p / "labeled_areas.gpkg", driver="GPKG")
                self.images_df[['geometry', 'path']].to_file(p / "image_df.gpkg", driver="GPKG")
                self.assigned_images[['geometry']].to_file(p / "assigned_images.gpkg", driver="GPKG")
                if self.assigned_target_features is not None:
                    self.assigned_target_features[['geometry']].to_file(p / "assigned_target_features.gpkg", driver="GPKG")

        return pdf, len(pdf)
    
    def get_rows_cols_from_geom(self, target_crs, ori_transform, geometry):
        minx, miny, maxx, maxy = geometry.bounds
        transformer_to_ori = Transformer.from_crs(self.project_crs, target_crs, always_xy=True)
        # Get the points in image coordinate reference system
        ic_minx, ic_maxy = transformer_to_ori.transform(minx, maxy)  # Top left
        ic_maxx, ic_miny = transformer_to_ori.transform(maxx, miny)  # Bottom right
        # In a normal array, row, col (0,0) means top left and row increase downwards
        #   while the column increase right wards
        # In a geo reference image with xy (longitude, latitude) orientation,
        #   y decreases south wards (opposite of rows) and x increases east wards (same as columns)
        row_start, col_start = rasterio.transform.rowcol(ori_transform, ic_minx,
                                                         ic_maxy)  # Bottom left, i.e image start
        row_end, col_end = rasterio.transform.rowcol(ori_transform, ic_maxx, ic_miny)  # Top right, i.e image end
        return col_start, col_end, row_start, row_end

    @staticmethod
    def read_window_from_image(img_path, wn, boundless=True, fill_value=0):
        with rasterio.open(img_path) as src:
            rw = src.read(window=wn, boundless=boundless, fill_value=fill_value)
        return rw

    @staticmethod
    def save(dataset, path):
        if not str(path).endswith('.pt'):
            path = path / f'dataset.pt'
        try:
            torch.save(dataset, path)
            return 1
        except Exception as e:
            return 0

    @staticmethod
    def load(path):
        if not str(path).endswith('.pt'):
            path = path / f'dataset.pt'
        return torch.load(path)

    @staticmethod
    def patches_count_calculator(area_dim, patch_dim, stride_dim, first_partial_patch=True):
        n = ((area_dim - patch_dim) / stride_dim) + 1
        if first_partial_patch:
            return max(1, math.ceil(n))
        else:
            return max(0, math.floor(n))

ModuleNotFoundError: No module named 'utils'