### For given aoi, prepare EBI sentinel latest images 

In [None]:
import os
import geopandas as gp
import numpy as np
import rasterio
import re
import pyproj
import json
import geojson
import shutil
import tempfile
import uuid

from geojson import Feature

import rasterio.mask
from rasterio.plot import reshape_as_raster
from rasterio.merge import merge
from rasterio.warp import calculate_default_transform, reproject, Resampling

from shapely import wkt
from shapely.ops import transform


from pathlib import Path
from datetime import datetime, timedelta
from sentinel2download.downloader import Sentinel2Downloader
from sentinel2download.overlap import Sentinel2Overlap

#### 0. Inputs

In [None]:
# Inputs
AOI = os.getenv('AOI', default="POLYGON ((-85.299088 40.339368, -85.332047 40.241477, -85.134979 40.229427, -85.157639 40.34146, -85.299088 40.339368))")
START_DATE = os.getenv("START_DATE", default="2020-07-01")
END_DATE = os.getenv("END_DATE", default="2020-08-01")  # on production user chooses only start date, so backend set end date = start date in the request to the model
SENTINEL2_GOOGLE_API_KEY = os.getenv('SENTINEL2_GOOGLE_API_KEY')
SATELLITE_CACHE_FOLDER = os.getenv('SENTINEL2_CACHE')

# Output folder
OUTPUT_FOLDER = os.getenv('OUTPUT_FOLDER')

In [None]:
OUTPUT_NODATA_FOLDER = os.path.join(OUTPUT_FOLDER, "nodata/")
OUTPUT_EBI_FILE = os.path.join(OUTPUT_FOLDER, "EBI.tif")
OUTPUT_TCI_FILE = os.path.join(OUTPUT_FOLDER, "TCI.tif")

os.makedirs(OUTPUT_NODATA_FOLDER, exist_ok=True)
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

#### 1. Transform AOI and get bound_box

In [None]:
aoi = gp.GeoDataFrame(geometry=[wkt.loads(AOI)], crs="epsg:4326")    
aoi_filename = "provided_aoi.geojson"
aoi.to_file(aoi_filename, driver="GeoJSON") 

In [None]:
COLORMAP_BRBG = os.path.join('/code', 'colormap.npy') 

#### 2. Overlap AOI with sentinel2grid 

In [None]:
s2overlap = Sentinel2Overlap(aoi_path=aoi_filename)
overlap_tiles = s2overlap.overlap_with_geometry()

#### 3. Load images

In [None]:
LOAD_DIR = SATELLITE_CACHE_FOLDER

BANDS = {"TCI", "CLD"}
NODATA_PIXEL_PERCENTAGE = 10.0
SEARCH_CLOUDY_PIXEL_PERCENTAGE = 80.0
AOI_CLOUDY_PIXEL_PERCENTAGE = 15.0
CONSTRAINTS = {'CLOUDY_PIXEL_PERCENTAGE': SEARCH_CLOUDY_PIXEL_PERCENTAGE}
PRODUCT_TYPE = 'L2A'

LAYERS = ['EBI', 'TCI']

In [None]:
def shift_date(date, delta=5, format='%Y-%m-%d'):
    date = datetime.strptime(date, format)
    date = date - timedelta(days=delta)    
    return datetime.strftime(date, format)

#### 3.1 Define max shift in dates - 30 days for loading images

In [None]:
MAX_SHIFT = 30
MAX_SHIFT_ITERS = 2

In [None]:
def dump_no_data_geojson(polygon, geojson_path):
    NO_DATA = 'No data'
    style = dict(color='red')
    feature = Feature(geometry=polygon, properties=dict(label=NO_DATA, style=style))
    feature['start_date'] = START_DATE
    feature['end_date'] = END_DATE
    feature['name'] = NO_DATA
    
    with open(geojson_path, 'w') as f:
        geojson.dump(feature, f)

def check_nodata_percentage_crop(tile_path, 
                                 aoi, 
                                 nodata_percentage_limit, 
                                 nodata):
    with rasterio.open(tile_path) as src:
        polygon = aoi.to_crs(src.meta['crs']).geometry[0]
        band, _ = rasterio.mask.mask(src, [polygon], crop=True, filled=False, indexes=1)
        masked_band = band[~band.mask]
        nodata_count = np.count_nonzero(masked_band == nodata)
        nodata_percentage = round(nodata_count / masked_band.size * 100, 2)
    if nodata_percentage>=nodata_percentage_limit:
        return True
    else:
        return False

