In [None]:
import ee
ee.Initialize()
import pandas as pd
import numpy as np
import threading
from concurrent import futures
import os
from osgeo import gdal
from osgeo import osr
import time
from keras.models import Model
from keras.layers import *
from keras import optimizers

In [None]:
def summer_img(l8, st_year, ed_year, path, row):
    """
        inputs: imagecollection, start year, edn year, path, and row
        outputs: imagecollection
    """
    landsat = l8.filter(ee.Filter.eq('WRS_PATH', int(path)))\
            .filter(ee.Filter.eq('WRS_ROW', int(row)))\
            .filter(ee.Filter.calendarRange(int(st_year),int(ed_year),'year'))\
            .filter(ee.Filter.calendarRange(6,8,'month'))\
            .sort('CLOUD_COVER')
    return landsat
def generate_grid(xmin, ymin, xmax, ymax, dx, dy):
    """
        inputs: the range of lon and lat, and the interval of lon and lat
        outputs: grids
    """
    xx = np.arange(xmin, xmax, dx)
    yy = np.arange(ymin, ymax, dy)
    rect=[]
    for i in yy:
        for j in xx:
            y1 = i
            y2 = i+dx
            x1 = j
            x2 = j+dy
            coords = ee.List([x1, y1, x2, y2])
            rect.append(ee.Algorithms.GeometryConstructors.Rectangle(coords))
    cells = list(rect)
    return ee.FeatureCollection(cells)

