In [None]:
import os
import json
import time
import numpy as np
import geopandas as gpd
import rasterio.mask
import shapely.wkt
import shutil
import re
import rasterio.mask as riomask
from datetime import datetime
from pathlib import Path

from sentinel2download.overlap import Sentinel2Overlap
from sentinel2download.downloader import Sentinel2Downloader

from code.index_research import calculate_ndmi, calculate_ndvi
from code.utils import stitch_tiles, dump_no_data_geojson

import warnings
warnings.filterwarnings('ignore')
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [None]:
REQUEST_ID=os.getenv('REQUEST_ID')
START_DATE=os.getenv('START_DATE')
END_DATE=os.getenv('END_DATE')
AOI=os.getenv('AOI')
SENTINEL2_GOOGLE_API_KEY=os.getenv('SENTINEL2_GOOGLE_API_KEY')
SATELLITE_CACHE_FOLDER=os.getenv('SENTINEL2_CACHE')
OUTPUT_FOLDER=os.getenv('OUTPUT_FOLDER')

#### 1. Transform AOI got GeoJSON file

In [None]:
aoi_polygon = shapely.wkt.loads(AOI)
aoi = gpd.GeoDataFrame(geometry=[aoi_polygon], crs="epsg:4326")

aoi_filename = f"{time.time()}_aoi.geojson"
aoi.to_file(aoi_filename, driver="GeoJSON")


#### 2. Overlap AOI with sentinel2grid 

In [None]:
s2overlap = Sentinel2Overlap(aoi_path=aoi_filename)
overlap_tiles = s2overlap.overlap_with_geometry()

#### 3. Load images

In [None]:
BASE = os.getcwd()
NDMI_PATH = os.path.join(BASE, f'data/rasters/{REQUEST_ID}_ndmi.tif')
NDVI_PATH = os.path.join(BASE, f'data/rasters/{REQUEST_ID}_ndvi.tif')
BANDS = {'B04', 'B08', 'B12', 'TCI', 'CLD'}

NODATA_PIXEL_PERCENTAGE = 10.0
SEARCH_CLOUDY_PIXEL_PERCENTAGE = 80.0
AOI_CLOUDY_PIXEL_PERCENTAGE = 15.0
CONSTRAINTS = {'CLOUDY_PIXEL_PERCENTAGE': SEARCH_CLOUDY_PIXEL_PERCENTAGE}
PRODUCT_TYPE = 'L2A'
NAME = 'Moisture content'

os.makedirs(OUTPUT_FOLDER, exist_ok=True)

#### 3.1 Check whether downloaded tiles match constraints

In [None]:
def check_nodata_percentage_crop(tile_path, 
                                 aoi, 
                                 nodata_percentage_limit, 
                                 nodata):
    with rasterio.open(tile_path) as src:
        polygon = aoi.to_crs(src.meta['crs']).geometry[0]
        band, _ = rasterio.mask.mask(src, [polygon], crop=True, filled=False, indexes=1)
        masked_band = band[~band.mask]
        nodata_count = np.count_nonzero(masked_band == nodata)
        nodata_percentage = round(nodata_count / masked_band.size * 100, 2)
    if nodata_percentage>=nodata_percentage_limit:
        return True
    else:
        return False

def check_cloud_percentage_crop(tile_path, 
                                aoi, 
                                cloud_percentage_limit,
                                cloud_probability=50):
    with rasterio.open(tile_path) as src:
        polygon = aoi.to_crs(src.meta['crs']).geometry[0]
        band, _ = rasterio.mask.mask(src, [polygon], crop=True, filled=False, indexes=1)
        masked_band = band[~band.mask]
        cloud_count = np.count_nonzero(masked_band >= cloud_probability)
        cloud_percentage = round(cloud_count / masked_band.size * 100, 2)
    if cloud_percentage>=cloud_percentage_limit:
        return True
    else:
        return False

def check_tile_validity(tile_folder, aoi, cloud_percentage_limit, nodata_percentage_limit):
    band_paths = [os.path.join(tile_folder, i) for i in os.listdir(tile_folder)]
    skip_tile = False
    for band_path in band_paths:
        if  '.jp2' != Path(band_path).suffix:
            continue
        if "MSK_CLDPRB_20m" in band_path:
            cloud_check = check_cloud_percentage_crop(band_path, aoi, cloud_percentage_limit)
            if cloud_check:
                skip_tile=True
                break
        else:
            nodata_check = check_nodata_percentage_crop(band_path, aoi, nodata_percentage_limit, 0)
            if nodata_check:
                skip_tile=True
                break
    return skip_tile, band_paths

