### For given aoi, prepare EBI sentinel latest images 

In [2]:
import os
import geopandas as gp
import numpy as np
import rasterio
import re
import tempfile
import pyproj
import uuid
import json
import geojson
import shutil

from geojson import Feature

import rasterio.mask
from rasterio import Affine
from rasterio.plot import reshape_as_raster
from rasterio.merge import merge
from rasterio.warp import calculate_default_transform, reproject, Resampling
from rasterio.profiles import DefaultGTiffProfile

from shapely import wkt
from shapely.geometry import Polygon, box
from shapely.ops import transform


from pathlib import Path
from datetime import datetime, timedelta
from sentinel2download.downloader import Sentinel2Downloader
from sentinel2download.overlap import Sentinel2Overlap

#### 0. Inputs

In [None]:


# Inputs
AOI = os.getenv('AOI')
START_DATE = os.getenv("START_DATE")
END_DATE = os.getenv("END_DATE")
SENTINEL2_GOOGLE_API_KEY = os.getenv('SENTINEL2_GOOGLE_API_KEY')
SATELLITE_CACHE_FOLDER = os.getenv('SENTINEL2_CACHE')

# Output folder
OUTPUT_FOLDER = os.getenv('OUTPUT_FOLDER')


OUTPUT_NODATA_FOLDER = os.path.join(OUTPUT_FOLDER, "/nodata")
OUTPUT_EBI_FILE = os.path.join(OUTPUT_FOLDER, "EBI.tif")
OUTPUT_TCI_FILE = os.path.join(OUTPUT_FOLDER, "TCI.tif")

os.makedirs(OUTPUT_NODATA_FOLDER, exist_ok=True)
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

#### 1. Transform AOI and get bound_box

In [3]:
aoi = gp.GeoDataFrame(geometry=[wkt.loads(AOI)], crs="epsg:4326")    
aoi_filename = "provided_aoi.geojson"
aoi.to_file(aoi_filename, driver="GeoJSON") 

In [5]:
COLORMAP_BRBG = os.path.join('/code', 'colormap.npy') 

#### 2. Overlap AOI with sentinel2grid 

In [6]:
def epsg_code(longitude, latitude):
        """
        Generates EPSG code from lon, lat
        :param longitude: float
        :param latitude: float
        :return: int, EPSG code
        """

        def _zone_number(lat, lon):
            if 56 <= lat < 64 and 3 <= lon < 12:
                return 32
            if 72 <= lat <= 84 and lon >= 0:
                if lon < 9:
                    return 31
                elif lon < 21:
                    return 33
                elif lon < 33:
                    return 35
                elif lon < 42:
                    return 37

            return int((lon + 180) / 6) + 1

        zone = _zone_number(latitude, longitude)

        if latitude > 0:
            return 32600 + zone
        else:
            return 32700 + zone

In [9]:
s2overlap = Sentinel2Overlap(aoi_path=aoi_filename)
overlap_tiles = s2overlap.overlap_with_geometry()


Unnamed: 0,tile,geometry
0,10SGF,"POLYGON ((-120.21176 36.36394, -120.20301 36.3..."


In [10]:
overlap_tiles.crs

<Geographic 2D CRS: EPSG:4326>
Name: WGS 84
Axis Info [ellipsoidal]:
- Lat[north]: Geodetic latitude (degree)
- Lon[east]: Geodetic longitude (degree)
Area of Use:
- name: World.
- bounds: (-180.0, -90.0, 180.0, 90.0)
Datum: World Geodetic System 1984 ensemble
- Ellipsoid: WGS 84
- Prime Meridian: Greenwich

#### 3. Load images

In [11]:
# API_KEY = os.path.join(BASE, ".secret/sentinel2_google_api_key.json")
LOAD_DIR = SATELLITE_CACHE_FOLDER

PRODUCT_TYPE = 'L2A'
BANDS = {"TCI"}
CONSTRAINTS = {'NODATA_PIXEL_PERCENTAGE': 10.0, 'CLOUDY_PIXEL_PERCENTAGE': 5.0, }

LAYERS = ['EBI', 'TCI']

In [12]:
START_DATE

'2020-01-20'

In [13]:
END_DATE

'2020-02-05'

In [14]:
def shift_date(date, delta=5, format='%Y-%m-%d'):
    date = datetime.strptime(date, format)
    date = date - timedelta(days=delta)    
    return datetime.strftime(date, format)

#### 3.1 Define max shift in dates - 30 days for loading images

