In [1]:
import os
import numpy as np
import pandas as pd
import re
import matplotlib.pyplot as plt
import cv2
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

SHAPE = 512


In [2]:
## defining a frame for image and mask storage
framObjTrain = {'img' : [],
           'mask' : []
          }

framObjValidation = {'img' : [],
           'mask' : []
          }

## defining data Loader function
def LoadData( frameObj = None, imgPath = None, maskPath = None, shape = SHAPE): ### !!! SPLIT DATASET
    imgNames = os.listdir(imgPath)
    maskNames = []
    
    ## generating mask names
    for mem in imgNames:
        maskNames.append(re.sub('\.jpg', '.png', mem))
    
    ## defining images and labels path
    imgAddr = imgPath + '/'
    maskAddr = maskPath + '/'
    
    ## loop all images
    for i in range (len(imgNames)):
        try:
            ## read an image and the corresponding label
            img = plt.imread(imgAddr + imgNames[i])
            mask = plt.imread(maskAddr + maskNames[i])
            
            ## normalize image color
            img = img/255.0
            
            ## resize image dimension to SHAPE x SHAPE
            img = cv2.resize(img, (shape, shape))
            mask = cv2.resize(mask, (shape, shape))
        except:
            continue
    
        frameObj['img'].append(img)
        frameObj['mask'].append(mask)
        
    return frameObj

In [None]:
framObjTrain = LoadData( framObjTrain, ### !!! IMPLEMENT CHANGED METHOD
                        imgPath = 'dataset/images', 
                        maskPath = 'dataset/new_labels',
                        shape = SHAPE)

In [None]:
## displaying data loaded by our function
import random
n = random.randint(0,100)

plt.subplot(1,2,1)
plt.imshow(framObjTrain['img'][n])
plt.subplot(1,2,2)
plt.imshow(framObjTrain['mask'][n])
plt.show()

In [None]:
## defining our CNN for encoding and decoding

myTransformer = tf.keras.models.Sequential([ ### !!! TRY DIFFERENT ARCHITECTURE
    ## defining encoder 
    tf.keras.layers.Input(shape= (SHAPE, SHAPE, 3)),
    tf.keras.layers.Conv2D(filters = 16, kernel_size = (3,3), activation = 'relu', padding = 'same'),
    tf.keras.layers.MaxPool2D(pool_size = (2, 2)),
    
    tf.keras.layers.Conv2D(filters = 32, kernel_size = (3,3), strides = (2,2), activation = 'relu', padding = 'valid'),
    tf.keras.layers.Conv2D(filters = 64, kernel_size = (3,3), strides = (2,2), activation = 'relu', padding = 'same'),
    tf.keras.layers.MaxPool2D(pool_size = (2, 2)),
    
    tf.keras.layers.Conv2D(filters = 64, kernel_size = (3,3), activation = 'relu', padding = 'same'),
    tf.keras.layers.Conv2D(filters = 128, kernel_size = (3,3), activation = 'relu', padding = 'same'),
    tf.keras.layers.Conv2D(filters = 128, kernel_size = (3,3), activation = 'relu', padding = 'same'),
    tf.keras.layers.Conv2D(filters = 256, kernel_size = (3,3), activation = 'relu', padding = 'same'),
    tf.keras.layers.Conv2D(filters = 512, kernel_size = (3,3), activation = 'relu', padding = 'same'),
    
    ## defining decoder path
    tf.keras.layers.UpSampling2D(size = (2,2)),
    tf.keras.layers.Conv2D(filters = 256, kernel_size = (3,3), activation = 'relu', padding = 'same'),
    tf.keras.layers.Conv2D(filters = 128, kernel_size = (3,3), activation = 'relu', padding = 'same'),
    tf.keras.layers.Conv2D(filters = 128, kernel_size = (3,3), activation = 'relu', padding = 'same'),
    tf.keras.layers.Conv2D(filters = 128, kernel_size = (3,3), activation = 'relu', padding = 'same'),
    
    tf.keras.layers.UpSampling2D(size = (2,2)),
    tf.keras.layers.Conv2D(filters = 64, kernel_size = (3,3), activation = 'relu', padding = 'same'),
    tf.keras.layers.UpSampling2D(size = (2,2)),
    tf.keras.layers.Conv2D(filters = 32, kernel_size = (3,3), activation = 'relu', padding = 'same'),
    tf.keras.layers.UpSampling2D(size = (2,2)),
    tf.keras.layers.Conv2D(filters = 16, kernel_size = (3,3), activation = 'relu', padding = 'same'),
    tf.keras.layers.Conv2D(filters = 3, kernel_size = (3,3), activation = 'relu', padding = 'same'),
    
    
    
])

In [None]:
myTransformer.summary()

In [None]:
myTransformer.compile(
    optimizer = tf.keras.optimizers.Adam(learning_rate = 5e-5), ### !!! TRY DIFFERENT ALGORITHM
    loss = 'mean_absolute_error', ### !!! TRY DIFFERENT ALGORITHM
    metrics = ['acc'])

In [None]:
#Training Data
retVal = myTransformer.fit(np.array(framObjTrain['img']),
                           np.array(framObjTrain['mask']),
                           batch_size=16, ### !!! TRY DIFFERENT BATCH SIZE
                           ### !!! ADD VALIDATION DATASET
                           epochs = 25 ### !!! CHANGE BACK TO 100
                          )

In [None]:
plt.plot(retVal.history['loss'], label = 'training_loss')
plt.plot(retVal.history['acc'], label = 'training_accuracy')
plt.legend()
plt.grid(True)

In [None]:
def predict16 (valMap, model, shape = SHAPE):
    ## getting and proccessing val data
    img = valMap['img']
    mask = valMap['mask']
    mask = mask[0:16]
    
    imgProc = img [0:16]
    imgProc = np.array(img)
    
    predictions = model.predict(imgProc)
    for i in range(len(predictions)):
        predictions[i] = cv2.merge((predictions[i,:,:,0],predictions[i,:,:,1],predictions[i,:,:,2]))
    
    return predictions, imgProc, mask


def Plotter(img, predMask, groundTruth):
    plt.figure(figsize=(7,7))
    
    plt.subplot(1,3,1)
    plt.imshow(img)
    plt.title('image')
    
    plt.subplot(1,3,2)
    plt.imshow(predMask)
    plt.title('Predicted Mask')
    
    plt.subplot(1,3,3)
    plt.imshow(groundTruth)
    plt.title('actual Mask')

In [None]:
sixteenPrediction, actuals, masks = predict16(framObjTrain, myTransformer)

In [None]:
n_r = random.randint(0,15)

Plotter(actuals[n_r], sixteenPrediction[n_r], masks[n_r])

In [None]:
# show the image, provide window name first
cv2.imshow('image window', sixteenPrediction[n_r])
# add wait key. window waits until user presses a key
cv2.waitKey(0)
# and finally destroy/close all open windows
cv2.destroyAllWindows()

In [None]:
#save_file = 'model.h5'
#myTransformer.save(save_file)