In [None]:
import os
import shutil
import time
import numpy as np
import geojson
import geopandas as gpd
import shapely.wkt
import rasterio
import rasterio.mask
import re 
from datetime import datetime
from pathlib import Path
from geojson import Feature
from sentinel2download.downloader import Sentinel2Downloader
from sentinel2download.overlap import Sentinel2Overlap

from code.index_research import calculate_ndvi
from code.plant_stress import PlantStress
from code.utils import stitch_tiles

import warnings
warnings.filterwarnings('ignore')
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [None]:
REQUEST_ID=os.getenv('REQUEST_ID')
START_DATE=os.getenv('START_DATE')
END_DATE=os.getenv('END_DATE')
AOI=os.getenv('AOI')
SENTINEL2_GOOGLE_API_KEY=os.getenv('SENTINEL2_GOOGLE_API_KEY')
SATELLITE_CACHE_FOLDER=os.getenv('SENTINEL2_CACHE')
OUTPUT_FOLDER=os.getenv('OUTPUT_FOLDER')

In [None]:
default_crs = 'EPSG:4326'

polygon_aoi = shapely.wkt.loads(AOI)
aoi_filename = f"{time.time()}_aoi.geojson"
aoi_gdf = gpd.GeoDataFrame(gpd.GeoSeries([polygon_aoi]), columns=["geometry"], crs=default_crs)
aoi_gdf.to_file(aoi_filename, driver="GeoJSON")

In [None]:
BASE = os.getcwd()
BANDS = {'B04', 'B08', 'CLD'}
NODATA_PIXEL_PERCENTAGE = 10.0
SEARCH_CLOUDY_PIXEL_PERCENTAGE = 50.0
AOI_CLOUDY_PIXEL_PERCENTAGE = 15.0
CONSTRAINTS = {'CLOUDY_PIXEL_PERCENTAGE': SEARCH_CLOUDY_PIXEL_PERCENTAGE}
PRODUCT_TYPE = 'L2A'
NAME = 'Field anomalies'
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

In [None]:
b04_tiles, b08_tiles, tci_tiles = [], [], []
s2overlap = Sentinel2Overlap(aoi_path=aoi_filename)
overlap_tiles = s2overlap.overlap_with_geometry()

In [None]:
def dump_no_data_geojson(polygon, geojson_path):
    NO_DATA = 'No data'
    style = dict(color='red')
    feature = Feature(geometry=polygon, properties=dict(label=NO_DATA, style=style))
    feature['start_date'] = START_DATE
    feature['end_date'] = END_DATE
    feature['name'] = NO_DATA
    
    with open(geojson_path, 'w') as f:
        geojson.dump(feature, f)

def check_nodata_percentage_crop(tile_path, 
                                 aoi, 
                                 nodata_percentage_limit, 
                                 nodata):
    with rasterio.open(tile_path) as src:
        polygon = aoi.to_crs(src.meta['crs']).geometry[0]
        band, _ = rasterio.mask.mask(src, [polygon], crop=True, filled=False, indexes=1)
        masked_band = band[~band.mask]
        nodata_count = np.count_nonzero(masked_band == nodata)
        nodata_percentage = round(nodata_count / masked_band.size * 100, 2)
    if nodata_percentage>=nodata_percentage_limit:
        return True
    else:
        return False

def check_cloud_percentage_crop(tile_path, 
                                aoi, 
                                cloud_percentage_limit,
                                cloud_probability=50):
    with rasterio.open(tile_path) as src:
        polygon = aoi.to_crs(src.meta['crs']).geometry[0]
        band, _ = rasterio.mask.mask(src, [polygon], crop=True, filled=False, indexes=1)
        masked_band = band[~band.mask]
        cloud_count = np.count_nonzero(masked_band >= cloud_probability)
        cloud_percentage = round(cloud_count / masked_band.size * 100, 2)
    if cloud_percentage>=cloud_percentage_limit:
        return True
    else:
        return False

def check_tile_validity(tile_folder, aoi, cloud_percentage_limit, nodata_percentage_limit):
    band_paths = [os.path.join(tile_folder, i) for i in os.listdir(tile_folder)]
    skip_tile = False
    for band_path in band_paths:
        if  '.jp2' != Path(band_path).suffix:
            continue
        if "MSK_CLDPRB_20m" in band_path:
            cloud_check = check_cloud_percentage_crop(band_path, aoi, cloud_percentage_limit)
            if cloud_check:
                skip_tile=True
                break
        else:
            nodata_check = check_nodata_percentage_crop(band_path, aoi, nodata_percentage_limit, 0)
            if nodata_check:
                skip_tile=True
                break
    return skip_tile, band_paths

