In [None]:
import tensorflow as tf
import numpy as np
import h5py
import matplotlib.pyplot as plt 
import seaborn
%matplotlib inline 

In [None]:
# if want to run by CPU
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"] = "7"

In [None]:
f_train = h5py.File('C:/Users/AlexZheng/Downloads/train.h5', 'r')   
f_test = h5py.File('C:/Users/AlexZheng/Downloads/test.h5', 'r')

In [None]:
rgb_train = np.array(f_train['rgb'])               
seg_train = np.array(f_train['seg'])
color_codes = np.array(f_train['color_codes'])

rgb_test = np.array(f_test['rgb'])
seg_test = np.array(f_test['seg'])

In [None]:
from tensorflow import keras
from tensorflow.keras.layers import Input,Conv2D, MaxPooling2D, Dropout, Conv2DTranspose, concatenate, BatchNormalization


filters = 16
Dropout_rate = 0.2

def Conv_Block(filters, input):
    conv = Conv2D(filters=filters, kernel_size=(3,3), padding='same', activation= 'relu')(input)
    conv = BatchNormalization()(conv)
    conv = Conv2D(filters=filters, kernel_size=(3,3), padding='same', activation= 'relu')(conv)
    conv = BatchNormalization()(conv)
    pool = MaxPooling2D(pool_size=(2,2))(conv)
    pool = Dropout(Dropout_rate)(pool, training=True)
    
    return conv,pool

def Deconv_Block(filters, conv, input):
    deconv = Conv2DTranspose(filters=8*filters, kernel_size=(3,3), padding='same',  strides=(2,2))(input)
    deconv = BatchNormalization()(deconv)
    conca = concatenate([deconv, conv])
    conca = Dropout(Dropout_rate)(conca)
    conca = Conv2D(filters=8*filters, kernel_size=(3,3), padding='same', activation= 'relu')(conca)
    conca = BatchNormalization()(conca)
    conca = Conv2D(filters=8*filters, kernel_size=(3,3), padding='same', activation= 'relu')(conca)
    conca = BatchNormalization()(conca)
    conca = Dropout(Dropout_rate)(conca, training=True)
    
    return conca 

# Input 
input = Input((128,256,3))

# Conv part 
conv1, pool1 = Conv_Block(filters, input)
conv2, pool2 = Conv_Block(2*filters, pool1)
conv3, pool3 = Conv_Block(4*filters, pool2)
conv4, pool4 = Conv_Block(8*filters, pool3)
conv5, pool5 = Conv_Block(16*filters, pool4)

# Middle part 
conv = Conv2D(filters=32*filters, kernel_size=(3,3), padding='same', activation= 'relu')(pool5)
conv = BatchNormalization()(conv)
conv = Conv2D(filters=32*filters, kernel_size=(3,3), padding='same', activation= 'relu')(conv)
conv = BatchNormalization()(conv)
conv = Dropout(Dropout_rate)(conv, training=True)

# Deconv part 
conca5 = Deconv_Block(16*filters, conv5, conv)
conca4 = Deconv_Block(8*filters, conv4, conca5)
conca3 = Deconv_Block(4*filters, conv3, conca4)
conca2 = Deconv_Block(2*filters, conv2, conca3)
conca1 = Deconv_Block(filters, conv1, conca2)

# Output 
output = Conv2D(filters=34, kernel_size=(1,1), activation='softmax')(conca1)

model1 = tf.keras.Model(inputs = input, outputs = output)
model1.summary() 

# Unet

## Input 
input = Input((128,256,3))
    
## Conv part 
conv1 = Conv2D(filters=filters, kernel_size=(3,3), padding='same', activation= 'relu')(input)
conv1 = BatchNormalization(conv1)
conv1 = Conv2D(filters=filters, kernel_size=(3,3), padding='same', activation= 'relu')(conv1)
conv1 = BatchNormalization(conv1)
pool1 = MaxPooling2D(pool_size=(2,2))(conv1)
pool1 = Dropout(Dropout_rate)(pool1)