def check_cloud_percentage_crop(tile_path, 
                                aoi, 
                                cloud_percentage_limit,
                                cloud_probability=50):
    with rasterio.open(tile_path) as src:
        polygon = aoi.to_crs(src.meta['crs']).geometry[0]
        band, _ = rasterio.mask.mask(src, [polygon], crop=True, filled=False, indexes=1)
        masked_band = band[~band.mask]
        cloud_count = np.count_nonzero(masked_band >= cloud_probability)
        cloud_percentage = round(cloud_count / masked_band.size * 100, 2)
    if cloud_percentage>=cloud_percentage_limit:
        return True
    else:
        return False

def check_tile_validity(tile_folder, aoi, cloud_percentage_limit, nodata_percentage_limit):
    band_paths = [os.path.join(tile_folder, i) for i in os.listdir(tile_folder)]
    skip_tile = False
    for band_path in band_paths:
        if  '.jp2' != Path(band_path).suffix:
            continue
        if "MSK_CLDPRB_20m" in band_path:
            cloud_check = check_cloud_percentage_crop(band_path, aoi, cloud_percentage_limit)
            if cloud_check:
                skip_tile=True
                break
        else:
            nodata_check = check_nodata_percentage_crop(band_path, aoi, nodata_percentage_limit, 0)
            if nodata_check:
                skip_tile=True
                break
    return skip_tile, band_paths

def validate_tile_downloads(loaded, tile, loadings, aoi, cloud_percentage_limit, nodata_percentage_limit):
    print(f"Validating images for tile: {tile}...")
    if not loaded:
        print(f"Images for tile {tile} were not loaded!")
        return loadings
    loaded_tile_folders = set([Path(i[0]).parent for i in loaded])
    tile_bands = []
    for loaded_tile_folder in loaded_tile_folders:
        skip_tile, band_paths = check_tile_validity(loaded_tile_folder, aoi, cloud_percentage_limit, nodata_percentage_limit)
        if skip_tile:
            shutil.rmtree(loaded_tile_folder)
        else:
            tile_bands += band_paths
    if tile_bands:
        loadings[tile] = tile_bands
    else:
        print(f"Tile images didn't match nodata/cloud constraints, so they were removed") 
    print(f"Validating images for tile {tile} finished")  
    return loadings

def load_images(tiles, start_date, end_date, aoi):
    loader = Sentinel2Downloader(SENTINEL2_GOOGLE_API_KEY)
    loadings = dict()

    for tile in tiles:
        print(f"Loading images for tile: {tile}...")
        count = 0
        start = start_date
        end = end_date
        loaded = []
        while count < MAX_SHIFT_ITERS:
            loaded = loader.download(PRODUCT_TYPE,
                                [tile],
                                start_date=start,
                                end_date=end,
                                output_dir=SATELLITE_CACHE_FOLDER,               
                                bands=BANDS,
                                constraints=CONSTRAINTS)
            if not loaded:
                end = start_date
                start = shift_date(start_date, delta=MAX_SHIFT) 
                print(f"For tile: {tile} and dates {start_date} {end_date} proper images not found! Shift dates to {start} {end}!")
            else:
                break
            count += 1
        print(f"Loading images for tile {tile} finished")
        if count < MAX_SHIFT_ITERS:
            loadings = validate_tile_downloads(loaded, tile, loadings, aoi, AOI_CLOUDY_PIXEL_PERCENTAGE, NODATA_PIXEL_PERCENTAGE)
            # TO-DO:
            # Agree with the Product how we should proceed analysis in case when the date for chosen date is not valid for making prediction 
            # (e.g. apply date shifts and show valid results but for another date or just message a user that there is no good data available)
        else:
            geojson_path = os.path.join(OUTPUT_FOLDER, f"{START_DATE}_{END_DATE}_no_data.geojson")
            dump_no_data_geojson(aoi.geometry[0], geojson_path)
            raise ValueError("Images not loaded for given AOI. Change dates, constraints")
    return loadings

In [None]:
loadings = load_images(overlap_tiles.Name.values, START_DATE, END_DATE, aoi)

In [None]:
loadings

#### 3.2 Filter loadings for every tile, get last image in daterange and bands

