In [1]:
#This (still in progress) code is (mostly) straightforward js-to-python translation from Chen et al's work
#Some variables/methods early on may be different, but it becomes a copy/paste job as it goes on further.
#This file, as the name implies, is only for testing. It will be deleted in the future.

#Chen, Y., Cao R.Y., Chen, J., Liu, L.C., and Matsushita, B. A practical approach to reconstruct high-quality Landsat NDVI time-series data by gap
#https://github.com/ChenY04/GEE


#imports, global consts, inits
import ee
import geemap

ee.Initialize(project='seamproj01')
map = geemap.Map()

NDVI_VIZ_PARAMS = {"min" : -1, "max": 1, "palette" : ["blue", "white", "green"]}

sudanStateBorders = ee.FeatureCollection("projects/seamproj01/assets/SudanStateBorders") #Shapefiles for Sudan administration borders, via OCHA HDX
#sudanCroplandMask = ee.FeatureCollection("projects/seamproj01/assets/Sudan_Cropland_Mask_CopernicusLCLU2019") #shapefile generated from Copernicus Moderate Dynamic Land Cover dataset.
testArea =  ee.FeatureCollection("projects/seamproj01/assets/test_area").geometry()

khartoum = sudanStateBorders.filter(ee.Filter.inList("ADM1_EN", rightValue=["Khartoum"])).geometry()
gezira = sudanStateBorders.filter(ee.Filter.inList("ADM1_EN", rightValue=["Aj Jazirah"])).geometry()

# khartoumCropland = sudanCroplandMask.filter(ee.Filter.inList("ADM1_EN", rightValue=["Khartoum"])).geometry()
# geziraCropland = sudanCroplandMask.filter(ee.Filter.inList("ADM1_EN", rightValue=["Aj Jazirah"])).geometry()

khartoumCroplandModified = ee.FeatureCollection("projects/seamproj01/assets/khartoum_cropmask_gt_1km2").geometry()
geziraCroplandModified = ee.FeatureCollection("projects/seamproj01/assets/gezira_cropmask_western_bank").geometry()


#Parameter reference 1 (SG filtering kernel)：(6,6,0,2) (window, degree)
SG_Coeff1 = ee.List([-0.076923102,-1.4901161e-008,0.062937059,0.11188812,0.14685316,0.16783218, 0.17482519,0.16783218,0.14685316,0.11188812,0.062937059,-1.4901161e-008,-0.076923102]) 
#Parameter reference 2：(4,4,0,3)
SG_Coeff2 = ee.List([-0.090909064,0.060606077,0.16883118,0.23376624,0.25541127,0.23376624, 0.16883118,0.060606077,-0.090909064]) 
#Parameter reference 3：(4,4,0,5)
SG_Coeff3 = ee.List([0.034965038,-0.12820521,0.069930017,0.31468537,0.41724950,0.31468537, 0.069930017,-0.12820521,0.034965038]) 
#Parameter reference 4：(4,4,0,2)
SG_Coeff4 = ee.List([-0.090909064,0.060606077,0.16883118,0.23376624,0.25541127,0.23376624, 0.16883118,0.060606077,-0.090909064]) 
#Parameter reference 5：(3,3,0,2)
SG_Coeff5 = ee.List([-0.095238090,0.14285715,0.28571430,0.33333334,0.28571430,0.14285715, -0.095238090])

In [2]:
#mappable function
def AddDoY(image : ee.Image) -> ee.Image:
    return image.set({"doy" : ee.Number.parse(ee.Date(image.date()).format('D'))})


def ProcessLandsat8SRColletion(col : ee.ImageCollection, roi : ee.Geometry, startDate : str, endDate : str) -> ee.ImageCollection:
    #TODO "DATE_ACQUIRED" property is not necessary (redundant as system:tim_start has same data). No need to keep it.

    def Clip(image : ee.Image) -> ee.Image:
        propertiesToKeep = {"DATE_ACQUIRED" : image.get("DATE_ACQUIRED"), "system:time_start" : image.get("system:time_start")}
        return image.clip(roi).set(propertiesToKeep)
    
    def MaskPoortQualityPixels(img : ee.Image) -> ee.Image:
        qaBand = img.select("QA_PIXEL")
    
        mask = qaBand.bitwiseAnd(8).eq(0) #pixels that aren't cloud (bit 3)
        mask = mask.And(qaBand.bitwiseAnd(16).eq(0)) #pixels that aren't cloud shadow (bit 4)
        mask = mask.And(qaBand.bitwiseAnd(128).eq(0)) #pixels that aren't water (bit 7)
        mask = mask.rename("mask")
        return img.updateMask(mask).addBands(mask)
    
    #Note: Returns only bands 4 and 5.
    def ScaleLandsat8SR(image : ee.Image) -> ee.Image:
        #mapping this function strips the image properties, but the time-related ones are necessary for filtering, so we add them to the result
        propertiesToKeep = {"DATE_ACQUIRED" : image.get("DATE_ACQUIRED"), "system:time_start" : image.get("system:time_start")}

        image = image.select("SR_B4", "SR_B5").multiply(0.0000275).add(-0.2).set(propertiesToKeep)
        return image
    
    def NDVIize(image : ee.Image) -> ee.Image:
        propertiesToKeep = {"DATE_ACQUIRED" : image.get("DATE_ACQUIRED"), "system:time_start" : image.get("system:time_start")}
        return image.normalizedDifference(["SR_B5", "SR_B4"]).rename("NDVI").set(propertiesToKeep)
    
    output = col.filterDate(startDate, endDate).filterBounds(roi)
    return output.map(Clip).map(MaskPoortQualityPixels).map(ScaleLandsat8SR).map(NDVIize).map(AddDoY)
    
    
