# This notebook was used in the Kaggle competition "tgs-salt-identification-challenge" for image segmentation
## The model takes a 101x101(x1 = grayscale) image + 1 feature and outputs an 101x101 image binary mask

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy
from PIL import Image
from scipy import ndimage
import pandas as pd
import os

from tqdm import tqdm_notebook
#from itertools import chain
from skimage.io import imread, imshow
from skimage.transform import resize
from skimage.morphology import label

from keras.preprocessing.image import load_img

## Get the training and test data

In [None]:
train_img_dir = 'train/images/'
train_mask_dir = 'train/masks/'
test_img_dir = 'test/images/'

In [None]:
train_img_names = [x.split('.')[0] for x in os.listdir(train_img_dir)]
test_img_names = [x.split('.')[0] for x in os.listdir(test_img_dir)]

In [None]:
train_df = pd.read_csv("train.csv", index_col="id", usecols=[0])

In [None]:
train_df = pd.read_csv("train.csv", index_col="id", usecols=[0])
depths_df = pd.read_csv("depths.csv", index_col="id")
train_df = train_df.join(depths_df)
test_df = depths_df[~depths_df.index.isin(train_df.index)]

print("Number of training samples:" , len( train_df) ,"\n", "Number of test samples:" , len(test_df) )

In [None]:
train_df["images"] = [np.array(load_img("train/images/{}.png".format(idx), grayscale=True)) / 255 for idx in train_img_names]

In [None]:
train_df["masks"] = [np.array(load_img("train/masks/{}.png".format(idx), grayscale=True)) / 255 for idx in train_img_names]

In [None]:
#is the mask empyty? i.e. is there any salt at all?
saltarray = []
for index, row in train_df.masks.iteritems():
    saltarray.append(1 if row.any() else 0)
train_df["salt"]= saltarray 

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
# introduced it to test other image sizes
# this can be useful when one is going to use a model with pre-trained weights, which requires a certain size
img_size_ori = 101
img_size_target = 101

def upsample(img):# not used
    if img_size_ori == img_size_target:
        return img
    return resize(img, (img_size_target, img_size_target), mode='constant', preserve_range=True)
    
def downsample(img):# not used
    if img_size_ori == img_size_target:
        return img
    return resize(img, (img_size_ori, img_size_ori), mode='constant', preserve_range=True)

In [None]:
x_train, x_valid, y_train, y_valid, depth_train, depth_valid, salt_train, salt_valid = train_test_split(
    np.array(train_df.images.map(upsample).tolist()).reshape(-1, img_size_target, img_size_target, 1),
    np.array(train_df.masks.map(upsample).tolist()).reshape(-1, img_size_target, img_size_target, 1),
    train_df.z.values,
    train_df.salt.values,
    test_size=0.1, random_state= 1)

In [None]:
# Normalize depths
depth_train_mean = depth_train.mean(axis=0, keepdims=True)
depth_train_std = depth_train.std(axis=0, keepdims=True)

depth_train = depth_train.astype("float64")
depth_train -= depth_train_mean
depth_train /= depth_train_std

depth_valid = depth_valid.astype("float64")
depth_valid -= depth_train_mean
depth_valid /= depth_train_std

## Augmentation of the data 

In [None]:
# Augmenting the data
# Typically a dataset of pics can be extended by symmetry operations, such as mirroring, rotation, adding noise, etc. 
# here is choose just all possibilities to mirror the images
# so in total the train a validation data will be enlarged by a factor 4
x_train = np.append(x_train, [np.fliplr(x) for x in x_train], axis=0)
x_train = np.append(x_train, [np.flipud(x) for x in x_train], axis=0)
y_train = np.append(y_train, [np.fliplr(x) for x in y_train], axis=0)
y_train = np.append(y_train, [np.flipud(x) for x in y_train], axis=0)
x_valid = np.append(x_valid, [np.fliplr(x) for x in x_valid], axis=0)
x_valid = np.append(x_valid, [np.flipud(x) for x in x_valid], axis=0)
y_valid = np.append(y_valid, [np.fliplr(x) for x in y_valid], axis=0)
y_valid = np.append(y_valid, [np.flipud(x) for x in y_valid], axis=0)
depth_train = np.append(depth_train, depth_train, axis=0)
depth_train = np.append(depth_train, depth_train, axis=0)
depth_valid = np.append(depth_valid, depth_valid, axis=0)
depth_valid = np.append(depth_valid, depth_valid, axis=0)
salt_train = np.append(salt_train, salt_train, axis=0)
salt_train = np.append(salt_train, salt_train, axis=0)
salt_valid = np.append(salt_valid, salt_valid, axis=0)
salt_valid = np.append(salt_valid, salt_valid, axis=0)
##maybe I should also try things like rot90, or scaling
##does not use to much...
print(x_train.shape)
print(y_valid.shape)

