### For given aoi, prepare TCI and NDVI sentinel latest images 

In [None]:
import os
import geopandas as gp
import numpy as np
import rasterio
import re
import tempfile
import pyproj
import uuid
import json
import geojson

from geojson import Feature

import rasterio.mask
from rasterio import Affine
from rasterio.plot import reshape_as_raster
from rasterio.merge import merge
from rasterio.warp import calculate_default_transform, reproject, Resampling

from shapely import wkt
from shapely.geometry import Polygon, box
from shapely.ops import transform


from pathlib import Path
from datetime import datetime, timedelta
from sentinel2download.downloader import Sentinel2Downloader
from sentinel2download.overlap import Sentinel2Overlap

### 0. Setting up parameters
#### Read Input from environment and setup output folders and filenames

In [None]:
# Inputs
AOI = os.getenv('AOI')
START_DATE = os.getenv("START_DATE")
END_DATE = os.getenv("END_DATE")
SENTINEL2_GOOGLE_API_KEY = os.getenv('SENTINEL2_GOOGLE_API_KEY')
SATELLITE_CACHE_FOLDER = os.getenv('SENTINEL2_CACHE')

# Output folder
OUTPUT_FOLDER = os.getenv('OUTPUT_FOLDER')


OUTPUT_NODATA_FOLDER = os.path.join(OUTPUT_FOLDER, "/nodata")
OUTPUT_TCI_FILE = os.path.join(OUTPUT_FOLDER, "TCI.tif")
OUTPUT_NDVI_FILE = os.path.join(OUTPUT_FOLDER, "NDVI.tif")

os.makedirs(OUTPUT_NODATA_FOLDER, exist_ok=True)

#### 1. Transform AOI and get bound_box

In [None]:
aoi = gp.GeoDataFrame(geometry=[wkt.loads(AOI)], crs="epsg:4326")    

In [None]:
sentinel_grid = gp.read_file(os.path.join('/code', "sentinel2grid.geojson"))

#### 2. Overlap AOI with sentinel2grid 

In [None]:
def epsg_code(longitude, latitude):
        """
        Generates EPSG code from lon, lat
        :param longitude: float
        :param latitude: float
        :return: int, EPSG code
        """

        def _zone_number(lat, lon):
            if 56 <= lat < 64 and 3 <= lon < 12:
                return 32
            if 72 <= lat <= 84 and lon >= 0:
                if lon < 9:
                    return 31
                elif lon < 21:
                    return 33
                elif lon < 33:
                    return 35
                elif lon < 42:
                    return 37

            return int((lon + 180) / 6) + 1

        zone = _zone_number(latitude, longitude)

        if latitude > 0:
            return 32600 + zone
        else:
            return 32700 + zone

In [None]:
def _intersect(aoi, grid):
    # Get the indices of the tiles that are likely to be inside the bounding box of the given Polygon
    geometry = aoi.geometry[0]

    tiles_indexes = list(grid.sindex.intersection(geometry.bounds))
    grid = grid.loc[tiles_indexes]

    # Make the precise tiles in Polygon query
    grid = grid.loc[grid.intersects(geometry)]

    # intersection area
    epsg = epsg_code(geometry.centroid.x, geometry.centroid.y)

    # to UTM projection in meters
    aoi.to_crs(epsg=epsg, inplace=True)
    grid.to_crs(epsg=epsg, inplace=True)

    return grid, epsg

In [None]:
def get_intersected_tiles(aoi, grid):
    
    grid, epsg = _intersect(aoi, sentinel_grid)
    
    
    grid.set_index("Name", drop=False, inplace=True)    
    intersected_grid = {"tile": [], "geometry": []}

    rest_aoi = aoi.copy()
    while rest_aoi.area.sum() > 0:
        intersection = gp.overlay(rest_aoi, grid, how="intersection")
        argmax = intersection.area.argmax()

        tile = intersection.loc[argmax, "Name"]
        intersected_geometry = intersection.loc[argmax, "geometry"]
        
        intersected_grid["tile"].append(tile)
        intersected_grid["geometry"].append(intersected_geometry)
        
        biggest_intersection = grid.loc[[tile]]
        rest_aoi = gp.overlay(rest_aoi, biggest_intersection, how="difference")
        grid = grid.loc[intersection["Name"]]
    
    overlap_tiles = gp.GeoDataFrame(intersected_grid, crs=epsg)
    overlap_tiles.to_crs(epsg=4326, inplace=True)

    return overlap_tiles

In [None]:
overlap_tiles = get_intersected_tiles(aoi.copy(), sentinel_grid.copy())
overlap_tiles

