#  BA Cartography tool - BAMT updated and adapted to Python API (v0.5)  
<div class="alert alert-block alert-info">
    <ul style="color: black; font-size:17px">
        <li>BA Cartography tool (BAMT) was initially created in GEE code editor by 
            Ekhi Roteta and Aitor Bastarrika (University of the Basque Country - UPV/EHU), 
            and Magí Franquesa (University of Alcalá - UAH).</li>
        <li>Updated and automatized in python by Amin Khairoun (University of Alcalá - UAH)</li>
    <li> Date:   26/01/2023 </li>
    </ul>
</div>                                                      

## Loading libaries and intializing GEE
<div class="alert alert-block alert-info">
    <p style="color: black; font-size:17px">You need to install <b>geemap</b> if it is 
        not already done.</p>
</div>

In [74]:
import os
import geemap
import numpy as np
import pandas as pd
import geopandas as gpd
import time
from datetime import datetime, timedelta, timezone
import re
import glob
from shapely.geometry import box, linestring, polygon
import shutil


In [None]:
import ee
ee.Authenticate(force=False)


In [None]:
import Training_geometries_Amazonia_20181001_20190401
import Training_geometries_Amazonia_20190401_20191001
import Training_geometries_Amazonia_20040401_20050401
import Training_geometries_Siberia


<div class="alert alert-block alert-info">
    <p style="color: black; font-size:17px">Follow the link to Google authentication system
        and copy paste the token. Check <b>View and manage your Google Earth Engine data</b>
        and <b>Manage your data and permissions in Cloud Storage and see the email 
        address for your Google Account.</b></p>
</div>

In [77]:
def GeocolToMultipol(geometry):
    geometries = geometry.geometries().map(lambda g: ee.Geometry(g).coordinates()).getInfo()
    MultiPol = ee.Geometry.MultiPolygon(geometries)
    return MultiPol

In [None]:
tiles_2d = ee.FeatureCollection("users/ekhiroteta/BAMT/BAMT_GEE_downloadableTiles_2d")
tiles_1d = ee.FeatureCollection("users/ekhiroteta/BAMT/BAMT_GEE_downloadableTiles_1d")
tiles_05d = ee.FeatureCollection("users/ekhiroteta/BAMT/BAMT_GEE_downloadableTiles_05d")
tiles_025d = ee.FeatureCollection("users/ekhiroteta/BAMT/BAMT_GEE_downloadableTiles_025d")
sahel = ee.FeatureCollection("users/aminkhairoun/Landsat_BA/Regions/Sahel")
amazonia = ee.FeatureCollection("users/aminkhairoun/Landsat_BA/Regions/Amazonia")
siberia_HRLC = ee.FeatureCollection("users/aminkhairoun/Landsat_BA/Regions/Siberia_HRLC")
siberia = ee.FeatureCollection("users/aminkhairoun/Landsat_BA/Regions/Siberia")



# BAMT algorithm processing

## Cloud masks

In [None]:
def mask_s2qa60(image):
    date = ee.Number.parse(ee.Date(image.get('system:time_start')).format('yyyyDDD'))
    QABand = image.select('QA60')
    B1Band = image.select('B1')
    mask = QABand.bitwiseAnd(ee.Number(2).pow(10).int()) \
    .Or(QABand.bitwiseAnd(ee.Number(2).pow(11).int())) \
    .Or(B1Band.gt(1500))
    image = image.select(['B2', 'B3', 'B4', 'B8A', 'B11', 'B12']) \
    .rename(['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2'])
    return image.updateMask(mask.eq(0)).set('date', date)

def mask_s2cor(image):
    date = ee.Number.parse(ee.Date(image.get('system:time_start')).format('yyyyDDD'))
    cloudProb = image.select('MSK_CLDPRB')
    snowProb = image.select('MSK_SNWPRB')
    cloud = cloudProb.lt(5)
    snow = snowProb.lt(5)
    scl = image.select('')
    shadow = scl.eq(3) # 3 = SCL cloud shadow
    cirrus = scl.eq(10) # 10 = cirrus
    # Cloud probability less than 5% or cloud shadow classification
    mask = (cloud.And(snow)).And(cirrus.neq(1)).And(shadow.neq(1))
    image = image.select(['B2', 'B3', 'B4', 'B8A', 'B11', 'B12']) \
    .rename(['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2'])
    return image.updateMask(mask).set('date', date)

def mask_s2cloudless(image):
    date = ee.Number.parse(ee.Date(image.get('system:time_start')).format('yyyyDDD'))
    cloudProb = image.select('probability')
    # Cloud probability less than 65% or cloud shadow classification
    mask = cloudProb.lt(65)
    image = image.select(['B2', 'B3', 'B4', 'B8A', 'B11', 'B12', 'probability']) \
    .rename(['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2', 'Clouds'])
    return image.updateMask(mask).set('date', date)

def mask_landsat(image):
    date = ee.Number.parse(ee.Date(image.get('system:time_start')).format('yyyyDDD'))
    mask = image.select('QA_PIXEL').bitwiseAnd(ee.Number(2).pow(3).int()).eq(0) \
    .And(image.select('QA_PIXEL').bitwiseAnd(ee.Number(2).pow(4).int()).eq(0)) \
    .And(image.select('QA_PIXEL').bitwiseAnd(ee.Number(2).pow(2).int()).eq(0)) \
    .And(image.select('QA_PIXEL').bitwiseAnd(ee.Number(2).pow(5).int()).eq(0)) \

    satellite = ee.String(image.get('SPACECRAFT_ID'))
    image = ee.Image(ee.Algorithms.If(
    satellite.compareTo('LANDSAT_4').eq(0) \
      .Or(satellite.compareTo('LANDSAT_5').eq(0)) \
      .Or(satellite.compareTo('LANDSAT_7').eq(0)),
    image.select(['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B7']) \
      .multiply(0.0000275).add(-0.2).multiply(10000) \
      .rename(['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2']),
    image.select(['SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7']) \
      .multiply(0.0000275).add(-0.2).multiply(10000) \
      .rename(['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2'])
    ))
    image = image.updateMask(mask.eq(1)).set('date', date) \
                .updateMask(image.select('Red').gt(50).And(image.select('NIR').gt(50)))
                ## invalid values generally seems to be classified burned
    if zone == 'Siberia':
        ## additional mask for clouds missed by CFMask
        image = image.updateMask(image.select('SWIR2').lt(3000).And(image.select('NIR').lt(3300)))
    elif zone == 'Amazonia' or zone == 'Iberia':
        image = image.updateMask(image.select('SWIR2').lt(3000).Or(image.select('NIR').lt(3300)))
    elif zone == 'Sahel':
        image = image.updateMask(image.select('SWIR2').lt(5000).Or(image.select('NIR').lt(5000)))               
    return image

def filterMonthsMCD64(month, parameters):
    parameters = ee.List(parameters)
    date_pre = ee.Date(parameters.get(1))
    date_post = ee.Date(parameters.get(2))

    month_date = ee.Date(month.get('system:time_start'))
    month = ee.Image(ee.Algorithms.If(
            month_date.difference(date_pre.update(**{'day':1}), 'day').eq(0),
            month.updateMask(month.gte(ee.Number.parse(date_pre.format('DDD')))),
            ee.Algorithms.If(
              month_date.difference(date_post.update(**{'day':1}), 'day').eq(0),
              month.updateMask(month.gt(0).And(month.lt(ee.Number.parse(date_post.format('DD'))))),
              month
              )
            ))
    return month 

def filterDatesMCD64(date_pre, date_post):
    date_pre = ee.Date(date_pre)
    date_post = ee.Date(date_post)
    mcd64 = ee.ImageCollection('MODIS/006/MCD64A1') \
    .filterDate(date_pre.update(**{'day':1}), date_post) \
    .select('BurnDate')

    parameters = ee.List([ee.List([]), date_pre, date_post])
    return mcd64.map(lambda month: filterMonthsMCD64(month, parameters))

def filterDatesFireCCI51(date_pre, date_post):
    date_pre = ee.Date(date_pre)
    date_post = ee.Date(date_post)
    fcci = ee.ImageCollection('ESA/CCI/FireCCI/5_1') \
    .filterDate(date_pre.update(**{'day':1}), date_post) \
    .select('BurnDate')

    parameters = ee.List([ee.List([]), date_pre, date_post])
    return fcci.map(lambda month: filterMonthsMCD64(month, parameters))

## Processing functions