## Now let me define the model

In [None]:
from keras.models import Model, load_model, save_model
from keras.layers import Input,Dropout,BatchNormalization,Activation,Add, ZeroPadding2D
from keras.layers.core import Lambda, RepeatVector, Reshape, Flatten, Dense
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPooling2D
from keras.layers.merge import concatenate
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from keras import backend as K
from keras import optimizers

import tensorflow as tf

from keras.preprocessing.image import array_to_img, img_to_array, load_img#,save_img

In [None]:
def BatchActivate(x):
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    return x

def convolution_block(x, filters, size, strides=(1,1), padding='same', activation=True):
    x = Conv2D(filters, size, strides=strides, padding=padding)(x)
    if activation == True:
        x = BatchActivate(x)
    return x

def residual_block(blockInput, num_filters=16, batch_activate = False):
    x = BatchActivate(blockInput)
    x = convolution_block(x, num_filters, (3,3) )
    x = convolution_block(x, num_filters, (3,3), activation=False)
    x = Add()([x, blockInput])
    if batch_activate:
        x = BatchActivate(x)
    return x

### The metric for the competition was Intersection over Union (IoU)

In [None]:
def get_iou_vector(A, B):
    batch_size = A.shape[0]
    metric = []
    for batch in range(batch_size):
        t, p = A[batch]>0, B[batch]>0
#         if np.count_nonzero(t) == 0 and np.count_nonzero(p) > 0:
#             metric.append(0)
#             continue
#         if np.count_nonzero(t) >= 1 and np.count_nonzero(p) == 0:
#             metric.append(0)
#             continue
#         if np.count_nonzero(t) == 0 and np.count_nonzero(p) == 0:
#             metric.append(1)
#             continue
        
        intersection = np.logical_and(t, p)
        union = np.logical_or(t, p)
        iou = (np.sum(intersection > 0) + 1e-10 )/ (np.sum(union > 0) + 1e-10)
        thresholds = np.arange(0.5, 1, 0.05)
        s = []
        for thresh in thresholds:
            s.append(iou > thresh)
        metric.append(np.mean(s))

    return np.mean(metric)

def my_iou_metric(label, pred):
    return tf.py_func(get_iou_vector, [label, pred>0.5], tf.float64)


In [None]:
# Build a U-Net model: a encoder + decoder architecture, see https://arxiv.org/abs/1505.04597
# I will input the additional feature (=depth) in the middle layer via concatenation

start_neurons = 32
DropoutRatio = 0.5
input_img = Input((101,101,1), name='img')
input_features = Input((1, ), name='feat')

conv1 = Conv2D(start_neurons * 1, (3, 3), activation=None, padding="same")(input_img)
conv1 = residual_block(conv1,start_neurons * 1)
conv1 = residual_block(conv1,start_neurons * 1, True)
pool1 = MaxPooling2D((2, 2))(conv1)
pool1 = Dropout(DropoutRatio/2)(pool1)

# 50 -> 25
conv2 = Conv2D(start_neurons * 2, (3, 3), activation=None, padding="same")(pool1)
conv2 = residual_block(conv2,start_neurons * 2)
conv2 = residual_block(conv2,start_neurons * 2, True)
pool2 = MaxPooling2D((2, 2))(conv2)
pool2 = Dropout(DropoutRatio)(pool2)

# 25 -> 12

