In [None]:
AOI = 'POLYGON ((-85.299088 40.339368, -85.332047 40.241477, -85.134979 40.229427, -85.157639 40.34146, -85.299088 40.339368))'
START_DATE = "2020-05-01"
END_DATE = "2020-06-30"

REQUEST_ID = '6'

### Detecting boundaries for given AOI

In [None]:
import os
import json
import time
import cv2
import rasterio
import pandas as pd
import numpy as np
import geopandas as gpd
import rasterio.mask
import tempfile
import shapely
import re

from tqdm import tqdm
from os.path import join, basename, split
from skimage import measure
from scipy.ndimage import rotate
from rasterio.features import rasterize, shapes
from rasterio.merge import merge
from shapely.geometry import Polygon, shape, LinearRing
import shapely.wkt
from pathlib import Path
from datetime import datetime
import yaml
import torch

from sentinel2download.downloader import Sentinel2Downloader
from sip_plot_boundary_detection_nn.code.preprocessing import (
    preprocess_sentinel_raw_data, read_raster, extract_tci)
from sip_plot_boundary_detection_nn.code.engine import load_model, val_tfs
from sip_plot_boundary_detection_nn.code.dataset import BoundaryDetector
from sip_plot_boundary_detection_nn.code.filter_polygons import filter_polygons, filter_sindex
from sip_plot_boundary_detection_nn.code.utils import transform_crs

import warnings
warnings.filterwarnings('ignore')
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)
torch.cuda.empty_cache()

In [None]:
default_crs = 'EPSG:4326'

polygon = shapely.wkt.loads(AOI)
aoi_filename = f"{time.time()}_aoi.geojson"
gpd.GeoDataFrame(gpd.GeoSeries([polygon]), columns=["geometry"]).to_file(aoi_filename, driver="GeoJSON")
start_date = datetime.strptime(START_DATE, '%Y-%m-%d')
end_date = datetime.strptime(END_DATE, '%Y-%m-%d')

In [None]:
def get_tiles(aoi_path, sentinel_tiles_path):
    '''
    Returns Sentinel-2 tiles that intersects with specified AoI.

        Parameters:
            aoi_path (str): Path to geojson/shp file with AoI to process.
            sentinel_tiles_path (str): Path to geojson/shp file with all Sentinel-2 tiles.

        Returns:
            date_tile_info (GeoDataFrame): Filtered tiles (tileID, geometry, date).
    '''
    aoi_file = gpd.read_file(aoi_path)
    sentinel_tiles = gpd.read_file(sentinel_tiles_path)
    sentinel_tiles.set_index("Name", drop=False, inplace=True)

    best_interseciton = {"tileID": [], "geometry": []}
    rest_aoi = aoi_file.copy()

    while rest_aoi.area.sum() > 0:
        res_intersection = gpd.overlay(rest_aoi, sentinel_tiles, how="intersection")
        biggest_area_idx = res_intersection.area.argmax()

        tileID = res_intersection.loc[biggest_area_idx, "Name"]
        this_aoi = res_intersection.loc[biggest_area_idx, "geometry"]

        best_interseciton["tileID"].append(tileID)
        best_interseciton["geometry"].append(this_aoi)

        biggest_intersection = sentinel_tiles.loc[[tileID]]
        rest_aoi = gpd.overlay(rest_aoi, biggest_intersection, how="difference")
        sentinel_tiles = sentinel_tiles.loc[res_intersection["Name"]]

    date_tile_info = gpd.GeoDataFrame(best_interseciton)
    date_tile_info.crs = aoi_file.crs
    
    return date_tile_info


In [None]:
def process_polygons(result_df, current_crs, limit=500, dst_crs="EPSG:4326"):
    """
    Prepare result Dataframe with polygons

        Parameters:
            result_df (pd.DataFrame): Result DataFrame
            limit (int): min area for polygon in m2
        Returns:
            GeoDataFrame: GeoDataFrame ready for saving
    """

    gdf = gpd.GeoDataFrame(result_df)
    gdf.crs = current_crs

    gdf.to_crs(dst_crs, inplace=True)
    return gdf


