In [None]:
%matplotlib inline

import keras
import os
import cv2
import numpy as np
from PIL import Image
import json
import keras.backend as K
import tensorflow as tf
import sklearn.preprocessing
from keras.models import Sequential
from keras import metrics
from keras.layers import Dense, Softmax, Conv2D, Input, Flatten, Lambda, MaxPooling2D, BatchNormalization, Conv2DTranspose, Dropout
from IPython.display import display 
import metricsSemSeg
from metricsSemSeg import pixel_accuracy, mean_accuracy, mean_IU, frequency_weighted_IU

CLASSTORGB = [[0,0,0],[255,255,255]]
NNName = "NASNet"
DATASET = "Seagrass"
CLASSES = 2
BSIZE = 2
X = 512
Y = 256
CLASSWEIGHTS = [0.1, 1.0]
TRAINSIZE = 1846/BSIZE#3424/BSIZE#4223/BSIZE
VALSIZE = 264/BSIZE#498/BSIZE#610/BSIZE
TESTSIZE = 525/BSIZE#975/BSIZE#1204/BSIZE
TRAINPATH = "images/"
LABELPATH = "ground-truth/"
DATASETPATH = "../data/"+DATASET+"/"
EPOCHS = 100
LR = 0.001
filepath = "../models/keras/"+NNName+str(X)+str(Y)+DATASET+"LR"+str(LR)+"batch"+str(BSIZE)+".h5"

import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
#config.gpu_options.per_process_gpu_memory_fraction = 0.3
set_session(tf.Session(config=config))


# semantic segmentation metrics
class metricsSS(object):

    def __init__(self, num_classes, _bsize):
        super().__init__()
        self.classes = num_classes
        self.batchSize = _bsize
        self.pA = metricsSemSeg.pixel_accuracy
        self.mA = metricsSemSeg.mean_accuracy
        self.mIoU = metricsSemSeg.mean_IU
        self.fwmIoU = metricsSemSeg.frequency_weighted_IU

    def preProcessKerasInput(self, _pred, _true):
            pred = K.argmax(_pred, axis=2)
            true = K.argmax(_true, axis=2)
            #pred = K.cast(pred, tf.float32)
            #true = K.cast(true, tf.float32)
            return pred, true

    def meanIoU(self, y_true, y_pred):
        metric = 0.0
        for b in range(self.batchSize):
            pred, true = self.preProcessKerasInput(y_pred[b], y_true[b])
            metric += tf.py_func(self.mIoU, [pred, true], tf.float32)

        return metric/self.batchSize

    def frequencyWeightedUI(self, y_true, y_pred):
        metric = 0.0
        for b in range(self.batchSize):
            pred, true = self.preProcessKerasInput(y_pred[b], y_true[b])
            metric += tf.py_func(self.fwmIoU, [pred, true], tf.float32)

        return metric/self.batchSize


    def pixelAccuracy(self, y_true, y_pred):
        metric = 0.0
        for b in range(self.batchSize):
            pred, true = self.preProcessKerasInput(y_pred[b], y_true[b])
            metric += tf.py_func(self.pA, [pred, true], tf.float32)

        return metric/self.batchSize


    def meanAccuracy(self, y_true, y_pred):
        metric = 0.0
        for b in range(self.batchSize):
            pred, true = self.preProcessKerasInput(y_pred[b], y_true[b])
            metric += tf.py_func(self.mA, [pred, true], tf.float32)

        return metric/self.batchSize


def deconv(f, model):
    model.add(Conv2DTranspose(f, 3, strides=2, padding='same', activation="relu"))


def conv(first=False, units=128, f=3, dilation=1, last=False):
    if first :  
        return Conv2D(
            units,
            f,
            input_shape=(Y, X, 3),
            strides=(1, 1),
            padding='same',
            activation="relu" if not last else "softmax",
            use_bias=True,
            dilation_rate=dilation,
            data_format="channels_last",
            kernel_initializer="glorot_normal",
            bias_initializer=keras.initializers.Constant(value=0.1)
        )
    else:
        return Conv2D(
            units if not last else CLASSES,
            3 if not last else 1,
            strides=(1, 1),
            padding='same',
            activation="relu",
            use_bias=True,
            data_format="channels_last",
            kernel_initializer="glorot_normal",
            bias_initializer=keras.initializers.Constant(value=0.1)
        )