def ProcessMODISCollection(col : ee.ImageCollection, roi : ee.Geometry, startDate : str, endDate : str) -> ee.ImageCollection:
    def Clip(image : ee.Image) -> ee.Image:
        propertiesToKeep = {"system:time_start" : image.get("system:time_start")}
        return image.clip(roi).set(propertiesToKeep)
    
    def MaskPoortQualityPixels(img : ee.Image) -> ee.Image:
        qaBand = img.select("State")
    
        mask = qaBand.bitwiseAnd(3).eq(0) #pixel is clear from cloud (bits 0 and 1)
        mask = mask.And(qaBand.bitwiseAnd(4).eq(0)) #not cloud shadow (bit 2)
        mask = mask.And(qaBand.bitwiseAnd(768).eq(0)) #no cirrus (bits 8 and 9)
        return img.updateMask(mask)
    
    def NDVIize(image : ee.Image) -> ee.Image:
        propertiesToKeep = {"system:time_start" : image.get("system:time_start")}
        return image.normalizedDifference(["sur_refl_b02", "sur_refl_b01"]).rename("NDVI").set(propertiesToKeep)

    output = col.filterDate(startDate, endDate).filterBounds(roi)
    return output.map(Clip).map(MaskPoortQualityPixels).map(NDVIize).map(AddDoY)

In [3]:
#copied [nearly] verbatim from the js code, then translated to python (with minimal effort)
def LinearInterpolation(col : ee.ImageCollection) -> ee.ImageCollection:
    def MaskDate(img : ee.Image ) -> ee.Image:
        onedoy = ee.Number.parse(ee.Date(img.date()).format('D'))
        img_doy = ee.Image(onedoy).updateMask(img.mask()).rename('DOY').toInt()
        return img.addBands(img_doy)

    def AddDoYBand(img : ee.Image ) -> ee.Image:
        onedoy = ee.Number.parse(ee.Date(img.date()).format('D'))
        img_doy = ee.Image(onedoy).rename('DOY').toInt().setDefaultProjection(img.projection())
        return img.addBands(img_doy)

    colMaskDate = col.map(MaskDate)

    bandNamesRaw = col.first().bandNames()
    
    col = col.map(AddDoYBand)
    
    bandNames = col.first().bandNames()
    bandIndex = ee.List.sequence(0,bandNames.size().subtract(1))
    
    col = col.toList(col.size())
    colMaskDate = colMaskDate.toList(colMaskDate.size())

    def Interpolate(i): #i is of the generic type "Object"
        i = ee.Number(i)
        pre = ee.ImageCollection(colMaskDate.slice(0,i.add(1))).reduce(ee.Reducer.lastNonNull())
        cur = col.get(i)
        next = ee.ImageCollection(colMaskDate.slice(i,-1)).reduce(ee.Reducer.firstNonNull())
        pre = ee.Image(pre).select(bandIndex,bandNames)
        cur = ee.Image(cur)
        next = ee.Image(next).select(bandIndex,bandNames)
        DOY_pre = pre.select('DOY')
        DOY_cur = cur.select('DOY')
        DOY_next = next.select('DOY')
        cur_date = cur.get('system:time_start')
        Interpol = pre.add(next.subtract(pre).multiply(DOY_cur.subtract(DOY_pre).divide(DOY_next.subtract(DOY_pre)))).set({'system:time_start':cur_date})
        return Interpol.select(bandNamesRaw).set('doy',cur.get('doy'))

    colInterpolated = ee.List.sequence(1,col.size().subtract(2)).map(Interpolate)

    colInterpolated = colInterpolated.insert(0,col.get(0)).insert(-1,col.get(-1))
    colInterpolated = ee.ImageCollection(colInterpolated).map(lambda img : img.select(bandNamesRaw))

    return colInterpolated

