# TITLE

In [1]:
import helper
import geojson
import satsearch
import intake
import numpy as np
import geopandas as gpd
import rioxarray 
import datetime as dt
import rasterio

import json
import geojson
import numpy
from osgeo import osr
import shapely
import geopandas as gpd
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr

def no_data_check(ds_array):
    masked = np.zeros((ds_array.shape[0], ds_array.shape[1]))
    masked[ds_array == 0] = 1
    percent_valid = (1 - sum(sum(masked)) / (ds_array.shape[0] * ds_array.shape[1])) * 100
    return percent_valid

def scl_from_cog(tile_item):
    ds = tile_item.SCL(chunks=dict(band=1, x=2048, y=2048)).to_dask()
    return ds

def get_tile_crs(tile_item):
    crs_out = tile_item.metadata['proj:epsg']
    return crs_out

def subset_cog(aoi_geojson,cog_ds,tile_crs):
    data = gpd.read_file(aoi_geojson)
    field_data = data.to_crs(epsg=tile_crs)
    bbox_transform = field_data.bounds
    subset = cog_ds.sel(y=slice(int(bbox_transform['maxy']),int(bbox_transform['miny'])), x=slice(int(bbox_transform['minx']),int(bbox_transform['maxx'])))
    return subset

def cloud_shadow_check(ds_array):
    masked = np.zeros((ds_array.shape[0], ds_array.shape[1]))
    masked[ds_array == 3] = 1
    masked[(ds_array >= 8)] = 1
    percent_valid = (1 - sum(sum(masked)) / (ds_array.shape[0] * ds_array.shape[1])) * 100
    return percent_valid

def bands_from_cog(bands,tile_item):
    stack = tile_item.stack_bands(bands)
    da = stack(chunks=dict(band=1, x=2048, y=2048)).to_dask()
    da['band'] = bands
    ds = da.to_dataset(dim='band')
    return ds

def subset_cog(aoi_geojson,cog_ds,tile_crs):
    data = gpd.read_file(aoi_geojson)
    field_data = data.to_crs(epsg=tile_crs)
    bbox_transform = field_data.bounds
    subset = cog_ds.sel(y=slice(int(bbox_transform['maxy']),int(bbox_transform['miny'])), x=slice(int(bbox_transform['minx']),int(bbox_transform['maxx'])))
    return subset

def reproject_cog(cog_ds,cog_crs,target_crs):
    crs = "EPSG:" + str(cog_crs)
    cog_ds = cog_ds.rio.write_crs(crs)
    crs = "EPSG:" + str(target_crs)
    ds = cog_ds.rio.reproject(crs)
    return ds
    
def rescale_data(scale_factor,ds):
    temp_x = np.asarray(ds.x)
    temp_y = np.asarray(ds.y)
    new_x = np.linspace(start=np.min(temp_x), stop = np.max(temp_x), num = np.size(temp_x)*scale_factor)
    new_y = np.linspace(start=np.min(temp_y), stop = np.max(temp_y), num = np.size(temp_y)*scale_factor)
    ds_out = ds.interp(x=new_x, y=new_y)
    return ds_out

def clip_to_field(aoi_geojson,cog_ds):
    data = gpd.read_file(aoi_geojson)
    clipped = cog_ds.rio.clip(data.geometry, data.crs, drop=False)
    clipped = clipped.where(clipped != -9999, np.nan)
    return clipped