def save_polygons(gdf, save_path):
    if len(gdf) == 0:
        return

    directory = os.path.dirname(save_path)
    if not os.path.exists(directory):
        os.makedirs(directory)
        
    gdf.to_file(save_path, driver='GeoJSON')
    
    return gdf

### Find tile indexes

In [None]:
NB_USER = os.getenv('NB_USER')
BASE = f"/home/{NB_USER}/work"

API_KEY = os.path.join(BASE, ".secret/sentinel2_google_api_key.json")
LOAD_DIR = os.path.join(BASE, "satellite_imagery")
RESULTS_DIR = os.path.join(BASE, "results/pbdnn")
PBD_DIR = os.path.join(BASE, "notebooks/pbdnn")

BANDS = {'TCI', 'B08'}
CONSTRAINTS = {'NODATA_PIXEL_PERCENTAGE': 15.0, 'CLOUDY_PIXEL_PERCENTAGE': 15.0, }
PRODUCT_TYPE = 'L2A'

In [None]:
local = False
ukr_shapefile = os.path.join(BASE, "data/notebooks/pbdnn/sip_plot_boundary_detection_nn/ukr_shapes/custom.geo.json")
if os.path.exists(ukr_shapefile):
    oh_shapefile = os.path.join(BASE, "data/notebooks/pbdnn/sip_plot_boundary_detection_nn/usa_shapes/states_shapes/ohio.geojson")
    in_shapefile = os.path.join(BASE, "data/notebooks/pbdnn/sip_plot_boundary_detection_nn/usa_shapes/states_shapes/indiana.geojson")
    il_shapefile = os.path.join(BASE, "data/notebooks/pbdnn/sip_plot_boundary_detection_nn/usa_shapes/states_shapes/illinois.geojson")
    config_file = os.path.join(BASE, "data/notebooks/pbdnn/sip_plot_boundary_detection_nn/code/config.yaml")
    API_KEY = os.path.join(BASE, "data/notebooks/pbdnn/sentinel2_google_api_key.json")
    PBD_DIR = os.path.join(BASE, "data/notebooks/pbdnn")
    sentinel_tiles_path = os.path.join(BASE, "data/notebooks/pbdnn/sentinel2grid.geojson")
    local = True
else:
    ukr_shapefile = os.path.join(BASE, "notebooks/pbdnn/sip_plot_boundary_detection_nn/ukr_shapes/custom.geo.json")
    oh_shapefile = os.path.join(BASE, "notebooks/pbdnn/sip_plot_boundary_detection_nn/usa_shapes/states_shapes/ohio.geojson")
    in_shapefile = os.path.join(BASE, "notebooks/pbdnn/sip_plot_boundary_detection_nn/usa_shapes/states_shapes/indiana.geojson")
    il_shapefile = os.path.join(BASE, "notebooks/pbdnn/sip_plot_boundary_detection_nn/usa_shapes/states_shapes/illinois.geojson")
    config_file = os.path.join(BASE, "notebooks/pbdnn/sip_plot_boundary_detection_nn/code/config.yaml")
    sentinel_tiles_path = os.path.join(BASE, "notebooks/pbdnn/sentinel2grid.geojson")
    local = False

with open(config_file) as f:
    config = yaml.safe_load(f)

### Check location before filtering non-agricultural lands

In [None]:
locations = {'notebooks/pbdnn/sip_plot_boundary_detection_nn/ukr_shapes/ukr_non_agriculture.geojson': ukr_shapefile,
             'notebooks/pbdnn/sip_plot_boundary_detection_nn/usa_shapes/ohio_shape.geojson': oh_shapefile,
             'notebooks/pbdnn/sip_plot_boundary_detection_nn/usa_shapes/indiana_shape.geojson': in_shapefile,
             'notebooks/pbdnn/sip_plot_boundary_detection_nn/usa_shapes/illinois_shape.geojson': il_shapefile
            }
filter_path = None
aoi = gpd.read_file(aoi_filename)

for filters_, location in locations.items():
    loc = gpd.read_file(location)
    if aoi.intersects(loc)[0]:
        if local:
            filter_path = os.path.join(BASE, "data", filters_)
        else:
            filter_path = os.path.join(BASE, filters_)

