In [None]:
import numpy as np
from osgeo import gdal
import matplotlib.pyplot as plt
import os
from keras.layers import *
from keras.models import Model

In [None]:
def read_tiff(fn):
    """
        inputs: tiff filename
        pouputs: image data
    """
    dataset = gdal.Open(fn,gdal.GA_ReadOnly)
    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize
    im_bands = dataset.RasterCount
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)
    if len(im_data.shape)==2:
        return im_data
    else:
        return im_data.transpose([1,2,0])
def features(dt):
    """
        inputs: images
        pouputs: images with feature bands
    """
    Blue = dt[:,:,1].reshape(dt.shape[:2]+(1,))
    Green = dt[:,:,2].reshape(dt.shape[:2]+(1,))
    Red = dt[:,:,3].reshape(dt.shape[:2]+(1,))
    NIR = dt[:,:,4].reshape(dt.shape[:2]+(1,))
    SWIR1 = dt[:,:,5].reshape(dt.shape[:2]+(1,))
    mndwi = dt[:,:,5].reshape(dt.shape[:2]+(1,))
    leb = dt[:,:,6].reshape(dt.shape[:2]+(1,))
    dem = dt[:,:,7].reshape(dt.shape[:2]+(1,))
    return np.concatenate([Blue,Green,Red,NIR,ndvi,mndwi,leb,dem],axis=-1)
def get_input(label, im):
    """
        inputs: None
        outputs: model
    """
    features_im = features(im)
    y = label[3:-3,3:-3]
    water_pixel_num = (y==1).sum()
    other_pixel_num = (y==0).sum()
    if water_pixel_num<other_pixel_num:
        loc_x = np.where(y!=1)[0]+3
        loc_y = np.where(y!=1)[1]+3
        loc_length = len(loc_x)
        loc_choice = np.random.choice(np.arange(loc_length),min(int(water_pixel_num*5),loc_length),replace=False)
        loc_x = loc_x[loc_choice]#np.random.choice(np.where(t!=1)[0],water_pixel_num,replace=False)
        loc_y = loc_y[loc_choice]#np.random.choice(np.where(t!=1)[1],water_pixel_num,replace=False)
        label[loc_x,loc_y] = 2
    elif water_pixel_num>other_pixel_num:
        label[np.where(y==0)[0]+3,np.where(y==0)[1]+3]=2
        loc_x = np.where(y!=0)[0]+3
        loc_y = np.where(y!=0)[1]+3
        label[loc_x,loc_y] = 0
        loc_length = len(loc_x)
        loc_choice = np.random.choice(np.arange(loc_length),other_pixel_num,replace=False)
        loc_x = loc_x[loc_choice]#np.random.choice(np.where(t!=1)[0],water_pixel_num,replace=False)
        loc_y = loc_y[loc_choice]#np.random.choice(np.where(t!=1)[1],water_pixel_num,replace=False)
        label[loc_x,loc_y] = 1
    else:
        label = 2-label
    loc_x = np.where(label[3:-3,3:-3]!=0)[0]
    loc_y = np.where(label[3:-3,3:-3]!=0)[1]
    pixel_num = len(loc_x)
    #print(pixel_num)
    im_train = np.zeros((pixel_num,7,7,8))
    lb_train = np.zeros((pixel_num,1,1,2))
    num = 0
    for i,j in zip(loc_x,loc_y):
        try:
            im_train[num] = features_im[i:i+7,j:j+7,:]
        except:
            print(i,j)
        if label[i+3,j+3] == 1:
            lb_train[num,:,:,0] = 1
        elif label[i+3,j+3] == 2:
            lb_train[num,:,:,1] = 1
        num+=1
    return im_train,lb_train
def CNN():
    """
        inputs: None
        outputs: model
    """
    inputs = Input((7,7,8))
    x = Conv2D(16, (3, 3), activation = 'relu')(inputs)
    x = Conv2D(32, (3, 3), activation = 'relu')(x)
    x = concatenate([x,inputs[:,2:-2,2:-2,:]])
    x = Conv2D(64, (3, 3), activation = 'relu')(x)
    x = concatenate([x,inputs[:,3:-3,3:-3,:]])
    x = Conv2D(128, (1, 1), activation = 'relu')(x)
    outputs = Conv2D(2,(1,1),activation='sigmoid')(x)
    model = Model(inputs=inputs, outputs=outputs)
    return model
def generate_inputs(label_fns, im_fns):
    """
        inputs: filename of labels and images
        outputs: None
    """
    for i, j in zip(label_fns, im_fns):
        im = read_tiff(j)
        label = read_tiff(i)
        label = label/label.max()
        yield get_input(label, im)

In [None]:
model = CNN()
fns = os.listdir(r'[folder]')
######################################
# folder of labels
######################################
label_fns = [os.path.join(r'[folder]',i) for i in fns]
######################################
# folder of labels
######################################
im_fns = [os.path.join(r'[folder]','im'+i[4:]) for i in fns]
######################################
# folder of images
######################################
epochs = {}
for ep in range(10):
    print('epoch:',ep+1)
    epochs[str(ep+1)] = {'loss':[],'acc':[]}
    with open(r'[filename]','a+') as f:
        ######################################
        # log filename
        ######################################
        f.write('epoch:'+str(ep+1)+'\n')
    num=1
    for i, j in generate_inputs(label_fns, im_fns):
        print('file',num)
        if i.shape[0]:
            history=model.fit(i,j,batch_size=4096,epochs = 1,shuffle = True)
            with open(r'[filename]','a+') as f:
                ######################################
                # log filename
                ######################################
                f.write(str(history.history)+'\n')
            epochs[str(ep+1)]['loss'].append(history.history['loss'][0])
            epochs[str(ep+1)]['acc'].append(history.history['acc'][0])
        num+=1
    
    model.save(r'[folder]'+'\\model_'+str(ep)+'.h5')
    ######################################
    # folder to save models
    ######################################