In [4]:
#copied [nearly] verbatim from the js code, then translated to python (with minimal effort)
def SG_Filter(imgCol,list_sgCoeff):

    bandlist = ee.Image(imgCol.first()).bandNames()
    newBandNames = bandlist.map(lambda band : ee.String(band).cat("_fitted"))
    
    windowSize = ee.Number(list_sgCoeff.size())
    half_winSize = windowSize.subtract(1).divide(2)
    imgColSize = ee.Number(imgCol.size())
    imgCol = ee.ImageCollection(imgCol).map(lambda img: img.select(bandlist))
    
    img_SgFilterCoeff = ee.Image.constant(list_sgCoeff)
    list_imgCol = imgCol.toList(imgCol.size())
    list_imgCol0 = list_imgCol.slice(0,half_winSize).cat(list_imgCol).cat(list_imgCol.slice(imgColSize.subtract(half_winSize)))
    imgColSize0 = imgColSize.add(windowSize).subtract(1)
    
    def Smooth(i): 
        i = ee.Number(i)
        imgCol_process = list_imgCol0.slice(i,i.add(windowSize))
        imgCol_process = ee.ImageCollection(imgCol_process).toBands()
        img_process = imgCol_process.multiply(img_SgFilterCoeff).reduce(ee.Reducer.sum()).rename('sg')
        
        img_org = ee.Image(list_imgCol0.get(i.add(half_winSize))).addBands(img_process).select('sg').rename(bandlist)
        return img_org.select(bandlist).rename(newBandNames).set('doy',ee.Image(list_imgCol.get(i)).get('doy'))  

    list_processImgcol = ee.List.sequence(0,imgColSize0.subtract(windowSize))
    list_processImgcol = list_processImgcol.map(Smooth)
    
    list_images_SgFiltered = list_processImgcol 
    return ee.ImageCollection(list_images_SgFiltered).sort('system:time_start')


In [5]:

startYear = 2023
startMonth = 1
endYear = 2023
endMonth = 2

roi = testArea

startDate = f"{startYear}-{startMonth}-1"
endDate = f"{endYear}-{endMonth + 1}-1" if endMonth < 12 else f"{endYear + 1}-1-1"

landsat8SR = ee.ImageCollection("LANDSAT/LC08/C02/T1_L2")
modidsSR = ee.ImageCollection("MODIS/061/MOD09Q1")

frDataLandsat = ProcessLandsat8SRColletion(landsat8SR, roi, startDate, endDate)
crDataMODIS = ProcessMODISCollection(modidsSR, roi, startDate, endDate)

# print (crDataMODIS.size().getInfo())
# print (crDataMODIS.first().bandNames().getInfo())
# print (crDataMODIS.first().propertyNames().getInfo())
# print (crDataMODIS.aggregate_array("doy").getInfo())

# map.centerObject(roi)
# map.addLayer(crData.first(), NDVI_VIZ_PARAMS, "NDVI")
# map.addLayer(roi)
# map


crDataMODIS = LinearInterpolation(crDataMODIS)

# print (crDataMODIS.size().getInfo())
# print (crDataMODIS.first().bandNames().getInfo())
# print (crDataMODIS.first().propertyNames().getInfo())
# print (crDataMODIS.aggregate_array("doy").getInfo())

crDataMODIS = SG_Filter(crDataMODIS, SG_Coeff2)

# print (crDataMODIS.size().getInfo())
# print (crDataMODIS.first().bandNames().getInfo())
# print (crDataMODIS.first().propertyNames().getInfo())
# print (crDataMODIS.aggregate_array("doy").getInfo())

frDataMODIS = crDataMODIS.map(lambda img : img.multiply(1.023).subtract(0.013).float().resample('bicubic').setDefaultProjection({"crs" :'EPSG:4326'}).copyProperties(img, img.propertyNames()))

#frDataLandsat = frDataLandsat.map(lambda img : img.select('NDVI').float().updateMask(img.select('mask')).rename('NDVI_masked'))

In [6]:
emptyImage = frDataLandsat.sum().mask(ee.Image(0))

def func(i):
    #note to self: in the original code, what crDataMODIS bellow takes the original, uninterpolated, unsmoothed data (original code gives interpolated and smoothed data own variables).
    #but since the size and time of all should be similar, I don't think that should be an each.
    targetDoy = ee.Number(ee.Image(crDataMODIS.toList(crDataMODIS.size()).get(i)).get('doy'))
    startT = targetDoy.subtract(4)
    endT = targetDoy.add(4)
    TimeFilter = ee.Filter.And(ee.Filter.lte('doy',endT),ee.Filter.gte('doy',startT))
    MergeWindow = frDataLandsat.filter(TimeFilter)
    ValueWindow = ee.Algorithms.If(condition = MergeWindow.size().gt(0),
        trueCase = MergeWindow.max().rename('NDVI'),
        falseCase = emptyImage.rename('NDVI') ) 
    ValueWindow = ee.Image(ValueWindow)
    maskWindow = ValueWindow.mask().rename('L_cmask')
    systime = ee.Image(crDataMODIS.toList(crDataMODIS.size()).get(i)).get('system:time_start')
    return ValueWindow.addBands(maskWindow).clip(roi).set('system:time_start',systime).set('doy',targetDoy)

frDataLandsat_list = ee.List.sequence(0,crDataMODIS.size().subtract(1)).map(func)

frDataLandsat = ee.ImageCollection(frDataLandsat_list)

# print (frDataLandsat.size().getInfo())
# print (frDataLandsat.first().bandNames().getInfo())
# print (frDataLandsat.first().propertyNames().getInfo())
# print (frDataLandsat.aggregate_array("doy").getInfo())


8
['NDVI', 'L_cmask']
['system:time_start', 'system:footprint', 'doy', 'system:index', 'system:bands', 'system:band_names']
[1, 9, 17, 25, 33, 41, 49, 57]