In [None]:
overlap_tiles.crs

#### 3. Load images

In [None]:
LOAD_DIR = SATELLITE_CACHE_FOLDER

PRODUCT_TYPE = 'L2A'
BANDS = {'TCI', 'B04', 'B08', }
CONSTRAINTS = {'NODATA_PIXEL_PERCENTAGE': 10.0, 'CLOUDY_PIXEL_PERCENTAGE': 5.0, }

LAYERS = ['TCI', 'NDVI', ]

In [None]:
START_DATE

In [None]:
END_DATE

In [None]:
def shift_date(date, delta=5, format='%Y-%m-%d'):
    date = datetime.strptime(date, format)
    date = date - timedelta(days=delta)    
    return datetime.strftime(date, format)

#### 3.1 Define max shift in dates - 30 days for loading images

In [None]:
MAX_SHIFT = 30

In [None]:
MAX_SHIFT_ITERS = 2

In [None]:
def load_images(tiles, start_date, end_date):
    loader = Sentinel2Downloader(SENTINEL2_GOOGLE_API_KEY)
    loadings = dict()
        
    for tile in tiles:
        start = start_date
        end = end_date
        
        print(f"Loading images for tile: {tile}...")
        count = 0
        while count < MAX_SHIFT_ITERS:
            loaded = loader.download(PRODUCT_TYPE,
                                [tile],
                                start_date=start,
                                end_date=end,
                                output_dir=LOAD_DIR,                       
                                bands=BANDS,
                                constraints=CONSTRAINTS)
        
            if not loaded:
                end = start_date
                start = shift_date(start_date, delta=MAX_SHIFT) 
                print(f"For tile: {tile} and dates {start_date} {end_date} proper images not found! Shift dates to {start} {end}!")
            else:
                break
            count += 1
        if loaded:
            loadings[tile] = loaded
            print(f"Loading images for tile {tile} finished")
        else:
            print(f"Images for tile {tile} were not loaded!")
        
    # tile_folders = dict()
    # for tile, tile_paths in loadings.items():
    #    tile_folders[tile] = {str(Path(tile_path[0]).parent) for tile_path in tile_paths}
    return loadings

In [None]:
loadings = load_images(overlap_tiles.tile.values, START_DATE, END_DATE)

#### 3.2 Filter loadings for every tile, get last image in daterange and bands

In [None]:
def filter_by_date(loadings):
    def _find_last_date(folders):        
        dates = list()
        for folder in folders:        
            search = re.search(r"_(\d+)T\d+_", str(folder))
            date = search.group(1)
            date = datetime.strptime(date, '%Y%m%d')
            dates.append(date)    
        last_date = max(dates)
        last_date = datetime.strftime(last_date, '%Y%m%d')
        return last_date
    
    filtered = dict()
    for tile, items in loadings.items():
        try:
            last_date = _find_last_date(items)
            bands_paths = dict()
            for path, _ in items:
                if last_date in path:
                    if 'B04_10m.jp2' in path:
                        bands_paths['RED'] = path
                    if 'B08_10m.jp2' in path:
                        bands_paths['NIR'] = path
                    if 'TCI_10m.jp2' in path:
                        bands_paths['TCI'] = path
            filtered[tile] = dict(paths=bands_paths, date=last_date)
        except Exception as ex:
            print(f"Error for {tile}: {str(ex)}")
    return filtered

In [None]:
filtered = filter_by_date(loadings)

#### 4. Calculate NDVI

In [None]:
TEMP_DIR = "/temp"
os.makedirs(TEMP_DIR, exist_ok=True)

# NOTEBOOK_DIR = os.path.join(BASE, "notebooks/example/tci_ndvi")
COLORMAP_BRBG = os.path.join("/code", "ndvi_colormap.npy") 

#### 4.1 Prepare color coding for NDVI

In [None]:
def prepare_colors(colors):
    colors = np.load(COLORMAP_BRBG)
    if colors.shape[1] == 4:
        # delete last channel, we use rgb
        colors = np.delete(colors, 3, axis=1)
    # colormap colors values in range [0-255], but in our case 0 - no data, -> have to color as [0, 0, 0] 
    colors[colors == 0] = 1
    colors[0] = [0, 0, 0]
    return colors

In [None]:
COLORS = prepare_colors(COLORMAP_BRBG)

In [None]:
COLORS.shape

In [None]:
colormap_tag = {"name": "Vegetation index", "colors": [], "labels": ["low", "high"]}

for color in COLORS:
    color_str = ",".join(list(map(lambda x: str(int(x)), color)))
    colormap_tag['colors'].append(color_str)