def deConv(depth):
    return Conv2DTranspose(
        depth,
        3,
        strides=(2, 2),
        padding='same',
        data_format="channels_last",
        kernel_initializer="glorot_normal",
        bias_initializer=keras.initializers.Constant(value=0.1)
    )

    
def netFCN():
    model = Sequential()
    # layers
    # encoding
    model.add(conv(first=True))
    model.add(MaxPooling2D(pool_size=(2,2)))
    model.add(BatchNormalization())
    model.add(Dropout(rate=0.3))


    model.add(conv(units=128))
    model.add(MaxPooling2D(pool_size=(2,2)))
    model.add(BatchNormalization())
    model.add(Dropout(rate=0.3))

    model.add(conv(units=256))
    #model.add(MaxPooling2D(pool_size=(2,2)))
    model.add(BatchNormalization())
    model.add(Dropout(rate=0.3))
    
    model.add(conv(units=512))
    L = BatchNormalization()
    model.add(L)

    # decoding
    model.add(deConv(512))
    model.add(deConv(256))
    #model.add(deConv(128))
    #model.add(keras.layers.UpSampling2D(size=(2, 2), data_format=None, interpolation='nearest'))
    #model.add(keras.layers.UpSampling2D(size=(2, 2), data_format=None, interpolation='nearest'))

    #L = Lambda(lambda input: pixelDeconv(tf.convert_to_tensor(input), 512, "dec1"))
    #model.add(L)
    #L = Lambda(lambda input: pixelDeconv(tf.convert_to_tensor(input), 256, "dec2"))
    #model.add(L)
    #L = Lambda(pixelDeconv(L.output, 128, "dec3"))
    #model.add(L)
    
    #model.add(Dropout(0.5))
    model.add(conv(last=True))
    model.add(Softmax())

    return model

def dilNet():
    model = Sequential()

    model.add(conv(first=True))
    model.add(conv(units=128))
    model.add(MaxPooling2D(pool_size=(2,2)))
    model.add(BatchNormalization())

    model.add(conv(units=128))
    model.add(conv(units=256))
    model.add(MaxPooling2D(pool_size=(2,2)))
    model.add(BatchNormalization())
        
    model.add(conv(units=256))
    model.add(conv(units=256))
    model.add(conv(units=256, f=1))
    model.add(BatchNormalization())

    model.add(conv(units=512, dilation=2))
    model.add(conv(units=512, dilation=2))
    model.add(conv(units=512, f=1, dilation=2))
    model.add(BatchNormalization())

    model.add(conv(units=512, dilation=4))
    model.add(conv(units=512, dilation=4))
    model.add(conv(units=512, f=1, dilation=4))
    model.add(BatchNormalization())

    model.add(deConv(512))
    model.add(deConv(256))

    #model.add(keras.layers.UpSampling2D(size=(2, 2), data_format=None, interpolation='nearest'))
    #model.add(keras.layers.UpSampling2D(size=(2, 2), data_format=None, interpolation='nearest'))

    model.add(conv(last=True))
    model.add(Softmax())    
    return model

def mobileNet_V2_SS(inputShape):
    model = Sequential()
    model.add(keras.applications.mobilenet_v2.MobileNetV2(
        input_shape=inputShape,
        alpha=1.0,
        depth_multiplier=1,
        include_top=False,
        weights=None
        
    ))
    
    deconv(320, model)
    deconv(256, model)
    deconv(256, model)
    deconv(256, model)
    deconv(128, model)
    model.add(conv(last=True))
    model.add(Softmax())
    
    return model



def inceptionResnet_SS(inputShape):
    model = Sequential()
    model.add(keras.applications.inception_resnet_v2.InceptionResNetV2(
        input_shape=inputShape,
        include_top=False,
        weights=None
    ))
    
    deconv(320, model)
    deconv(256, model)
    deconv(256, model)
    deconv(256, model)
    deconv(128, model)
    model.add(conv(last=True))
    model.add(Softmax())
    
    return model
              
def xception_SS(inputShape):
    model = Sequential()
    model.add(keras.applications.xception.Xception(
        input_shape=inputShape,
        include_top=False,
        weights=None
    ))
    
    deconv(320, model)
    deconv(256, model)
    deconv(256, model)
    deconv(256, model)
    deconv(128, model)
    model.add(conv(last=True))
    model.add(Softmax())
    
    return model
              
def nasNet_SS(inputShape):
    model = Sequential()
    model.add(keras.applications.nasnet.NASNetMobile(
        input_shape=inputShape,
        include_top=False,
        weights=None
    ))
    
    deconv(320, model)
    deconv(256, model)
    deconv(256, model)
    deconv(256, model)
    deconv(128, model)
    model.add(conv(last=True))
    model.add(Softmax())
    
    return model

