In [2]:
import sys
sys.path.append('../')

In [3]:
import os
import cv2
import copy
import numpy as np
import tensorflow as tf
from keras import backend as K
import matplotlib.pyplot as plt
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

def expend_as(x, n):
    y = Lambda(lambda x, repnum: K.repeat_elements(x, repnum, axis=3), arguments={'repnum': n})(x)
    return y

def conv_bn_act(x, filters, drop_out=0.0):
    x = Conv2D(filters, (3, 3), activation=None, padding='same')(x)

    if drop_out > 0:
        x = Dropout(drop_out)(x)

    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    return x

def attention_layer(d, e, n):
    d1 = Conv2D(n, (1, 1), activation=None, padding='same')(d)
    e1 = Conv2D(n, (1, 1), activation=None, padding='same')(e)

    concat_de = add([d1, e1])

    relu_de = Activation('relu')(concat_de)
    conv_de = Conv2D(1, (1, 1), padding='same')(relu_de)
    sigmoid_de = Activation('sigmoid')(conv_de)

    shape_e = K.int_shape(e)
    upsample_psi = expend_as(sigmoid_de, shape_e[3])

    return multiply([upsample_psi, e])

def feature_fused_module(x, filters, compression=0.5, drop_out=0.0):
    x1 = Conv2D(filters, (3, 3), dilation_rate=2, padding='same')(x)

    if drop_out > 0:
        x1 = Dropout(drop_out)(x1)

    x1 = BatchNormalization()(x1)
    x1 = Activation('relu')(x1)

    x2 = Conv2D(filters, (3, 3), padding='same')(x)

    if drop_out > 0:
        x2 = Dropout(drop_out)(x2)

    x2 = BatchNormalization()(x2)
    x2 = Activation('relu')(x2)

    x3 = add([x1, x2])

    x3 = GlobalAveragePooling2D()(x3)

    x3 = Dense(int(filters * compression))(x3)
    x3 = BatchNormalization()(x3)
    x3 = Activation('relu')(x3)

    x3 = Dense(filters)(x3)

    x3p = Activation('sigmoid')(x3)

    x3m = Lambda(lambda x: 1 - x)(x3p)

    x4 = multiply([x1, x3p])
    x5 = multiply([x2, x3m])

    return add([x4, x5])