def validate_tile_downloads(loaded, tile, loadings, aoi, cloud_percentage_limit, nodata_percentage_limit):
    print(f"Validating images for tile: {tile}...")
    if not loaded:
        print(f"Images for tile {tile} were not loaded!")
        return loadings
    loaded_tile_folders = set([Path(i[0]).parent for i in loaded])
    tile_bands = []
    for loaded_tile_folder in loaded_tile_folders:
        skip_tile, band_paths = check_tile_validity(loaded_tile_folder, aoi, cloud_percentage_limit, nodata_percentage_limit)
        if skip_tile:
            shutil.rmtree(loaded_tile_folder)
        else:
            tile_bands += band_paths
    if tile_bands:
        loadings[tile] = tile_bands
    else:
        print(f"Tile images didn't match nodata/cloud constraints, so they were removed") 
    print(f"Validating images for tile {tile} finished")  
    return loadings

In [None]:
def load_images(tiles, start_date, end_date, aoi):
    loader = Sentinel2Downloader(SENTINEL2_GOOGLE_API_KEY)
    loadings = dict()

    for tile in tiles:
        print(f"Loading images for tile: {tile}...")
        loaded = loader.download(PRODUCT_TYPE,
                            [tile],
                            start_date=start_date,
                            end_date=end_date,
                            output_dir=SATELLITE_CACHE_FOLDER,               
                            bands=BANDS,
                            constraints=CONSTRAINTS)
        print(f"Loading images for tile {tile} finished")
        loadings = validate_tile_downloads(loaded, tile, loadings, aoi, AOI_CLOUDY_PIXEL_PERCENTAGE, NODATA_PIXEL_PERCENTAGE)
    return loadings

loadings = load_images(overlap_tiles.Name.values, START_DATE, END_DATE, aoi)


In [None]:
loadings

In [None]:
def filter_by_date(loadings):
    def _find_last_date(folders):        
        dates = list()
        for folder in folders:        
            search = re.search(r"_(\d+)T\d+_", str(folder))
            date = search.group(1)
            date = datetime.strptime(date, '%Y%m%d')
            dates.append(date)    
        last_date = max(dates)
        last_date = datetime.strftime(last_date, '%Y%m%d')
        return last_date
    
    filtered = {
        'BO4': [],
        'BO8': [],
        'B12':[],
        'TCI': []
    }
    for tile, items in loadings.items():
        try:
            last_date = _find_last_date(items)
            for path in items:
                if last_date in path:
                    if 'B04_10m.jp2' in path:
                        filtered['BO4'] += [path]
                    if 'B08_10m.jp2' in path:
                        filtered['BO8'] += [path]
                    if 'B12_20m.jp2' in path:
                        filtered['B12'] += [path]
                    if 'TCI_10m.jp2' in path:
                        filtered['TCI'] += [path]
        except Exception as ex:
            print(f"Error for {tile}: {str(ex)}")
    return filtered

In [None]:
filtered_tiles = filter_by_date(loadings)

In [None]:
filtered_tiles

In [None]:
if not filtered_tiles:
    geojson_path = os.path.join(OUTPUT_FOLDER, f"{START_DATE}_{END_DATE}_no_data.geojson")
    dump_no_data_geojson(aoi.geometry[0], geojson_path)
    raise ValueError("Images not loaded for given AOI. Change dates, constraints")

In [None]:
b04_tile = stitch_tiles(filtered_tiles['BO4'], filtered_tiles['BO4'][0].replace('.jp2', '_merged.tif'))
b08_tile = stitch_tiles(filtered_tiles['BO8'], filtered_tiles['BO8'][0].replace('.jp2', '_merged.tif'))
b12_tile = stitch_tiles(filtered_tiles['B12'], filtered_tiles['B12'][0].replace('.jp2', '_merged.tif'))
tci_tile = stitch_tiles(filtered_tiles['TCI'], filtered_tiles['TCI'][0].replace('.jp2', '_merged.tif'))

In [None]:
calculate_ndmi(b08_tile, b12_tile, out_path=NDMI_PATH, nodata=np.nan)
calculate_ndvi(b04_tile, b08_tile, out_path=NDVI_PATH, nodata=np.nan)

In [None]:
with rasterio.open(tci_tile) as src:
    tci_image, tfs = riomask.mask(
        src, aoi.to_crs(src.crs).geometry, all_touched=False, crop=True)
    
with rasterio.open(NDMI_PATH) as src:
    ndmi, tfs = riomask.mask(
        src, aoi.to_crs(src.crs).geometry, all_touched=False, crop=True)
    meta = src.meta
    meta['transform'] = tfs
    meta['width'] = ndmi.shape[-1]
    meta['height'] = ndmi.shape[-2]
    