def validate_tile_downloads(loaded, tile, loadings, aoi, cloud_percentage_limit, nodata_percentage_limit):
    print(f"Validating images for tile: {tile}...")
    if not loaded:
        print(f"Images for tile {tile} were not loaded!")
        return loadings
    loaded_tile_folders = set([Path(i[0]).parent for i in loaded])
    tile_bands = []
    for loaded_tile_folder in loaded_tile_folders:
        skip_tile, band_paths = check_tile_validity(loaded_tile_folder, aoi, cloud_percentage_limit, nodata_percentage_limit)
        if skip_tile:
            shutil.rmtree(loaded_tile_folder)
        else:
            tile_bands += band_paths
    if tile_bands:
        loadings[tile] = tile_bands
    else:
        print(f"Tile images didn't match nodata/cloud constraints, so they were removed") 
    print(f"Validating images for tile {tile} finished")  
    return loadings


In [None]:
def load_images(tiles, start_date, end_date, aoi):
    loader = Sentinel2Downloader(SENTINEL2_GOOGLE_API_KEY)
    loadings = dict()

    for tile in tiles:
        print(f"Loading images for tile: {tile}...")
        loaded = loader.download(PRODUCT_TYPE,
                            [tile],
                            start_date=start_date,
                            end_date=end_date,
                            output_dir=SATELLITE_CACHE_FOLDER,               
                            bands=BANDS,
                            constraints=CONSTRAINTS)
        print(f"Loading images for tile {tile} finished")
        loadings = validate_tile_downloads(loaded, tile, loadings, aoi, AOI_CLOUDY_PIXEL_PERCENTAGE, NODATA_PIXEL_PERCENTAGE)
    return loadings

In [None]:
loadings = load_images(overlap_tiles.Name.values, START_DATE, END_DATE, aoi_gdf)

In [None]:
loadings

In [None]:
def filter_by_date(loadings):
    def _find_last_date(folders):        
        dates = list()
        for folder in folders:        
            search = re.search(r"_(\d+)T\d+_", str(folder))
            date = search.group(1)
            date = datetime.strptime(date, '%Y%m%d')
            dates.append(date)    
        last_date = max(dates)
        last_date = datetime.strftime(last_date, '%Y%m%d')
        return last_date
    
    filtered = {
        'BO4': [],
        'BO8': []
    }
    for tile, items in loadings.items():
        try:
            last_date = _find_last_date(items)
            for path in items:
                if last_date in path:
                    if 'B04_10m.jp2' in path:
                        filtered['BO4'] += [path]
                    if 'B08_10m.jp2' in path:
                        filtered['BO8'] += [path]
        except Exception as ex:
            print(f"Error for {tile}: {str(ex)}")
    return filtered

In [None]:
filtered = filter_by_date(loadings)

In [None]:
filtered

In [None]:
if not filtered['BO4'] or not filtered['BO8']:
    geojson_path = os.path.join(OUTPUT_FOLDER, f"{START_DATE}_{END_DATE}_no_data.geojson")
    dump_no_data_geojson(aoi_gdf.geometry[0], geojson_path)
    raise ValueError("Images not loaded for given AOI. Change dates, constraints")

In [None]:
if len(filtered['BO4']) > 1:
    b04_tile = stitch_tiles(filtered['BO4'], filtered['BO4'][0].replace('.jp2', '_merged.tif'))
    b08_tile = stitch_tiles(filtered['BO8'], filtered['BO8'][0].replace('.jp2', '_merged.tif'))
else:
    b04_tile = filtered['BO4'][0]
    b08_tile = filtered['BO8'][0]

In [None]:
ndvi_path = os.path.join(BASE, f'{START_DATE}_{END_DATE}_ndvi.tif')
if not os.path.exists(ndvi_path):
    calculate_ndvi(b04_tile, b08_tile, out_path=ndvi_path, nodata=np.nan)

In [None]:
min_ndvi = 0.3
z_score = 5
z_score_anom = 1
colors = {"Normal Growth": (0, 0, 0), "Anomaly": (182, 10, 28)}

field = gpd.read_file(aoi_filename)

In [None]:
ps = PlantStress(min_ndvi=min_ndvi, noise_z_score=z_score, anomaly_z_score=z_score_anom)

In [None]:
raster_path = os.path.join(OUTPUT_FOLDER, f'{START_DATE}_{END_DATE}_field.tif')
ps.segment_field(NAME, field, ndvi_path, raster_path, START_DATE, END_DATE, REQUEST_ID)