In [None]:
import os
import numpy as np
from PIL import Image, ImageOps
import tensorflow as tf
from numpy import asarray
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, UpSampling2D, concatenate
import matplotlib.pyplot as plt
# from tensorflow.keras.optimizers import SGD

print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
from tensorflow.keras.optimizers import Adam
from keras.callbacks import EarlyStopping
# from tensorflow.keras.preprocessing.image import ImageDataGenerator
physical_devices = tf.config.list_physical_devices('GPU')


In [None]:
# set the image and mask directory paths
img_dir = 'Images'
mask_dir = 'labels'

# initialize empty lists to store the data
images = []
masks = []

# loop over each image file and load the corresponding mask
for filename in os.listdir(img_dir):
    if filename.endswith('.png'):
        # load the image and resize it
        img = Image.open(os.path.join(img_dir, filename)).convert('RGB')
        img = img.resize((256, 256))

        # load the mask and resize it
        mask = Image.open(os.path.join(mask_dir, filename))
        mask = mask.resize((256, 256))

        # normalize the image pixel values to be between 0 and 1
        img = np.array(img) / 255.0
        mask = np.array(mask) /255.0

        # stack the image and mask together
        # img_mask = np.dstack((img, mask))
        # add the data to the corresponding lists
        images.append(img)
        masks.append(mask)

# convert the data to numpy arrays
images = np.array(images)
masks = np.array(masks)

# split the data into training and testing sets
X_train, X_test, Y_train, Y_test = train_test_split(images, masks, test_size=0.20)

In [None]:
batch_size = 5
lr=3e-4
# sgd = SGD(lr=0.002, momentum=0.9)
def multi_res_unet(input_shape):
    # Input layer
    inputs = Input(shape=input_shape)
        # Encoder
    conv1 = Conv2D(64, 3, activation='relu', padding='same')(inputs)
    conv1 = Conv2D(64, 3, activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(128, 3, activation='relu', padding='same')(pool1)
    conv2 = Conv2D(128, 3, activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(256, 3, activation='relu', padding='same')(pool2)
    conv3 = Conv2D(256, 3, activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(512, 3, activation='relu', padding='same')(pool3)
    conv4 = Conv2D(512, 3, activation='relu', padding='same')(conv4)
    drop4 = Dropout(0.6)(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)

    # Bottom
    conv5 = Conv2D(1024, 3, activation='relu', padding='same')(pool4)
    conv5 = Conv2D(1024, 3, activation='relu', padding='same')(conv5)
    drop5 = Dropout(0.5)(conv5)

    # Decoder
    up6 = UpSampling2D(size=(2, 2))(drop5)
    up6 = Conv2D(512, 2, activation='relu', padding='same')(up6)
    merge6 = concatenate([drop4, up6], axis=3)
    conv6 = Conv2D(512, 3, activation='relu', padding='same')(merge6)
    conv6 = Conv2D(512, 3, activation='relu', padding='same')(conv6)

    up7 = UpSampling2D(size=(2, 2))(conv6)
    up7 = Conv2D(256, 2, activation='relu', padding='same')(up7)
    merge7 = concatenate([conv3, up7], axis=3)
    conv7 = Conv2D(256, 3, activation='relu', padding='same')(merge7)
    conv7 = Conv2D(256, 3, activation='relu', padding='same')(conv7)

    up8 = UpSampling2D(size=(2, 2))(conv7)
    up8 = Conv2D(128, 2, activation='relu', padding='same')(up8)
    merge8 = concatenate([conv2, up8], axis=3)
    conv8 = Conv2D(128, 3, activation='relu', padding='same')(merge8)
    conv8 = Conv2D(128, 3, activation='relu', padding='same')(conv8)

    up9 = UpSampling2D(size=(2, 2))(conv8)
    up9 = Conv2D(64, 2, activation='relu', padding='same')(up9)
    merge9 = concatenate([conv1, up9], axis=3)
    conv9 = Conv2D(64, 3, activation='relu', padding='same')(merge9)
    conv9 = Conv2D(64, 3, activation='relu', padding='same')(conv9)
    conv9 = Conv2D(64, 3, activation='relu', padding='same')(conv9)
    conv9 = Conv2D(2, 3, activation='relu', padding='same')(conv9)
    out = Conv2D(1, 1, activation='sigmoid')(conv9)

    model = Model(inputs=inputs, outputs=out)
    model.compile(optimizer=Adam(lr=lr), loss='binary_crossentropy',metrics=['accuracy'])
    return model
tf.keras.backend.clear_session()

In [None]:
# Define input image dimensions
early_stop = EarlyStopping(monitor='val_loss', patience=10)
img_width = 256
img_height = 256
img_channels = 3

# Define input shape
input_shape = (img_height, img_width, img_channels)

# Define the U-Net architecture
# inputs = Input(shape=input_shape)

model = multi_res_unet(input_shape)

In [None]:
# Train the model on the training set
epochs = 120
history = model.fit(X_train,Y_train, epochs=epochs, validation_data=(X_test,Y_test), batch_size=batch_size,callbacks=[early_stop])

In [None]:
path_str = f"Ploted_Data/e{epochs}_lr{lr}_bs{batch_size}.png"
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Accuracy and Loss')
plt.ylabel('Accuracy/Loss')
plt.xlabel('Epoch')
plt.legend(['Accuracy', 'Validation Accuracy', 'Loss', 'Validation Loss'], loc='upper left')
plt.savefig(path_str)
plt.clf()

In [None]:
img_dir = 'Testing'
test = []

# loop over each image file and load the corresponding mask
for filename in os.listdir(img_dir):
    if filename.endswith('.png'):
        # load the image and resize it
        img = Image.open(os.path.join(img_dir, filename)).convert('RGB')
        img = img.resize((256, 256))
        # normalize the image pixel values to be between 0 and 1
        img = np.array(img)/255

        test.append(img)

# convert the data to numpy arrays
test = np.array(test)


In [None]:
y_pred = model.predict(test)
y_pred[y_pred<np.mean(y_pred)]=0
y_pred[y_pred>=np.mean(y_pred)]=1
# Plot the original images and the corresponding segmentation masks
for i in range(len(test)):
    path_str = f"Ploted_Data/result{i}.png" 
    fig, axs = plt.subplots(1, 2)
    axs[0].imshow(test[i], cmap='gray')
    axs[1].imshow(y_pred[i,:,:,0], cmap='gray')
    # plt.show()
    plt.savefig(path_str)
    plt.clf()