# 3rdParties

In [None]:
import time
import os, ee, copy
from pprint import pprint
from tqdm import tqdm

# Authorize GEE

In [None]:
ee.Authenticate()
ee.Initialize(project='') # please change to your project id

# Self-defined Functions



## Data Generation Functions


In [None]:
################################################################################ for satellite imagery

sentinel_1 = ee.ImageCollection("COPERNICUS/S1_GRD")
dem = ee.Image("USGS/SRTMGL1_003")
gmted = ee.Image('USGS/GMTED2010_FULL').select(['be75'])
dem = dem.unmask(gmted)

def resampleImg(img):
    return img.resample('bilinear').reproject(crs=crs, scale=10).unmask(0)

def handle_empty_collection(collection, band_suffix):
    empty_image = ee.Image.constant([0, 0]).rename(['VV', 'VH'])

    def rename_bands(img):
        return img.rename(['VV' + band_suffix, 'VH' + band_suffix])

    return ee.ImageCollection(ee.Algorithms.If(
        collection.size(),
        collection.map(rename_bands),
        ee.ImageCollection([empty_image.rename(['VV' + band_suffix, 'VH' + band_suffix])])
    ))

def genSentinel1Data(geo_bounds, date, crs):
    sentinel_1_filtered = sentinel_1.filter(ee.Filter.bounds(geo_bounds)) \
                                    .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VV')) \
                                    .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VH')) \
                                    .filterDate(date[0], date[1]) \
                                    .select(['VV', 'VH'])
    if sentinel_1_filtered.size().getInfo() > 0:
        sentinel_1_filtered_iw = sentinel_1_filtered.filter(ee.Filter.eq('instrumentMode', 'IW')) \
                                                    .filter(ee.Filter.eq('resolution_meters', 10))
        if sentinel_1_filtered.size().getInfo() == 0:
            print('No Sentinel-1 IW data in specified date! Return Sentinel-1 EW data instead!', flush=True)
            sentinel_1_filtered_ew = sentinel_1_filtered.filter(ee.Filter.eq('instrumentMode', 'EW')) \
                                                        .filter(ee.Filter.eq('resolution_meters', 40)) \
                                                        .map(resampleImg)
            sentinel_1_filtered = sentinel_1_filtered_ew
        else:
            sentinel_1_filtered = sentinel_1_filtered_iw

    else:
        print('No Sentinel-1 data in specified date! Return image collection with 0 value instead!', flush=True)
        band_names = ['VV', 'VH']
        # Create a zero-filled "default" image with the same band names
        default_img = ee.Image([ee.Image.constant(0)] * len(band_names)).rename(band_names)
        sentinel_1_filtered = ee.ImageCollection([default_img])

    sentinel_1_final_ascending = handle_empty_collection(sentinel_1_filtered.filter(ee.Filter.eq('orbitProperties_pass', 'ASCENDING')), '_ascending')
    sentinel_1_final_descending = handle_empty_collection(sentinel_1_filtered.filter(ee.Filter.eq('orbitProperties_pass', 'DESCENDING')), '_descending')

    return sentinel_1_final_ascending, sentinel_1_final_descending

def genSentinel2Data(geo_bounds, date, bands):
    s2_sr_cld_col = getS2SRCldCol(geo_bounds, date[0], date[1])
    s2_sr = (s2_sr_cld_col.map(addCldShdwMask)
                          .map(applyCldShdwMask)
                          .select(bands))
    # print("Sentinel-2 Data Information", s2_sr.getInfo())
    return s2_sr

def getS2SRCldCol(geo_bounds, start_date, end_date):
    CLOUD_FILTER = 90
    # Import and filter S2 SR.
    s2_sr_col = (ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
                   .filterBounds(geo_bounds)
                   .filterDate(start_date, end_date)

                   .filter(ee.Filter.lte('CLOUDY_PIXEL_PERCENTAGE', CLOUD_FILTER)))

    # Import and filter s2cloudless.
    s2_cloudless_col = (ee.ImageCollection('COPERNICUS/S2_CLOUD_PROBABILITY')
                          .filterBounds(geo_bounds)
                          .filterDate(start_date, end_date))

    # Join the filtered s2cloudless collection to the SR collection by the 'system:index' property.
    return ee.ImageCollection(ee.Join.saveFirst('s2cloudless').apply(**{
        'primary': s2_sr_col,
        'secondary': s2_cloudless_col,
        'condition': ee.Filter.equals(**{
            'leftField': 'system:index',
            'rightField': 'system:index'
        })
    }))