conv3 = Conv2D(start_neurons * 4, (3, 3), activation=None, padding="same")(pool2)
conv3 = residual_block(conv3,start_neurons * 4)
conv3 = residual_block(conv3,start_neurons * 4, True)
pool3 = MaxPooling2D((2, 2))(conv3)    
pool3 = Dropout(DropoutRatio)(pool3)

# 12 -> 6

conv4 = Conv2D(start_neurons * 8, (3, 3), activation=None, padding="same")(pool3)

conv4 = residual_block(conv4,start_neurons * 8)

conv4 = residual_block(conv4,start_neurons * 8, True)

pool4 = MaxPooling2D((2, 2))(conv4)

pool4 = Dropout(DropoutRatio)(pool4)

#add depth

f_repeat = RepeatVector(6*6)(input_features)
f_conv = Reshape((6, 6, 1))(f_repeat)
pool4_feat = concatenate([pool4, f_conv], -1)

# Middle

convm = Conv2D(start_neurons * 16, (3, 3), activation=None, padding="same")(pool4_feat)

convm = residual_block(convm,start_neurons * 16)

convm = residual_block(convm,start_neurons * 16, True)
    

    # 6 -> 12
deconv4 = Conv2DTranspose(start_neurons * 8, (3, 3), strides=(2, 2), padding="same")(convm)

uconv4 = concatenate([deconv4, conv4])

uconv4 = Dropout(DropoutRatio)(uconv4)
    
uconv4 = Conv2D(start_neurons * 8, (3, 3), activation=None, padding="same")(uconv4)

uconv4 = residual_block(uconv4,start_neurons * 8)

uconv4 = residual_block(uconv4,start_neurons * 8, True)
    
    # 12 -> 25
#deconv3 = Conv2DTranspose(start_neurons * 4, (3, 3), strides=(2, 2), padding="same")(uconv4)

deconv3 = Conv2DTranspose(start_neurons * 4, (3, 3), strides=(2, 2), padding="valid")(uconv4)

uconv3 = concatenate([deconv3, conv3])    

uconv3 = Dropout(DropoutRatio)(uconv3)
    
uconv3 = Conv2D(start_neurons * 4, (3, 3), activation=None, padding="same")(uconv3)

uconv3 = residual_block(uconv3,start_neurons * 4)

uconv3 = residual_block(uconv3,start_neurons * 4, True)

    # 25 -> 50
deconv2 = Conv2DTranspose(start_neurons * 2, (3, 3), strides=(2, 2), padding="same")(uconv3)

uconv2 = concatenate([deconv2, conv2])
        
uconv2 = Dropout(DropoutRatio)(uconv2)

uconv2 = Conv2D(start_neurons * 2, (3, 3), activation=None, padding="same")(uconv2)

uconv2 = residual_block(uconv2,start_neurons * 2)

uconv2 = residual_block(uconv2,start_neurons * 2, True)
    
    # 50 -> 101
    #deconv1 = Conv2DTranspose(start_neurons * 1, (3, 3), strides=(2, 2), padding="same")(uconv2)

deconv1 = Conv2DTranspose(start_neurons * 1, (3, 3), strides=(2, 2), padding="valid")(uconv2)

uconv1 = concatenate([deconv1, conv1])
    
uconv1 = Dropout(DropoutRatio)(uconv1)

uconv1 = Conv2D(start_neurons * 1, (3, 3), activation=None, padding="same")(uconv1)

uconv1 = residual_block(uconv1,start_neurons * 1)

uconv1 = residual_block(uconv1,start_neurons * 1, True)
    
#uconv1 = Dropout(DropoutRatio/2)(uconv1)
#output_layer = Conv2D(1, (1,1), padding="same", activation="sigmoid")(uconv1)

output_layer_noActi = Conv2D(1, (1,1), padding="same", activation=None)(uconv1)

output_layer =  Activation('sigmoid')(output_layer_noActi)


model = Model(inputs=[input_img, input_features], outputs=[output_layer])
model.compile(optimizer='adam', loss='binary_crossentropy',metrics=[my_iou_metric]) 
model.summary()