In [None]:
def locate_tiles(studyArea, list_grids, list_tile_errors):
    UMLError_tiles_1d = list_tile_errors[0]
    UMLError_tiles_05d = list_tile_errors[1]
    UMLError_tiles_025d = list_tile_errors[2]
    
    list_tiles_orig = tiles_2d.filterBounds(studyArea)
    list_tiles = tiles_2d.filterBounds(studyArea)

    for i in range(len(UMLError_tiles_1d)):
        tile = UMLError_tiles_1d[i]
        list_tiles = list_tiles \
          .filter(ee.Filter.neq('TILE', tile)) \
          .merge(tiles_1d.filter(ee.Filter.stringStartsWith('TILE', tile)))

    for i in range(len(UMLError_tiles_05d)):
        tile = UMLError_tiles_05d[i]
        list_tiles = list_tiles \
          .filter(ee.Filter.neq('TILE', tile)) \
          .merge(tiles_05d.filter(ee.Filter.stringStartsWith('TILE', tile)))

    for i in range(len(UMLError_tiles_025d)):
        tile = UMLError_tiles_025d[i]
        list_tiles = list_tiles \
          .filter(ee.Filter.neq('TILE', tile)) \
          .merge(tiles_025d.filter(ee.Filter.stringStartsWith('TILE', tile)))
    
    return list_tiles

def split_tile(tile, grid_out):
    list_tiles = grid_out.filter(ee.Filter.stringStartsWith('TILE', tile))
    return list_tiles

def define_parameters(dataset, zone):
    nameBase = 'BAMT_BA_' + zone + '_'
    if dataset == 'Landsat':
        nameBase = nameBase + 'Lndst'
        pixelSize = 30
    elif dataset == 'Sentinel2_SR':
        nameBase = nameBase + 'S2MSISR'
        pixelSize = 20
    elif dataset == 'Sentinel2_TOA':
        nameBase = nameBase + 'S2MSITOA'
        pixelSize = 20
    return ee.Dictionary({'nameBase': nameBase, 'pixelSize': pixelSize})    

def get_indices(image):
    nbr2 = image.normalizedDifference(['SWIR2', 'SWIR1']).multiply(10000).int16()
    nbr = image.normalizedDifference(['SWIR2', 'NIR']).multiply(10000).int16()
    ndvi = image.normalizedDifference(['Red', 'NIR']).multiply(10000).int16()
    return image.int16().addBands([nbr2.rename(['NBR2']), nbr.rename(['NBR']), ndvi.rename(['NDVI'])])

def get_dates(image):
    dates = image.metadata('date')
    return image.addBands(dates.rename('dates').updateMask(image.select('NBR').gt(-10000)))
    
