In [None]:
# define functions and callback to display sample prediction
# adapted from https://www.tensorflow.org/tutorials/images/segmentation

def display(display_list):
    plt.figure(figsize=(18, 8))
    title = ['Input Image', 'True Mask', 'Predicted Mask']  
    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()
    
def create_mask(pred_mask):
    pred_mask = tf.math.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]

def show_predictions(dataset=None, num=1):
    if dataset:
        for image, mask in dataset.take(num):
            pred_mask = model.predict(image)
            display([image[0], mask[0], create_mask(pred_mask)])
    else:
        pred_mask = create_mask(model.predict(sample_image[tf.newaxis, ...]))
        display([sample_image, seg2rgb(sample_mask, NUM_CLASSES), seg2rgb(pred_mask, NUM_CLASSES)])
        
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        clear_output(wait=True)
        show_predictions()
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))

def build_unet(n_classes):
    
    inputs = Input(shape=INPUT_SHAPE)
    
    # encoder: contracting; downsample
    # 1 - downsample
    f1, p1 = downsample_block(inputs, 64)
    # 2 - downsample
    f2, p2 = downsample_block(p1, 128)
    # 3 - downsample
    f3, p3 = downsample_block(p2, 256)
    # 4 - downsample
    f4, p4 = downsample_block(p3, 512)
    
    # 5 - bottleneck
    bottleneck = double_conv_block(p4, 1024)
    
    # decoder: expanding; upsample
    # 6 - upsample
    u6 = upsample_block(bottleneck, f4, 512)
    # 7 - upsample
    u7 = upsample_block(u6, f3, 256)
    # 8 - upsample
    u8 = upsample_block(u7, f2, 128)
    # 9 - upsample
    u9 = upsample_block(u8, f1, 64)
    
    # outputs
    outputs = Conv2D(n_classes, 1, padding="same", activation = "softmax")(u9)
    
    # U-Net model with Keras Functional API
    unet_model = tf.keras.Model(inputs, outputs, name="U-Net")
    
    return unet_model

In [1]:
NUM_EPOCHS = 20

model_history = model.fit(x =  X_tr, y = y_tr,
                          batch_size=BATCH_SIZE,
                          epochs=NUM_EPOCHS,
                          validation_data=(X_te, y_te),
                          verbose=1,
                          callbacks=[DisplayCallback()])

NameError: name 'model' is not defined

### sklearn miou
requires `run_eagerly=True` in `model.compile()` which increases running time 3x

In [1]:
# define custom loss and metrics

# use miou/jaccard score for multiclass segmentation;
# our implementation of binary counts as multiclass
# because it keeps two separate channels

# tensorflow's implementation is here: https://www.tensorflow.org/api_docs/python/tf/keras/metrics/MeanIoU
# and here's a callback for the sklearn one but it seems like a lot:
# https://keras.io/examples/keras_recipes/sklearn_metric_callbacks/

from sklearn.metrics import jaccard_score

def raw_miou(y_true, y_pred):
    """Returns raw per-class average."""
    raw_miou = jaccard_score(
        y_true=y_true.numpy().flatten(),
        y_pred=np.argmax(y_pred.numpy(), axis=-1).flatten(),
        average='macro')
    return raw_miou

def weighted_miou(y_true, y_pred):
    """Returns weighted average accounting for class imbalance."""
    weighted_miou = jaccard_score(
        y_true=y_true.numpy().flatten(),
        y_pred=np.argmax(y_pred.numpy(), axis=-1).flatten(),
        average='weighted')
    return weighted_miou

In [None]:
# same as above but with tensorflow argmax and flatten

def raw_miou(y_true, y_pred):
    """Returns raw per-class average."""
    raw_miou = jaccard_score(
        y_true=tf.reshape(y_true, [-1]).numpy(),
        y_pred=tf.reshape(tf.math.argmax(y_pred, axis=-1), [-1]).numpy(),
        average='macro')
    return raw_miou

def weighted_miou(y_true, y_pred):
    """Returns weighted average accounting for class imbalance."""
    weighted_miou = jaccard_score(
        y_true=tf.reshape(y_true, [-1]).numpy(),
        y_pred=tf.reshape(tf.math.argmax(y_pred, axis=-1), [-1]).numpy(),
        average='weighted')
    return weighted_miou