## Train the model

In [None]:
#Train the model
early_stopping = EarlyStopping(monitor='my_iou_metric', mode = 'max',patience=20, verbose=1)
model_checkpoint = ModelCheckpoint("./unet_best.model",monitor='my_iou_metric', 
                                   mode = 'max', save_best_only=True, verbose=1)
reduce_lr = ReduceLROnPlateau(monitor='my_iou_metric', mode = 'max', factor=0.5, patience=5, min_lr=0.0001, verbose=1)
#reduce_lr = ReduceLROnPlateau(factor=0.2, patience=5, min_lr=0.00001, verbose=1)

epochs = 50
batch_size = 32
history = model.fit({"img":np.concatenate((x_train,y_train)), 
                     "feat":np.concatenate((depth_train,depth_valid))},  
                    np.concatenate((y_train,y_valid)),
                    #validation_data=({'img': x_valid, 'feat': depth_valid}, y_valid), 
                    epochs=epochs,
                    batch_size=batch_size,
                    callbacks=[early_stopping, model_checkpoint, reduce_lr], 
                    verbose=2)

In [None]:
# Let me have a look at the loss and metric during training
fig, (ax_loss, ax_score) = plt.subplots(1, 2, figsize=(15,5))
ax_loss.plot(history.epoch, history.history["loss"], label="Train loss")
ax_loss.plot(history.epoch, history.history["val_loss"], label="Validation loss")
ax_loss.legend()
ax_score.plot(history.epoch, history.history["my_iou_metric"], label="Train score")
ax_score.plot(history.epoch, history.history["val_my_iou_metric"], label="Validation score")
ax_score.legend()

## Evaluate the model and find optimal threshold for predictions

In [None]:
model = load_model("unet_best.model", custom_objects={"my_iou_metric":my_iou_metric} )

In [None]:
# reflection seems to help quite a bit, but not too much -> almost no difference whether you 
# flip just "lr" or in addition also "ud"...  
def predict_result(model,test_data,img_size_target): # predict both orginal and reflect x
    x_test = test_data["img"]
    preds_test = model.predict(test_data).reshape(-1, img_size_target, img_size_target)
    x_test_reflect =  np.array([np.fliplr(x) for x in x_test])
    preds_test_reflect = model.predict({"img": x_test_reflect,"feat": test_data["feat"]}).reshape(-1, img_size_target, img_size_target)
    #x_test_reflect2 =  np.array([np.flipud(x) for x in x_test])
    #preds_test_reflect2 = model.predict({"img": x_test_reflect2,"feat": test_data["feat"]}).reshape(-1, img_size_target, img_size_target)
    #x_test_reflect3 =  np.array([np.flipud(x) for x in x_test_reflect])
    #preds_test_reflect3 = model.predict({"img": x_test_reflect3,"feat": test_data["feat"]}).reshape(-1, img_size_target, img_size_target)
    preds_test += np.array([ np.fliplr(x) for x in preds_test_reflect] )
    #preds_test += np.array([ np.flipud(x) for x in preds_test_reflect2] )
    #preds_test += np.array([ np.fliplr(np.flipud(x)) for x in preds_test_reflect3] )
    return preds_test/2 #4

## Let me now work on the optimal threshold. Many other competitors do the threshold optimization on the validation data. Let me try this as well and compare with a more fancy idea: a dynamical threshold, see later...

In [None]:
#Score the model and do a threshold optimization by the best IoU.

# taken from src: https://www.kaggle.com/aglotero/another-iou-metric
def iou_metric(y_true_in, y_pred_in, print_table=False):
    labels = y_true_in
    y_pred = y_pred_in


    true_objects = 2
    pred_objects = 2

    #  if all zeros, original code  generate wrong  bins [-0.5 0 0.5],
    temp1 = np.histogram2d(labels.flatten(), y_pred.flatten(), bins=([0,0.5,1], [0,0.5, 1]))