def FF_UNet(input_shape=(256, 256, 3), filters=32, compression=0.5, drop_out=0, half_net=False, attention_gates=False):

    inputShape = Input(input_shape)

    c1 = feature_fused_module(inputShape, filters, compression=compression, drop_out=drop_out)
    c1 = feature_fused_module(c1, filters, compression=compression, drop_out=drop_out)
    p1 = MaxPooling2D((2, 2))(c1)
    filters = 2 * filters

    c2 = feature_fused_module(p1, filters, compression=compression, drop_out=drop_out)
    c2 = feature_fused_module(c2, filters, compression=compression, drop_out=drop_out)
    p2 = MaxPooling2D((2, 2))(c2)
    filters = 2 * filters

    c3 = feature_fused_module(p2, filters, compression=compression, drop_out=drop_out)
    c3 = feature_fused_module(c3, filters, compression=compression, drop_out=drop_out)
    p3 = MaxPooling2D((2, 2))(c3)
    filters = 2 * filters

    c4 = feature_fused_module(p3, filters, compression=compression, drop_out=drop_out)
    c4 = feature_fused_module(c4, filters, compression=compression, drop_out=drop_out)
    p4 = MaxPooling2D((2, 2))(c4)
    filters = 2 * filters

    cm = feature_fused_module(p4, filters, compression=compression, drop_out=drop_out)
    cm = feature_fused_module(cm, filters, compression=compression, drop_out=drop_out)

    filters = filters // 2

    u4 = Conv2DTranspose(filters, (2, 2), strides=(2, 2), padding='same')(cm)

    if attention_gates:
        u4 = concatenate([u4, attention_layer(u4, c4, 1)], axis=3)
    else:
        u4 = concatenate([u4, c4], axis=3)

    if half_net:
        c5 = conv_bn_act(u4, filters, drop_out=drop_out)
        c5 = conv_bn_act(c5, filters, drop_out=drop_out)
    else:
        c5 = feature_fused_module(u4, filters, compression=compression, drop_out=drop_out)
        c5 = feature_fused_module(c5, filters, compression=compression, drop_out=drop_out)

    filters = filters // 2

    u3 = Conv2DTranspose(filters, (2, 2), strides=(2, 2), padding='same')(c5)

    if attention_gates:
        u3 = concatenate([u3, attention_layer(u3, c3, 1)], axis=3)
    else:
        u3 = concatenate([u3, c3], axis=3)

    if half_net:
        c6 = conv_bn_act(u3, filters, drop_out=drop_out)
        c6 = conv_bn_act(c6, filters, drop_out=drop_out)
    else:
        c6 = feature_fused_module(u3, filters, compression=compression, drop_out=drop_out)
        c6 = feature_fused_module(c6, filters, compression=compression, drop_out=drop_out)

    filters = filters // 2

    u2 = Conv2DTranspose(filters, (2, 2), strides=(2, 2), padding='same')(c6)

    if attention_gates:
        u2 = concatenate([u2, attention_layer(u2, c2, 1)], axis=3)
    else:
        u2 = concatenate([u2, c2], axis=3)

    if half_net:
        c7 = conv_bn_act(u2, filters, drop_out=drop_out)
        c7 = conv_bn_act(c7, filters, drop_out=drop_out)

    else:
        c7 = feature_fused_module(u2, filters, compression=compression, drop_out=drop_out)
        c7 = feature_fused_module(c7, filters, compression=compression, drop_out=drop_out)

    filters = filters // 2

    u1 = Conv2DTranspose(filters, (2, 2), strides=(2, 2), padding='same')(c7)

    if attention_gates:
        u1 = concatenate([u1, attention_layer(u1, c1, 1)], axis=3)
    else:
        u1 = concatenate([u1, c1], axis=3)

    if half_net:
        c8 = conv_bn_act(u1, filters, drop_out=drop_out)
        c8 = conv_bn_act(c8, filters, drop_out=drop_out)
    else:
        c8 = feature_fused_module(u1, filters, compression=compression, drop_out=drop_out)
        c8 = feature_fused_module(c8, filters, compression=compression, drop_out=drop_out)

    c9 = Conv2D(1, (1, 1), padding="same", activation='sigmoid')(c8)

    return Model(inputs=[inputShape], outputs=[c9])

In [4]:
from preprocess.prepare_dataset import data_gen

out_path = '/home/quyet/DATA_ML/WorkSpace/segmentation/data/road_multi'
overlap_mask = os.path.join(out_path, 'mask_cut_crop')
train_dataset, valid_dataset, _, _ = data_gen(os.path.join(overlap_mask, '*.tif'), img_size=256, 
                                                            batch_size=2, N_CLASSES=1, numband=3, 
                                                            split_ratios=0.8, test_data=False, multi=False)

Training:validation = 1804:452


2022-07-11 08:21:15.029313: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 4806 MB memory:  -> device: 0, name: GeForce GTX 1060 6GB, pci bus id: 0000:05:00.0, compute capability: 6.1


In [6]:
from models import loss
from models.metrics import iou, dice_coef
from models.callback.save_best import SavebestweightsandEarlyStopping

model_name = 'ffnet'
mission = 'road'
img_size = 256
num_class = 1 
batch_size = 2

def lr_decay(epoch):
    initial_learningrate=1e-3
    if epoch < 1:
        return initial_learningrate
    else:
        return initial_learningrate * 0.9 ** (epoch)

if batch_size >1:
    val_batch_size = int(batch_size/2)
else:
    val_batch_size = batch_size
    
print("Init metric function")
if num_class==1:
    recall = tf.keras.metrics.Recall()
    precision = tf.keras.metrics.Precision()
    model_metrics = [precision, recall, dice_coef, iou, tf.keras.metrics.BinaryAccuracy(threshold=0.5)]
else:
    recall = tf.keras.metrics.Recall()
    precision = tf.keras.metrics.Precision()
    accuracy = tf.keras.metrics.CategoricalAccuracy()
    model_metrics = [precision, recall, dice_coef, iou, accuracy]
    