In [None]:
def project(path, row):
    """
        inputs: path, row
        outputs: None(store masks, images and labels)
    """
    def predict(x):
        pre = np.argmin(model.predict(x[np.newaxis,:])[0],axis=-1)
        if pre.sum() <= 10:
            return 0
        else:
            return pre
    def get_array_from_image(image,grids,LDC = False):

        def acquire_img(i):
            results = {}
            try:
                band_arrs =image.sampleRectangle(region=ee.Geometry(ee.Geometry.Polygon(i['geometry']['coordinates'])))
                if not LDC:
                    band_arr = band_arrs.get('SR_B2')
                else:
                    band_arr = band_arrs.get('QA_PIXEL')
                np_arr = np.array(band_arr.getInfo())[:,:,np.newaxis]
                crs_transform = image.clip(ee.Geometry.Polygon(i['geometry']['coordinates'])).getInfo()['bands'][0]['crs_transform']
                if LDC==False:
                    for j in [3,4,5,6,98,99]:#range(2,8):
                        band_arr = band_arrs.get('SR_B'+str(j))
                        np_arr = np.concatenate([np_arr,np.array(band_arr.getInfo())[:,:,np.newaxis]],axis=-1)
                    band_arr = band_arrs.get('elevation')
                    np_arr = np.concatenate([np_arr,np.array(band_arr.getInfo())[:,:,np.newaxis]],axis=-1)
                else:
                    pass
                results[i['id']] = {}
                results[i['id']]['geo']=i
                results[i['id']]['value']=np_arr
                results[i['id']]['dim'] = np_arr.shape
            except:
                results = {}
            return results
        results = acquire_img(grids[0])
        if results == {}:
            return 0
        return results
    def write_tiff(im,pre):
        target = osr.SpatialReference()
        target.ImportFromEPSG(4326)
        originX = min(im['geo']['geometry']['coordinates'][0])[0]#lat
        originY = max(im['geo']['geometry']['coordinates'][0])[1]#lon
        rows = im['dim'][0]
        cols = im['dim'][1]
        c = np.array(im['geo']['geometry']['coordinates'][0])
        pixelWidth = (c[:,0].max()-c[:,0].min())/cols#lon
        pixelHieght = (c[:,1].min()-c[:,1].max())/rows#lat
        transform = (originX, pixelWidth, 0, originY, 0, pixelHieght)
        shape = pre.shape
        if len(shape) == 3:
            dimention = shape[2]
            if dimention > 1:
                driver=gdal.GetDriverByName('Gtiff')
                fn = r'[folder]'+'\\im\\im%.4f' % originX + '%.4f' % originY + '.tif'
                ##################################################
                #folder should be given
                ##################################################
                outRaster = driver.Create(fn,cols,rows,dimention,6)
                outRaster.SetGeoTransform(transform)
                outRaster.SetProjection(target.ExportToWkt())
                for i in range(dimention):
                    outband=outRaster.GetRasterBand(i+1)
                    outband.WriteArray(pre[:,:,i])
            else:
                driver=gdal.GetDriverByName('Gtiff')
                fn = r'[folder]'+'\\mask\\mask%.4f' % originX + '%.4f' % originY + '.tif'
                ##################################################
                #folder should be given
                ##################################################
                outRaster = driver.Create(fn,cols,rows,1,1)
                outRaster.SetGeoTransform(transform)
                outRaster.SetProjection(target.ExportToWkt())
                outband=outRaster.GetRasterBand(1)
                outband.WriteArray(pre[:,:,0])
        elif len(shape) == 2:
            driver=gdal.GetDriverByName('Gtiff')
            fn = r'[folder]'+'\\mask\\mask%.4f' % originX + '%.4f' % originY + '.tif'
            ##################################################
            #folder should be given
            ##################################################
            outRaster = driver.Create(fn,cols,rows,1,1)
            outRaster.SetGeoTransform(transform)
            outRaster.SetProjection(target.ExportToWkt())
            outband=outRaster.GetRasterBand(1)
            outband.WriteArray(pre)
        else:
            pass
        return fn


    def read_tiff(fn):
        """
            inputs: tiff file path
            outpots: data of this image
        """
        dataset = gdal.Open(fn,gdal.GA_ReadOnly)
        im_width = dataset.RasterXSize #column number
        im_height = dataset.RasterYSize #row number
        im_bands = dataset.RasterCount #bands number
        im_data = dataset.ReadAsArray(0,0,im_width,im_height)
        if len(im_data.shape)==2:
            return im_data
        else:
            return im_data.transpose([1,2,0])
    def CNN():
        """
            inputs: None
            outpots: CNN model 
        """
        inputs = Input((7,7,7))
        x = Conv2D(16, (3, 3), activation = 'relu')(inputs)
        x = Conv2D(32, (3, 3), activation = 'relu')(x)
        x = concatenate([x,inputs[:,2:-2,2:-2,:]])
        x = Conv2D(64, (3, 3), activation = 'relu')(x)
        x = concatenate([x,inputs[:,3:-3,3:-3,:]])
        x = Conv2D(128, (1, 1), activation = 'relu')(x)
        outputs = Conv2D(2,(1,1),activation='softmax')(x)
        model = Model(inputs=inputs, outputs=outputs)
        return model
    def noise_correct(mask,image,threshold):
        """
            inputs: 预测掩膜图像，特征影像
            outputs: 纠正后的图像
        """
        water_index = np.where(mask[:,:,0]>threshold)
        water_pixels = image[water_index[0],water_index[1],:]
        mean = water_pixels.mean(axis=0)
        std = water_pixels.std(ddof=1,axis=0)
        min_r = mean-3*std
        max_r = mean+3*std
        water_mk = np.zeros(image.shape[:2])
        for i,band in enumerate(image.transpose([2,0,1])):
            water_mk+=(band>min_r[i]) * (band< max_r[i])
        return water_mk[3:-3,3:-3]*(mask[:,:,0]>threshold)
    def features(dt):
        """
            inputs: image data
            outputs: feature bands of images for noise correction
        """        
        Blue = dt[:,:,1].reshape(dt.shape[:2]+(1,))
        Green = dt[:,:,2].reshape(dt.shape[:2]+(1,))
        Red = dt[:,:,3].reshape(dt.shape[:2]+(1,))
        NIR = dt[:,:,4].reshape(dt.shape[:2]+(1,))
        ndvi = ((dt[:,:,4]-dt[:,:,3])/(dt[:,:,4]+dt[:,:,3])).reshape(dt.shape[:2]+(1,))
        mndwi = ((dt[:,:,2]-dt[:,:,5])/(dt[:,:,2]+dt[:,:,5])).reshape(dt.shape[:2]+(1,))
        dem = dt[:,:,7].reshape(dt.shape[:2]+(1,))
        return np.concatenate([Blue,Green,Red,NIR,ndvi,mndwi,dem],axis=-1)
    def weather_in_normal(im,mk):
        """
            inputs: image, mask
            oututs: Confidence interval left and right thresholds
        """
        m = np.array([i.mean() for i in (im.transpose([2,0,1])*mk)])
        s = np.array([i.std(ddof=1) for i in (im.transpose([2,0,1])*mk)])
        left = m-3*s
        right = m+3*s
        return left,right
    def write_std_mask(mask,im):
        """
            inputs: mask file after noise correction, im
            outputs: None
        """
        target = osr.SpatialReference()
        target.ImportFromEPSG(4326)
        #tmp_tr = image.getInfo()['bands'][0]['crs_transform']
        originX = min(im['geo']['geometry']['coordinates'][0])[0]#lat
        originY = max(im['geo']['geometry']['coordinates'][0])[1]#lon
        shape = im['value'].shape
        rows = shape[0]
        cols = shape[1]
        pre = np.zeros(shape[:2])
        pre[3:-3,3:-3] = mask
        c = np.array(im['geo']['geometry']['coordinates'][0])
        pixelWidth = (c[:,0].max()-c[:,0].min())/cols#lon
        pixelHieght = (c[:,1].min()-c[:,1].max())/rows#lat
        #SetGeoTransform((originX, pixelWidth, 0, originY, 0, pixelHeight))
        transform = (originX, pixelWidth, 0, originY, 0, pixelHieght)
        driver=gdal.GetDriverByName('Gtiff')
        outRaster = driver.Create(r'[folder]'+'\\labels\\%.4f' % originX + '%.4f' % originY + '.tif',cols,rows,1,1)
        ##################################################
        #folder should be given
        ##################################################
        outRaster.SetGeoTransform(transform)
        outRaster.SetProjection(target.ExportToWkt())
        outband=outRaster.GetRasterBand(1)
        outband.WriteArray(pre)
    def cloud_free(image):
        """
            inputs: image in ee format
            output: free-cloud image added NDWI and LEB bands
        """
        #weights1,weights2,weights3 and weights4 are convolution kernels for Linear enhancement
        weights1 = [[0,0,-1,0,0],
                   [0,0, 0,0,0],
                   [0,0, 2,0,0],
                   [0,0, 0,0,0],
                   [0,0, -1,0,0]]
        weights2 = [[-1,0,0,0,0],
                    [0,0, 0,0,0],
                    [0,0, 2,0,0],
                    [0,0, 0,0,0],
                    [0,0,0,0, -1]]
        weights3 = [[0,0,0,0,-1],
                   [0,0, 0,0,0],
                   [0,0, 2,0,0],
                   [0,0, 0,0,0],
                   [-1,0, 0,0,0]]
        weights4 = [[0,0,0,0,0],
                   [0,0, 0,0,0],
                   [-1,0, 2,0,-1],
                   [0,0, 0,0,0],
                   [0,0, 0,0,0]]
        cloud = image.select('QA_PIXEL').rightShift(1).bitwiseAnd(15)#0b1111)
        image = image.select('SR_B.').multiply(0.0000275).add(-0.2).addBands(image.select('QA_PIXEL').rightShift(7).bitwiseAnd(1))
        image_freecloud = image.updateMask(cloud.expression('b==0',{'b':cloud}))
        ndwi = image_freecloud.normalizedDifference(['SR_B3','SR_B6'])
        smooth1 = ndwi.convolve(ee.Kernel.fixed(weights=ee.List(weights1)))
        smooth2 = ndwi.convolve(ee.Kernel.fixed(weights=ee.List(weights2)))
        smooth3 = ndwi.convolve(ee.Kernel.fixed(weights=ee.List(weights3)))
        smooth4 = ndwi.convolve(ee.Kernel.fixed(weights=ee.List(weights4)))
        smooth = ee.Image([smooth1,smooth2,smooth3,smooth4]).reduce(ee.Reducer.max()).rename(['SR_B99'])
        out_y = image_freecloud.addBands(ndwi.rename(['SR_B98'])).addBands(smooth)
        return out_y
    def write_grid_image(grid):
        mask = get_array_from_image(water,[grid],LDC=True)
        im = get_array_from_image(image,[grid])
        if im:
            tmp_mk= mask.popitem()
            tmp_im = im.popitem()
            if (tmp_mk[1]['value'] > 0).sum()>100:
                with open(r'[filename]','a+') as f:
                    ##################################################
                    #log filename should be given
                    ##################################################
                    f.write(str(path)+' '+str(row)+' '+grid['id']+' start!\n')
                mask_file = write_tiff(tmp_mk[1],(tmp_mk[1]['value']>0)*255)
                im_file = write_tiff(tmp_im[1],tmp_im[1]['value'])
                
                #noise correction
                dt = tmp_im[1]['value']
                dt = features(dt)
                mk = tmp_mk[1]['value'][:,:,0]
                t = (mk[3:-3,3:-3]==1).copy()*1
                water_pixel_num = (t==1).sum()
                threshold = 0.5
                model = CNN()
                model.compile(optimizer=optimizers.Adam(lr=1e-5), loss='categorical_crossentropy',metrics=['acc'])
                left,right = weather_in_normal(dt,mk)

                for times in range(3):
                    water_pixel_num = (t==1).sum()
                    loc_x = np.where(t!=1)[0]
                    loc_y = np.where(t!=1)[1]
                    if water_pixel_num < ((t.shape[0]*t.shape[1])/2):
                        loc_length = len(loc_x)
                        loc_choice = np.random.choice(np.arange(loc_length),water_pixel_num,replace=False)
                        loc_x = loc_x[loc_choice]
                        loc_y = loc_y[loc_choice]
                    t[loc_x,loc_y]=2
                    nums = (t!=0).sum()
                    im_train = np.zeros((nums, 7,7,7))
                    mk_train = np.zeros((nums,1,1,2))
                    mk_copy = np.zeros(mk.shape)
                    mk_copy[3:-3,3:-3]=t.copy()
                    count=0
                    for idx,loc_idx in enumerate(zip(np.where(t!=0)[0],np.where(t!=0)[1])):
                        im_train[idx] = dt[loc_idx[0]:loc_idx[0]+7,loc_idx[1]:loc_idx[1]+7]
                        tmp = mk_copy[loc_idx[0]+3,loc_idx[1]+3]*(im_train[idx,3,3,-1].__lt__(0.35))
                        #
                        if tmp == 1:
                            mk_train[idx,:,:,0] = 1
                        elif tmp == 2:
                            mk_train[idx,:,:,1] = 1
                    model = CNN()
                    model.compile(optimizer=optimizers.Adam(lr=1e-5), loss='categorical_crossentropy',metrics=['acc'])
                    if times == 0:
                        epochs = 30
                        batch = 16
                    else:
                        epochs = 5
                        batch = 16
                    model.fit(im_train, mk_train, batch_size=batch,epochs = epochs,shuffle = True)#32 30
                    p = model.predict(dt[np.newaxis,:])[0]
                    if times == 10:
                        t = (p[:,:,0]>p[:,:,1])*1
                    else:
                        t = (noise_correct(p,dt,0.95)>5)*1
                write_std_mask((p[:,:,0]>p[:,:,1])*255,tmp_mk[1])
            elif (tmp_mk[1]['value'] > 0).sum()>0:
                with open(r'D:\selenge_dem1\jl.txt','a+') as f:
                    f.write(str(path)+' '+str(row)+' '+grid['id']+' pass\n')
                pass

        else:
            pass
    with open(r'[filename]','a+') as f:
        ##################################################
        #log filename should be given
        ##################################################
        f.write(str(path)+' '+str(row)+' start！\n')

    imagec = summer_img(l8, st_year, ed_year, path, row)
    geo = imagec.first().geometry()
    DEM_0 = ee.Image('CGIAR/SRTM90_V4').reproject(crs = 'EPSG:4326',scale=30)
    DEM = DEM_0.divide(4000)
    coor = np.array(geo.getInfo()['coordinates'][0])
    dx = 0.135
    dy = 0.135
    xmin = coor[:,0].min()
    xmax = coor[:,0].max()
    ymin = coor[:,1].min()
    ymax = coor[:,1].max()
    grid = generate_grid(xmin, ymin, xmax, ymax, dx, dy)
    grid = grid.filterBounds(geo)
    grids = grid.getInfo()['features']
    image_freecloud = imagec.map(cloud_free).median().reproject(crs = 'EPSG:4326',scale=30)
    image = image_freecloud.select('SR_B.*').addBands(DEM.clip(geo)).reproject(crs = 'EPSG:4326',scale=30)
    image = image.setDefaultProjection(image.projection())
    water = image_freecloud.select('QA_PIXEL')

    st = time.time()
    times_num = 100
    interval = int(len(grids)/times_num) 
    if len(grids)//times_num:
        times_num+=1
    for times in range(times_num):   
        i0 = int(interval*times)
        i1 = int(min(interval*(times+1),len(grids)))
        grid_threads = []
        for idx,i in enumerate(grids[i0:i1]):
            grid_t = threading.Thread(target=write_grid_image, args=[i])
            grid_t.start()
            grid_threads.append(grid_t)
        for grid_thread in grid_threads:
            grid_thread.join()
        print('%.2f' % (times/times_num*100)+'%')
    long = time.time()-st
    hour = int(long/3600)
    minute = int((long-3600*hour)/60)
    second = long-60*int(long/60)
    print(str(hour)+':'+str(minute)+':'+str(second))
    with open(r'[filename]','a+') as f:
        ##################################################
        #log filename should be given
        ##################################################
        f.write(str(path)+' '+str(row)+' write done！ Time consuming: %.4f\n' % long)

In [None]:
path_row = pd.read_csv(r'[pathrow]')
path_row = path_row[(path_row['PATH']==131) | (path_row['PATH']==135)]
##################################################
# we use a file to store path and row, 
# and images of path==131 and path==135 
# is our training samples
# you can modify the path row by yourself
##################################################
path = np.array(path_row['PATH'])
row = np.array(path_row['ROW'])
l8 = ee.ImageCollection('LANDSAT/LC08/C02/T1_L2')
st_year = 2020
ed_year = 2020

In [None]:
project_threads = []
for i,j in zip(path[4:8],row[4:8]):
    thread_ = threading.Thread(target=project, args=[i,j,])
    thread_.start()
    project_threads.append(thread_)
for thread_t in project_threads:
    thread_t.join()