#     temp1 = np.histogram2d(labels.flatten(), y_pred.flatten(), bins=(true_objects, pred_objects))
    #print(temp1)
    intersection = temp1[0]
    #print("temp2 = ",temp1[1])
    #print(intersection.shape)
   # print(intersection)
    # Compute areas (needed for finding the union between all objects)
    #print(np.histogram(labels, bins = true_objects))
    area_true = np.histogram(labels,bins=[0,0.5,1])[0]
    #print("area_true = ",area_true)
    area_pred = np.histogram(y_pred, bins=[0,0.5,1])[0]
    area_true = np.expand_dims(area_true, -1)
    area_pred = np.expand_dims(area_pred, 0)

    # Compute union
    union = area_true + area_pred - intersection
  
    # Exclude background from the analysis
    intersection = intersection[1:,1:]
    intersection[intersection == 0] = 1e-9
    
    union = union[1:,1:]
    union[union == 0] = 1e-9

    # Compute the intersection over union
    iou = intersection / union

    # Precision helper function
    def precision_at(threshold, iou):
        matches = iou > threshold
        true_positives = np.sum(matches, axis=1) == 1   # Correct objects
        false_positives = np.sum(matches, axis=0) == 0  # Missed objects
        false_negatives = np.sum(matches, axis=1) == 0  # Extra objects
        tp, fp, fn = np.sum(true_positives), np.sum(false_positives), np.sum(false_negatives)
        return tp, fp, fn

    # Loop over IoU thresholds
    prec = []
    if print_table:
        print("Thresh\tTP\tFP\tFN\tPrec.")
    for t in np.arange(0.5, 1.0, 0.05):
        tp, fp, fn = precision_at(t, iou)
        if (tp + fp + fn) > 0:
            p = tp / (tp + fp + fn)
        else:
            p = 0
        if print_table:
            print("{:1.3f}\t{}\t{}\t{}\t{:1.3f}".format(t, tp, fp, fn, p))
        prec.append(p)
    
    if print_table:
        print("AP\t-\t-\t-\t{:1.3f}".format(np.mean(prec)))
    return np.mean(prec)

def iou_metric_batch(y_true_in, y_pred_in):
    batch_size = y_true_in.shape[0]
    metric = []
    for batch in range(batch_size):
        value = iou_metric(y_true_in[batch], y_pred_in[batch])
        metric.append(value)
    return np.mean(metric)



In [None]:
preds_valid = predict_result(model,{"img":x_valid,"feat":depth_valid},img_size_target)

In [None]:
## Scoring for last model, choose threshold by validation data 
thresholds = np.linspace(0.2, 0.8, 30)

ious = np.array([iou_metric_batch(y_valid, preds_valid > threshold) for threshold in tqdm_notebook(thresholds)])
print(ious)

In [None]:
# instead of using default 0.5 as threshold, use validation data to find the bewith a small bottleneckst threshold.
threshold_best_index = np.argmax(ious) 
iou_best = ious[threshold_best_index]
threshold_best = thresholds[threshold_best_index]

plt.plot(thresholds, ious)
plt.plot(threshold_best, iou_best, "xr", label="Best threshold")
plt.xlabel("Threshold")
plt.ylabel("IoU")
plt.title("Threshold vs IoU ({}, {})".format(threshold_best, iou_best))
plt.legend()

## Alternatively I want to try something else: Let me introduce a threshold that reflects the probabilty that there is salt in the image at all. Thus I build a second model - a binary classifier - that predicts whether there is salt in the image at all

In [None]:
## Let me write a binary classifier that decides wheter there is salt or not in the image