In [15]:
MAX_SHIFT = 30

In [16]:
MAX_SHIFT_ITERS = 2

In [17]:
def load_images(tiles, start_date, end_date):
    loader = Sentinel2Downloader(SENTINEL2_GOOGLE_API_KEY)
    
    loadings = dict()
        
    for tile in tiles:
        start = start_date
        end = end_date
        
        print(f"Loading images for tile: {tile}...")
        count = 0
        while count < MAX_SHIFT_ITERS:
            loaded = loader.download(PRODUCT_TYPE,
                                [tile],
                                start_date=start,
                                end_date=end,
                                output_dir=LOAD_DIR,                       
                                bands=BANDS,
                                constraints=CONSTRAINTS)
        
            if not loaded:
                end = start_date
                start = shift_date(start_date, delta=MAX_SHIFT) 
                print(f"For tile: {tile} and dates {start_date} {end_date} proper images not found! Shift dates to {start} {end}!")
            else:
                break
            count += 1
        if loaded:
            loadings[tile] = loaded
            print(f"Loading images for tile {tile} finished")
        else:
            print(f"Images for tile {tile} were not loaded!")
        
    # tile_folders = dict()
    # for tile, tile_paths in loadings.items():
    #    tile_folders[tile] = {str(Path(tile_path[0]).parent) for tile_path in tile_paths}
    return loadings

In [18]:
loadings = load_images(overlap_tiles.Name.values, START_DATE, END_DATE)

Loading images for tile: 10SGF...
Loading images for tile 10SGF finished


#### 3.2 Filter loadings for every tile, get last image in daterange and bands

In [19]:
def filter_by_date(loadings):
    def _find_last_date(folders):        
        dates = list()
        for folder in folders:        
            search = re.search(r"_(\d+)T\d+_", str(folder))
            date = search.group(1)
            date = datetime.strptime(date, '%Y%m%d')
            dates.append(date)    
        last_date = max(dates)
        last_date = datetime.strftime(last_date, '%Y%m%d')
        return last_date
    
    filtered = dict()
    for tile, items in loadings.items():
        try:
            last_date = _find_last_date(items)
            bands_paths = dict()
            for path, _ in items:
                if last_date in path:
                    if 'B04_10m.jp2' in path:
                        bands_paths['RED'] = path
                    if 'B03_10m.jp2' in path:
                        bands_paths['GREEN'] = path
                    if 'B02_10m.jp2' in path:
                        bands_paths['BLUE'] = path
                    if 'TCI_10m.jp2' in path:
                        bands_paths['TCI'] = path
            filtered[tile] = dict(paths=bands_paths, date=last_date)
        except Exception as ex:
            print(f"Error for {tile}: {str(ex)}")
    return filtered

In [20]:
filtered = filter_by_date(loadings)

#### 4. Calculate EBI

In [21]:
TEMP_DIR = "/temp"
os.makedirs(TEMP_DIR, exist_ok=True)

#### 4.1 Prepare color coding for EBI

In [22]:
def prepare_colors(colors):
    colors = np.load(COLORMAP_BRBG)
    if colors.shape[1] == 4:
        # delete last channel, we use rgb
        colors = np.delete(colors, 3, axis=1)
    # colormap colors values in range [0-255], but in our case 0 - no data, -> have to color as [0, 0, 0] 
    colors[colors == 0] = 1
    colors[0] = [0, 0, 0]
    return colors

In [23]:
COLORS = prepare_colors(COLORMAP_BRBG)

In [24]:
COLORS.shape

(256, 3)

In [25]:
colormap_tag = {"name": "Enhanced Blooming index", "colors": [], "labels": ["low", "high"]}

for color in COLORS:
    color_str = ",".join(list(map(lambda x: str(int(x)), color)))
    colormap_tag['colors'].append(color_str)

colormap_tag = json.dumps(colormap_tag)
# example of colormap_tag format
# {"name": "Vegetation index", "colors": ["0,0,0", "255,0,0", "0,255,0", "0,0,255" ...], "labels": ["low", "high"]}


In [26]:
def scale(val, a=1, b=255, nodata=0.0):
    min_val = 0
    max_val = 1
    scaled = (b - a) * (val - min_val) / (max_val - min_val) + a
    scaled = np.around(scaled)
    scaled[np.isnan(scaled) == True] = nodata
    scaled = scaled.astype(np.uint8)
    print(scaled.max())
    return scaled

