In [None]:
import numpy as np
import pandas as pd
import os
from pathlib import Path
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score
import tensorflow as tf
from tensorflow.keras.constraints import Constraint
from tensorflow.keras import Sequential, Model, Input 
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.layers import Dense, BatchNormalization, Dropout, Conv2D, Conv2DTranspose, MaxPooling2D, Flatten, LeakyReLU, Reshape, UpSampling2D, Concatenate

In [2]:
def U_net(input_shape):
    
    i = Input(shape=input_shape)
    
    en1 = Conv2D(64, (3, 3), padding='same', activation='relu')(i)
    
    x = MaxPooling2D(pool_size=(2, 2))(en1)
    
    en2 = Conv2D(128, (3, 3), padding='same', activation='relu')(x)
    
    x = MaxPooling2D(pool_size=(2, 2))(en2)

    en3 = Conv2D(256, (3, 3), padding='same', activation='relu')(x)
    
    x = MaxPooling2D(pool_size=(2, 2))(en3)

    en4 = Conv2D(512, (3, 3), padding='same', activation='relu')(x)
    
    x = MaxPooling2D(pool_size=(2, 2))(en4)
    
    bot = Conv2D(1024, (3, 3), padding='same', activation='relu')(x)

    x = UpSampling2D(size=(2, 2))(bot)
    conc1 = Concatenate()([en4, x])
    
    x = Conv2D(512, (3, 3), padding='same', activation='relu')(conc1)

    x = UpSampling2D(size=(2, 2))(x)
    conc2 = Concatenate()([en3, x])
    
    x = Conv2D(256, (3, 3), padding='same', activation='relu')(conc2)

    x = UpSampling2D(size=(2, 2))(x)
    conc3 = Concatenate()([en2, x])
    
    x = Conv2D(128, (3, 3), padding='same', activation='relu')(conc3)

    x = UpSampling2D(size=(2, 2))(x)
    conc4 = Concatenate()([en1, x])
    
    x = Conv2D(64, (3, 3), padding='same', activation='relu')(conc4)

    out = Conv2D(1, (1, 1), padding='same')(x)
    
    model = Model(i, out)
    
    return model

In [3]:
model = U_net((512, 512, 3))

In [None]:
model.summary()

In [None]:
tf.keras.utils.plot_model(model, show_shapes=True)

In [None]:
root = Path('./a-large-scale-fish-dataset/Fish_Dataset/Fish_Dataset')

In [6]:
def data_loader():
    

    sub_paths = root.glob(r'**/*.png')
    
    imgA = []
    imgB = []
    
    for path in sub_paths:

        label = os.path.split(os.path.split(path)[0])[1]

        if label[-2:] != 'GT':

            new_path = os.path.split(path)[0] + '/' +  os.path.split(path)[1]
            GT_path = os.path.split(path)[0] + ' GT/' + os.path.split(path)[1]

            imgA.append(new_path)
            imgB.append(GT_path)

    return list(zip(imgA, imgB))

In [7]:
data = data_loader()

In [9]:
def normalize(input_image, input_mask):
    input_image = tf.cast(input_image, tf.float32) / 255.0
    input_mask = tf.cast(input_mask, tf.float32) / 255.0
    
    return input_image, input_mask

In [None]:
def load_image(image_path):
    
    img0 = tf.io.read_file(image_path[0])
    img0 = tf.io.decode_image(img0, channels=3)

    img0.set_shape([None, None, 3])
    img0 = tf.image.resize(img0, [512, 512])
    img0 = tf.image.convert_image_dtype(img0, dtype=tf.float32)
    
    img1 = tf.io.read_file(image_path[1])
    img1 = tf.io.decode_image(img1, channels=1)
    
    img1.set_shape([None, None, 1])
    img1 = tf.image.resize(img1, [512, 512])    
    img1 = tf.image.convert_image_dtype(img1, dtype=tf.float32)
    
    img0, img1 = normalize(img0, img1)
    
    return img0, img1

dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.map(load_image)

In [None]:
image_count = len(list(root.glob('**/*.png')))

In [12]:
val_size = int((image_count / 2) * 0.2)
train_ds = dataset.skip(val_size)
val_ds = dataset.take(val_size)

In [13]:
model.compile(
    optimizer='adam',
    loss='mse',
    metrics=['accuracy']
)

In [15]:
def display():
    
    shit = model.predict(sample_image[tf.newaxis, ...])
    display_list = [sample_image, sample_mask, shit[0]]
    plt.figure(figsize=(15, 15))

    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()

In [None]:
sample_image = ''
sample_mask = ''

for img, mask in train_ds.take(1):
    sample_image = img
    sample_mask = mask
display()

In [17]:
train_ds = train_ds.batch(8)
val_ds = val_ds.batch(8)

In [None]:
model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10
)

In [None]:
display()