def addCloudBands(img):
    CLD_PRB_THRESH = 50
    # Get s2cloudless image, subset the probability band.
    cld_prb = ee.Image(img.get('s2cloudless')).select('probability')

    # Condition s2cloudless by the probability threshold value.
    is_cloud = cld_prb.gt(CLD_PRB_THRESH).rename('clouds')

    # Add the cloud probability layer and cloud mask as image bands.
    return img.addBands(ee.Image([cld_prb, is_cloud]))

def addShadowBands(img):
    # Identify water pixels from the SCL band.
    not_water = img.select('SCL').neq(6)

    # Identify dark NIR pixels that are not water (potential cloud shadow pixels).
    SR_BAND_SCALE = 1e4
    NIR_DRK_THRESH = 0.15
    dark_pixels = img.select('B8').lt(NIR_DRK_THRESH*SR_BAND_SCALE).multiply(not_water).rename('dark_pixels')

    # Determine the direction to project cloud shadow from clouds (assumes UTM projection).
    shadow_azimuth = ee.Number(90).subtract(ee.Number(img.get('MEAN_SOLAR_AZIMUTH_ANGLE')));

    # Project shadows from clouds for the distance specified by the CLD_PRJ_DIST input.
    CLD_PRJ_DIST = 1
    cld_proj = (img.select('clouds').directionalDistanceTransform(shadow_azimuth, CLD_PRJ_DIST*10)
        .reproject(**{'crs': img.select(0).projection(), 'scale': 100})
        .select('distance')
        .mask()
        .rename('cloud_transform'))

    # Identify the intersection of dark pixels with cloud shadow projection.
    shadows = cld_proj.multiply(dark_pixels).rename('shadows')

    # Add dark pixels, cloud projection, and identified shadows as image bands.
    return img.addBands(ee.Image([dark_pixels, cld_proj, shadows]))

def addCldShdwMask(img):
    # Add cloud component bands.
    img_cloud = addCloudBands(img)

    # Add cloud shadow component bands.
    img_cloud_shadow = addShadowBands(img_cloud)

    # Combine cloud and shadow mask, set cloud and shadow as value 1, else 0.
    is_cld_shdw = img_cloud_shadow.select('clouds').add(img_cloud_shadow.select('shadows')).gt(0)

    # Remove small cloud-shadow patches and dilate remaining pixels by BUFFER input.
    # 20 m scale is for speed, and assumes clouds don't require 10 m precision.
    BUFFER = 50
    is_cld_shdw = (is_cld_shdw.focalMin(2).focalMax(BUFFER*2/20)
        .reproject(**{'crs': img.select([0]).projection(), 'scale': 20})
        .rename('cloudmask'))

    # Add the final cloud-shadow mask to the image.
    return img_cloud_shadow.addBands(is_cld_shdw)

def applyCldShdwMask(img):
    # Subset the cloudmask band and invert it so clouds/shadow are 0, else 1.
    not_cld_shdw = img.select('cloudmask').Not()

    # Subset reflectance bands and update their masks, return the result.
    return img.select('B.*').updateMask(not_cld_shdw)

def genTopoData():
    # dem = dem.unmask(gmted)
    terrain = ee.Terrain.products(dem)
    return terrain.select(['elevation', 'slope', 'aspect'])

################################################################################ for products

def preprocess_esa_2020():
    built_up_idx = 50
    esa = ee.Image('ESA/WorldCover/v100/2020')
    esa = esa.remap([built_up_idx], [1], 0)
    return esa

def preprocess_gisa_2019():
  gisa = ee.ImageCollection("projects/sat-io/open-datasets/GISA_1972_2019")
  gisa_2019 = gisa.map(get_gisa_2019)
  return gisa_2019

def get_gisa_2019(img):
    return img.lte(37)