In [27]:
def color_ndvi(scaled, colors):
    colored = np.reshape(colors[scaled.flatten()], tuple((*scaled.shape, 3)))
    colored = reshape_as_raster(colored)
    return colored

In [28]:
def to_crs(poly, target, current='EPSG:4326'):
    # print(f"TARGET CRS: {target}")
    project = pyproj.Transformer.from_crs(pyproj.CRS(current), pyproj.CRS(target), always_xy=True).transform
    transformed_poly = transform(project, poly)
    return transformed_poly 

In [29]:
def crop(input_path, output_path, polygon, date, name=None, colormap=None):
    with rasterio.open(input_path) as src:
        out_image, out_transform = rasterio.mask.mask(src, [polygon], crop=True)
        out_meta = src.meta
        
        out_meta.update(driver='GTiff',
                        height=out_image.shape[1],
                        width=out_image.shape[2],
                        transform=out_transform,
                        nodata=0, )

    with rasterio.open(output_path, "w", **out_meta) as dest:
        dest.update_tags(start_date=date, end_date=date)
        if name:
            dest.update_tags(name=name)
        if colormap:
            dest.update_tags(colormap=colormap)
        dest.write(out_image)

In [30]:
def transform_crs(data_path, save_path, dst_crs="EPSG:4326", resolution=(10, 10)):
    with rasterio.open(data_path) as src:
        if resolution is None:
            transform, width, height = calculate_default_transform(
                src.crs, dst_crs, src.width, src.height, *src.bounds
            )
        else:
            transform, width, height = calculate_default_transform(
                src.crs,
                dst_crs,
                src.width,
                src.height,
                *src.bounds,
                resolution=resolution,
            )
        kwargs = src.meta.copy()
        kwargs.update(
            {"crs": dst_crs, "transform": transform, "width": width, "height": height}
        )
        with rasterio.open(save_path, "w", **kwargs) as dst:
            for i in range(1, src.count + 1):
                reproject(
                    source=rasterio.band(src, i),
                    destination=rasterio.band(dst, i),
                    src_transform=src.transform,
                    src_crs=src.crs,
                    dst_transform=transform,
                    dst_crs=dst_crs,
                    resampling=Resampling.nearest,
                )

    return save_path

In [31]:
def stitch_tiles(paths, out_raster_path, date, name=None, colormap=None):
    if not isinstance(paths[0], str):
        paths = [str(x) for x in paths]
    tiles = []
    tmp_files = []
    
    crs = None
    meta = None
    for i, path in enumerate(paths):
        if i == 0:
            file = rasterio.open(path)
            meta, crs = file.meta, file.crs
        else:
            tmp_path = path.replace(
                '.jp2', '_tmp.jp2').replace('.tif', '_tmp.tif')
            crs_transformed = transform_crs(path, tmp_path, 
                                            dst_crs=crs, 
                                            resolution=None)
            tmp_files.append(crs_transformed)
            file = rasterio.open(crs_transformed)
        tiles.append(file)
            
    tile_arr, transform = merge(tiles, method='last')
    
    meta.update({"driver": "GTiff",
                 "height": tile_arr.shape[1],
                 "width": tile_arr.shape[2],
                 "transform": transform,
                 "crs": crs})
    
    if '.jp2' in out_raster_path:
        out_raster_path = out_raster_path.replace('.jp2', '.tif')
    print(f'saved raster {out_raster_path}')

    for tile in tiles:
        tile.close()
        
    for tmp_file in tmp_files:
        try:
            os.remove(tmp_file)
        except FileNotFoundError:
            print(f'Tile {tmp_file} was removed or renamed, skipping')
        
    with rasterio.open(out_raster_path, "w", **meta) as dst:
        dst.update_tags(start_date=date, end_date=date)
        if name:
            dst.update_tags(name=name)
        if colormap:
            dst.update_tags(colormap=colormap)
        dst.write(tile_arr)
    
    return out_raster_path

In [32]:
def dump_no_data_geosjon(polygon, geojson_path):
    NO_DATA = 'No data'
    style = dict(color='red')
    feature = Feature(geometry=polygon, properties=dict(label=NO_DATA, style=style))
    feature['start_date'] = START_DATE
    feature['end_date'] = END_DATE
    feature['name'] = NO_DATA
    
    with open(geojson_path, 'w') as f:
        geojson.dump(feature, f)

