In [None]:
AOI = 'POLYGON ((29.93584944158209 48.49763260054361, 29.92222752138016 48.19522597206074, 29.59257705249343 48.20884789226267, 29.60347458865497 48.51670328882631, 29.93584944158209 48.49763260054361))'
START_DATE = "2020-05-01"
END_DATE = "2020-06-30"
REQUEST_ID = 6

### Detecting boundaries for given AOI

In [None]:
import os
import time
import cv2
import rasterio
import pandas as pd
import numpy as np
import geopandas as gpd
import rasterio.mask
import tempfile
import shapely
import re

from tqdm import tqdm
from os.path import join, basename, split
from skimage import measure
from scipy.ndimage import rotate
from rasterio.features import rasterize, shapes
from shapely.geometry import Polygon, shape, LinearRing
import shapely.wkt
from pathlib import Path
from datetime import datetime
import yaml

from sentinel2download.downloader import Sentinel2Downloader
from sip_plot_boundary_detection_nn.code.preprocessing import (
    preprocess_sentinel_raw_data, read_raster, extract_tci)
from sip_plot_boundary_detection_nn.code.engine import *
from sip_plot_boundary_detection_nn.code.dataset import BoundaryDetector
from sip_plot_boundary_detection_nn.code.filter_polygons import filter_polygons

import warnings
warnings.filterwarnings('ignore')
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [None]:
import torch.cuda as cuda
cuda.empty_cache()

In [None]:
default_crs = 'EPSG:4326'

polygon = shapely.wkt.loads(AOI)
aoi_filename = f"{time.time()}_aoi.geojson"
gpd.GeoDataFrame(gpd.GeoSeries([polygon]), columns=["geometry"]).to_file(aoi_filename, driver="GeoJSON")
start_date = datetime.strptime(START_DATE, '%Y-%m-%d')
end_date = datetime.strptime(END_DATE, '%Y-%m-%d')

In [None]:
def get_tiles(aoi_path, sentinel_tiles_path):
    '''
    Returns Sentinel-2 tiles that intersects with specified AoI.

        Parameters:
            aoi_path (str): Path to geojson/shp file with AoI to process.
            sentinel_tiles_path (str): Path to geojson/shp file with all Sentinel-2 tiles.

        Returns:
            date_tile_info (GeoDataFrame): Filtered tiles (tileID, geometry, date).
    '''
    aoi_file = gpd.read_file(aoi_path)
    sentinel_tiles = gpd.read_file(sentinel_tiles_path)
    sentinel_tiles.set_index("Name", drop=False, inplace=True)

    best_interseciton = {"tileID": [], "geometry": []}
    rest_aoi = aoi_file.copy()

    while rest_aoi.area.sum() > 0:
        res_intersection = gpd.overlay(rest_aoi, sentinel_tiles, how="intersection")
        biggest_area_idx = res_intersection.area.argmax()

        tileID = res_intersection.loc[biggest_area_idx, "Name"]
        this_aoi = res_intersection.loc[biggest_area_idx, "geometry"]

        best_interseciton["tileID"].append(tileID)
        best_interseciton["geometry"].append(this_aoi)

        biggest_intersection = sentinel_tiles.loc[[tileID]]
        rest_aoi = gpd.overlay(rest_aoi, biggest_intersection, how="difference")
        sentinel_tiles = sentinel_tiles.loc[res_intersection["Name"]]

    date_tile_info = gpd.GeoDataFrame(best_interseciton)
    date_tile_info.crs = aoi_file.crs
    
    return date_tile_info


In [None]:
def process_polygons(result_df, current_crs, limit=500, dst_crs="EPSG:4326"):
    """
    Prepare result Dataframe with polygons

        Parameters:
            result_df (pd.DataFrame): Result DataFrame
            limit (int): min area for polygon in m2
        Returns:
            GeoDataFrame: GeoDataFrame ready for saving
    """

    gdf = gpd.GeoDataFrame(result_df)
    gdf.crs = current_crs

    gdf.to_crs(dst_crs, inplace=True)
    return gdf


def save_polygons(gdf, save_path):
    if len(gdf) == 0:
        return

    directory = os.path.dirname(save_path)
    if not os.path.exists(directory):
        os.makedirs(directory)
        
    save_path = save_path + ".temp" 
    gdf.to_file(save_path, driver='GeoJSON')
    os.rename(save_path, save_path[:-5])
    
    return gdf

### Find tile indexes

In [None]:
NB_USER = os.getenv('NB_USER')
BASE = f"/home/{NB_USER}/work"

API_KEY = os.path.join(BASE, ".secret/sentinel2_google_api_key.json")
LOAD_DIR = os.path.join(BASE, "satellite_imagery")
RESULTS_DIR = os.path.join(BASE, "results/pbdnn")
PBD_DIR = os.path.join(BASE, "notebooks/pbdnn")

