https://eo-learn.readthedocs.io/en/latest/examples/land-cover-map/SI_LULC_pipeline.html

In [None]:
# Firstly, some necessary imports

# Jupyter notebook related
# %reload_ext autoreload
# %autoreload 2
%matplotlib inline

# Built-in modules
import pickle
import sys
import os
import datetime
import itertools
from aenum import MultiValueEnum

# Basics of Python data handling and visualization
import numpy as np
np.random.seed(42)
import geopandas as gpd
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.colors import ListedColormap, BoundaryNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable
from shapely.geometry import Polygon
from tqdm.auto import tqdm

# Imports from eo-learn and sentinelhub-py
from eolearn.core import EOTask, EOPatch, LinearWorkflow, FeatureType, OverwritePermission, \
    LoadTask, SaveTask, EOExecutor, ExtractBandsTask, MergeFeatureTask
from eolearn.io import SentinelHubInputTask, ExportToTiff
from eolearn.mask import AddMultiCloudMaskTask, AddValidDataMaskTask
from eolearn.geometry import VectorToRaster, PointSamplingTask, ErosionTask
from eolearn.features import LinearInterpolation, SimpleFilterTask, NormalizedDifferenceIndexTask
from sentinelhub import UtmZoneSplitter, BBox, CRS, DataCollection

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  ProcApiType("bool_mask", 'DN', 'UINT8', np.bool, FeatureType.MASK): [


In [None]:
# Load geojson file
ice_gdf = gpd.read_file('data/cis_SGRDRHB_20200525T1800Z_pl_a.shp')#load shape file inclusing sea ice concentration data
ice_gdf = ice_gdf.to_crs('EPSG:32617')

region = ice_gdf.dropna()#remove polygons where there is not data (corresponds to land)
region = region.geometry.unary_union#get the untion of all the water area
region = gpd.GeoDataFrame(geometry=[region], crs=ice_gdf.crs)

# region = region.buffer(500)

# Get the country's shape in polygon format
region_shape = region.geometry.values[-1]

# Plot country
region.plot()

# Print size
print('Dimension of the area is {0:.0f} x {1:.0f} m2'.format(region_shape.bounds[2] - region_shape.bounds[0],
                                                             region_shape.bounds[3] - region_shape.bounds[1]))

In [None]:
from sentinelhub import BBoxSplitter
# Create the splitter to obtain a list of bboxes
bbox_splitter = BBoxSplitter([region_shape], region.crs, (35, 35))

bbox_list = np.array(bbox_splitter.get_bbox_list())
info_list = np.array(bbox_splitter.get_info_list())
for n, info in enumerate(info_list):
    info['index']=n

# Prepare info of selected EOPatches
geometry = [Polygon(bbox.get_polygon()) for bbox in bbox_list]
idxs = [info['index'] for info in info_list]
idxs_x = [info['index_x'] for info in info_list]
idxs_y = [info['index_y'] for info in info_list]

gdf = gpd.GeoDataFrame({'index': idxs, 'index_x': idxs_x, 'index_y': idxs_y},
                           crs=region.crs,
                           geometry=geometry)

In [None]:
# select a 3x3 area (id of center patch)
ID = 294

# Obtain surrounding 5x5 patches
patchIDs = []
for idx, [bbox, info] in enumerate(zip(bbox_list, info_list)):
    if (abs(info['index_x'] - info_list[ID]['index_x']) <= 1 and
        abs(info['index_y'] - info_list[ID]['index_y']) <= 1):
        patchIDs.append(idx)

# Check if final size is 3x3
if len(patchIDs) != 3*3:
    print('Warning! Use a different central patch ID, this one is on the border.')

# Change the order of the patches (used for plotting later)
patchIDs = np.transpose(np.fliplr(np.array(patchIDs).reshape(3, 3))).ravel()

# save to shapefile
shapefile_name = './grid_sea_ice.gpkg'
gdf.to_file(shapefile_name, driver='GPKG')

# figure
fig, ax = plt.subplots(figsize=(40, 40))
gdf.plot(ax=ax,facecolor='w',edgecolor='r',alpha=0.5)
region.plot(ax=ax, facecolor='w',edgecolor='b',alpha=0.5)
ax.set_title('Selected 3x3 Tiles', fontsize=25);
for bbox, info in zip(bbox_list, info_list):
    geo = bbox.geometry
    ax.text(geo.centroid.x, geo.centroid.y, info['index'], ha='center', va='center', size=18)

gdf[gdf.index.isin(patchIDs)].plot(ax=ax,facecolor='g',edgecolor='r',alpha=0.5)

plt.axis('off');

In [None]:
class SentinelHubValidData:
    """
    Combine Sen2Cor's classification map with `IS_DATA` to define a `VALID_DATA_SH` mask
    The SentinelHub's cloud mask is asumed to be found in eopatch.mask['CLM']
    """
    def __call__(self, eopatch):
        return np.logical_and(eopatch.mask['IS_DATA'].astype(np.bool),
                              np.logical_not(eopatch.mask['CLM'].astype(np.bool)))

class CountValid(EOTask):
    """
    The task counts number of valid observations in time-series and stores the results in the timeless mask.
    """
    def __init__(self, count_what, feature_name):
        self.what = count_what
        self.name = feature_name

    def execute(self, eopatch):
        eopatch.add_feature(FeatureType.MASK_TIMELESS, self.name, np.count_nonzero(eopatch.mask[self.what],axis=0))

        return eopatch

In [None]:
#task for downloading saetlite data
band_names = ['B03', 'B04', 'B08']
add_data = SentinelHubInputTask(
    bands_feature=(FeatureType.DATA, 'BANDS'),
    bands = band_names,
    resolution=200,
    maxcc=0.8,
    time_difference=datetime.timedelta(minutes=120),
    data_collection=DataCollection.SENTINEL2_L1C,
    additional_data=[(FeatureType.MASK, 'dataMask', 'IS_DATA'),
                     (FeatureType.MASK, 'CLM'),]
)

# TASK FOR SAVING TO OUTPUT (if needed)
path_out = './eopatches/'
if not os.path.isdir(path_out):
    os.makedirs(path_out)
save = SaveTask(path_out, overwrite_permission=OverwritePermission.OVERWRITE_PATCH)

# TASK FOR SAVING TO OUTPUT (if needed)
path_out = './eopatches/'
if not os.path.isdir(path_out):
    os.makedirs(path_out)
save = SaveTask(path_out, overwrite_permission=OverwritePermission.OVERWRITE_PATCH)

# #adding mask
# land_use_ref = gpd.read_file('data/ice_charts/2019/cis_SGRDRHB_20190520T1800Z_pl_b_20190522205156/cis_SGRDRHB_20190520T1800Z_pl_b.shp')
# land_use_ref.fillna(-100, inplace=True)
# land_use_ref['CT'] = land_use_ref['CT'].astype('int')
# land_use_ref = land_use_ref.to_crs(region.crs)

# rasterization_task = VectorToRaster(land_use_ref, (FeatureType.MASK_TIMELESS, 'LULC'),
#                                     values_column='CT', raster_shape=(FeatureType.DATA, 'BANDS'),
#                                     raster_dtype=np.int16, no_data_value=-100)

In [None]:
def calculate_valid_data_mask(eopatch):
    is_data_mask = eopatch.mask['IS_DATA'].astype(np.bool)
    cloud_mask = ~eopatch.mask['CLM'].astype(np.bool)
    return np.logical_and(is_data_mask, cloud_mask)

add_valid_mask = AddValidDataMaskTask(predicate=calculate_valid_data_mask, valid_data_feature='VALID_DATA')

def calculate_coverage(array):
    return 1.0 - np.count_nonzero(array) / np.size(array)

class AddValidDataCoverage(EOTask):

    def execute(self, eopatch):

        valid_data = eopatch.get_feature(FeatureType.MASK, 'VALID_DATA')
        time, height, width, channels = valid_data.shape

        coverage = np.apply_along_axis(calculate_coverage, 1,
                                       valid_data.reshape((time, height * width * channels)))

        eopatch.add_feature(FeatureType.SCALAR, 'COVERAGE', coverage[:, np.newaxis])
        return eopatch

add_coverage = AddValidDataCoverage()

cloud_coverage_threshold = 0.10

class ValidDataCoveragePredicate:

    def __init__(self, threshold):
        self.threshold = threshold

    def __call__(self, array):
        return calculate_coverage(array) < self.threshold

remove_cloudy_scenes = SimpleFilterTask((FeatureType.MASK, 'VALID_DATA'),
                                        ValidDataCoveragePredicate(cloud_coverage_threshold))

In [None]:
workflow = LinearWorkflow(
    add_data,
    add_valid_mask,
    add_coverage, 
    remove_cloudy_scenes,
    time_raster,
    save,)
# Let's visualize it
workflow.dependency_graph()

In [None]:
%%time

# Execute the workflow
time_interval = ['2019-05-01', '2019-5-30'] # time interval for the SH request

# define additional parameters of the workflow
execution_args = []
for idx, bbox in enumerate(bbox_list[patchIDs]):
    execution_args.append({
        add_data:{'bbox': bbox, 'time_interval': time_interval},
        save: {'eopatch_folder': f'eopatch_{idx}'}
    })
#     print(execution_args[0])

executor = EOExecutor(workflow, execution_args, save_logs=True)
executor.run()

executor.make_report()

In [None]:
for id in range(9):
    try:
        eopatch = EOPatch.load('./eopatches/eopatch_{}/'.format(id),  lazy_loading=True)
        # fig, ax = plt.subplots(1, figsize=(10,10))
        # ax.imshow(eopatch.mask['VALID_DATA'][0])
        print(id)
        print(eopatch.timestamp)
    except:
        pass

In [None]:
eopatch = EOPatch.load('./eopatches/eopatch_5/',  lazy_loading=True)

chart_dir = "./data/ice_charts/2019"
base_len=31#the base naming of the folder and files(without the extension)

chart_folders = [name for name in os.listdir(chart_dir)]#all the sea ice chart folders available
chart_dates = np.array([datetime.datetime.strptime(name.split('_')[2][:8], '%Y%m%d') for name in os.listdir(chart_dir)])#all the dates available
chart_dates = chart_dates+datetime.timedelta(hours=12)#add 12 hours to the date so it is in the middle of the day

def get_nearest_chart_dates(eop):#given an eopatch, return the dates of the closest available ice charts for each image
    sat_dates = np.array(eop.timestamp)
    mask_dates=[]

    for date in sat_dates:
        closest_date_id = np.argsort(abs(date-chart_dates))[0]
        closest_date = chart_dates[closest_date_id]
        mask_dates.append(closest_date)
    return(mask_dates)

def get_path(date):#get the file path of the sea ice shapefile corresponding to a date
    names = [name for name in chart_folders if date.strftime('%Y%m%d') in name]#get all the the folder names that match the date
    folder = max(names, key=len) #get the longest named folder (this is the most recent revision in case the folder was updated)
    if len(folder)>base_len:
        file = folder[:31]
    else: 
        file = folder
    return '/'+folder+'/'+file+'.shp'

m_dates = get_nearest_chart_dates(eopatch)
[get_path(date) for date in m_dates]
m_dates[0].year

In [None]:
eopatch.data['BANDS'].shape

In [None]:
class TimeRaster(EOTask):
    
    def __init__(self, chart_dir, base_len):
        self.chart_dir = chart_dir
        self.base_len = base_len

    def execute(self, eopatch):
        
        chart_folders = [name for name in os.listdir(chart_dir)]#all the sea ice chart folders available
        chart_dates = np.array([datetime.datetime.strptime(name.split('_')[2][:8], '%Y%m%d') for name in os.listdir(chart_dir)])#all the dates available
        chart_dates = chart_dates+datetime.timedelta(hours=12)#add 12 hours to the date so it is in the middle of the day
        
        #get the ice chart dates corresponding to the image dates
        sat_dates = np.array(eopatch.timestamp)#get the dates for each image in the eopatch
        mask_dates=[]#this will hold the dates of the ice chart associated with each satellite image
        for date in sat_dates:
            closest_date_id = np.argsort(abs(date-chart_dates))[0]
            closest_date = chart_dates[closest_date_id]#closest ice chart date to the image date
            mask_dates.append(closest_date)
        
        eopatch.add_feature(FeatureType.SCALAR_TIMELESS, 'TEST', np.array(mask_dates))
        return eopatch
#         chart_paths = []
#         #get the paths of the ice charts corresponding to the images
#         for date in mask_dates:
#             names = [name for name in chart_folders if date.strftime('%Y%m%d') in name]#get all the the folder names that match the date
#             folder = max(names, key=len) #get the longest named folder (this is the most recent revision in case the folder was updated)
#             if len(folder)>base_len:
#                 file = folder[:31]
#             else: 
#                 file = folder
#             chart_paths.append('/'+folder+'/'+file+'.shp')
        
        
#         add_raster = VectorToRaster(land_use_ref, (FeatureType.MASK_TIMELESS, 'TEST'),
#                                     values_column='CT', raster_shape=(FeatureType.DATA, 'BANDS'),
#                                     raster_dtype=np.int16, no_data_value=-100)
#         add_raster(eopatch)
        
        
        
        
        
        

#         valid_data = eopatch.get_feature(FeatureType.MASK, 'VALID_DATA')
#         time, height, width, channels = valid_data.shape

#         coverage = np.apply_along_axis(calculate_coverage, 1,
#                                        valid_data.reshape((time, height * width * channels)))

#         eopatch.add_feature(FeatureType.SCALAR, 'COVERAGE', coverage[:, np.newaxis])

time_raster = TimeRaster(chart_dir, base_len)

In [None]:
eopatch.mask_timeless

In [None]:
# Draw the RGB image
date = datetime.datetime(2019,5,14)
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(15, 8))