### Download data

In [None]:
def _check_folder(tile_folder, file, limit, nodata):
    with rasterio.open(os.path.join(tile_folder, file)) as src:              
        # Read in image as a numpy array
        array = src.read(1)
        # Count the occurance of NoData values in np array
        nodata_count = np.count_nonzero(array == nodata)
        # Get a % of NoData pixels
        nodata_percentage = round(nodata_count / array.size * 100, 2)
        if nodata_percentage <= limit:
            return True
        else:
            return False

In [None]:
def check_nodata(loadings, product_type, limit=15.0, nodata=0):
    filtered = dict()          
    
    for tile, folders in loadings.items():
        filtered_folders = set()
        for folder in folders:
            for file in os.listdir(folder):
                if file.endswith(".jp2") and "OPER" not in file:
                    if product_type == 'L1C' and limit:
                         if _check_folder(folder, file, limit, nodata):
                            filtered_folders.add(folder)
                            break
                    else:
                        filtered_folders.add(folder)
        filtered[tile] = filtered_folders
    return filtered

In [None]:
def load_images(api_key, tiles, start_date, end_date, output_dir, product_type="L2A"):
    loader = Sentinel2Downloader(api_key)
    loadings = dict()
    for tile in tiles:

        loaded = loader.download(product_type,
                                 [tile],
                                 start_date=start_date,
                                 end_date=end_date,
                                 output_dir=output_dir,                       
                                 bands=BANDS,
                                constraints=CONSTRAINTS)
        print(f'{tile} loaded')
        
        loadings[tile] = loaded
    
    tile_folders = dict()
    for tile, tile_paths in loadings.items():
        tile_folders[tile] = {str(Path(tile_path[0]).parent) for tile_path in tile_paths}
    return tile_folders

In [None]:
cloud_regex = r'\<CLOUDY_PIXEL_PERCENTAGE\>[0-9]*\.?[0-9]*</CLOUDY_PIXEL_PERCENTAGE>'

def get_min_clouds(loadings, max_ptc=5):
    filtered = dict()
    min_ptc = max_ptc
    
    for tile, folders in loadings.items():
        filtered_folders = set()
        for folder in folders:
            for file in os.listdir(folder):
                
                if "MTD_TL.xml" in file: # MTD_TL.xml
                    
                    with open(os.path.join(folder, file)) as f:
                        ptc = f.read()
                        ptc = re.search(cloud_regex, ptc)
                        
                        if ptc is not None:
                            ptc = ''.join([x for x in ptc.group(0) if x.isdigit() or x=='.'])
                            filtered_folders.add((ptc, folder))
                            
                else:
                    filtered_folders.add(('50', folder))

        filtered[tile] = sorted(filtered_folders)[0][1]

    return filtered

In [None]:
# Credit for baseline: work/notebooks/pw/raster_predict.ipynb
def create_style():
    style = {'color': '#C0C0C0', 'stroke': 'e80e27', 'stroke-width': 2}

    return str(style)