checkpoint_filepath= '/home/quyet/DATA_ML/Projects/segmentation/logs/tmp'
log_dir = '/home/quyet/DATA_ML/Projects/segmentation/logs/graph'
weights_path = '/home/quyet/DATA_ML/WorkSpace/segmentation/weights/%s/'%(model_name) +model_name+'_'+mission+'_'+str(img_size)+'_'+str(num_class)+'class.h5'
patience = 10
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_filepath, save_weights_only= True, 
                                                                monitor='val_loss', mode='min', save_best_only=True)
model_lrscheduler_callback = tf.keras.callbacks.LearningRateScheduler(lr_decay, verbose=1)
model_lrreduce_callback = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=patience, min_lr=1e-7, verbose=1)
model_earlystopping_callback = SavebestweightsandEarlyStopping(patience=patience, weights_path=weights_path)
model_endtrainnan_callback = tf.keras.callbacks.TerminateOnNaN()
model_tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, write_graph=True, write_images=True)
model_callbacks = [model_checkpoint_callback, model_lrscheduler_callback,
                    model_lrreduce_callback, model_earlystopping_callback,
                    model_tensorboard_callback,]

model = FF_UNet(attention_gates=True)

optimizer = tf.keras.optimizers.Adam()
model.compile(optimizer = optimizer, loss = loss.balanced_cross_entropy_loss,
             metrics = model_metrics)

# model.load_weights('/home/quyet/DATA_ML/WorkSpace/segmentation/weights/ffnet/ffnet_road_256_1class_val.h5')
history_train = model.fit(train_dataset, batch_size=batch_size, epochs=100, verbose=1, 
                      callbacks=model_callbacks, validation_data=valid_dataset, 
                      validation_batch_size=val_batch_size, use_multiprocessing=True)

Init metric function

Epoch 00001: LearningRateScheduler setting learning rate to 0.001.
Epoch 1/100


2022-07-11 08:21:56.383175: I tensorflow/stream_executor/cuda/cuda_dnn.cc:366] Loaded cuDNN version 8100

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.


    902/Unknown - 875s 937ms/step - loss: 0.1348 - precision_1: 0.9387 - recall_1: 0.7069 - dice_coef: 0.7654 - iou: 0.6222 - binary_accuracy: 0.7332
Save best train weights.
Save best val weights.

Epoch 00002: LearningRateScheduler setting learning rate to 0.0009000000000000001.
Epoch 2/100
Save best train weights.

Epoch 00003: LearningRateScheduler setting learning rate to 0.0008100000000000001.
Epoch 3/100
Save best train weights.

Epoch 00004: LearningRateScheduler setting learning rate to 0.0007290000000000002.
Epoch 4/100
Save best train weights.

Epoch 00005: LearningRateScheduler setting learning rate to 0.0006561000000000001.
Epoch 5/100
Save best train weights.

Epoch 00006: LearningRateScheduler setting learning rate to 0.00059049.
Epoch 6/100
Save best train weights.

Epoch 00007: LearningRateScheduler setting learning rate to 0.000531441.
Epoch 7/100
Save best train weights.
Save best val weights.

Epoch 00008: LearningRateScheduler setting learning rate to 0.00047829690

In [7]:
import cv2
import numpy as np

from tqdm import tqdm
from osgeo import gdal
from postprocess.convert_tif import dilation_obj, remove_small_items, write_image

def get_im_by_coord(org_im, start_x, start_y,num_band, padding, crop_size, input_size):
    startx = start_x-padding
    endx = start_x+crop_size+padding
    starty = start_y - padding
    endy = start_y+crop_size+padding
    result=[]
    img = org_im[starty:endy, startx:endx]
    img = img.swapaxes(2,1).swapaxes(1,0)
    for chan_i in range(num_band):
        result.append(cv2.resize(img[chan_i],(input_size, input_size), interpolation = cv2.INTER_CUBIC))
    return np.array(result).swapaxes(0,1).swapaxes(1,2)