class DataSequence(keras.utils.Sequence):

    def __init__(self, batchSize, trainType):
        self.dataSetPath = DATASETPATH
        
        if trainType == "train":
            jsonData = json.load(open(DATASETPATH+"train.json"))
            self.x = list(map(lambda i:os.path.basename(i["image"]) if i["depth"] <= float(2) else None, jsonData))
            self.y = list(map(lambda i:os.path.basename(i["ground-truth"]) if i["depth"] <= float(2) else None, jsonData))
            self.x = list(filter(lambda i:i != None, self.x))
            self.y = list(filter(lambda i:i != None, self.y))
            
            
            
        elif trainType == "validation":
            jsonData = json.load(open(DATASETPATH+"validate.json"))
            self.x = list(map(lambda i:os.path.basename(i["image"]) if i["depth"] <= float(2) else None, jsonData))
            self.y = list(map(lambda i:os.path.basename(i["ground-truth"]) if i["depth"] <= float(2) else None, jsonData))
            self.x = list(filter(lambda i:i != None, self.x))
            self.y = list(filter(lambda i:i != None, self.y))
            
        elif trainType == "test":
            jsonData = json.load(open(DATASETPATH+"test.json"))
            self.x = list(map(lambda i:os.path.basename(i["image"]) if i["depth"] <= float(2) else None, jsonData))
            self.y = list(map(lambda i:os.path.basename(i["ground-truth"]) if i["depth"] <= float(2) else None, jsonData))
            self.x = list(filter(lambda i:i != None, self.x))
            self.y = list(filter(lambda i:i != None, self.y))
            
        else:
            raise "unknown dataset type, valid are train, validation and test"
    
        
        self.batch_size = batchSize
        print("Loaded data with size of ", len(self))
    def __len__(self):
        return int(len(self.x) / self.batch_size)

    def __getitem__(self, idx):
        
        imgBatch = None
        labelImgBatch = None
        
        for b in range(self.batch_size):
            #  loading train and label image
            
            elIdx = idx*self.batch_size+b
            img = cv2.imread(self.dataSetPath+TRAINPATH+self.x[elIdx])
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.resize(img, (X, Y), interpolation=cv2.INTER_NEAREST)

            labelImg = cv2.imread(self.dataSetPath+LABELPATH+self.y[elIdx])
            labelImg = cv2.cvtColor(labelImg, cv2.COLOR_BGR2RGB)
            labelImg = cv2.resize(labelImg, (X, Y), interpolation=cv2.INTER_NEAREST)


            if CLASSES == 2:
                labelImg[(labelImg  >= 128).all(-1)] = [255,255,255]
                labelImg[(labelImg  <= 127).all(-1)] = [0,0,0]


            #print(self.dataSetPath+TRAINPATH+self.x[idx], self.dataSetPath+LABELPATH+self.y[idx])
            #display(Image.fromarray(img, "RGB"))
            #display(Image.fromarray(labelImg, "RGB"))

            # process train and label image
            img = ((img - img.mean()) / img.std()).astype(np.float32)
            img = np.array(img)
            img = np.reshape(img, (1,Y,X,3))



            for rgbIdx, rgbV in enumerate(CLASSTORGB):
                labelImg[(labelImg == rgbV).all(-1)] = rgbIdx


            labelImg = labelImg[:,:,0].astype(np.int32)
            labelImg = np.reshape(labelImg, (1,Y,X,1))


            # dont know why but there are(some datasets) rgb values which are not assigned to a class
            # because of this these values are not replaced with their assigned class and
            # have to be removed as in assigned to class zero aka black
            #print("UNIQUE RGB VALUES", np.unique(np.array(img).reshape((int(6291456/3), 3)), axis=0))

            labelImg[(labelImg >= CLASSES)] = 0

            #onehot = keras.utils.to_categorical(labelImg, num_classes=2, dtype='float32')
            #print(img.shape, labelImg.shape)
            #calculcate the weights for the current image
            #sampleWeights = onehot * [0.1, 1.0]
            #sampleWeights = np.sum(sampleWeights, axis=3)
            #sampleWeights = sampleWeights.reshape((self.batch_size,224,224,1))
        
            if imgBatch is None:
                imgBatch = img
                labelImgBatch = labelImg
            else:
                imgBatch = np.concatenate((imgBatch, img), axis=0)
                labelImgBatch = np.concatenate((labelImgBatch, labelImg), axis=0)
            
        return imgBatch, labelImgBatch#, sampleWeights
    

    