conv2 = Conv2D(filters=2*filters, kernel_size=(3,3), padding='same', activation= 'relu')(pool1)
conv2 = BatchNormalization(conv2)
conv2 = Conv2D(filters=2*filters, kernel_size=(3,3), padding='same', activation= 'relu')(conv2)
conv2 = BatchNormalization(conv2)
pool2 = MaxPooling2D(pool_size=(2,2))(conv2)
pool2 = Dropout(Dropout_rate)(pool2)

conv3 = Conv2D(filters=4*filters, kernel_size=(3,3), padding='same', activation= 'relu')(pool2)
conv3 = BatchNormalization(conv3)
conv3 = Conv2D(filters=4*filters, kernel_size=(3,3), padding='same', activation= 'relu')(conv3)
conv3 = BatchNormalization(conv3)
pool3 = MaxPooling2D(pool_size=(2,2))(conv3)
pool3 = Dropout(Dropout_rate)(pool3)

conv4 = Conv2D(filters=8*filters, kernel_size=(3,3), padding='same', activation= 'relu')(pool3)
conv4 = BatchNormalization(conv4)
conv4 = Conv2D(filters=8*filters, kernel_size=(3,3), padding='same', activation= 'relu')(conv4)
pool4 = MaxPooling2D(pool_size=(2,2))(conv4)
pool4 = Dropout(Dropout_rate)(pool4)

conv5 = Conv2D(filters=16*filters, kernel_size=(3,3), padding='same', activation= 'relu')(pool4)
conv5 = BatchNormalization(conv5)
conv5 = Conv2D(filters=16*filters, kernel_size=(3,3), padding='same', activation= 'relu')(conv5)
conv5 = BatchNormalization(conv5)
pool5 = MaxPooling2D(pool_size=(2,2))(conv5)
pool5 = Dropout(Dropout_rate)(pool5)

## Middle part 
conv = Conv2D(filters=32*filters, kernel_size=(3,3), padding='same', activation= 'relu')(pool5)
conv = BatchNormalization(conv)
conv = Conv2D(filters=32*filters, kernel_size=(3,3), padding='same', activation= 'relu')(conv)
conv = BatchNormalization(conv)

## Deconv part 
deconv5 = Conv2DTranspose(filters=16*filters, kernel_size=(3,3), padding='same', strides=(2,2))(conv)
conca5 = concatenate([deconv5, conv5])
conca5 = Dropout(Dropout_rate)(conca5)
conca5 = Conv2D(filters=16*filters, kernel_size=(3,3), padding='same', activation= 'relu')(conca5)
conca5 = Conv2D(filters=16*filters, kernel_size=(3,3), padding='same', activation= 'relu')(conca5)

deconv4 = Conv2DTranspose(filters=8*filters, kernel_size=(3,3), padding='same',  strides=(2,2))(conca5)
conca4 = concatenate([deconv4, conv4])
conca4 = Dropout(Dropout_rate)(conca4)
conca4 = Conv2D(filters=8*filters, kernel_size=(3,3), padding='same', activation= 'relu')(conca4)
conca4 = Conv2D(filters=8*filters, kernel_size=(3,3), padding='same', activation= 'relu')(conca4)

deconv3 = Conv2DTranspose(filters=4*filters, kernel_size=(3,3), padding='same',  strides=(2,2))(conca4)
conca3 = concatenate([deconv3, conv3])
conca3 = Dropout(Dropout_rate)(conca3)
conca3 = Conv2D(filters=4*filters, kernel_size=(3,3), padding='same', activation= 'relu')(conca3)
conca3= Conv2D(filters=4*filters, kernel_size=(3,3), padding='same', activation= 'relu')(conca3)

deconv2 = Conv2DTranspose(filters=2*filters, kernel_size=(3,3), padding='same',  strides=(2,2))(conca3)
conca2 = concatenate([deconv2, conv2])
conca2 = Dropout(Dropout_rate)(conca2)
conca2 = Conv2D(filters=2*filters, kernel_size=(3,3), padding='same', activation= 'relu')(conca2)
conca2 = Conv2D(filters=2*filters, kernel_size=(3,3), padding='same', activation= 'relu')(conca2)