def SaltModel(input_shape):
    
    X_input = Input(input_shape)

    # Zero-Padding: pads the border of X_input with zeroes
    X = ZeroPadding2D((3, 3))(X_input)

    # CONV -> BN -> RELU Block applied to X
    X = Conv2D(32, (7, 7), strides = (1, 1), name = 'conv0')(X)
    X = BatchNormalization(axis = 3, name = 'bn0')(X) #really this axis??
    X = Activation('relu')(X)
    
    # MORE CONVS
    X = MaxPooling2D((2, 2))(X)
    #shortcut = X
    X = Conv2D(32, (3, 3), strides = (1, 1), padding="same")(X)
    X = BatchNormalization()(X)
    X = Activation('relu')(X)
    X = Conv2D(32, (3, 3), strides = (1, 1), padding="same")(X)
    X = BatchNormalization()(X)
    #X = layers.add([X, shortcut])
    X = Activation('relu')(X)

    # MAXPOOL
    X = MaxPooling2D((2, 2), name='max_pool')(X)

    # FLATTEN X + FULLYCONNECTED
    X = Flatten()(X)
    
    # MORE DENSE
    X = Dense(128)(X)
    X = BatchNormalization()(X)
    X = Activation('relu')(X)
    X = Dropout(0.5)(X)
    
    X = Dense(1, activation='sigmoid', name='fc')(X)

    model = Model(inputs = X_input, outputs = X)
        
    return model

In [None]:
saltmodel = SaltModel((101,101,1))
saltmodel.compile( optimizer = "Adam", loss='binary_crossentropy', metrics = ["accuracy"] )

In [None]:
history = saltmodel.fit(x = x_train, 
                        y = salt_train,
                        validation_data=(x_valid, salt_valid),
                        epochs=10,
                        batch_size= 64,
                        verbose =2)

In [None]:
#first see how this performs on the validation data
salt_threshold_valid = saltmodel.predict(x_valid)

In [None]:
metric=[]
for batch in range(800):
        value = iou_metric(y_valid[batch],preds_valid[batch] > 0.5-0.2*(salt_threshold_valid[batch]) )
        metric.append(value)
np.mean(metric)

## Now comes the test data for the competition

In [None]:
x_test = np.array([(np.array(load_img("test/images/{}.png".format(idx), grayscale = True))) / 255 for idx in test_img_names]).reshape(-1, img_size_target, img_size_target, 1)

In [None]:
depth_test = depths_df[depths_df.index.isin(test_df.index)].z.values
depth_test = depth_test.astype("float64")
depth_test -= depth_train_mean
depth_test /= depth_train_std


In [None]:
salt_threshold_test = saltmodel.predict(x_test)

In [None]:
preds_test = predict_result(model,{'img': x_test, 'feat': depth_test},img_size_target)