def pred(fileName, debug=True):
    predImg = cv2.imread("../results/"+fileName)
    predImg = cv2.cvtColor(predImg, cv2.COLOR_BGR2RGB)
    predImg = cv2.resize(predImg, (X, Y), interpolation=cv2.INTER_NEAREST)
    predImg = np.expand_dims(((predImg  - predImg.mean()) / predImg.std()).astype(np.float32), axis=0)
    predClasses = model.predict(predImg)
    predClasses = np.argmax(predClasses, axis=3).flatten()
    #print(predClasses.shape, np.bincount(predClasses), predClasses)
    predImg = np.zeros((X*Y,3))

    for idx, p in enumerate(predClasses):
        predImg[idx] = CLASSTORGB[p]
         
   
    predImg = predImg.reshape((Y, X, 3)).astype("uint8")
    
    if debug:
        display(Image.fromarray(predImg, "RGB"))

def evaluate(data, model):

    from metricsSemSeg import pixel_accuracy, mean_accuracy, mean_IU, frequency_weighted_IU

    totalCorrect = 0
    totalCount = TESTSIZE*X*Y*3

    totalPAcc = 0.0
    totalMAcc = 0.0
    totalMIU = 0.0
    totalFWIU = 0.0

    i = 0

    
    for idx in range(len(data)):#data.getNextBatchValidation(BSIZE, max_images):# data.config["validationSize"]):
            x, y = data[idx]

            predClasses = model.predict_on_batch(x)
            for b in range(BSIZE):
                pred = np.squeeze(predClasses[b])
                pred = np.argmax(pred, axis=2)
                labelData = np.squeeze(y[b])

                if i % 100/BSIZE == 0:
                    print("Image ", i, " evaluated...")

                totalPAcc = pixel_accuracy(pred, labelData) if totalPAcc == 0.0 else  (totalPAcc + pixel_accuracy(pred, labelData))/2
                totalMAcc = mean_accuracy(pred, labelData) if totalMAcc == 0.0 else  (totalMAcc + mean_accuracy(pred, labelData))/2
                totalMIU = mean_IU(pred, labelData) if totalMIU == 0.0 else  (totalMIU + mean_IU(pred, labelData))/2
                totalFWIU = frequency_weighted_IU(pred, labelData) if totalFWIU == 0.0 else  (totalFWIU + frequency_weighted_IU(pred, labelData))/2

                i = i+1


    print("Pixel accuracy: ", totalPAcc ," || Mean accuracy: ", totalMAcc ," || Mean intersection union:", totalMIU ," || frequency weighted IU: ", totalFWIU)

class regularPred(keras.callbacks.Callback):
    
    def __init__(self):
        self.data = data = DataSequence(BSIZE, "test")

    def on_epoch_end(self, epoch, logs={}):
        pred("predictSeagrass.jpg")
        if epoch % 5 == 0:
            evaluate(self.data, self.model)
    #def on_batch_end(self, batch, logs={}):
    #    if batch % 500 == 0:
    #        pred("predict.jpg")
        
    
# in case class weights are needed
def sparse_crossentropy_weighted(ground_truth, predictions):
    
    ground_truth = tf.cast(ground_truth, tf.int32)
        
    #onehot_labels = tf.one_hot(tf.squeeze(ground_truth,3), CLASSES)
    #weights = onehot_labels * CLASSWEIGHTS
    #weights = tf.reduce_sum(weights, 3)
    
    return tf.reduce_mean(
        tf.losses.sparse_softmax_cross_entropy(
                        labels=ground_truth,
                        logits=predictions))
     #                   weights=weights))