print(date)
eopatch = EOPatch.load('./eopatches/eopatch_5/',  lazy_loading=True)
dates = np.array(eopatch.timestamp)
closest_date_id = np.argsort(abs(date-dates))[0]
axs[0].imshow(np.clip(eopatch.data['BANDS'][closest_date_id][..., [2, 1, 0]] * 3.5, 0, 1))
axs[0].set_xticks([])
axs[0].set_yticks([])
axs[0].set_aspect("auto")
im = axs[1].imshow(eopatch.mask_timeless['LULC'].squeeze())
fig.colorbar(im, ax=axs[1], orientation='vertical')
axs[1].set_xticks([])
axs[1].set_yticks([])
axs[1].set_aspect("auto")
del eopatch

fig.subplots_adjust(wspace=0, hspace=0)

In [None]:
path_out = './eopatches'

fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(10, 10))

for i in tqdm(range(len(patchIDs))):
    eopatch = EOPatch.load(f'{path_out}/eopatch_{i}', lazy_loading=True)
    ax = axs[i//3][i%3]
    im = ax.imshow(eopatch.mask_timeless['LULC'].squeeze())#, cmap=lulc_cmap, norm=lulc_norm)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('auto')
    del eopatch

fig.subplots_adjust(wspace=0, hspace=0)

cb = fig.colorbar(im, ax=axs.ravel().tolist(), orientation='horizontal', pad=0.01, aspect=100)
# cb.ax.tick_params(labelsize=20)
# cb.set_ticks([entry.id for entry in LULC])
# cb.ax.set_xticklabels([entry.name for entry in LULC], rotation=45, fontsize=15)
plt.show()