deconv1 = Conv2DTranspose(filters=filters, kernel_size=(3,3), padding='same',  strides=(2,2))(conca2)
conca1 = concatenate([deconv1, conv1])
conca1 = Dropout(Dropout_rate)(conca1)
conca1 = Conv2D(filters=filters, kernel_size=(3,3), padding='same', activation= 'relu')(conca1)
conca1 = Conv2D(filters=filters, kernel_size=(3,3), padding='same', activation= 'relu')(conca1)

##  Output
output = Conv2D(filters=34, kernel_size=(1,1), activation='softmax')(conca1)

model1 = tf.keras.Model(inputs = input, outputs = output)
model1.summary() 


In [None]:
tf.keras.utils.plot_model(model1)    # plot the model 

In [None]:
from tensorflow.keras import callbacks  

model1.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model1.optimizer.lr = 1E-3

log_dir = "logs/fit/Unet"
tensorboard_callback = callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)        

#with tf.device("/CPU:0"):
model1.history = model1.fit(
    x = rgb_train, 
    y = seg_train,
    batch_size=16,                      # reduce the batch_size to save storage 
    epochs=50,                          # reduce the epoches to shorten the training time 
    validation_data=(rgb_test, seg_test),
    callbacks=[tensorboard_callback]
) 

In [None]:
model1.save('models/model1.h5')        # save the model 

In [None]:
model1 = tf.keras.models.load_model('models/model4.h5')   # load the trained model

In [None]:
predictions = model1.predict(rgb_test[:20])          # get 20 predictions

In [None]:
def output2img(data):                                # onehot code to imgs
    imgs = np.zeros((data.shape[0],data.shape[1],data.shape[2],3))
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            for k in range(data.shape[2]):
                label = np.argmax(data[i,j,k,:])
                imgs[i,j,k] = np.array(color_codes[label])
    return imgs

def seg2img(data):                                  # seg label to imgs 
    
    imgs = np.zeros((data.shape[0],data.shape[1],data.shape[2],3)) 
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            for k in range(data.shape[2]):
                imgs[i,j,k] = np.array(color_codes[data[i,j,k]])
    return imgs

imgs_pred = output2img(predictions)
imgs_true = seg2img(seg_test[:20])
imgs_orig = rgb_test[:20]

In [None]:
# plot the images 
plt.figure(figsize=(12, 24))
num = 3
k = 9
plt.title("Image Comparations")

for i in range(num):
    plt.subplot(num, 3, i*num+1)
    plt.title("Original Image "+str(i+4))
    plt.imshow(tf.keras.preprocessing.image.array_to_img(imgs_orig[i+k]))
    plt.subplot(num, 3, i*num+2)
    plt.title("Image True Seg "+str(i+4))
    plt.imshow(tf.keras.preprocessing.image.array_to_img(imgs_true[i+k]))
    plt.subplot(num, 3, i*num+3)
    plt.title("Model Output "+str(i+4))
    plt.imshow(tf.keras.preprocessing.image.array_to_img(imgs_pred[i+k]))

In [None]:
# aleatoric and epistemic uncertainties

def alea_uncertainties(predict, true):
    img_alea = np.zeros((128, 256))
    for i in range(128):
        for j in range(256):
            uncertainty = -np.log(predict[i,j,true[i,j]])
            img_alea[i,j] = uncertainty
    seaborn.heatmap(img_alea)
    
def epis_uncertainties(predict):
    img_epis = np.zeros((128, 256))
    num = predict.shape[0]
    for i in range(128):
        for j in range(256):
            list = []
            for k in range(num): 
                label = np.argmax(predict[k,i,j,:])
                list.append(label)
            uncertainty = np.var(list)
            img_epis[i,j] = uncertainty
    seaborn.heatmap(img_epis, )
            
                

In [None]:
plt.figure(figsize=(12, 3))
plt.title("Uncertainties")
plt.subplot(1, 2, 1)
plt.title("aleatoric uncertainty 5")
alea_uncertainties(predictions[10], seg_test[10])
plt.subplot(1, 2, 2)
plt.title("epistemic uncertainty 5")
epis_uncertainties(predict_list5)