BANDS = {'TCI'}
CONSTRAINTS = {'NODATA_PIXEL_PERCENTAGE': 15.0, 'CLOUDY_PIXEL_PERCENTAGE': 40.0, }
PRODUCT_TYPE = 'L2A'

In [None]:
local = False
if local:
    ukr_shapefile = os.path.join(BASE, "data/notebooks/pbdnn/sip_plot_boundary_detection_nn/ukr_shapes/custom.geo.json")
    usa_shapefile = os.path.join(BASE, "data/notebooks/pbdnn/sip_plot_boundary_detection_nn/usa_shapes/custom.geo.json")
    config_file = os.path.join(BASE, "data/notebooks/pbdnn/sip_plot_boundary_detection_nn/code/config.yaml")
    API_KEY = os.path.join(BASE, "data/notebooks/pbdnn/sentinel2_google_api_key.json")
    PBD_DIR = os.path.join(BASE, "data/notebooks/pbdnn")
else:
    ukr_shapefile = os.path.join(BASE, "notebooks/pbdnn/sip_plot_boundary_detection_nn/ukr_shapes/custom.geo.json")
    usa_shapefile = os.path.join(BASE, "notebooks/pbdnn/sip_plot_boundary_detection_nn/usa_shapes/custom.geo.json")
    config_file = os.path.join(BASE, "notebooks/pbdnn/sip_plot_boundary_detection_nn/code/config.yaml")

with open(config_file) as f:
    config = yaml.safe_load(f)

### Check location before filtering non-agricultural lands

In [None]:
ukraine = gpd.read_file(ukr_shapefile)
usa = gpd.read_file(usa_shapefile)
aoi = gpd.read_file(aoi_filename)

if local:
    if aoi.intersects(ukraine)[0]:
        filter_path = os.path.join(
            BASE, "data/notebooks/pbdnn/sip_plot_boundary_detection_nn/ukr_shapes/ukr_non_agriculture.geojson")
    elif aoi.intersects(usa)[0]:
        filter_path = os.path.join(
            BASE, "data/notebooks/pbdnn/sip_plot_boundary_detection_nn/usa_shapes/us_shape.geojson")
    else:
        filter_path = None
else:
    if aoi.intersects(ukraine)[0]:
        filter_path = os.path.join(
            BASE, "notebooks/pbdnn/sip_plot_boundary_detection_nn/ukr_shapes/ukr_non_agriculture.geojson")
    elif aoi.intersects(usa)[0]:
        filter_path = os.path.join(
            BASE, "notebooks/pbdnn/sip_plot_boundary_detection_nn/usa_shapes/us_shape.geojson")
    else:
        filter_path = None

### Download data

In [None]:
def _check_folder(tile_folder, file, limit, nodata):
    with rasterio.open(os.path.join(tile_folder, file)) as src:              
        # Read in image as a numpy array
        array = src.read(1)
        # Count the occurance of NoData values in np array
        nodata_count = np.count_nonzero(array == nodata)
        # Get a % of NoData pixels
        nodata_percentage = round(nodata_count / array.size * 100, 2)
        if nodata_percentage <= limit:
            return True
        else:
            return False

In [None]:
def check_nodata(loadings, product_type, limit=15.0, nodata=0):
    filtered = dict()          
    
    for tile, folders in loadings.items():
        filtered_folders = set()
        for folder in folders:
            for file in os.listdir(folder):
                if file.endswith(".jp2") and "OPER" not in file:
                    if product_type == 'L1C' and limit:
                         if _check_folder(folder, file, limit, nodata):
                            filtered_folders.add(folder)
                            break
                    else:
                        filtered_folders.add(folder)
        filtered[tile] = filtered_folders
    return filtered

In [None]:
def filter_date(loadings):
    def _find_last_date(folders):        
        dates = list()
        for folder in folders:
            search = re.search(r"_(\d+)T\d+_", str(folder))
            date = search.group(1)
            date = datetime.strptime(date, '%Y%m%d')
            dates.append(date)    
        last_date = max(dates)
        last_date = datetime.strftime(last_date, '%Y%m%d')
        return last_date
    
    filtered = dict()
    for tile, folders in loadings.items():
        
        try:
            last_date = _find_last_date(folders)
            for folder in folders:
                if last_date in folder:
                    filtered[tile] = folder
        except Exception as ex:
            pass
    return filtered

In [None]:
def load_images(api_key, tiles, start_date, end_date, output_dir, product_type="L2A"):
    loader = Sentinel2Downloader(api_key)
    loadings = dict()
    for tile in tiles:

        loaded = loader.download(product_type,
                                 [tile],
                                 start_date=start_date,
                                 end_date=end_date,
                                 output_dir=output_dir,                       
                                 bands=BANDS,
                                constraints=CONSTRAINTS)
        
        loadings[tile] = loaded
    
    tile_folders = dict()
    for tile, tile_paths in loadings.items():
        tile_folders[tile] = {str(Path(tile_path[0]).parent) for tile_path in tile_paths}
    return tile_folders