In [33]:
def calculate_EBI(tci_path, save_path, eps=256):
    np.seterr(divide='ignore', invalid='ignore')
    with rasterio.open(tci_path) as src:
        red = src.read(1).astype(rasterio.float32)
        green = src.read(2).astype(rasterio.float32)
        blue = src.read(3).astype(rasterio.float32)
        crs = str(src.crs)
    EBI = (red + green + blue) / ((green / blue) * (red - blue + eps))
    EBI = scale(EBI)
    print(EBI.max())
    print(EBI.min())
    colored = color_ndvi(EBI, COLORS)
    out_meta = src.meta.copy()
    out_meta.update(dtype=rasterio.uint8,
                    driver='GTiff',
                    nodata=0,
                    count=3, )
    
    # Create the file
    with rasterio.open(save_path, 'w', **out_meta) as dst:
         dst.write(colored)

In [None]:
def clean_images(images):
    for i in range(len(images)):
        os.remove(images[i])

#### 4.2 Calculate and crop EBI

#### Filenames have next names: TILE_ID_ACQUIREDDATE

In [35]:

if not filtered:
    geojson_path = os.path.join(OUTPUT_NODATA_FOLDER, "aoi.geojson")
    dump_no_data_geosjon(aoi.geometry[0], geojson_path)    
    raise ValueError("Images not loaded for given AOI. Change dates, constraints")

images_tci = []
images = []
for row in overlap_tiles.itertuples():
    tile = row.Name
    polygon = row.geometry
    if not tile in filtered:
        tile_geojson_path = os.path.join(OUTPUT_NODATA_FOLDER, "%s_aoi.geojson" % tile)
        print("No data loaded for tile", tile)
        dump_no_data_geosjon(polygon, tile_geojson_path)    
    
    try:
        paths = filtered[tile]['paths']
        print(f"{tile}: Start calculation EBI")
        
        acquired_date = filtered[tile]['date']
        base_filename = f"{tile}_{acquired_date}_"
        temp_filename = os.path.join(TEMP_DIR, base_filename + "EBI.tif.temp")
        temp_filename_tci = os.path.join(TEMP_DIR, base_filename + "TCI.tif.temp")
        
        with rasterio.open(paths['TCI']) as src:
            tile_crs = str(src.crs)
        transformed_poly = to_crs(polygon, tile_crs)

        crop(paths['TCI'], temp_filename_tci, transformed_poly, acquired_date, name="Sentinel-2 EBI", colormap=colormap_tag)
        
        calculate_EBI(temp_filename_tci, temp_filename)
        print(f"{tile}: End calculation EBI")
    
        filename = temp_filename[:-5]
        filename_tci = temp_filename_tci[:-5]
        print(f"{tile}: Rename {temp_filename}->{filename}\n")
        os.rename(temp_filename, filename)
        os.rename(temp_filename_tci, filename_tci)
        images_tci.append(filename_tci)
        images.append(filename)
    except Exception as e:
        print(f"{tile}: Cannot calculate EBI: {str(e)}")

if len(images) > 1:
    full_ndvi = stitch_tiles(images, OUTPUT_EBI_FILE, acquired_date, name="Sentinel-2 EBI", colormap=colormap_tag)
    full_tci = stitch_tiles(filename_tci, OUTPUT_TCI_FILE, acquired_date, name="Sentinel-2 RGB Raster")

else:
    shutil.copy(images[0], OUTPUT_EBI_FILE)
    shutil.copy(images_tci[0], OUTPUT_TCI_FILE)
    
clean_images(images)
clean_images(images_tci)

10SGF: Start calculation EBI
nan nan
False
[[0.06540084 0.38192929 0.37794473 ... 0.02233773 0.2384486  0.01353603]
 [0.15214531 2.437046   0.19571892 ... 0.33867028 0.38764995 0.36047431]
 [0.19945355 2.66359447 2.93524044 ... 0.09519022 0.00705271 0.47389163]
 ...
 [0.58958775 0.63071895 0.91467181 ... 0.11168488 0.47247247 0.13086932]
 [0.152872   0.62857143 0.2742941  ... 0.09501558 0.15121951 0.25673798]
 [0.38472419 0.84745604 0.21332728 ... 0.35053027 0.26453129 1.74070284]]
255
10SGF: End calculation EBI
10SGF: Rename /home/ritalatuha/quantum/sip_blooming_index/./results/45/45_10SGF_20200201_EBI.tif.temp->/home/ritalatuha/quantum/sip_blooming_index/./results/45/45_10SGF_20200201_EBI.tif