In [None]:
config['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
config['device'] = 'cpu'

if local:
    model_path = os.path.join(
        BASE, 'data/notebooks/pbdnn/sip_plot_boundary_detection_nn/models/HighResolutionNet32_none_15_nir_tci_multitemporal.pth')
else:
    model_path = os.path.join(
        BASE, 'notebooks/pbdnn/sip_plot_boundary_detection_nn/models/HighResolutionNet32_none_15_nir_tci_multitemporal.pth')

model = load_model(config['model'].lower(), model_path, config['device'], channels=4)
    
print('Model loaded successfully')
    
origin_name = os.path.basename(aoi_filename).replace(".geojson", "")
df = pd.DataFrame({
    'aoi_fp': [''],
    'b08_tile_path': [''],
    'filters_fp': [''],
    'geom_fp': [''],
    'tci_tile_path': ['']
})
detector = BoundaryDetector(model, df, tiles_dir=f'/home/{NB_USER}/work/satellite_imagery')
print('Detector is ready')

In [None]:
result_df = pd.DataFrame([])

if filter_path is not None:
    filters = gpd.read_file(filter_path)
    aoi = gpd.read_file(aoi_filename).to_crs(filters.crs)
            
    if 'ukr_non_agriculture' in filter_path:
        filters['geometry'] = filters.buffer(0)

    filters = filters[filters.intersects(aoi.geometry.values[0])]
    if len(filters) > 1:
        df_crs = filters.crs
        filters = gpd.GeoDataFrame(
            {'geometry': [filters[filters.geometry.is_valid].unary_union]})
        filters.crs = df_crs
        
    print('filters are loaded')
else:
    filters = None

In [None]:
filters

In [None]:
year = START_DATE.split('-')[0]

starts = [f'{year}-06-01', f'{year}-06-08', f'{year}-06-15', f'{year}-06-22', 
          f'{year}-07-01', f'{year}-07-08', f'{year}-07-15', f'{year}-07-22']

ends = [f'{year}-06-07', f'{year}-06-14', f'{year}-06-21', f'{year}-06-30', 
        f'{year}-07-07', f'{year}-07-14', f'{year}-07-21', f'{year}-07-30']

In [None]:
def stitch_tiles(paths, out_raster_path='test.tif'):
    tiles = []
    tmp_files = []
    
    for i, path in enumerate(paths):
        if i == 0:
            file = rasterio.open(path)
            meta, crs = file.meta, file.crs
        else:
            tmp_path = path.replace(
                '.jp2', '_tmp.jp2').replace('.tif', '_tmp.tif')
            crs_transformed = transform_crs(path, tmp_path, 
                                            dst_crs=crs, 
                                            resolution=None)
            tmp_files.append(crs_transformed)
            file = rasterio.open(crs_transformed)
        tiles.append(file)
            
    tile_arr, transform = merge(tiles, method='last')
    
    
    meta.update({"driver": "GTiff",
                 "height": tile_arr.shape[1],
                 "width": tile_arr.shape[2],
                 "transform": transform,
                 "crs": crs})
    
    if '.jp2' in out_raster_path:
        out_raster_path = out_raster_path.replace('.jp2', '_merged.tif')
    else:
        out_raster_path = out_raster_path.replace('.tif', '_merged.tif')
    print(f'saved raster {out_raster_path}')

    for tile in tiles:
        tile.close()
        
    for tmp_file in tmp_files:
        try:
            os.remove(tmp_file)
        except FileNotFoundError:
            print(f'Tile {tmp_file} was removed or renamed, skipping')
        
    with rasterio.open(out_raster_path, "w", **meta) as dst:
        dst.write(tile_arr)
    
    return out_raster_path

In [None]:
def mean_img_pred(images):
    out_path = images[0].replace('.tif', '_mean.tif')
    to_stack = []
    for img in images:
        with rasterio.open(img) as src:
            img_arr = src.read(1)
            to_stack.append(img_arr)
            meta=src.meta
            
    mean_img = np.stack(to_stack, axis=-1)
    mean_img = np.mean(mean_img, axis=-1)
    
    with rasterio.open(out_path, 'w', **meta) as dst:
        dst.write(mean_img.astype(np.uint8), 1)
        
    return out_path

### Running boundary detection and filtration of predictions

In [None]:
ens_rasters = []
all_dirs = []

for i in range(len(starts)):
    tci_rasters, b08_rasters = [], []
    start_date, end_date = starts[i], ends[i]
    
    date_tile_info = get_tiles(aoi_filename, sentinel_tiles_path)
    loadings = load_images(API_KEY, date_tile_info.tileID.values, start_date, end_date, LOAD_DIR, PRODUCT_TYPE)
    checked = check_nodata(loadings, PRODUCT_TYPE)
    
    try:
        checked = get_min_clouds(checked)
    except Exception:
        print(f'No clean raster found for period from {start_date} to {end_date}, skipping')
        continue
    
    for i, tile in date_tile_info.iterrows():

        try:
            tile_folder = Path(checked[tile.tileID])
            print(f'filtered: {tile_folder}')
            
        except Exception as ex:
            print(ex)
            continue

        full_tci_tile = [os.path.join(tile_folder, filename) for filename in os.listdir(tile_folder) if 'TCI_10m.jp2' in filename]
        full_b08_tile = [os.path.join(tile_folder, filename) for filename in os.listdir(tile_folder) if 'B08_10m.jp2' in filename]
        tci_rasters.append(full_tci_tile[0])
        b08_rasters.append(full_b08_tile[0])
        
        all_dirs.append(tile_folder)

    print(f'rasters to be processed: {len(tci_rasters)}')
    
    if len(tci_rasters)>1:
        raster_path_tci = stitch_tiles(tci_rasters)
        raster_path_b08 = stitch_tiles(b08_rasters)
    elif len(tci_rasters)==1:
        raster_path = tci_rasters[0]
    elif len(tci_rasters)==0:
        print('WARNING: no rasters were found!')
        continue
        
    if '.jp2' in raster_path:
        out_raster = raster_path.replace('.jp2', '_prediction.tif')
    else:
        out_raster = raster_path.replace('.tif', '_prediction.tif')
        
    raster_dir = os.path.join(*os.path.split(raster_path)[:-1])

    pred_tif_path = detector.raster_prediction(raster_dir=raster_dir,
                                               out_raster_path=out_raster,
                                               aoi_path=aoi_filename,
                                               conf_thresh=config['threshold'],
                                               bands=['NIR', 'TCI'])
    ens_rasters.append(pred_tif_path)
ens_rasters

In [None]:
out_geom = aoi_filename.replace('_aoi.', '_prediction.')
ens_raster = mean_img_pred(ens_rasters)

polygons = detector.process_raster_predictions(ens_raster,
                                               shapes_path=out_geom,
                                               aoi_path=aoi_filename, 
                                               conf_thresh=config['threshold'],
                                               min_poly_area=10e3)
polygons = gpd.GeoDataFrame(polygons)

In [None]:
if filters is not None:
    try:
        polys = filter_sindex(polygons, filters)
    except Exception as e:
        print(e)
        filters['geometry'] = filters.buffer(0)
        polys = filter_polygons(polygons, filters)
    target_crs = filters.crs
else:
    print('Non-agricultural data is not found for a given AOI, proceeding without filtering')
    polys = polygons
    target_crs = polys.crs

In [None]:
df = pd.DataFrame({"geometry": polys.geometry}).reset_index()
df["id"] = pd.Series(map(lambda x: f"{origin_name}_{tile.tileID}_{x}", df.index.values))
df["tileID"] = tile.tileID

In [None]:
result_df = pd.concat([result_df, df])
gdf = process_polygons(result_df, target_crs)
gdf.head(3)

### Saving polygons into results folder and adding metadata

In [None]:
tmp_suffix = ".temp"
gdf['style'] = create_style()
save_path = os.path.join(RESULTS_DIR, f"{REQUEST_ID}_{START_DATE}_{END_DATE}.geojson{tmp_suffix}")
save_polygons(gdf, save_path)

try:
    with open(save_path) as file:
        geoms = json.load(file)
except Exception:
    geoms = {}
    
geoms['end_date'] = END_DATE
geoms['start_date'] = START_DATE
geoms['name'] = "Fields' boundaries"
geoms['request_id'] = REQUEST_ID

with open(save_path, 'w') as file:
    json.dump(geoms, file)
os.rename(save_path, save_path[:-5])

In [None]:
try:
    os.remove(aoi_filename)
except FileNotFoundError:
    print('No helping geojson files were generated')

In [None]:
try:
    os.remove(out_raster)
except FileNotFoundError:
    print('No helping raster files were generated')
    
try:
    os.remove(ens_raster)
except FileNotFoundError:
    print('No helping raster files were generated')

In [None]:
if len(ens_rasters) > 0:
    for raster in ens_rasters:
        try:
            os.remove(raster)
        except FileNotFoundError:
            print('No helping prediction raster found, skipping')

In [None]:
for folder in all_dirs:
    for file in os.listdir(folder):
        try:
            os.remove(os.path.join(folder, file))
        except IsADirectoryError:
            for f in os.listdir(os.path.join(folder, file)):
                os.remove(os.path.join(folder, file, f))
            os.rmdir(os.path.join(folder, file))
        
    try:
        os.rmdir(folder)
    except FileNotFoundError:
        print('No directories with rasters found')