def generate_images(studyArea, dataset, zone, date_1, date_2, date_2_2, date_3, visualize, 
                    include_dates, include_biomes):
    parameters = define_parameters(dataset, zone)
    nameBase = ee.String(parameters.get('nameBase'))
    pixelSize = ee.Number(parameters.get('pixelSize'))
    if dataset == 'Landsat':
        pre_image = ee.ImageCollection('LANDSAT/LT04/C02/T1_L2').filterBounds(studyArea).filterDate(date_1, date_2) \
          .merge(ee.ImageCollection('LANDSAT/LT05/C02/T1_L2').filterBounds(studyArea).filterDate(date_1, date_2)) \
          .merge(ee.ImageCollection('LANDSAT/LC08/C02/T1_L2').filterBounds(studyArea).filterDate(date_1, date_2)) \
          .merge(ee.ImageCollection('LANDSAT/LC09/C02/T1_L2').filterBounds(studyArea).filterDate(date_1, date_2)) \
          .merge(ee.ImageCollection('LANDSAT/LE07/C02/T1_L2').filterBounds(studyArea).filterDate(date_1, date_2)) 
        pre_image = pre_image.map(mask_landsat)
        
        post_image = ee.ImageCollection('LANDSAT/LT04/C02/T1_L2').filterBounds(studyArea).filterDate(date_2_2, date_3) \
          .merge(ee.ImageCollection('LANDSAT/LT05/C02/T1_L2').filterBounds(studyArea).filterDate(date_2_2, date_3)) \
          .merge(ee.ImageCollection('LANDSAT/LC08/C02/T1_L2').filterBounds(studyArea).filterDate(date_2_2, date_3)) \
          .merge(ee.ImageCollection('LANDSAT/LC09/C02/T1_L2').filterBounds(studyArea).filterDate(date_2_2, date_3)) \
          .merge(ee.ImageCollection('LANDSAT/LE07/C02/T1_L2').filterBounds(studyArea).filterDate(date_2_2, date_3))
        post_image = post_image.map(mask_landsat)
        
    elif dataset == 'Sentinel2_SR':
        pre_image = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED').filterBounds(studyArea).filterDate(date_1, date_2).map(mask_s2cor)
        post_image = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED').filterBounds(studyArea).filterDate(date_2_2, date_3).map(mask_s2cor)

    elif dataset == 'Sentinel2_TOA':
        pre_image = ee.ImageCollection('COPERNICUS/S2_HARMONIZED').filterBounds(studyArea).filterDate(date_1, date_2).map(mask_s2cloudless)
        post_image = ee.ImageCollection('COPERNICUS/S2_HARMONIZED').filterBounds(studyArea).filterDate(date_2_2, date_3).map(mask_s2cloudless)
 
    pre_image = pre_image.map(get_indices).map(get_dates)
    post_image = post_image.map(get_indices).map(get_dates)
    # print(f"post_image: {post_image.first().getInfo()}")
    # print(f"pre_image: {pre_image.first().getInfo()}")
    
    if zone == 'Sahel':
        mosaic_pre = 'dates'
        date_1_ref = '2018' + date_1[4:]
        date_2_ref = '2019' + date_2[4:]
        pre_image_ref = ee.ImageCollection('LANDSAT/LT04/C02/T1_L2').filterBounds(studyArea).filterDate(date_1_ref, date_2_ref) \
            .merge(ee.ImageCollection('LANDSAT/LT05/C02/T1_L2').filterBounds(studyArea).filterDate(date_1_ref, date_2_ref)) \
            .merge(ee.ImageCollection('LANDSAT/LE07/C02/T1_L2').filterBounds(studyArea).filterDate(date_1_ref, date_2_ref)) \
            .merge(ee.ImageCollection('LANDSAT/LC08/C02/T1_L2').filterBounds(studyArea).filterDate(date_1_ref, date_2_ref)) \
            .merge(ee.ImageCollection('LANDSAT/LC09/C02/T1_L2').filterBounds(studyArea).filterDate(date_1_ref, date_2_ref)) \
            .filter(ee.Filter.calendarRange(ee.Number.parse(date_1[5:7]), ee.Number.parse(date_2[5:7]), 'month')) \
            .map(mask_landsat).map(get_indices).map(get_dates)
        
        # print(f"pre_image_ref: {pre_image_ref.first().getInfo()}")
        pre_image_ref = pre_image_ref.qualityMosaic('dates') 
        # pre_add_col = post_image

    else:
        mosaic_pre = 'NBR'
            
    if (date_3[:4] in ['2001', '2002']) & (dataset == 'Landsat'):
        print(f'Processed year is {date_3[:4]}, mask L5 edge anomaly is activated')
        pre_image = pre_image.map(lambda im: im.updateMask(im.select('NBR2').lt(3000))).qualityMosaic(mosaic_pre) 
        post_image = post_image.map(lambda im: im.updateMask(im.select('NBR2').lt(3000))).qualityMosaic('NBR') 
    else:    
        pre_image = pre_image.qualityMosaic(mosaic_pre) 
        post_image = post_image.qualityMosaic('NBR')         
    date_image = post_image.select('dates').int()
    
    if zone == 'Sahel':
        # pre_image = pre_add_col.map(lambda im: im.updateMask(im.select('NBR').lt(pre_image.select('NBR')) \
        #     .And(im.select('dates').lt(date_image.subtract(16))))) \
        #     .merge(ee.ImageCollection([pre_image])).qualityMosaic('dates')
        
        pre_image_ref = pre_image_ref.select(pre_image_ref.bandNames().remove('dates')) \
            .addBands(ee.Image.constant(ee.Number.parse(ee.Date(date_1).format('yyyyDDD'))) \
                  .mask(pre_image_ref.select('NBR').gt(-10000)).rename('dates'))
        pre_image_ref = pre_image_ref.updateMask(pre_image_ref.select('NBR').gte(pre_image.select('NBR')) \
            .Or(pre_image.select('dates').gt(ee.Number.parse(ee.Date(date_2_2).advance(-90, 'day').format('yyyyDDD')))).unmask(0).eq(0));
        
        pre_image = pre_image_ref.unmask(pre_image)
        
    elif (zone == 'Siberia') and (date_2_2[:4] == '1999'):
        print(f'Expanding the pre composite of {date_2_2[:4]} by data from 2019')
        date_1_ref = '2019' + date_1[4:]
        date_2_ref = '2019' + date_2[4:]
        pre_image_ref = ee.ImageCollection('LANDSAT/LT04/C02/T1_L2').filterBounds(studyArea).filterDate(date_1_ref, date_2_ref) \
            .merge(ee.ImageCollection('LANDSAT/LT05/C02/T1_L2').filterBounds(studyArea).filterDate(date_1_ref, date_2_ref)) \
            .merge(ee.ImageCollection('LANDSAT/LE07/C02/T1_L2').filterBounds(studyArea).filterDate(date_1_ref, date_2_ref)) \
            .merge(ee.ImageCollection('LANDSAT/LC08/C02/T1_L2').filterBounds(studyArea).filterDate(date_1_ref, date_2_ref)) \
            .merge(ee.ImageCollection('LANDSAT/LC09/C02/T1_L2').filterBounds(studyArea).filterDate(date_1_ref, date_2_ref)) \
            .map(mask_landsat).map(get_indices).map(get_dates)
        
        pre_image_ref = pre_image_ref.qualityMosaic('NBR').select(pre_image_ref.first().bandNames().remove('dates')) \
            .addBands(ee.Image.constant(ee.Number.parse(ee.Date(date_1).format('yyyyDDD'))).rename('dates'))
        pre_image = pre_image.unmask(pre_image_ref)

    band_list = ['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2', 'NBR2', 'NBR', 'NDVI']
    diff_image = ee.Image()
    list_length = len(band_list)
    for i in range(list_length):
        band = band_list[i]
        diff_image = diff_image.addBands(post_image.select(band).subtract(pre_image.select(band)).int16().rename(band + '_diff'))
        diff_image = diff_image.addBands(post_image.select(band).rename(band + '_post'))
        diff_image = diff_image.addBands(pre_image.select(band).rename(band + '_pre'))

    diff_image = diff_image.select('.*_.*')
    
    if include_dates:
        pre_dates = pre_image.select('dates').mod(1000) \
            .add(pre_image.select('dates').divide(1000).int16() \
              .subtract(ee.Number.parse(date_2[0:4])).multiply(365))
        post_dates = post_image.select('dates').mod(1000) \
            .add(post_image.select('dates').divide(1000).int16() \
              .subtract(ee.Number.parse(date_2[0:4])).multiply(365))
        diff_dates = post_dates.subtract(pre_dates).rename('Dates_diff').int16()
        diff_image = diff_image.addBands(diff_dates)
    if include_biomes:
        biomes = ee.Image(f"users/fireccihrba/BAMT/Ancillary/Biomes_{zone}")
        if biomes.projection().wkt().getInfo() != diff_image.projection().wkt().getInfo():
            print(diff_image.projection().getInfo()['crs'])
            biomes = biomes.reproject(
                crs=diff_image.projection().getInfo()['crs'],
                scale=pixelSize
            )
        print("biomes projection:", biomes.projection().getInfo())
        biomes = biomes.rename('biome')
        diff_image = diff_image.addBands(biomes)
        
    if visualize:
        post_image_rgb = post_image.select(['NBR2', 'NIR', 'Red'])
        pre_image_rgb = pre_image.select(['NBR2', 'NIR', 'Red'])
        Map.addLayer(studyArea, {'color': 'red'}, 'Study Area')
        Map.addLayer(pre_image_rgb, {'min': [-5000, 1000, 0], 'max': [500, 3000, 2000]}, 'Pre-fire', False)
        Map.addLayer(post_image_rgb,{'min': [-5000, 1000, 0], 'max': [500, 3000, 2000]}, 'Post-fire', False)
        Map.addLayer(post_image_rgb.subtract(pre_image_rgb), 
                     {'min': [30, -100, -100], 'max': [100, 50, 150]}, 'Difference', False)
        Map.addLayer(pre_image.select(['SWIR2', 'NIR', 'Red']),
          {'min': 10/(0.0001*255*2), 'max': 160/(0.0001*255*2), 'gamma': 1}, 'Pre SWIR2', True)
        Map.addLayer(post_image.select(['SWIR2', 'NIR', 'Red']),
          {'min': 10/(0.0001*255*2), 'max': 160/(0.0001*255*2), 'gamma': 1}, 'Post SWIR2', True)
        Map.addLayer(pre_image.select('dates'), {
#         // 'min': ee.Number.parse(date_1[:4]).add(200).getInfo(), 
#         // 'max': ee.Number.parse(date_1[:4]).add(350).getInfo(),
        'min':2006240, 'max':2006350, 
        'palette': ['yellow', 'orange', 'red']}, 'pre dates', False)
        Map.addLayer(post_image.select('dates'), {
#         // 'min': ee.Number.parse(date_1[:4]).add(200).getInfo(), 
#         // 'max': ee.Number.parse(date_1[:4]).add(350).getInfo(),
        'min':2006240, 'max':2006350, 
        'palette': ['yellow', 'orange', 'red']}, 'post dates', False)
        Map.addLayer(diff_image.select('Dates_diff'), {min:0, max:100}, 'diff dates', False)

        # MCD64
        if (int(date_3[0:4])>=2000):
            mcd64 = filterDatesMCD64(date_2_2, date_3)
            Map.addLayer(mcd64, {'min':0, 'max':1, 'palette':['#000000', '#ff0000']}, 'MCD64A1', False)
            print(ee.ImageCollection(mcd64).first().getInfo())

        # FireCCI51
        if (int(date_3[0:4])>=2001):
            fcci = filterDatesFireCCI51(date_2_2, date_3)
            Map.addLayer(fcci, {'min':0, 'max':1, 'palette':['#000000', '#ff0000']}, 'FireCCI51', False)

        # Active fires
        fires = ee.ImageCollection('FIRMS').filter(
        ee.Filter.date(date_2_2, date_3))
        fires = fires.select('T21')
        firesVis = {
        'min': 325.0,
        'max': 400.0,
        'palette': ['red', 'orange', 'yellow'],
        }
        Map.addLayer(fires, firesVis, 'Fires', False)
        
    return ee.Dictionary({'diff_image': diff_image, 'date_image': date_image, 
                          'nameBase': nameBase, 'pixelSize': pixelSize})

def process_classification(studyArea, dataset, zone, joined_years, date_1, date_2, date_2_2, 
               date_3, assetPath, transfer_tr, status='predicting', save_model=False, ref_year=2019, 
               visualize=False, include_dates=False, include_biomes=False, anomaly=None):
    
    input_dict = generate_images(studyArea, dataset, zone, date_1, date_2, 
                     date_2_2, date_3, visualize, include_dates, include_biomes)
    diff_image = input_dict.get('diff_image')
    nameBase, pixelSize = input_dict.get('nameBase'), input_dict.get('pixelSize')
    diff_mask = ee.Image(diff_image).select('NIR_diff').mask().eq(0)

    if status == 'training':
        classification = train(studyArea, dataset, zone, date_2_2, date_3, ee.Image(diff_image), 
                       pixelSize, assetPath, transfer_tr, save_model, joined_years, anomaly)
    elif status == 'predicting':
        classification = pred_pretrained(dataset, zone, joined_years, date_2_2, date_3, 
                                         ee.Image(diff_image), assetPath, ref_year, anomaly)
    else:
        print('Error: status has two options: "training" or "predicting"')
    classification = classification.set('date_image', input_dict.get('date_image')) \
                        .set('diff_image', diff_image) \
                        .set('nameBase', nameBase).set('pixelSize', pixelSize)
    return classification

def check_geometries(studyArea, burned, unburned):
    if type(studyArea) == ee.geometry.Geometry:
        if (studyArea.coordinates().size().getInfo()==0):
            print('Please define a study area')
    else:
        burned_tr = ee.FeatureCollection(ee.List(burned.geometries()).map(
          lambda geometry: ee.Feature(ee.Geometry(geometry)).set('class', 1))).filterBounds(studyArea)
        unburned = ee.Geometry.MultiPolygon(
                unburned.getInfo()['coordinates'] + unburnable.getInfo()['coordinates'])
        unburned_tr = ee.FeatureCollection(ee.List(unburned.geometries()).map(
          lambda geometry: ee.Feature(ee.Geometry(geometry)).set('class', 0))).filterBounds(studyArea)
        if (burned_tr.size().getInfo()==0):
            print('Please define some burned polygon(s)')
        elif (unburned_tr.size().getInfo()==0):
            print('Please define some unburned polygon(s)')
    return burned_tr, unburned_tr