class Sentinel2:
    def __init__(self, aoi, start_date, end_date, field_name):
        self.aoi = helper.normalize_geojson(aoi)
        self.field_path = aoi
        self.field_name = field_name
        self.start_date = start_date
        self.end_date = end_date
        self.main_dir = 's2_generated'

        self.bbox = self._calc_aoi_bbox()

    def _query_cogs(self, date):
        date_string = date.strftime("%Y-%m-%d")

        results = satsearch.Search.search(url='https://earth-search.aws.element84.com/v0',
                                    collections=['sentinel-s2-l2a-cogs'],
                                    datetime=date_string,
                                    bbox=self.bbox,    
                                    sort=['<datetime'])
        
        return results
    
    def dates_and_cloud_check(self, dates, check_clouds = True):
        valid_dates = []

        for date in dates:
            results = self._query_cogs(date)
            # print (date, results.found())
            if not results.found():
                continue
            
            items = results.items()
            catalog_out = intake.open_stac_item_collection(items)

            if not check_clouds:
                valid_dates.append(date)
                continue

            valid_percent = []
            cloudy_percent = []

            for i in range(len(catalog_out)):   
                ds_scl = scl_from_cog(tile_item = catalog_out[list(catalog_out)[i]])
                cog_crs = get_tile_crs(tile_item = catalog_out[list(catalog_out)[i]])
                scl_field = subset_cog(self.field_path,ds_scl,cog_crs)
                valid_percent.append(no_data_check(np.squeeze(np.asarray(scl_field))))
                cloudy_percent.append(cloud_shadow_check(np.squeeze(np.asarray(scl_field))))

            # print (valid_percent, cloudy_percent)            
            index_highest_valid_pixels = np.argmax(valid_percent)

            if valid_percent[index_highest_valid_pixels]>90 and cloudy_percent[index_highest_valid_pixels]>90:
                valid_dates.append(date)
                # print ('appended')

        return valid_dates

    def _calc_aoi_bbox(self):
        aoi_coords = list(geojson.utils.coords(self.aoi))

        bbox = []
        for i in (0, 1):
            res = sorted(aoi_coords, key=lambda x: x[i])
            bbox.append((res[0][i], res[-1][i]))

        aoi_bbox = [bbox[0][0], bbox[1][0], bbox[0][1], bbox[1][1]]
    
        return aoi_bbox

    def get_dates_of_available_images(self, check_clouds = True):
        date_list = [self.start_date  + dt.timedelta(x) for x in range((self.end_date-self.start_date).days)]   

        return self.dates_and_cloud_check(date_list, check_clouds)
            
    def get_all_bands(self, date):
        # band_list  = ['red', 'green', 'blue', 'nir']
        band_list = ['red', 'green', 'blue', 'B05', 'B06', 'B07', 'nir','B8A', 'B09', 'swir16', 'swir22']

        results = self._query_cogs(date)
        
        if results.found() == 0:
            #raise error here
            print ('no data for this date')
            exit()

        items = results.items()
        catalog = intake.open_stac_item_collection(items)

        cog_crs = get_tile_crs(tile_item = catalog[list(catalog)[0]])       

        bands = {}
        for band_name in band_list:
            bands[band_name] = bands_from_cog(bands = [band_name],tile_item = catalog[list(catalog)[0]])

        bboxes = {}
        for band_name in band_list:
            band = bands[band_name]
            band_bbox = subset_cog(self.field_path, band, cog_crs)
            bboxes[band_name] = band_bbox

        field_crs = 4326
        target_crs = 3857
        
        final_bands = {}
        for band_name in band_list:
            band_bbox = bboxes[band_name]
            band_bbox_reproject = reproject_cog(band_bbox, cog_crs, target_crs = field_crs)
            band_field = clip_to_field(self.field_path,band_bbox_reproject)
            band_field_reproject = reproject_cog(band_field,field_crs,target_crs)

            final_bands[band_name] = band_field_reproject

        return final_bands

In [2]:
def get_s2_data_for_geojson(geojson_path, query_date, field_name):
    """
    function to download Sentinel 2 AOI for the provided geojson file

    args -
        geojson_path = path to geojson file
        query_date = the date for which prediction is required
        field_name = name of the field that we are downloading data for
    returns -
        list of the following -
        [
        ]
    """
    query_date = query_date
    start_date = query_date
    end_date = query_date + dt.timedelta(days=1)

    #get the band values for the query_date
    s2 = Sentinel2(aoi=geojson_path, start_date=start_date, end_date=end_date, field_name=field_name)
    bands = s2.get_all_bands(query_date)
    
    print ('adsads', bands.keys())
    return bands


In [3]:
query_date = "2019-07-01"
query_date = dt.datetime.strptime(query_date, "%Y-%m-%d")

geojson_path = "8755.geojson"
s2_data = get_s2_data_for_geojson(geojson_path, query_date, "test")


adsads dict_keys(['red', 'green', 'blue', 'B05', 'B06', 'B07', 'nir', 'B8A', 'B09', 'swir16', 'swir22'])


In [4]:
print (s2_data.keys())

dict_keys(['red', 'green', 'blue', 'B05', 'B06', 'B07', 'nir', 'B8A', 'B09', 'swir16', 'swir22'])


In [5]:
band_list  = ['red', 'green', 'blue', 'B05', 'B06', 'B07', 'nir', 'B8A', 'B09', 'swir16', 'swir22']

for band_name in band_list:
    print (s2_data[band_name][band_name].shape)

(30, 38)
(30, 38)
(30, 38)
(14, 18)
(14, 18)
(14, 18)
(30, 38)
(14, 18)
(5, 6)
(14, 18)
(14, 18)


In [None]:
band_list  = ['red', 'green', 'blue', 'B05', 'B06', 'B07', 'nir', 'B08', 'B09', 'swir16', 'swir22']


In [None]:
for band in band_list:
    print (s2_data[band])