In [None]:
# Credit for baseline: work/notebooks/pw/raster_predict.ipynb
def create_style(class_):
    colors = dict(boundary='#e80e27')
    
    style = dict(color=colors.get(class_.lower(), '#C0C0C0'),
                stroke='#e80e27')
    style['stroke-width'] = 2
    return str(style)

In [None]:
sentinel_tiles_path = "sentinel2grid.geojson"
model = make_unet_plusplus()
config['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'

if local:
    config['model_weights_path'] = os.path.join(
        BASE, 'data/notebooks/pbdnn/sip_plot_boundary_detection_nn/models/chkpt_UnetPlusPlus_imagenet_200_new_dataset_v3.pt')
else:
    config['model_weights_path'] = os.path.join(
        BASE, 'notebooks/pbdnn/sip_plot_boundary_detection_nn/models/chkpt_UnetPlusPlus_imagenet_200_new_dataset_v3.pt')

if config['device'] == 'cuda':
    model.load_state_dict(torch.load(config['model_weights_path']))
else:
    model.load_state_dict(torch.load(
        config['model_weights_path'],map_location=torch.device('cpu')))
    
date_tile_info = get_tiles(aoi_filename, sentinel_tiles_path)
loadings = load_images(API_KEY, date_tile_info.tileID.values, START_DATE, END_DATE, LOAD_DIR, PRODUCT_TYPE)
checked = check_nodata(loadings, PRODUCT_TYPE)
filtered = filter_date(checked)
    
origin_name = os.path.basename(aoi_filename).replace(".geojson", "")

detector = BoundaryDetector(model, tiles_dir=f'/home/{NB_USER}/work/satellite_imagery')

In [None]:
result_df = pd.DataFrame([])

if filter_path is not None:
    filters = gpd.read_file(filter_path)
    aoi = gpd.read_file(aoi_filename).to_crs(filters.crs)
            
    if 'ukr_non_agriculture' in filter_path:
        filters['geometry'] = filters.buffer(0)
            
    filters = filters[filters.intersects(aoi.geometry.values[0])]

In [None]:
from rasterio.merge import merge

def stitch_tiles(paths, out_raster_path='test.tif'):
    tiles = []
    
    for i, p in enumerate(paths):
        file = rasterio.open(p)
        tiles.append(file)
            
        
    tile_arr, transform = merge(tiles, method='last')
    meta, crs = file.meta, file.crs
    
    meta.update({"driver": "GTiff",
                 "height": tile_arr.shape[1],
                 "width": tile_arr.shape[2],
                 "transform": transform,
                 "crs": crs})
    
    out_raster_path = out_raster_path.replace('.jp2', '.tif')

    for t in tiles:
        t.close()
        
    with rasterio.open(out_raster_path, "w", **meta) as dst:
        dst.write(tile_arr)
    
    return out_raster_path

In [None]:
with tempfile.TemporaryDirectory(dir=PBD_DIR) as tmpdirname:
    rasters = []
    for i, tile in date_tile_info.iterrows():
        try:
            tile_folder = Path(filtered[tile.tileID])
        except Exception as ex:
            continue

        full_tile = [os.path.join(tile_folder, x) for x in os.listdir(tile_folder) if x.endswith('.jp2')]
        rasters.append(full_tile[0])
        
    if len(rasters)>1:
        raster_path = stitch_tiles(rasters)
    elif len(rasters)==0:
        raster_path = rasters[0]
            
    out_raster = raster_path.replace('.tif', '_prediction.tif')
    out_geom = aoi_filename.replace('_aoi.', '_prediction.')
        
    start_raster = time.time()
    pred_tif_path = detector.raster_prediction(in_raster_path=raster_path,
                                                out_raster_path=out_raster,
                                                aoi_path=aoi_filename,
                                                conf_thresh=0.25)
        
    start_poly = time.time()
    polygons = detector.process_raster_predictions(pred_tif_path,
                                                    shapes_path=out_geom,
                                                    aoi_path=aoi_filename, 
                                                    conf_thresh=0.25)
    polygons = gpd.GeoDataFrame(polygons)

    try:
        polys = filter_polygons(polygons, filters)
    except Exception:
        filters['geometry'] = filters.buffer(0)
        polys = filter_polygons(polygons, filters)
            
    df = pd.DataFrame({"geometry": polys.geometry}).reset_index()
    df["id"] = pd.Series(map(lambda x: f"{origin_name}_{tile.tileID}_{x}", df.index.values))
    df["tileID"] = tile.tileID
    df["start_date"] = START_DATE
    df["end_date"] = END_DATE

    result_df = pd.concat([result_df, df])

gdf = process_polygons(result_df, filters.crs)
gdf['class_'] = 'boundary'
gdf['style'] = gdf.class_.apply(lambda cl: create_style(cl))
save_path = os.path.join(RESULTS_DIR, f"{origin_name}_prediction.geojson")
save_polygons(gdf, save_path)