with rasterio.open(NDVI_PATH) as src:
    ndvi, _ = riomask.mask(
        src, aoi.to_crs(src.crs).geometry, all_touched=False, crop=True)

In [None]:
# class_name -> [[(ndmi_range), (ndvi_range)], [(ndmi_range), (ndvi_range)], ... ]

class_names = {
    
    "no water stress": [
        [(-0.6, -0.4), (0.1, 0.2)],
        [(-0.4, -0.2), (0.1, 0.3)],
        [(-0.2, 0.0), (0.1, 0.3)],
        [(0.0, 0.2), (0.1, 0.4)],
        [(0.2, 0.4), (0.1, 0.4)],
        [(0.4, 0.6), (0.7, 1.0)],
        [(0.6, 0.8), (0.6, 1.0)],
        [(0.8, 1.0), (0.5, 1.0)],
        
    ],
    
    "low water stress": [
        [(-0.6, -0.4), (0.2, 0.3)],
        [(-0.4, -0.2), (0.3, 0.4)],
        [(-0.2, 0.0), (0.3, 0.5)],
        [(0.0, 0.2), (0.4, 0.7)],
        [(0.2, 0.4), (0.4, 0.5)],
        [(0.4, 0.6), (0.4, 0.7)],
        [(0.6, 0.8), (0.3, 0.6)],
        [(0.8, 1.0), (0.1, 0.5)],
    ],
    
    "high water stress" : [
        [(-0.6, -0.4), (0.3, 0.6)],
        [(-0.4, -0.2), (0.4, 0.6)],
        [(-0.2, 0.0), (0.5, 0.7)],
        [(0.0, 0.2), (0.7, 0.9)],
        [(0.2, 0.4), (0.5, 0.9)],
        [(0.4, 0.6), (0.1, 0.4)],
        [(0.6, 0.8), (0.1, 0.3)],
    ],
    
    "drought": [
        [(-0.4, -0.2), (0.6, 1.0)],
        [(-0.6, -0.4), (0.6, 1.0)],
        [(-0.2, 0.0), (0.7, 1.0)],
        [(0.0, 0.2), (0.9, 1.0)],
        [(0.2, 0.4), (0.9, 1.0)]
    ]
}

NUM_CLASSES = len(class_names)
arr = np.array(range(0, NUM_CLASSES)) / NUM_CLASSES

colors = [
    (138, 206, 126),
    (48, 145, 67),
    (255, 218, 102),
    (182, 10, 28),
]

labels = []

mask = np.zeros((ndmi[0].shape[-2], ndmi[0].shape[-1], 3)).astype(np.uint8)
for idx, (name, values) in enumerate(class_names.items()):
    class_area = 0
    for pix_vals in values:
        
        ndmi_pix, ndvi_pix = pix_vals
        class_area += np.where(((ndmi[0] >= ndmi_pix[0])&(ndmi[0] <= ndmi_pix[1])) & ((ndvi[0] >= ndvi_pix[0])&(ndvi[0] <= ndvi_pix[1])), 1, 0).sum() / 10**4 
        mask[((ndmi[0] >= ndmi_pix[0])&(ndmi[0] <= ndmi_pix[1])) & ((ndvi[0] >= ndvi_pix[0])&(ndvi[0] <= ndvi_pix[1]))] = colors[idx]

    labels.append({
        "color": ",".join(list(map(lambda x: str(int(x)), colors[idx]))),
        "name": name,
        "area": round(class_area, 3)
    })

In [None]:
labels = json.dumps(labels)
mask = mask.astype(np.float32)
labels

In [None]:
meta.update(
    count=3,
    nodata=0,
    compress='lzw',
    photometric='RGB'
)

result_name = f"moisture_anomaly_{START_DATE}_{END_DATE}.tif"
colored_tif = os.path.join(OUTPUT_FOLDER, result_name)
tci_tif = os.path.join(OUTPUT_FOLDER, f"tci_tile_{START_DATE}_{END_DATE}.tif")

with rasterio.open(colored_tif, 'w', **meta) as dst:
    dst.update_tags(start_date=START_DATE, 
                    end_date=END_DATE, 
                    request_id=REQUEST_ID,
                    labels=labels,
                    name=NAME)

    for i in range(mask.shape[-1]):
        dst.write(mask[:,:,i], indexes=i+1)

with rasterio.open(tci_tif, 'w', **meta) as dst:
    tci_image = tci_image.astype(np.float32)
    dst.update_tags(start_date=START_DATE, 
                    end_date=END_DATE, 
                    request_id=REQUEST_ID,
                    labels=labels,
                    name=NAME)

    for i in range(mask.shape[-1]):
        dst.write(tci_image[i,:,:], indexes=i+1)