colormap_tag = json.dumps(colormap_tag)
# example of colormap_tag format
# {"name": "Vegetation index", "colors": ["0,0,0", "255,0,0", "0,255,0", "0,0,255" ...], "labels": ["low", "high"]}
colormap_tag

In [None]:
def scale(ndvi, a=1, b=255, nodata=0.0):
    # ndvi is in range [-1; 1], nodata is setted to 0.0 value. Be careful with comprassions!
    min = -1 # np.nanmin(ndvi)
    max = 1 # np.nanmax(ndvi)
    scaled = (b - a) * (ndvi - min) / (max - min) + a
    scaled = np.around(scaled)
    scaled[np.isnan(scaled) == True] = nodata
    scaled = scaled.astype(np.uint8)
    return scaled

In [None]:
def color_ndvi(scaled, colors):
    colored = np.reshape(colors[scaled.flatten()], tuple((*scaled.shape, 3)))
    colored = reshape_as_raster(colored)
    return colored

In [None]:
def NDVI(nir_path, red_path, save_path):
    # Asllow division by zero
    np.seterr(divide='ignore', invalid='ignore')
    
    with rasterio.open(nir_path) as src:
        nir = src.read(1).astype(rasterio.float32)
        crs = str(src.crs)
    with rasterio.open(red_path) as src:
        red = src.read(1).astype(rasterio.float32)

    # Calculate NDVI
    ndvi = ((nir - red) / (nir + red)) 
    
    scaled = scale(ndvi)
    colored = color_ndvi(scaled, COLORS) 
    
    
    # Set spatial characteristics of the output object
    out_meta = src.meta.copy()    
    out_meta.update(dtype=rasterio.uint8,
                    driver='GTiff',
                    nodata=0,
                    count=3, )

    # Create the file
    with rasterio.open(save_path, 'w', **out_meta) as dst:
         dst.write(colored)
    return crs

In [None]:
def to_crs(poly, target, current='EPSG:4326'):
    # print(f"TARGET CRS: {target}")
    project = pyproj.Transformer.from_crs(pyproj.CRS(current), pyproj.CRS(target), always_xy=True).transform
    transformed_poly = transform(project, poly)
    return transformed_poly 

In [None]:
def crop(input_path, output_path, polygon, date, name=None, colormap=None):
    with rasterio.open(input_path) as src:
        out_image, out_transform = rasterio.mask.mask(src, [polygon], crop=True)
        # print(out_transform)
        out_meta = src.meta
        
        out_meta.update(driver='GTiff',
                        height=out_image.shape[1],
                        width=out_image.shape[2],
                        transform=out_transform,
                        nodata=0, )

    with rasterio.open(output_path, "w", **out_meta) as dest:
        dest.update_tags(start_date=date, end_date=date)
        if name:
            dest.update_tags(name=name)
        if colormap:
            dest.update_tags(colormap=colormap)
        dest.write(out_image)

In [None]:
def transform_crs(data_path, save_path, dst_crs="EPSG:4326", resolution=(10, 10)):
    with rasterio.open(data_path) as src:
        if resolution is None:
            transform, width, height = calculate_default_transform(
                src.crs, dst_crs, src.width, src.height, *src.bounds
            )
        else:
            transform, width, height = calculate_default_transform(
                src.crs,
                dst_crs,
                src.width,
                src.height,
                *src.bounds,
                resolution=resolution,
            )
        kwargs = src.meta.copy()
        kwargs.update(
            {"crs": dst_crs, "transform": transform, "width": width, "height": height}
        )
        with rasterio.open(save_path, "w", **kwargs) as dst:
            for i in range(1, src.count + 1):
                reproject(
                    source=rasterio.band(src, i),
                    destination=rasterio.band(dst, i),
                    src_transform=src.transform,
                    src_crs=src.crs,
                    dst_transform=transform,
                    dst_crs=dst_crs,
                    resampling=Resampling.nearest,
                )

    return save_path