def getModel(modelName):
    import os
    from keras.models import load_model
    from nets.deepLabv3Keras import BilinearUpsampling, relu6
    metricsPA = metricsSS(CLASSES, BSIZE).pixelAccuracy
    metricsMA = metricsSS(CLASSES, BSIZE).meanAccuracy
    metricsMIOU = metricsSS(CLASSES, BSIZE).meanIoU
    metricsFWMIOU = metricsSS(CLASSES, BSIZE).frequencyWeightedUI
    
    if os.path.isfile(filepath):
        model = load_model(filepath, custom_objects={
                                        "sparse_crossentropy_weighted":sparse_crossentropy_weighted,
                                        'pixelAccuracy': metricsPA,
                                        "meanAccuracy": metricsMA,
                                        "meanIoU": metricsMIOU,
                                        "frequencyWeightedUI": metricsFWMIOU,
                                        "BilinearUpsampling": BilinearUpsampling,
                                        "relu6":relu6
                                    }
                          )

        print("Model loaded from h5 file")
    else:
        if modelName == "deeplabv3":
            from nets.deepLabv3Keras import Deeplabv3
            model = Deeplabv3(weights=None, backbone="mobilenetv2", input_shape=(Y,X,3), classes=CLASSES, OS=8)  
        elif modelName == "mobilenetv2":
            model = mobileNet_V2_SS((Y,X,3))
        elif modelName =="FCN":
            model = netFCN()
        elif modelName =="dilNet":
            model = dilNet()
        elif modelName =="NASNet":
            model = nasNet_SS((Y,X,3))
        elif modelName =="inception":
            model = inceptionResnet_SS((Y,X,3))
        elif modelName =="xception":
            model = xception_SS((Y,X,3)) 
        else:
            raise "no right modelname given"

        print("Fresh model loaded from architecture")
        
    return model

In [None]:
%matplotlib inline



model = getModel(NNName)
        
#model.summary()
model.compile(
                loss=sparse_crossentropy_weighted,
                optimizer=keras.optimizers.Nadam(lr=LR),
                metrics=[]
            )

print("Model compiled...")

modelCheckpointer = keras.callbacks.ModelCheckpoint(
                                    filepath,
                                    monitor='val_loss',
                                    verbose=1,
                                    save_best_only=True,
                                    save_weights_only=False,
                                    mode='min',
                                    period=1)

earlyStopper = keras.callbacks.EarlyStopping(
                                          monitor='val_loss',
                                          min_delta=0.0001,
                                          patience=20,
                                          verbose=1,
                                          mode='min'
                                    )
lrReducer = keras.callbacks.ReduceLROnPlateau(
                                    monitor='val_loss',
                                    factor=0.1,
                                    patience=5,
                                    verbose=1,
                                    mode="min",
                                    min_delta=0.0001,
                                    cooldown=0,
                                    min_lr=0)

history = model.fit_generator(
        DataSequence(BSIZE, "train"),
        steps_per_epoch=TRAINSIZE,
        epochs=EPOCHS,
        verbose=1,
        use_multiprocessing=True,
        shuffle=True,
        validation_data=DataSequence(BSIZE,"validation"),
        validation_steps=VALSIZE,
        callbacks=[
            regularPred(),
            modelCheckpointer,
            earlyStopper,
            lrReducer
            ]
)



testloss = model.evaluate_generator(
    DataSequence(BSIZE,"test"),
    steps=TESTSIZE,
    use_multiprocessing=True,
    verbose=1
)

print(testloss)


In [None]:

#https://machinelearningmastery.com/display-deep-learning-model-training-history-in-keras/
import matplotlib.pyplot as plt
# summarize history for accuracy
#plt.plot(history.history['pixelAccuracy'])
#plt.plot(history.history['val_pixelAccuracy'])
#plt.title('pixel accuracy')
#plt.ylabel('accuracy')
#plt.xlabel('epoch')
#plt.legend(['train',"val"], loc='upper left')
#plt.show()
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', "val"], loc='upper left')
plt.show()
# summarize history for mean iou
#plt.plot(history.history['meanIoU'])
#plt.plot(history.history['val_meanIoU'])
#plt.title('model mean IoU')
#plt.ylabel('mean_iou')
#plt.xlabel('epoch')
#plt.legend(['train', "val"], loc='upper left')
#plt.show()
#plt.plot(history.history['meanAccuracy'])
#plt.plot(history.history['val_meanAccuracy'])
#plt.title('model meanAccuracy')
#plt.ylabel('meanAccuracy')
#plt.xlabel('epoch')
#plt.legend(['train', "val"], loc='upper left')
#plt.show()
#plt.plot(history.history['frequencyWeightedUI'])
#plt.plot(history.history['val_frequencyWeightedUI'])
#plt.title('model frequencyWeightedUI')
#plt.ylabel('frequencyWeightedUI')
#plt.xlabel('epoch')
#plt.legend(['train', "val"], loc='upper left')
#plt.show()

In [None]:
from data import Data
from metricsSemSeg import pixel_accuracy, mean_accuracy, mean_IU, frequency_weighted_IU