def get_img_coords(w, h, padding, crop_size):
    new_w = w + 2*padding
    new_h = h + 2*padding
    cut_w = list(range(padding, new_w - padding, crop_size))
    cut_h = list(range(padding, new_h - padding, crop_size))

    list_hight = []
    list_weight = []
    for i in cut_h:
        if i < new_h - padding - crop_size:
            list_hight.append(i)
    list_hight.append(new_h-crop_size-padding)

    for i in cut_w:
        if i < new_w - crop_size - padding:
            list_weight.append(i)
    list_weight.append(new_w-crop_size-padding)

    img_coords = []
    for i in list_weight:
        for j in list_hight:
            img_coords.append([i, j])
    return img_coords

def padded_for_org_img(values, num_band, padding):
    padded_org_im = []
    for i in range(num_band):
        band = np.pad(values[i], padding, mode='reflect')
        padded_org_im.append(band)

    values = np.array(padded_org_im).swapaxes(0,1).swapaxes(1,2)
    print(values.shape)
    del padded_org_im
    return values

def predict(model, values, img_coords, num_band, h, w, padding, crop_size, 
            input_size, batch_size, thresh_hold, choose_stage):
    cut_imgs = []
    for i in range(len(img_coords)):
        im = get_im_by_coord(values, img_coords[i][0], img_coords[i][1],
                            num_band,padding, crop_size, input_size)
        cut_imgs.append(im)

    a = list(range(0, len(cut_imgs), batch_size))

    if a[len(a)-1] != len(cut_imgs):
        a[len(a)-1] = len(cut_imgs)

    y_pred = []
    for i in tqdm(range(len(a)-1)):
        x_batch = []
        x_batch = np.array(cut_imgs[a[i]:a[i+1]])
        # print(x_batch.shape)
        img_edge = []
        # for img_x in x_batch:
        #     lab_batch = color.rgb2lab(img_x)  
            # img_edge.append(cv2.Canny(np.asarray(np.uint8(lab_batch)),0,0)[..., np.newaxis])
        # print(img_edge.shape)
        # img_edge = np.array(img_edge)
        
        # print(x_batch.shape, img_edge.shape)
        # y_batch = model.predict((x_batch/255, img_edge/255))
        y_batch = model.predict(x_batch/255)
        if len(model.outputs)>1:
            y_batch = y_batch[choose_stage]
        mutilabel = False
        if y_batch.shape[-1]>=2:
            mutilabel = True
            y_batch = np.argmax(y_batch, axis=-1)
        # print(np.unique(y_batch), y_batch.shape)
            
        y_pred.extend(y_batch)
    big_mask = np.zeros((h, w)).astype(np.float16)
    for i in range(len(cut_imgs)):
        true_mask = y_pred[i].reshape((input_size,input_size))
        if not mutilabel:
            true_mask = (true_mask>thresh_hold).astype(np.uint8)
            true_mask = (cv2.resize(true_mask,(input_size, input_size), interpolation = cv2.INTER_CUBIC)>thresh_hold).astype(np.uint8)
            # true_mask = true_mask.astype(np.float16)
        start_x = img_coords[i][1]
        start_y = img_coords[i][0]
        big_mask[start_x-padding:start_x-padding+crop_size, start_y-padding:start_y -
                    padding+crop_size] = true_mask[padding:padding+crop_size, padding:padding+crop_size]
    del cut_imgs
    return big_mask

img_size = 256
num_band = 3
crop_size = 200
batch_size = 1
thresh_hold = 0.8
choose_stage = 0

model.load_weights('/home/quyet/DATA_ML/WorkSpace/segmentation/weights/ffnet/ffnet_road_256_1class_train.h5')
image_path = '/home/quyet/DATA_ML/Projects/road_multi/crop/img/test.tif'
dataset = gdal.Open(image_path)
values = dataset.ReadAsArray()[0:num_band]
h,w = values.shape[1:3]    
padding = int((img_size - crop_size)/2)
img_coords = get_img_coords(w, h, padding, crop_size)
values = padded_for_org_img(values, num_band, padding)
big_mask = predict(model, values, img_coords, num_band, h, w, padding, crop_size, 
                    img_size, batch_size, thresh_hold, choose_stage)

(6237, 6126, 3)


100%|█████████████████████████████████████████| 960/960 [01:44<00:00,  9.14it/s]


In [8]:
image_path = '/home/quyet/DATA_ML/Projects/road_multi/crop/img/test.tif'
result_path = write_image(image_path, big_mask)

Write image...