def define_prepost(joined_years):
    if joined_years == 'None':
        date_pre = str(ref_year) + date_2_2[5:7] + date_2_2[8:10]
        date_post = str(ref_year) + date_3[5:7] + date_3[8:10]
    elif joined_years == 'Previous':
        date_pre = str(ref_year-1) + date_2_2[5:7] + date_2_2[8:10]
        date_post = str(ref_year) + date_3[5:7] + date_3[8:10]
    elif joined_years == 'Next':
        date_pre = str(ref_year) + date_2_2[5:7] + date_2_2[8:10]
        date_post = str(ref_year+1) + date_3[5:7] + date_3[8:10]
    return date_pre, date_post

def train(studyArea, dataset, zone, date_2_2, date_3, image, pixelSize, 
              assetPath, transfer_tr, save_model, joined_years, anomaly):
    if not transfer_tr:
        burned_tr, unburned_tr = check_geometries(studyArea, burned, unburned)
        training = image.sampleRegions(**{
        'collection': burned_tr.merge(unburned_tr),
        'properties': ['class'],
        'scale': pixelSize,
        'tileScale': 4
        })
    else:
        date_pre, date_post = define_prepost(joined_years)
        training = ee.FeatureCollection(assetPath + f'Training/Training_{zone}_{dataset}_{date_pre}_{date_post}')
        burned_tr = training.filter(ee.Filter.eq('class', 1))
        
    bands = image.bandNames()

    RF = ee.Classifier.smileRandomForest(**{'numberOfTrees':200, 'minLeafPopulation':10, 'maxNodes':450})
    RF = RF.setOutputMode('PROBABILITY')
    trained = RF.train(training, 'class', bands)
    classified = image.select(bands).classify(trained)
    if not transfer_tr:    
        thr_seed = ee.Number(classified.reduceRegions(
        reducer = ee.Reducer.mean(),
        collection = burned_tr,
        scale = pixelSize).aggregate_mean('mean'))
    else:
        thr_seed = burned_tr.select(bands).classify(trained).aggregate_mean('classification')
    print(thr_seed.getInfo())
    if save_model:
        date_pre = date_2_2[0:4] + date_2_2[5:7] + date_2_2[8:10]
        date_post = date_3[0:4] + date_3[5:7] + date_3[8:10]
        trees = ee.List(ee.Dictionary(trained.explain()).get('trees'))
        dummy = ee.Feature(None)
        col = ee.FeatureCollection(trees.map(lambda x: dummy.set('tree',x) \
          .set('thr_seed', thr_seed).setGeometry(ee.Geometry.Point([0, 0]))))
        naming = f'RF_{zone}_{dataset}_{date_pre}_{date_post}'
        if anomaly:
            naming = f'RF_{anomaly}_{zone}_{dataset}_{date_pre}_{date_post}'
        assetId = assetPath + 'Models/' + naming
        
        task = ee.batch.Export.table.toAsset(**{
          'collection': col,
          'description': naming,
          'assetId': assetId,
        })
        task.start()
    return ee.Dictionary({'classified': classified, 'trained': trained, 
                          'thr_seed': thr_seed})

def pred_pretrained(dataset, zone, joined_years, date_2_2, date_3, image, assetPath, ref_year, anomaly):
    date_pre, date_post = define_prepost(joined_years)   
    print(anomaly)
    naming = f'Models/RF_{zone}_{dataset}_{date_pre}_{date_post}'
    if anomaly:
        naming = f'Models/RF_{anomaly}_{zone}_{dataset}_{date_pre}_{date_post}'
    assetId = assetPath + naming
    featureCollection = ee.FeatureCollection(assetId)
    trees = featureCollection.aggregate_array('tree')
    classifier = ee.Classifier.decisionTreeEnsemble(trees).setOutputMode('raw')
    bands = image.bandNames()
    classified_ndarray = image.select(bands).classify(classifier)
    classified_1darray = classified_ndarray.arrayReduce(ee.Reducer.mean(), ee.List([0]))
    classified = ee.Image(classified_1darray.arrayProject([0]).arrayFlatten([['classification']]))
    thr_seed = featureCollection.aggregate_first('thr_seed')
    print(thr_seed.getInfo())
    return ee.Dictionary({'classified': classified, 'thr_seed': thr_seed})

def process_BA(BA_raster, diff_image, studyArea, zone, date_image, mask, 
               tile, thr_seed, BA_format, parameters, anomaly):
    BA_seed = BA_raster.gte(ee.Number(thr_seed))
    # BA_raster = BA_raster.gte(0.5)

    distance = BA_seed.focalMin(kernel=ee.Kernel.euclidean(20, units='meters'), iterations=1) \
        .distance(kernel=ee.Kernel.euclidean(2000, units='meters'), 
                  skipMasked=False).byte().rename('distance')
    
    if zone != "Siberia":
        if anomaly == "Inundations":
            dist_mask = distance.updateMask(BA_raster.gte(0.4) \
                    .Or(diff_image.select('NIR_diff').lte(-300) \
                      .And(diff_image.select('NBR2_post').gte(-600) \
                          .And(diff_image.select('NBR2_diff').gte(50)) \
                          .And(diff_image.select('Red_diff').gte(400)) \
                          .And(diff_image.select('NIR_post').gte(200)))))
        else:
            dist_mask = distance.updateMask(BA_raster.gte(0.4) \
                    .Or(diff_image.select('NIR_diff').lte(-300) \
                      .And(diff_image.select('NBR2_post').gte(-600) \
                          .And(diff_image.select('NBR2_diff').gte(50)) \
                          .And(diff_image.select('Red_diff').gte(400)))))
    else:                                        
        dist_mask = distance.updateMask(BA_raster.gte(0.4) \
                .Or(diff_image.select('NIR_diff').lte(-300) \
                  .And(diff_image.select('NBR2_post').gte(-600) \
                      .And(diff_image.select('NBR2_diff').gte(50)) \
                      .And(diff_image.select('Red_diff').gte(400)))) \
                .Or(diff_image.select('NBR2_post').gte(0) \
                  .And(diff_image.select('NBR2_diff').gte(1000)) \
                  .And(diff_image.select('Red_post').lte(300))))
    
    if BA_format == 'GeoTiff':
        BA_raster = BA_raster.gte(0.5).Or(dist_mask.gt(0))
        return BA_raster
    
    elif BA_format == 'SHP':
        BA_raster = BA_raster.gte(0.5).Or(dist_mask.unmask(0).gt(0))
        BA_vectors = BA_raster.addBands(BA_seed).reduceToVectors(**{
          'geometry': tile.geometry(),
          'crs': parameters['crs'], 
          'scale': parameters['pixelSize'], 
          'geometryType': 'polygon',
          'eightConnected': False,
          'reducer': ee.Reducer.sum(),
          'tileScale': 16,
          'maxPixels': 1e13,

        }) \
            .filter(ee.Filter.greaterThanOrEquals('label', 1)) \
            .filter(ee.Filter.greaterThanOrEquals('sum', 1)) \
            .filter(ee.Filter.contains(**{'leftValue': studyArea, 'rightField':'.geo'})) 

        return BA_vectors

## Visualization functions

In [None]:
def view_BA(studyArea, dataset, zone, joined_years, date_1, date_2, date_2_2, date_3, **kwargs):
    print(studyArea.getInfo())
    classification = process_classification(studyArea, dataset, zone, joined_years, 
                                            date_1, date_2, date_2_2, date_3, **kwargs)
    BA_raster = ee.Image(classification.get('classified'))
    diff_image = ee.Image(classification.get('diff_image'))
    dates = ee.Image(classification.get('date_image'))
    thr_seed = classification.get('thr_seed')
    BA_seed = BA_raster.gte(ee.Number(thr_seed))

    distance = BA_seed.focalMin(kernel=ee.Kernel.euclidean(20, units='meters'), iterations=1) \
        .distance(kernel=ee.Kernel.euclidean(2000, units='meters'), 
                  skipMasked=False).byte().rename('distance')