In [None]:
def filter_by_date(loadings):
    def _find_last_date(folders):        
        dates = list()
        for folder in folders:        
            search = re.search(r"_(\d+)T\d+_", str(folder))
            date = search.group(1)
            date = datetime.strptime(date, '%Y%m%d')
            dates.append(date)    
        last_date = max(dates)
        last_date = datetime.strftime(last_date, '%Y%m%d')
        return last_date
    
    filtered = {
        'TCI': []
    }
    for tile, items in loadings.items():
        try:
            last_date = _find_last_date(items)
            for path in items:
                if last_date in path:
                    if 'TCI_10m.jp2' in path:
                        filtered['TCI'] += [path]
            filtered['date'] = last_date
        except Exception as ex:
            print(f"Error for {tile}: {str(ex)}")
    return filtered

In [None]:
filtered = filter_by_date(loadings)
filtered

#### 4. Calculate EBI

#### 4.1 Prepare color coding for EBI

In [None]:
def prepare_colors(colors):
    colors = np.load(COLORMAP_BRBG)
    if colors.shape[1] == 4:
        # delete last channel, we use rgb
        colors = np.delete(colors, 3, axis=1)
    # colormap colors values in range [0-255], but in our case 0 - no data, -> have to color as [0, 0, 0] 
    colors[colors == 0] = 1
    colors[0] = [0, 0, 0]
    return colors

In [None]:
COLORS = prepare_colors(COLORMAP_BRBG)

In [None]:
COLORS.shape

In [None]:
colormap_tag = {"name": "Enhanced Blooming index", "colors": [], "labels": ["low", "high"]}

for color in COLORS:
    color_str = ",".join(list(map(lambda x: str(int(x)), color)))
    colormap_tag['colors'].append(color_str)

colormap_tag = json.dumps(colormap_tag)
# example of colormap_tag format
# {"name": "Vegetation index", "colors": ["0,0,0", "255,0,0", "0,255,0", "0,0,255" ...], "labels": ["low", "high"]}


In [None]:
def scale(val, a=1, b=255, nodata=0.0):
    min_val = 0
    max_val = 1
    scaled = (b - a) * (val - min_val) / (max_val - min_val) + a
    scaled = np.around(scaled)
    scaled[np.isnan(scaled) == True] = nodata
    scaled = scaled.astype(np.uint8)
    return scaled

In [None]:
def color_ebi(scaled, colors):
    colored = np.reshape(colors[scaled.flatten()], tuple((*scaled.shape, 3)))
    colored = reshape_as_raster(colored)
    return colored

In [None]:
def to_crs(poly, target, current='EPSG:4326'):
    project = pyproj.Transformer.from_crs(pyproj.CRS(current), pyproj.CRS(target), always_xy=True).transform
    transformed_poly = transform(project, poly)
    return transformed_poly 

In [None]:
def crop(input_path, output_path, polygon, date, name=None, colormap=None):
    with rasterio.open(input_path) as src:
        out_image, out_transform = rasterio.mask.mask(src, [polygon], crop=True)
        out_meta = src.meta
        
        out_meta.update(driver='GTiff',
                        height=out_image.shape[1],
                        width=out_image.shape[2],
                        transform=out_transform,
                        nodata=0, )

    with rasterio.open(output_path, "w", **out_meta) as dest:
        dest.update_tags(start_date=date, end_date=date)
        if name:
            dest.update_tags(name=name)
        if colormap:
            dest.update_tags(colormap=colormap)
        dest.write(out_image)

In [None]:
def transform_crs(data_path, save_path, dst_crs="EPSG:4326", resolution=(10, 10)):
    with rasterio.open(data_path) as src:
        if resolution is None:
            transform, width, height = calculate_default_transform(
                src.crs, dst_crs, src.width, src.height, *src.bounds
            )
        else:
            transform, width, height = calculate_default_transform(
                src.crs,
                dst_crs,
                src.width,
                src.height,
                *src.bounds,
                resolution=resolution,
            )
        kwargs = src.meta.copy()
        kwargs.update(
            {"crs": dst_crs, "transform": transform, "width": width, "height": height}
        )
        with rasterio.open(save_path, "w", **kwargs) as dst:
            for i in range(1, src.count + 1):
                reproject(
                    source=rasterio.band(src, i),
                    destination=rasterio.band(dst, i),
                    src_transform=src.transform,
                    src_crs=src.crs,
                    dst_transform=transform,
                    dst_crs=dst_crs,
                    resampling=Resampling.nearest,
                )

    return save_path

