Import Libraries


In [1]:
import UNet
import tensorflow as tf

from PIL import Image
import numpy as np
import glob

Import and Format Data


In [2]:
def rgbToOneHot(img, colorDict):
    numClasses = len(colorDict)
    shape = img.shape[:2]+(numClasses,)
    arr = np.zeros( shape, dtype=np.int8 )

    for i, cls in enumerate(colorDict):
        arr[:,:,i] = np.all(img.reshape( (-1,3) ) == cls, axis=1).reshape(shape[:2])

    return arr

In [3]:
fileList = glob.glob("data/512x512/Train/images_512/*.jpeg")
XTrain = np.array([np.array(Image.open(fname)) for fname in fileList])

fileList = glob.glob("data/512x512/Train/mask_512/*.png")
YTrain = np.array([np.array(Image.open(fname)) for fname in fileList])

fileList = glob.glob("data/512x512/Test/images_512/*.jpeg")
XTest = np.array([np.array(Image.open(fname)) for fname in fileList])

fileList = glob.glob("data/512x512/Test/mask_512/*.png")
YTest = np.array([np.array(Image.open(fname)) for fname in fileList])

print(XTrain.shape)
print(YTrain.shape)
print(XTest.shape)
print(YTest.shape)

colors = {(128, 0, 0): 0,
          (0, 0, 0): 1,
          (0, 128, 0): 2,
          (128, 128, 0): 3}

YTrain = np.array([rgbToOneHot(YTrain[i], colors) for i in range(YTrain.shape[0])])
YTest = np.array([rgbToOneHot(YTest[i], colors) for i in range(YTest.shape[0])])

print(YTrain.shape)
print(YTest.shape)

(44, 512, 512, 3)
(44, 512, 512, 3)
(44, 512, 512, 3)
(44, 512, 512, 3)
(44, 512, 512, 4)
(44, 512, 512, 4)


Create Model + Tools


In [4]:
# Create model
model = UNet.UNet((512, 512, 3), 4, [32, 64, 128, 256, 512], 3, 2, 4)
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 sequential (Sequential)        (None, 512, 512, 32  10400       ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 max_pooling2d (MaxPooling2D)   (None, 256, 256, 32  0           ['sequential[0][0]']             
                                )                                                             

In [5]:
# Create checkpoint callback
checkpoint = tf.keras.callbacks.ModelCheckpoint(
    "model.h5", verbose=1, save_best_only=True
)

# Create callbacks
callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor="accuracy", patience=2, verbose=1),
    tf.keras.callbacks.TensorBoard(log_dir="logs"),
]

In [6]:
model.fit(XTrain, YTrain, batch_size=4, epochs=10, verbose=2, callbacks=callbacks)

Epoch 1/10
11/11 - 91s - loss: 0.6606 - accuracy: 0.4045 - 91s/epoch - 8s/step
Epoch 2/10
11/11 - 86s - loss: 0.5301 - accuracy: 0.7007 - 86s/epoch - 8s/step
Epoch 3/10
11/11 - 86s - loss: 0.4741 - accuracy: 0.7703 - 86s/epoch - 8s/step
Epoch 4/10
11/11 - 89s - loss: 0.4366 - accuracy: 0.7790 - 89s/epoch - 8s/step
Epoch 5/10
11/11 - 92s - loss: 0.4039 - accuracy: 0.7779 - 92s/epoch - 8s/step
Epoch 6/10
11/11 - 89s - loss: 0.3796 - accuracy: 0.7831 - 89s/epoch - 8s/step
Epoch 7/10
11/11 - 88s - loss: 0.3558 - accuracy: 0.7858 - 88s/epoch - 8s/step
Epoch 8/10
11/11 - 89s - loss: 0.3324 - accuracy: 0.7937 - 89s/epoch - 8s/step
Epoch 9/10
11/11 - 88s - loss: 0.3203 - accuracy: 0.7893 - 88s/epoch - 8s/step
Epoch 10/10
11/11 - 89s - loss: 0.3055 - accuracy: 0.7978 - 89s/epoch - 8s/step


<keras.callbacks.History at 0x1e48259e3d0>

In [9]:
# Evaluate model
model.evaluate(XTest, YTest, verbose=2)

2/2 - 18s - loss: 0.4585 - accuracy: 0.7017 - 18s/epoch - 9s/step


[0.4585326015949249, 0.701694905757904]

In [8]:
# Save the model
model.save("model.h5")