In [None]:
def stitch_tiles(paths, out_raster_path, date, name=None, colormap=None):
    if not isinstance(paths[0], str):
        paths = [str(x) for x in paths]
    tiles = []
    tmp_files = []
    
    crs = None
    meta = None
    for i, path in enumerate(paths):
        if i == 0:
            file = rasterio.open(path)
            meta, crs = file.meta, file.crs
        else:
            tmp_path = path.replace(
                '.jp2', '_tmp.jp2').replace('.tif', '_tmp.tif')
            crs_transformed = transform_crs(path, tmp_path, 
                                            dst_crs=crs, 
                                            resolution=None)
            tmp_files.append(crs_transformed)
            file = rasterio.open(crs_transformed)
        tiles.append(file)
            
    tile_arr, transform = merge(tiles, method='last')
    
    meta.update({"driver": "GTiff",
                 "height": tile_arr.shape[1],
                 "width": tile_arr.shape[2],
                 "transform": transform,
                 "crs": crs})
    
    if '.jp2' in out_raster_path:
        out_raster_path = out_raster_path.replace('.jp2', '.tif')
    print(f'saved raster {out_raster_path}')

    for tile in tiles:
        tile.close()
        
    for tmp_file in tmp_files:
        try:
            os.remove(tmp_file)
        except FileNotFoundError:
            print(f'Tile {tmp_file} was removed or renamed, skipping')
        
    with rasterio.open(out_raster_path, "w", **meta) as dst:
        dst.update_tags(start_date=date, end_date=date)
        if name:
            dst.update_tags(name=name)
        if colormap:
            dst.update_tags(colormap=colormap)
        dst.write(tile_arr)
    
    return out_raster_path

In [None]:
def dump_no_data_geosjon(polygon, geojson_path):
    NO_DATA = 'No data'
    TCI_NDVI_NO_DATA = 'TCI_NDVI\nNo data available'
    style = dict(color='red')
    feature = Feature(geometry=polygon, properties=dict(label=NO_DATA, style=style))
    feature['start_date'] = START_DATE
    feature['end_date'] = END_DATE
    feature['name'] = TCI_NDVI_NO_DATA
    
    with open(geojson_path, 'w') as f:
        geojson.dump(feature, f)

#### 4.2 Calculate and crop NDVI, TCI

#### Filenames have next names: REQUESTID_TILE_ID_ACQUIREDDATE

In [None]:
if not filtered:
    geojson_path = os.path.join(OUTPUT_NODATA_FOLDER, "aoi.geojson")
    dump_no_data_geosjon(aoi.geometry[0], geojson_path)    
    raise ValueError("Images not loaded for given AOI. Change dates, constraints")


tci_images = []
ndvi_images = []
for row in overlap_tiles.itertuples():
    tile = row.tile
    polygon = row.geometry
    if not tile in filtered:
        tile_geojson_path = os.path.join(OUTPUT_NODATA_FOLDER, "%s_aoi.geojson" % tile)
        print("No data loaded for tile", tile)
        dump_no_data_geosjon(polygon, tile_geojson_path)    
    
    try:
        paths = filtered[tile]['paths']
        print(f"{tile}: Start calculation TCI, NDVI")
        
        acquired_date = filtered[tile]['date']
        base_filename = f"{tile}_{acquired_date}_"
        temp_ndvi_filename = os.path.join(TEMP_DIR, base_filename + "NDVI.tif.temp")
        temp_tci_filename = os.path.join(TEMP_DIR, base_filename + "TCI.tif.temp")
        
        tile_crs = NDVI(paths['NIR'], paths['RED'], temp_ndvi_filename)
        transformed_poly = to_crs(polygon, tile_crs)
        
        # Crop and save NDVI
        crop(temp_ndvi_filename, temp_ndvi_filename, transformed_poly, acquired_date, name="Sentinel-2 Vegetation Index (NDVI)", colormap=colormap_tag)
        # Crop and save TCI
        crop(paths['TCI'], temp_tci_filename, transformed_poly, acquired_date, name="Sentinel-2 RGB raster")
        
        print(f"{tile}: End calculation TCI, NDVI")
    
        ndvi_filename = temp_ndvi_filename[:-5]
        tci_filename = temp_tci_filename[:-5]
        print(f"{tile}: Rename {temp_ndvi_filename}->{ndvi_filename}\n {temp_tci_filename}->{tci_filename}")
        os.rename(temp_ndvi_filename, ndvi_filename)
        os.rename(temp_tci_filename, tci_filename)
        tci_images.append(tci_filename)
        ndvi_images.append(ndvi_filename)
    except Exception as e:
        print(f"{tile}: Cannot calculate TCI, NDVI: {str(e)}")

if tci_images:
    tci_full = stitch_tiles(tci_images, OUTPUT_TCI_FILE, acquired_date,  name="Sentinel-2 RGB raster")
    ndvi_full = stitch_tiles(ndvi_images, OUTPUT_NDVI_FILE, acquired_date, name="Sentinel-2 Vegetation Index (NDVI)", colormap=colormap_tag)
    for i in range(len(tci_images)):
        os.remove(tci_images[i])
        os.remove(ndvi_images[i])