#     mask = BA_raster.gte(0.5)
    if zone != "Siberia":
        if kwargs['anomaly'] == "Inundations":
            dist_mask = distance.updateMask(BA_raster.gte(0.4) \
                    .bitwiseOr(diff_image.select('NIR_diff').lte(-300) \
                      .bitwiseAnd(diff_image.select('NBR2_post').gte(-600) \
                          .bitwiseAnd(diff_image.select('NBR2_diff').gte(50)) \
                          .bitwiseAnd(diff_image.select('Red_diff').gte(400)) \
                          .bitwiseAnd(diff_image.select('NIR_post').gte(200)))))
        else:
            dist_mask = distance.updateMask(BA_raster.gte(0.4) \
                    .bitwiseOr(diff_image.select('NIR_diff').lte(-300) \
                      .bitwiseAnd(diff_image.select('NBR2_post').gte(-600) \
                          .bitwiseAnd(diff_image.select('NBR2_diff').gte(50)) \
                          .bitwiseAnd(diff_image.select('Red_diff').gte(400)))))
    else:                                        
        dist_mask = distance.updateMask(BA_raster.gte(0.4) \
                .bitwiseOr(diff_image.select('NIR_diff').lte(-300) \
                  .bitwiseAnd(diff_image.select('NBR2_post').gte(-600) \
                      .bitwiseAnd(diff_image.select('NBR2_diff').gte(50)) \
                      .bitwiseAnd(diff_image.select('Red_diff').gte(400)))) \
                .bitwiseOr(diff_image.select('NBR2_post').gte(0) \
                  .bitwiseAnd(diff_image.select('NBR2_diff').gte(1000)) \
                  .bitwiseAnd(diff_image.select('Red_post').lte(300))))
    
    mask = BA_raster.gte(0.5).Or(dist_mask)
    # Map.addLayer(mask, {}, 'Mask BA')
    BA_raster = BA_raster.updateMask(mask)
    Map.centerObject(studyArea)
    Map.addLayer(ee.Image(classification.get('classified')), 
        {'min':0.5, 'max':1, 'palette':['#ffff00', '#ffffff', '#0000ff']}, 'Probability', False)
    Map.addLayer(BA_raster, {'min':0.5, 'max':1, 'palette':['#ffffff', '#0000ff']}, 'BA', False)
    Map.addLayer(BA_seed, {'min':0, 'max':1, 'palette':['#ffffff', '#ff0000']}, 'Seeds', False)
    Map.addLayer(dates, {'min':1, 'max':366, 'palette':['#ffff00', '#ff0000']}, 'Dates', False)
    

## Export functions

In [None]:
import oauth2client
import httplib2 
from oauth2client import client, tools, file
from httplib2 import Http
from apiclient.discovery import build
import io

def avoid_overwrite(filename, case='Drive', folderId=None, local_path=None, user='firecci', 
                    cred_path=None, pattern='', export_folder=None):
    pattern = re.compile(pattern)
    if case == 'Drive':
        if not cred_path:
            cred_path = os.getcwd()
        credentials_file_path = f'{cred_path}/credentials_{user}.json'
        clientsecret_file_path = f'{cred_path}/client_secret_{user}.json'

        # define API scope
        SCOPE = 'https://www.googleapis.com/auth/drive'
        # define store
        store = oauth2client.file.Storage(credentials_file_path)
        credentials = store.get()
        # get access token
        if not credentials or credentials.invalid:
            flow = client.flow_from_clientsecrets(clientsecret_file_path, SCOPE)
            credentials = tools.run_flow(flow, store)

    #     print(credentials.client_id)
        http = httplib2.Http()
        http = credentials.authorize(http)
        service = build('drive', 'v3', http=http)
        folder_id = folderId
        query = f"parents = '{folder_id}'"
        response = service.files().list(corpora='allDrives', q=query, orderBy='name',
                                includeItemsFromAllDrives=True, supportsAllDrives=True).execute()
        files = response.get('files')
        nextPageToken = response.get('nextPageToken')

        while nextPageToken:
            response = service.files().list(corpora='allDrives', q=query, orderBy='name', pageToken=nextPageToken,
                                includeItemsFromAllDrives=True, supportsAllDrives=True).execute()
            files.extend(response.get('files'))
            nextPageToken = response.get('nextPageToken')

        exists = False
        for f in files:
            name_drive = f.get('name')
            if pattern.search(name_drive):
                if filename in name_drive:
                    print('already exists') 
                    exists = True
                    break 
    
    elif case == 'Local':
        files = glob.glob(f'{local_path}/*', recursive=True)
        files = [os.path.basename(os.path.splitext(f)[0]) for f in files if pattern.search(f)]
        exists = False
        if filename in files:
            print('already exists') 
            exists = True
            
    return exists        

In [83]:
def download_BA(studyArea, dataset, zone, joined_years, list_grids, date_1, date_2, 
                date_2_2, date_3, kwargs, kwargs_class, kwargs_drive):
    
    if 'TILE' in str(studyArea.getInfo()):
        list_tiles = studyArea
        studyArea = studyArea.geometry()
        print('Tile study area is used')
    else:
        ## 6000 m ~ 0.15 deg in Lat 0, and ~ 0.5 deg in Lat 75N/S
        list_tiles = locate_tiles(studyArea.dissolve().buffer(-15000), list_grids, list_tile_errors)
        print('Zone study area is used')
    
    classification = process_classification(studyArea, dataset, zone, joined_years, 
                                            date_1, date_2, date_2_2, date_3, **kwargs_class)
    classified = ee.Image(classification.get('classified'))
    date_image = ee.Image(classification.get('date_image'))
    diff_image = ee.Image(classification.get('diff_image'))
    mask = diff_image.select('NIR_diff').mask().eq(0)
    thr_seed = classification.get('thr_seed')
     
    number_tiles = list_tiles.size().getInfo()
    parameters = define_parameters(dataset, zone)
    nameBase = ee.String(parameters.get('nameBase'))
    pixelSize = parameters.get('pixelSize').getInfo()
    tasks_BA, tasks_prob, tasks_dates = {}, {}, {}

    for i in range(number_tiles):
        tile = ee.Feature(list_tiles.toList(1, i).get(0)) 
        tile = tile.setGeometry(tile.geometry().intersection(studyArea))
        tilename = tile.get('TILE').getInfo()
        print(f'initialize {tilename}')
        drop = kwargs['drop']
        if np.array([re.compile(i).search(tilename) for i in drop]).any():
            print(f'{tilename} is dropped')
            continue
            
        if kwargs['crs'] == 'UTM':
            ## Projected to UTM zones
            parameters = {'crs':tile.get('PROJ').getInfo(), 
                      'pixelSize':pixelSize}
        elif kwargs['crs'] == 'WGS':
            ## Projected to WGS84
            parameters = {'crs':'EPSG:4326', 
                          'pixelSize':pixelSize}
        date_pre = date_2_2[0:4] + date_2_2[5:7] + date_2_2[8:10]
        date_post = date_3[0:4] + date_3[5:7] + date_3[8:10]
        name = nameBase.getInfo() + '_' + date_pre + '-' + date_post + tile.get('TILE').getInfo()
        if '_TILE' not in name:
            name = nameBase.getInfo() + '_' + date_pre + '-' + date_post + '_TILE-' + tile.get('TILE').getInfo()
        exp_prob, exp_BA, exp_dates = kwargs['exp_prob'], kwargs['exp_BA'], kwargs['exp_dates']
        BA_format = kwargs['BA_format']
        export_folder = kwargs_drive['export_folder']
        anomaly = kwargs_class['anomaly']
        if exp_prob:
            status = False
            if kwargs['no_overwrite']:
                status = avoid_overwrite( name + '_PROB', case=kwargs['no_overwrite'], **kwargs_drive)
            if not (status and kwargs['no_overwrite']):
                print(f'Probability export of {tilename}')
                BA_raster = classified.clip(tile.geometry()) \
                            .unmask(0).add(mask.multiply(2)).multiply(100).round().byte()
                tasks_prob[i] = ee.batch.Export.image.toDrive(**{'image': BA_raster, 
                                      'description': name + '_PROB', 
                                      'folder': export_folder, 
                                      'crs': parameters['crs'], 
                                      'scale': parameters['pixelSize'],
                                      'region': tile.geometry(), 
                                      'maxPixels': 1e13})
                tasks_prob[i].start()
        ## to create the new folder
        time.sleep(3)
        
        if exp_dates:
            status = False
            if kwargs['no_overwrite']:
                status = avoid_overwrite(name + '_DATES', case=kwargs['no_overwrite'], **kwargs_drive)
            if not (status and kwargs['no_overwrite']):        
                print(f'Dates export of {tilename}')
                dates_image = date_image.clip(tile.geometry())
                tasks_dates[i] = ee.batch.Export.image.toDrive(**{'image': dates_image, 
                                      'description': name + '_DATES', 
                                      'folder': export_folder, 
                                      'crs': parameters['crs'], 
                                      'scale': parameters['pixelSize'],
                                      'region': tile.geometry(), 
                                      'maxPixels': 1e13})
                tasks_dates[i].start()

        if exp_BA:
            status = False
            if kwargs['no_overwrite']:
                status = avoid_overwrite( name + '_SHP', case=kwargs['no_overwrite'], **kwargs_drive)
            if not (status and kwargs['no_overwrite']):
                print(f'BA export of {tilename} as {BA_format}')
                if BA_format == 'SHP':
                    BA_raster = classified.clip(tile.geometry())
                    BA_vectors = process_BA(BA_raster, diff_image, studyArea, zone, date_image, 
                                            mask, tile, thr_seed, BA_format, parameters, anomaly)
                    tasks_BA[i] = ee.batch.Export.table.toDrive(**{'collection': BA_vectors, 
                                                   'description': name + '_SHP', 
                                                   'folder': export_folder, 
                                                   'fileFormat': 'SHP',
                                                                  })
                elif BA_format == 'GeoTiff':
                    BA_raster = classified.clip(tile.geometry())
                    BA_raster = process_BA(BA_raster, diff_image, studyArea, zone, date_image, 
                                            mask, tile, thr_seed, BA_format, parameters, anomaly)
                    tasks_BA[i] = ee.batch.Export.image.toDrive(**{'image': BA_raster, 
                                          'description': name + '_BA', 
                                          'folder': export_folder, 
                                          'crs': parameters['crs'], 
                                          'scale': parameters['pixelSize'],
                                          'region': tile.geometry(), 
                                          'maxPixels': 1e13})                    
                tasks_BA[i].start()