In [None]:
def stitch_tiles(paths, out_raster_path, date, name=None, colormap=None):
    if not isinstance(paths[0], str):
        paths = [str(x) for x in paths]
    tiles = []
    tmp_files = []
    
    crs = None
    meta = None
    for i, path in enumerate(paths):
        if i == 0:
            file = rasterio.open(path)
            meta, crs = file.meta, file.crs
        else:
            tmp_path = path.replace(
                '.jp2', '_tmp.jp2').replace('.tif', '_tmp.tif')
            crs_transformed = transform_crs(path, tmp_path, 
                                            dst_crs=crs, 
                                            resolution=None)
            tmp_files.append(crs_transformed)
            file = rasterio.open(crs_transformed)
        tiles.append(file)
            
    tile_arr, transform = merge(tiles, method='last')
    
    meta.update({"driver": "GTiff",
                 "height": tile_arr.shape[1],
                 "width": tile_arr.shape[2],
                 "transform": transform,
                 "crs": crs})
    
    if '.jp2' in out_raster_path:
        out_raster_path = out_raster_path.replace('.jp2', '.tif')
    print(f'saved raster {out_raster_path}')

    for tile in tiles:
        tile.close()
        
    for tmp_file in tmp_files:
        try:
            os.remove(tmp_file)
        except FileNotFoundError:
            print(f'Tile {tmp_file} was removed or renamed, skipping')
        
    with rasterio.open(out_raster_path, "w", **meta) as dst:
        dst.update_tags(start_date=date, end_date=date)
        if name:
            dst.update_tags(name=name)
        if colormap:
            dst.update_tags(colormap=colormap)
        dst.write(tile_arr)
    
    return out_raster_path

In [None]:
def calculate_EBI(tci_path, save_path, eps=256):
    np.seterr(divide='ignore', invalid='ignore')
    with rasterio.open(tci_path) as src:
        red = src.read(1).astype(rasterio.float32)
        green = src.read(2).astype(rasterio.float32)
        blue = src.read(3).astype(rasterio.float32)
        crs = str(src.crs)
    EBI = (red + green + blue) / ((green / blue) * (red - blue + eps))
    EBI = scale(EBI)

    colored = color_ebi(EBI, COLORS)
    out_meta = src.meta.copy()
    out_meta.update(dtype=rasterio.uint8,
                    driver='GTiff',
                    nodata=0,
                    count=3, )
    
    # Create the file
    with rasterio.open(save_path, 'w', **out_meta) as dst:
         dst.write(colored)

#### 4.2 Calculate and crop EBI

#### Filenames have next names: TILE_ID_ACQUIREDDATE

In [None]:
if not filtered['TCI']:
    geojson_path = os.path.join(OUTPUT_NODATA_FOLDER, "aoi.geojson")
    dump_no_data_geojson(aoi.geometry[0], geojson_path)
    raise ValueError("Images not loaded for given AOI. Change dates, constraints")

with tempfile.TemporaryDirectory() as tmp_dir:
    images_tci = []
    images_ebi = []
    polygon = wkt.loads(AOI)
    for tci_path in filtered['TCI']:
        try:
            print("Start calculation EBI")
            acquired_date = filtered['date']
            
            with rasterio.open(tci_path) as src:
                tile_crs = str(src.crs)
            transformed_poly = to_crs(polygon, tile_crs)
            temp_cropped_tci = os.path.join(tmp_dir, f"{uuid.uuid4()}.tif")
            temp_cropped_ebi = os.path.join(tmp_dir, f"{uuid.uuid4()}.tif")

            crop(tci_path, temp_cropped_tci, transformed_poly, acquired_date, name="Sentinel-2 EBI", colormap=colormap_tag)
            calculate_EBI(temp_cropped_tci, temp_cropped_ebi)
            print("End calculation EBI")
        
            images_tci.append(temp_cropped_tci)
            images_ebi.append(temp_cropped_ebi)
        except Exception as e:
            print(f"Cannot calculate EBI: {str(e)}")

    if len(images_ebi) > 1:
        stitch_tiles(images_ebi, OUTPUT_EBI_FILE, acquired_date, name="Sentinel-2 EBI", colormap=colormap_tag)
        stitch_tiles(images_tci, OUTPUT_TCI_FILE, acquired_date, name="Sentinel-2 RGB Raster")
    else:
        shutil.copy(images_ebi[0], OUTPUT_EBI_FILE)
        shutil.copy(images_tci[0], OUTPUT_TCI_FILE)