config = json.load(open("nets/netFCNConfig.json"))
# load data object initially which provides training and test data loader
data = Data("../data/"+DATASET+"/configData"+DATASET+".json")
totalCorrect = 0
totalCount = TESTSIZE*X*Y*3

totalPAcc = 0.0
totalMAcc = 0.0
totalMIU = 0.0
totalFWIU = 0.0

i = 0
model = getModel(NNName)

#for labelData, imgData in data.getNextBatchTest(BSIZE, TESTSIZE*BSIZE):
data = DataSequence(BSIZE, "test")
for idx in range(len(data)):
        x, y = data[idx]
    
        #print(imgData.shape, labelData.shape)

        predClasses = model.predict_on_batch(x)
        for b in range(BSIZE):
            pred = np.squeeze(predClasses[b])
            pred = np.argmax(pred, axis=2)
            labelData = np.squeeze(y[b])

            if i % 100/BSIZE == 0:
                print("Image ", i, " evaluated...")

            #print(predClasses.shape, labelData.shape)

            totalPAcc = pixel_accuracy(pred, labelData) if totalPAcc == 0.0 else  (totalPAcc + pixel_accuracy(pred, labelData))/2
            totalMAcc = mean_accuracy(pred, labelData) if totalMAcc == 0.0 else  (totalMAcc + mean_accuracy(pred, labelData))/2
            totalMIU = mean_IU(pred, labelData) if totalMIU == 0.0 else  (totalMIU + mean_IU(pred, labelData))/2
            totalFWIU = frequency_weighted_IU(pred, labelData) if totalFWIU == 0.0 else  (totalFWIU + frequency_weighted_IU(pred, labelData))/2

            i = i+1


print("Pixel accuracy: ", totalPAcc ," || Mean accuracy: ", totalMAcc ," || Mean intersection union:", totalMIU ," || frequency weighted IU: ", totalFWIU)

In [None]:
from data import Data
import matplotlib as mpl
#mpl.use('TkAgg')
import matplotlib.pyplot as plt
# load data object initially which provides training and test data loader
#data = Data("../data/"+DATASET+"/configData"+DATASET+".json")


max_images = 60
model = getModel(NNName)
#sanity check

x_valid = []
y_valid = []
x_preds = []
y_preds = []

print("loading data...")

data = DataSequence(BSIZE, "validation")
print(len(data))
for idx in range(int(max_images/BSIZE)):#data.getNextBatchValidation(BSIZE, max_images):# data.config["validationSize"]):
        x, y = data[idx]
        predClasses = model.predict_on_batch(x)
        print(predClasses.shape)
        for e in range(BSIZE):
            pred = np.squeeze(predClasses[e])
            pred = np.argmax(pred, axis=2)
            x_preds.append(pred.squeeze())
            x_valid.append(x[e].squeeze())
            y_valid.append(y[e].squeeze())


x_valid = np.array(x_valid)
y_valid = np.array(y_valid)
x_preds = np.array(x_preds)
y_preds = np.array(y_preds)

print(x_valid.shape, y_valid.shape, x_preds.shape, y_preds.shape)

grid_width = 15

grid_height = int(max_images / grid_width)
fig, axs = plt.subplots(grid_height, grid_width, figsize=(grid_width, grid_height))
for idx, i in enumerate(x_valid):
    img = (x_valid[idx] * 255).astype(np.uint8)
    mask = (y_valid[idx]  * 255).astype(np.uint8)
    ax = axs[int(idx / grid_width), idx % grid_width]
    #ax.imshow(img, cmap="Greys")
    ax.imshow(mask, alpha=0.6, cmap="Greens")
    #ax.imshow(pred, alpha=0.6, cmap="OrRd")
    ax.set_yticklabels([])
    ax.set_xticklabels([])


plt.suptitle("Green: salt")
plt.show()

#display predictions
grid_height = int(max_images / grid_width)
fig, axs = plt.subplots(grid_height, grid_width, figsize=(grid_width, grid_height))
for idx, i in enumerate(x_valid):
    img = (x_valid[idx] * 255).astype(np.uint8)
    pred = (x_preds[idx] * 255).astype(np.uint8)
    ax = axs[int(idx / grid_width), idx % grid_width]
    #ax.imshow(img, cmap="Greys")
    #ax.imshow(mask, alpha=0.6, cmap="Greens")
    ax.imshow(pred, alpha=0.6, cmap="OrRd")
    ax.set_yticklabels([])
    ax.set_xticklabels([])


plt.suptitle("Red: prediction")
plt.show()