# Execution of BAMT Algorithm

## Initializing variables 
<div class="alert alert-block alert-info">
    <p><b style="font-size:17px">Main input variables: <br></b>
        <ul style="color: black; font-size:17px">
            <li>Three years are to be specified: <b>ref_year</b> is the year in which 
                the pre-trained model has been generated, while <b>start_year</b> 
                and <b>end_year</b> delimitate the window of processing. If only, one year 
                is to be processed the they will have the same value as in the example below.</li>
            <li>Please define these three dates for the two periods. The pre-fire composite 
                image will be produced from data between <b>date_1</b> (inclusive) and 
                <b>date_2</b> (exclusive), the post-fire image with data from <b>date_2_2</b> 
                (inclusive) and <b>date_3</b> (exclusive).</li>
            <li>In the <b>studyArea</b> you need to choose you area of interest (e.g. sahel).</li>
            <li> Choose the dataset to be used, for which accepted values are:
                <ul style="list-style-type:square;">
                    <li>'Landsat': all available images from Landsat-4 TM, Landsat-5 TM, 
                        Landsat-7 ETM+, Landsat-8 OLI and Landsat-9 OLI-2 sensors.</li>
                    <li>'Sentinel2_SR': all available images from the Sentinel-2 MSI sensor at
                        level 2A.</li>
                    <li>'Sentinel2_TOA': all available images from the Sentinel-2 MSI sensor at
                        level 1C, where cloud mask uses S2CLOUDLESS classification.</li>
                </ul>
            <li>The <b>identifier</b> the processed zone naming.</li>
            <li>The <b>assetPath</b> is the path of your model assets are saved.</li>
            <li> Choose the coordinate system <b>crs</b> to be used for outputs. The accepted values are:
                <ul style="list-style-type:square;">
                    <li>'WGS': WGS84 (EPSG:4326).</li>
                    <li>'UTM': In case you want to project to one of the regions of UTM. 
                    The algorithm projects to the corresponding region automatically.</li>
                </ul>
            <li>Please list here the tiles that could not be downloaded because 'User memory 
                limit exceeded'. There are three levels of tiles splitting: 2d tiles that lead to 
                error are listed in <b>UMLError_tiles_1d</b> the if you need deeper splitting 
                you can use <b>UMLError_tiles_05d</b> or even <b>UMLError_tiles_025d</b></li>
        </ul> 
</div>    

<div class="alert alert-block alert-warning" >
<p><b style="color: black; font-size:17px">If you include <span style="color:red;">UMLError_tiles</span> 
    here, it means that all the processing will be executed and the concerned tile will be 
    splitted. If you are interested in repeataing only the broken tiles, use the 3rd 
    sub-section of export section. <br></b>
        
</div>
    

In [None]:
def define_params(year, zone, dataset, period=None):
    if zone == 'Amazonia':
        studyArea = amazonia.geometry()
        if period == 'Period 0':  
            ## Yearly
            date_1 = f'{year-1}-04-01'
            date_2 = f'{year}-04-01'
            date_2_2 = f'{year}-04-01'
            date_3 = f'{year+1}-04-01'
            joined_years = 'Next'
            burned = GeocolToMultipol(Training_geometries_Amazonia_20040401_20050401.burned)
            unburned = GeocolToMultipol(Training_geometries_Amazonia_20040401_20050401.unburned)
            unburnable = GeocolToMultipol(Training_geometries_Amazonia_20040401_20050401.unburnable)
        elif period == 'Period 1': 
            ## Period 1
            date_1 = f'{year-1}-10-01'
            date_2 = f'{year}-04-01'
            date_2_2 = f'{year}-04-01'
            date_3 = f'{year}-10-01'
            joined_years = 'None'
            burned = GeocolToMultipol(Training_geometries_Amazonia_20190401_20191001.burned)
            unburned = GeocolToMultipol(Training_geometries_Amazonia_20190401_20191001.unburned)
            unburnable = GeocolToMultipol(Training_geometries_Amazonia_20190401_20191001.unburnable)
        elif period == 'Period 2':
            ## Period 2
            date_1 = f'{year-1}-04-01'
            date_2 = f'{year-1}-10-01'
            date_2_2 = f'{year-1}-10-01'
            date_3 = f'{year}-04-01'
            joined_years = 'Previous'
            burned = GeocolToMultipol(Training_geometries_Amazonia_20181001_20190401.burned)
            unburned = GeocolToMultipol(Training_geometries_Amazonia_20181001_20190401.unburned)
            unburnable = GeocolToMultipol(Training_geometries_Amazonia_20181001_20190401.unburnable)
    
    elif zone == 'Siberia':
        studyArea = siberia.geometry()
        ## Yearly
        date_1 = f'{year-1}-03-01'
        date_2 = f'{year-1}-12-01'
        date_2_2 = f'{year}-03-01'
        date_3 = f'{year}-12-01'
        joined_years = 'None' 
        burned = GeocolToMultipol(Training_geometries_Siberia.burned)
        unburned = GeocolToMultipol(Training_geometries_Siberia.unburned)
        unburnable = GeocolToMultipol(Training_geometries_Siberia.unburnable)  
        
    elif zone == 'Sahel':
        studyArea = sahel.geometry()
        if period == 'Period 1':
            ## Period 1
            date_1 = f'{year-1}-04-01'
            date_2 = f'{year}-01-01'
            date_2_2 = f'{year}-01-01'
            date_3 = f'{year}-04-01'
            joined_years = 'None'
            burned = None
            unburned = None
            unburnable = None
        elif period == 'Period 2': 
            ## Period 2
            date_1 = f'{year-1}-10-01'
            date_2 = f'{year}-04-01'
            date_2_2 = f'{year}-04-01'
            date_3 = f'{year}-10-01'
            joined_years = 'None'
            burned = None
            unburned = None
            unburnable = None
        elif period == 'Period 3':
            ## Period 3
            date_1 = f'{year-1}-03-01'
            date_2 = f'{year-1}-10-01'
            date_2_2 = f'{year-1}-10-01'
            date_3 = f'{year}-01-01'
            joined_years = 'Previous'
            burned = None
            unburned = None
            unburnable = None
    
    return studyArea, date_1, date_2, date_2_2, date_3, joined_years, burned, unburned, unburnable