In [None]:
"""
used for converting the decoded image to rle mask
as required for competition submission
"""
def rle_encode(im):
    '''
    im: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = im.flatten(order = 'F')
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [None]:
## Try for both thresholds: the dynamic one seems to perform a little bit better, but almost no difference
#pred_dict = {idx: rle_encode(np.round(downsample(preds_test[i]) > threshold_best )) for i, idx in enumerate(test_img_names) }
pred_dict = {idx: rle_encode(np.round(downsample(preds_test[i]) > (0.5-0.2*(salt_threshold_test[i]-0.5)) ) ) for i, idx in enumerate(test_img_names)}

## Make the submission file & submit it

In [None]:
sub = pd.DataFrame.from_dict(pred_dict,orient='index')
sub.index.names = ['id']
sub.columns = ['rle_mask']
sub.to_csv('submission.csv')

In [None]:
!kaggle competitions submit -c tgs-salt-identification-challenge -f "./submission.csv" -m "unet with depth in middle layer, dyn. threshold opt., augmentation + avg in test set"

# Now let me redo the same for a different loss function -- so-called lovasz-loss -- that is used by most of the other competitors

In [None]:
# code download from: https://github.com/bermanmaxim/LovaszSoftmax
def lovasz_grad(gt_sorted):
    """
    Computes gradient of the Lovasz extension w.r.t sorted errors
    See Alg. 1 in paper
    """
    gts = tf.reduce_sum(gt_sorted)
    intersection = gts - tf.cumsum(gt_sorted)
    union = gts + tf.cumsum(1. - gt_sorted)
    jaccard = 1. - intersection / union
    jaccard = tf.concat((jaccard[0:1], jaccard[1:] - jaccard[:-1]), 0)
    return jaccard


# --------------------------- BINARY LOSSES ---------------------------

def lovasz_hinge(logits, labels, per_image=True, ignore=None):
    """
    Binary Lovasz hinge loss
      logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
      labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
      per_image: compute the loss per image instead of per batch
      ignore: void class id
    """
    if per_image:
        def treat_image(log_lab):
            log, lab = log_lab
            log, lab = tf.expand_dims(log, 0), tf.expand_dims(lab, 0)
            log, lab = flatten_binary_scores(log, lab, ignore)
            return lovasz_hinge_flat(log, lab)
        losses = tf.map_fn(treat_image, (logits, labels), dtype=tf.float32)
        loss = tf.reduce_mean(losses)
    else:
        loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
    return loss


def lovasz_hinge_flat(logits, labels):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore: label to ignore
    """

    def compute_loss():
        labelsf = tf.cast(labels, logits.dtype)
        signs = 2. * labelsf - 1.
        errors = 1. - logits * tf.stop_gradient(signs)
        errors_sorted, perm = tf.nn.top_k(errors, k=tf.shape(errors)[0], name="descending_sort")
        gt_sorted = tf.gather(labelsf, perm)
        grad = lovasz_grad(gt_sorted)
        loss = tf.tensordot(tf.nn.relu(errors_sorted), tf.stop_gradient(grad), 1, name="loss_non_void")
        return loss

    # deal with the void prediction case (only void pixels)
    loss = tf.cond(tf.equal(tf.shape(logits)[0], 0),
                   lambda: tf.reduce_sum(logits) * 0.,
                   compute_loss,
                   strict=True,
                   name="loss"
                   )
    return loss


def flatten_binary_scores(scores, labels, ignore=None):
    """
    Flattens predictions in the batch (binary case)
    Remove labels equal to 'ignore'
    """
    scores = tf.reshape(scores, (-1,))
    labels = tf.reshape(labels, (-1,))
    if ignore is None:
        return scores, labels
    valid = tf.not_equal(labels, ignore)
    vscores = tf.boolean_mask(scores, valid, name='valid_scores')
    vlabels = tf.boolean_mask(labels, valid, name='valid_labels')
    return vscores, vlabels

def lovasz_loss(y_true, y_pred):
    y_true, y_pred = K.cast(K.squeeze(y_true, -1), 'int32'), K.cast(K.squeeze(y_pred, -1), 'float32')
    #logits = K.log(y_pred / (1. - y_pred))
    logits = y_pred #Jiaxin
    loss = lovasz_hinge(logits, y_true, per_image = True, ignore = None)
    return loss

In [None]:
# this loss only works in range (-∞，+∞), so adjust the threshold in the metric
def my_iou_metric2(label, pred):
    return tf.py_func(get_iou_vector, [label, pred>0], tf.float64)

In [None]:
# remove last activation layer and use losvasz loss
input_x = model.layers[0].input
input_xx = model.layers[77].input

output_layer = model.layers[-1].input
modelCONT = Model(inputs=[input_x, input_xx], outputs=[output_layer])
c = optimizers.adam(lr = 0.01)

# lovasz_loss need input range (-∞，+∞), so cancel the last "sigmoid" activation  
# Then the default threshod for pixel prediction is 0 instead of 0.5, as in my_iou_metric_2.
modelCONT.compile(loss=lovasz_loss, optimizer=c, metrics=[my_iou_metric_2])

model.summary()


## takes very long to train, forget this model

In [None]:
early_stopping = EarlyStopping(monitor='val_my_iou_metric_2', mode = 'max',patience=20, verbose=1)
model_checkpoint = ModelCheckpoint("./unet_bestCONT.model",monitor='val_my_iou_metric_2', 
                                   mode = 'max', save_best_only=True, verbose=1)
reduce_lr = ReduceLROnPlateau(monitor='val_my_iou_metric_2', mode = 'max',factor=0.5, patience=5, min_lr=0.0001, verbose=1)
epochs = 20
batch_size = 32

history = modelCONT.fit({'img': x_train, 'feat': depth_train}, 
                    y_train,
                    validation_data=({'img': x_valid, 'feat': depth_valid}, y_valid), 
                    epochs=epochs,
                    batch_size=batch_size,
                    callbacks=[early_stopping, model_checkpoint, reduce_lr], 
                    verbose=2)