def preprocess_wsf2019():
    wsf2019 = ee.ImageCollection('projects/mixpointsextracion/assets/wsf2019')
    settlement_idx = 255

    def remap_settlement(img):
        img = img.unmask(0)
        img = img.remap([settlement_idx], [1], 0)
        return img

    wsf2019 = wsf2019.map(remap_settlement)
    return wsf2019

esa2020 = preprocess_esa_2020()
gisa2019 = preprocess_gisa_2019().mosaic().unmask(0)
wsf2019 = preprocess_wsf2019().mosaic()

# Export Data

## General Information

In [None]:
patch_collection = ee.FeatureCollection('') # please specify your feature collection path
start_date, end_date = '2019-01-01', '2020-01-01'
output_gd_folder = 'datasets'
with_product = True

GROUP_SIZE = 100                        # each time, GROUP_SIZE tasks are submitted
THRESHOLD = 3000 - GROUP_SIZE           # when the number of running tasks reach THRESHOLD, stop submitting new tasks (3000 is the max number of running tasks)
INTERVAL_SECONDS = 60 * 60 * 0.5        # when the number of running tasks reach THRESHOLD, stop submitting new tasks and wait for INTERVAL_SECONDS

## General Function

In [None]:
def genMultimodalData(img_geo, crs, export_dict_basic, start_date, end_date):
    s2_coll = genSentinel2Data(img_geo, [start_date, end_date], ['B4', 'B3', 'B2', 'B8'])
    s1_coll_list = genSentinel1Data(img_geo, [start_date, end_date], crs)
    topo = genTopoData()

    s2 = s2_coll.median().divide(10000).toFloat()
    s1_ascending = s1_coll_list[0].mean()
    s1_descending = s1_coll_list[1].mean()
    s1 = s1_ascending.addBands(s1_descending).toFloat()
    topo = topo.resample('bilinear').reproject(crs=crs, scale=10).toFloat()

    return s2, s1, topo

def submitExportTask(stacked_data, export_dict_basic):
    export_dict_basic['image'] = stacked_data
    task = ee.batch.Export.image.toDrive(**export_dict_basic)
    task.start()

def exportMulimodalImagery(img_geo, crs, export_dict_basic, start_date, end_date):
    s2, s1, topo = genMultimodalData(img_geo, crs, export_dict_basic, start_date, end_date)
    stacked_data = s2.addBands(s1).addBands(topo)
    submitExportTask(stacked_data, export_dict_basic)

def exportMulimodalImageryNProducts(img_geo, crs, export_dict_basic, start_date, end_date):
    s2, s1, topo = genMultimodalData(img_geo, crs, export_dict_basic, start_date, end_date)
    stacked_data = s2.addBands(s1).addBands(topo).addBands(esa2020.toFloat()).addBands(gisa2019.toFloat()).addBands(wsf2019.toFloat())
    submitExportTask(stacked_data, export_dict_basic)

## Export patches

In [None]:
size = patch_collection.size().getInfo()
print('There are', size, 'patches to be exported!')
patch_list = patch_collection.toList(size)

start = 0
for group_start in range(start, size, GROUP_SIZE):
    group_end = min(group_start + GROUP_SIZE, size)
    print('Start exporting from', group_start, 'to', group_end)

    for i in tqdm(range(group_start, group_end)):
        img_feature = ee.Feature(patch_list.get(i))
        img_name = 'patch_' + str(group_start+i)
        img_geo = img_feature.geometry()
        crs = img_feature.get('crs').getInfo()

        export_dict_basic = {'description': img_name,
                             'region': img_geo,
                             'scale': 10,
                             'crs': crs,
                             'folder': output_gd_folder,
                             'maxPixels': 1e10}

        if not with_product:
            exportMulimodalImagery(img_geo, crs, export_dict_basic, start_date, end_date)
        else:
            exportMulimodalImageryNProducts(img_geo, crs, export_dict_basic, start_date, end_date)

    while True:
        tasks = ee.batch.Task.list()
        task_states = [task.state for task in tasks]
        num_running_and_ready_tasks = task_states.count('RUNNING') + task_states.count('READY')

        if num_running_and_ready_tasks < THRESHOLD:
            break

        time.sleep(INTERVAL_SECONDS)