def split_zone(studyArea, grid, splits=2):
    bounds = studyArea.bounds()
    ## Split the bounding box into four equal-sized rectangles
    xmin = ee.List(ee.List(bounds.coordinates().get(0)).get(0)).get(0).getInfo()
    ymin = ee.List(ee.List(bounds.coordinates().get(0)).get(0)).get(1).getInfo()
    xmax = ee.List(ee.List(bounds.coordinates().get(0)).get(2)).get(0).getInfo()
    ymax = ee.List(ee.List(bounds.coordinates().get(0)).get(2)).get(1).getInfo()

    rectangle1 = ee.Geometry.BBox(xmin, ymin, (xmin + xmax) / 2, (ymin + ymax) / 2);
    rectangle2 = ee.Geometry.BBox((xmin + xmax) / 2, ymin, xmax, (ymin + ymax) / 2);
    rectangle3 = ee.Geometry.BBox(xmin, (ymin + ymax) / 2, (xmin + xmax) / 2, ymax);
    rectangle4 = ee.Geometry.BBox((xmin + xmax) / 2, (ymin + ymax) / 2, xmax, ymax);
    splitGeometry = ee.FeatureCollection([rectangle1, rectangle2, rectangle3, rectangle4])

    params = ee.List([ee.List([]), ee.List([])])
    
    def get_rectangles(i, parameters):
        already_selected = ee.List(parameters).get(0)
        zone_list = ee.List(parameters).get(1)
#         xmin, ymin, xmax, ymax = ee.List(parameters).get(2).getInfo()
#         for j in range(splits**2):
        new_rect = tiles_2d.filterBounds(ee.Feature(splitGeometry.toList(4).get(i)).geometry()) \
                                .filter(ee.Filter.inList('TILE', already_selected).Not())
        already_selected = ee.List(already_selected).cat(new_rect.aggregate_array('TILE'))
        zone_list = ee.List(zone_list).add(new_rect)
        return ee.List([already_selected, zone_list])
    zones = ee.List.sequence(0, splits**2-1, 1).iterate(get_rectangles, params) 
        
    return ee.List(ee.List(zones).get(1))


In [None]:
ref_year = 2019
year = 2019

dataset = 'Landsat'
zone = 'Amazonia'
period = 'Period 1'
assetPath = 'users/fireccihrba/BAMT/'
# assetPath = 'projects/ee-aminkhairoun/assets/FireCCIHR/'
crs = 'WGS'
UMLError_tiles_1d = []
# In case some 1 degree tiles cannot be downloaded yet, list them here:
UMLError_tiles_05d = []
# And if some 0.5 degree tiles cannot be downloaded yet, list them here:
UMLError_tiles_025d = []
list_grids = [tiles_2d, tiles_1d, tiles_05d, tiles_025d]
list_tile_errors = [UMLError_tiles_1d, UMLError_tiles_05d, UMLError_tiles_025d]
studyArea, date_1, date_2, date_2_2, date_3, joined_years, burned, unburned, unburnable = define_params(year, zone, 
                                                                                 dataset, period)
print(date_1, date_2, date_2_2, date_3, joined_years)
print(zone, studyArea.bounds().getInfo())

2018-10-01 2019-04-01 2019-04-01 2019-10-01 None
Amazonia {'geodesic': False, 'type': 'Polygon', 'coordinates': [[[-62.00000101451595, -24.00000393003831], [-46.999999760070345, -24.00000393003831], [-46.999999760070345, -11.999998281877398], [-62.00000101451595, -11.999998281877398], [-62.00000101451595, -24.00000393003831]]]}


## Visualize results
<div class="alert alert-block alert-warning">
    <p><b style="color: black; font-size:17px">Visualization is possible in Python GEE API. However we recommend to use GEE code editor for that purpose as it is more convenient. Therefore, the same thing applies for training, we recommend to train and save the pre-trained model in GEE code editor and use the python code to automatize predictions.<br></b>
</div>

In [64]:
kwargs = dict(assetPath=assetPath, ref_year=ref_year, visualize=True, transfer_tr=True,
      status='training', save_model=False, include_dates=True, include_biomes=False,
              anomaly=None,
            #   anomaly='Inundations'
             )
a = '20S058W'
# a = 'TILE-50N099E'
# a = '14N032E'
geom = tiles_2d.filter(ee.Filter.eq('TILE', a))
# geom = tiles_amazonia_4d.filter(ee.Filter.eq('TILE', a))
Map = geemap.Map()
view_BA(geom, dataset, zone, joined_years, date_1, date_2, date_2_2, date_3, **kwargs)
Map

