#### Import packages

In [None]:
! pip install -U segmentation-models
! pip install q keras==2.3.1
! pip install tensorflow==2.1.0
! pip install Augmentor

import tensorflow as tf
import segmentation_models as sm
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Activation, BatchNormalization, Conv2D, concatenate, Conv2DTranspose, Dense, Dropout, Input, MaxPooling2D
from keras.callbacks import History
from tensorflow.keras.optimizers import * 

import PIL
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import cv2
import glob
import Augmentor
from natsort import natsorted
import os
import random
import requests
import zipfile

#### Drive mount
Only run if using Google Colab

In [None]:
from google.colab import drive
drive.mount('/content/drive')

#### Define constants

In [None]:
png_dim = 512

mask_png_dir = ""
sat_png_dir = ""

save_path = ""

#### Load training data
Load training data from defined input directories and augment with rotations, flips, and crops. 

In [None]:
ground_truth_images = natsorted(glob.glob(sat_png_dir + "/*.png"))
segmentation_mask_images = natsorted(glob.glob(mask_png_dir + "/*.png"))
collated_images_and_masks = list(zip(ground_truth_images, 
                                    segmentation_mask_images, ))
images = [[np.asarray(Image.open(y)) for y in x] for x in collated_images_and_masks]

p = Augmentor.DataPipeline(images)
p.rotate90(probability=0.5)
p.rotate270(probability=0.5)
p.flip_left_right(probability=0.8)
p.flip_top_bottom(probability=0.3)
p.crop_random(probability=1, percentage_area=0.5)
p.resize(probability=1.0, width=png_dim, height=png_dim)

augmented_images = p.sample(700)

r_index = 3
f, axarr = plt.subplots(1, 2, figsize=(6,4))
axarr[0].imshow(augmented_images[r_index][0])
axarr[1].imshow(augmented_images[r_index][1], cmap="gray")

X_data = []
Y_data = []
for i in range(len(augmented_images)):
    X_data.append(augmented_images[i][0])
    Y_data.append(augmented_images[i][1])

X_data = np.array(X_data)
Y_data = np.array(Y_data)
X_data = X_data[:,:,:,:3]
Y_data = Y_data[:,:,:,:1]

#### Train default pretrained U-Net
Train U-Net model on augmented data using a ResNet-34 encoder pretrained on ImageNet (Deng et al. 2009). Uses a binary cross-entropy loss and Adam optimizer.

In [None]:
model = sm.Unet('resnet34', classes=1, encoder_weights='imagenet', activation='sigmoid')
BACKBONE = 'resnet34'
preprocess_input = sm.get_preprocessing(BACKBONE)
# define model
model = sm.Unet(BACKBONE, classes=1, encoder_weights='imagenet')
model.compile(
    'Adam',
    loss=sm.losses.binary_crossentropy,
    metrics=[sm.metrics.iou_score, sm.metrics.precision, sm.metrics.recall],
)

x_train = preprocess_input(X_data)
y_train = preprocess_input(Y_data)
detection_model = model.fit(
    x=x_train,
    y=y_train,
    batch_size=16,
    epochs=60,
    validation_split=0.2
)

model.save(save_path)

#### Optional: Train custom U-Net 
If sufficient data is present, or different U-Net specifications are required (e.g. more image channels, different resolution), a custom U-Net may be modified and used.

In [4]:
def conv_layers(c, filters, size, act, k_init, pad):
  x = Conv2D(filters, size, activation=act, kernel_initializer=k_init, padding=pad)(c)
  x = BatchNormalization()(x)
  x = Dropout(0.1)(x)
  x = Conv2D(filters, size, activation=act, kernel_initializer=k_init, padding=pad)(x)
  return x

def custom_unet():

  in_layer = Input((128, 128, 3))

  c_1 = conv_layers(in_layer, 32, (3,3), 'relu', 'he_normal', 'same')
  c_1 = BatchNormalization()(c_1)
  p_1 = MaxPooling2D((2, 2)) (c_1)

  c_2 = conv_layers(p_1, 64, (3,3), 'relu', 'he_normal', 'same')
  c_2 = BatchNormalization()(c_2)
  p_2 = MaxPooling2D((2, 2)) (c_2)

  c_3 = conv_layers(p_2, 128, (3,3), 'relu', 'he_normal', 'same')
  c_3 = BatchNormalization()(c_3)
  p_3 = MaxPooling2D((2, 2)) (c_3)

  c_4 = conv_layers(p_3, 256, (3,3), 'relu', 'he_normal', 'same')
  c_4 = BatchNormalization()(c_4)
  p_4 = MaxPooling2D((2, 2)) (c_4)

  c_5 = conv_layers(p_4, 512, (3,3), 'relu', 'he_normal', 'same')
  c_5 = BatchNormalization()(c_5)

  up_6 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same') (c_5)
  up_6 = concatenate([up_6, c_4])

  c_6 = conv_layers(up_6, 256, (3,3), 'relu', 'he_normal', 'same')
  c_6 = BatchNormalization()(c_6)

  u_7 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same') (c_6)
  u_7 = concatenate([u_7, c_3])

  c_7 = conv_layers(u_7, 128, (3,3), 'relu', 'he_normal', 'same')
  c_7 = BatchNormalization()(c_7)

  u_8 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same') (c_7)
  u_8 = concatenate([u_8, c_2])

  c_8 = conv_layers(u_8, 64, (3,3), 'relu', 'he_normal', 'same')
  c_8 = BatchNormalization()(c_8)

  u_9 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same') (c_8)
  u_9 = concatenate([u_9, c_1], axis=3)

  c_9 = conv_layers(u_9, 256, (3,3), 'relu', 'he_normal', 'same')
  c_9 = BatchNormalization()(c_9)

  out_layer = Conv2D(1, (1, 1), activation='sigmoid')(c_9)
  model = Model(inputs=[in_layer], outputs=[out_layer], name="satellite_unet")
  model.summary()

  return model 

In [5]:
model = custom_unet()

Model: "satellite_unet"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 128, 128, 3) 0                                            
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 128, 128, 32) 896         input_2[0][0]                    
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 128, 128, 32) 128         conv2d_10[0][0]                  
__________________________________________________________________________________________________
dropout_5 (Dropout)             (None, 128, 128, 32) 0           batch_normalization_10[0][0]     
_____________________________________________________________________________________

In [13]:
def compile_model(learning_rate, pt_weights):
  model.compile(optimizer=Adam(lr=learning_rate), loss="binary_crossentropy")
  if(pt_weights):
    model.load_weights(pt_weights)