{'type': 'FeatureCollection', 'columns': {'PROJ': 'String', 'TILE': 'String', 'system:index': 'String'}, 'version': 1646414399963967, 'id': 'users/ekhiroteta/BAMT/BAMT_GEE_downloadableTiles_2d', 'properties': {'system:asset_size': 12446137}, 'features': [{'type': 'Feature', 'geometry': {'type': 'Polygon', 'coordinates': [[[-58.050000006347716, -22.050000544585505], [-57.98437497159881, -22.050000505959073], [-57.91874992945067, -22.050000525632218], [-57.85312498235496, -22.05000051576474], [-57.78749997335406, -22.050000573744672], [-57.72187491953429, -22.05000051140126], [-57.65624992502229, -22.050000562830235], [-57.59062492144783, -22.050000520364712], [-57.52499992759351, -22.050000517572215], [-57.45937487649837, -22.0500005072349], [-57.393749874515514, -22.050000560583005], [-57.328124855946136, -22.05000054923127], [-57.2624998420428, -22.050000524584714], [-57.19687485464257, -22.05000053761283], [-57.13124983013115, -22.05000053920607], [-57.0656248776739, -22.050000518605

Map(center=[0, 0], controls=(WidgetControl(options=['position', 'transparent_bg'], widget=SearchDataGUI(childr…

## Export results

### 1. Main export loop
<div class="alert alert-block alert-info">
    <p style="color: black; font-size:17px">This is the main execution loop. 
        Specify the month and date of <b>date_1</b>,  <b>date_2</b>, <b>date_2_2</b>
        and <b>date_3</b> so that they can change dynamically 
        with the processing year. The algorithms has three outputs that you 
        can either export (<b style="color:#199448">True</b>) or reject 
        (<b style="color:#199448">False</b>): </p>
            <ul style="list-style-type:square;color: black; font-size:17px">
                <li><b>exp_prob</b>: In case you want to export a GeoTIFF burn probablity, set 
                    it to <b style="color:#199448">True</b>, otherwise 
                    <b style="color:#199448">False</b>.</li>
                <li><b>exp_BA</b>: In case you want to export burn area shapefiles, set 
                    it to <b style="color:#199448">True</b>, otherwise 
                    <b style="color:#199448">False</b>.</li>
                <li><b>exp_dates</b>: In case you want to export a GeoTIFF burn dates, set 
                    it to <b style="color:#199448">True</b>, otherwise 
                    <b style="color:#199448">False</b>.</li>
            </ul>
</div>

In [None]:
ref_year = 2019
start_year = 1999
end_year = 1990
splits = 2
# Map = geemap.Map()
for year in np.arange(start_year, end_year-1, -1):
    drop = []
    time.sleep(1)
    kwargs = dict(crs=crs, exp_prob=True, exp_BA=True, exp_dates=True, 
                  drop=drop, no_overwrite=None, BA_format='GeoTiff')
    kwargs_class = dict(assetPath=assetPath, ref_year=ref_year, visualize=False, transfer_tr=True,
        status='training', save_model=False, include_dates=True, include_biomes=False, anomaly=None)
    kwargs_drive = dict(folderId='1W2CB7z1pGRgwiwnzbqQRWDY9buoHPQ23', 
                  user='firecci', cred_path=None, pattern='', export_folder='Sahel')
#     kwargs_drive = dict(local_path=f'{local_path}/{zone}/ByTile/**/GEE/**',
#             pattern='')

    studyArea, date_1, date_2, date_2_2, date_3, joined_years, burned, unburned, unburnable = define_params(year, zone, dataset, period)
    all_tiles = locate_tiles(studyArea.dissolve().buffer(-15000), 
                     list_grids, list_tile_errors).aggregate_array('TILE').getInfo()                                                                                    
    print(f"Period {period[-1]} of the year {year} is being processed: {len(all_tiles)} tiles")
    print(zone, date_1, date_2, date_2_2, date_3, joined_years)
    
    for i in range(splits**2):
        geom = ee.FeatureCollection(split_zone(studyArea, tiles_2d, splits).get(i)).geometry().intersection(studyArea)                      
        list_tiles = locate_tiles(geom.dissolve().buffer(-15000), list_grids, list_tile_errors).aggregate_array('TILE').getInfo()
        print(f'Split: {i+1}. The number of tiles to be processed: {len(list_tiles)}')
        download_BA(geom, dataset, zone, joined_years, list_grids, 
                    date_1, date_2, date_2_2, date_3, kwargs, kwargs_class, kwargs_drive)
        
print('Done !')

## ABoVE

In [None]:
local_path = '/media/amin/STORAGE/STORAGE/OneDrive/PhD/Landsat_BA'
grid2d = gpd.read_file(f'{local_path}/Regions/BAMT_GEE_downloadableTiles_2d.shp')
locations = pd.read_csv(f'{local_path}/Ancillary/ABoVE/Aggregated_data.csv')
selection_ABoVE = locations.loc[locations.longitude < 0]
selection_ABoVE = gpd.GeoDataFrame(
    selection_ABoVE, 
    geometry=gpd.points_from_xy(selection_ABoVE.longitude, selection_ABoVE.latitude), crs="EPSG:4326")
selection_ABoVE['geometry'] = selection_ABoVE['geometry'].buffer(1)



  selection_ABoVE['geometry'] = selection_ABoVE['geometry'].buffer(1)


In [4]:
clipped_2d = gpd.overlay(grid2d, gpd.GeoDataFrame(geometry=[box(-153, 52, -88, 68)], crs="EPSG:4326"), 
                         how='intersection')
clipped_2d

Unnamed: 0,PROJ,TILE,geometry
0,EPSG:32605,68N156W,"POLYGON ((-152.95000 65.95000, -152.99844 65.9..."
1,EPSG:32605,68N153W,"POLYGON ((-149.95000 65.95000, -149.99844 65.9..."
2,EPSG:32605,70N156W,"POLYGON ((-152.95000 67.95000, -152.99844 67.9..."
3,EPSG:32605,70N153W,"POLYGON ((-149.95000 67.95000, -149.99844 67.9..."
4,EPSG:32605,52N156W,"POLYGON ((-152.99844 52.05000, -152.95000 52.0..."
...,...,...,...
225,EPSG:32616,58N090W,"POLYGON ((-90.05000 58.05000, -90.00156 58.050..."
226,EPSG:32616,60N090W,"POLYGON ((-90.05000 60.05000, -90.00156 60.050..."
227,EPSG:32616,62N090W,"POLYGON ((-90.05000 62.05000, -90.00156 62.050..."
228,EPSG:32616,64N090W,"POLYGON ((-90.05000 64.05000, -90.00156 64.050..."


In [5]:
BAgdf = gpd.GeoDataFrame()
for i in clipped_2d.index[:]:
#     print(clipped_1d.loc[[i]].geometry)
    new = gpd.overlay(clipped_2d.loc[[i]], selection_ABoVE, how='intersection')
    if len(new) > 0:
        print(clipped_2d.loc[i, "TILE"])
        gdf = pd.concat([gdf, new])
gdf = gdf.reset_index().drop(columns=['index'])

68N153W
70N153W
64N153W
66N153W
68N150W
68N147W
70N150W
64N150W
64N147W
66N150W
66N147W
68N144W
62N144W
62N141W
64N144W
64N141W
66N144W
66N141W
64N138W
66N138W
62N123W
60N120W
60N117W
62N120W
62N117W
64N120W
64N117W
66N120W
66N117W
54N111W
56N114W
56N111W
58N114W
58N111W
60N114W
60N111W
62N114W
64N114W
66N114W
54N108W
54N105W
56N108W
56N105W
58N108W
58N105W
56N102W
56N099W
58N102W
58N099W
54N090W
56N090W


In [8]:
unique_id = np.unique([gdf.loc[i, 'TILE'] + "-" + str(gdf.loc[i, 'burn_year']) for i in gdf.index])
selected_tiles = pd.DataFrame({'year': [int(i[-4:]) for i in unique_id], 
                               'tile': [i[:-5] for i in unique_id]})
selected_tiles

Unnamed: 0,year,tile
0,2003,54N090W
1,2003,54N105W
2,2015,54N105W
3,2003,54N108W
4,2015,54N108W
...,...,...
90,2003,68N153W
91,2004,68N153W
92,2005,68N153W
93,2005,70N150W


In [None]:
''' These tiles needs to be rectified visually to eliminate the ones not needed'''
geoms = [grid2d.loc[grid2d.TILE == i].geometry.values[0] for i in selected_tiles.tile]
Tiles_ABoVE = gpd.GeoDataFrame(selected_tiles, geometry=geoms)
Tiles_ABoVE.to_file(f'{local_path}/Regions/ABoVE_Tiles1.shp')
Tiles_ABoVE

Unnamed: 0,year,tile,geometry
0,2003,54N090W,"POLYGON ((-90.05000 54.05000, -90.00156 54.050..."
1,2003,54N105W,"POLYGON ((-105.05000 54.05000, -105.00156 54.0..."
2,2015,54N105W,"POLYGON ((-105.05000 54.05000, -105.00156 54.0..."
3,2003,54N108W,"POLYGON ((-108.05000 54.05000, -108.00156 54.0..."
4,2015,54N108W,"POLYGON ((-108.05000 54.05000, -108.00156 54.0..."
...,...,...,...
90,2003,68N153W,"POLYGON ((-149.95000 65.95000, -149.99844 65.9..."
91,2004,68N153W,"POLYGON ((-149.95000 65.95000, -149.99844 65.9..."
92,2005,68N153W,"POLYGON ((-149.95000 65.95000, -149.99844 65.9..."
93,2005,70N150W,"POLYGON ((-146.95000 67.95000, -146.99844 67.9..."


In [105]:
Tiles_ABoVE = gpd.read_file(f'{local_path}/Regions/ABoVE_Tiles1.shp')
Tiles_ABoVE['PROJ'] = [grid2d.loc[grid2d.TILE == i].PROJ.values[0] for i in Tiles_ABoVE.tile]
Tiles_ABoVE['geometry'] = [grid2d.loc[grid2d.TILE == i].geometry.values[0] for i in Tiles_ABoVE.tile]
Tiles_ABoVE['geometry'] = [polygon.Polygon(np.round(np.column_stack(a.simplify(0.1).exterior.xy))) 
                           for a in Tiles_ABoVE['geometry']]
Tiles_ABoVE.rename(columns={'tile': 'TILE'}, inplace=True)
Tiles_ABoVE.set_crs(epsg=4326, inplace=True)
Tiles_ABoVE.to_file(f'{local_path}/Regions/Tiles_ABoVE.shp')

In [None]:
ref_year = 2019

kwargs = dict(crs=crs, exp_prob=True, exp_BA=True, exp_dates=True, 
              drop=[], no_overwrite=None, BA_format='GeoTiff')
kwargs_class = dict(assetPath=assetPath, ref_year=ref_year, visualize=False, transfer_tr=True, 
    status='predicting', save_model=True, include_dates=True, include_biomes=False, anomaly=None)
kwargs_drive = dict(folderId='1W2CB7z1pGRgwiwnzbqQRWDY9buoHPQ23', 
              user='firecci', cred_path=None, pattern='', export_folder='ABoVE')


missing = selected_tiles.tile.values
for t in missing[:]:
    tile = tiles_2d.filter(ee.Filter.eq('TILE', t))
    Years = selected_tiles.loc[selected_tiles.tile == t, 'year'].values
    for y in Years.copy():
        Years.append(y+1)
    Years = np.unique(Years)
    print(t, Years)
    for y in Years:
        studyArea, date_1, date_2, date_2_2, date_3, joined_years, burned, unburned, unburnable = define_params(year, zone, 
                                                                                         dataset, period)
        print(date_1, date_2, date_2_2, date_3)
        download_BA(tile, dataset, zone, joined_years, list_grids, 
                    date_1, date_2, date_2_2, date_3, kwargs, kwargs_class, kwargs_drive) 

54N108W [2015]
2014-03-01 2014-12-01 2015-03-01 2015-12-01
None
0.8788902153501179
initialize 54N108W
Probability export of 54N108W
Dates export of 54N108W
BA export of 54N108W as GeoTiff
2015-03-01 2015-12-01 2016-03-01 2016-12-01
None
0.8788902153501179
initialize 54N108W
Probability export of 54N108W
Dates export of 54N108W
BA export of 54N108W as GeoTiff
54N111W [2003]
2002-03-01 2002-12-01 2003-03-01 2003-12-01
None
0.8788902153501179
initialize 54N111W
Probability export of 54N111W
Dates export of 54N111W
BA export of 54N111W as GeoTiff
2003-03-01 2003-12-01 2004-03-01 2004-12-01
None
0.8788902153501179
initialize 54N111W
Probability export of 54N111W
Dates export of 54N111W
BA export of 54N111W as GeoTiff
56N090W [2003]
2002-03-01 2002-12-01 2003-03-01 2003-12-01
None
0.8788902153501179
initialize 56N090W
Probability export of 56N090W
Dates export of 56N090W
BA export of 56N090W as GeoTiff
2003-03-01 2003-12-01 2004-03-01 2004-12-01
None
0.8788902153501179
initialize 56N090W
Pro