In [None]:
import os
import sys
import random
import warnings
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import cv2
from tqdm import tqdm_notebook, tnrange
from itertools import chain
from skimage.io import imread, imshow, concatenate_images
from skimage.transform import resize
from skimage.morphology import label
from keras.models import Model, load_model
from keras.layers import Input, Lambda, Conv2D, Conv2DTranspose, MaxPooling2D, concatenate 
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras import backend as K
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from tensorflow.keras.layers import Input, Convolution2D, BatchNormalization, LeakyReLU, Add, ReLU, GlobalAveragePooling2D, AveragePooling2D, UpSampling2D, Activation 
from tensorflow.keras.models import Model

In [None]:
im_width = 512
im_height = 512
im_chan = 1

path_train = '/kaggle/input/small-dataset/Train/'
path_test = '/kaggle/input/small-dataset/Test/'

train_ids = next(os.walk(path_train+"images"))[2]
test_ids = next(os.walk(path_test+"images"))[2]

In [None]:
X_train = np.zeros((len(train_ids), im_height, im_width, im_chan), dtype=np.uint8)
Y_train = np.zeros((len(train_ids), im_height, im_width, 1), dtype=bool)

print('Getting and resizing train images and masks ... ')

sys.stdout.flush()
for n, id_ in tqdm_notebook(enumerate(train_ids), total=len(train_ids)):
    path = path_train
    img = load_img(path + '/images/' + id_)
    
    x = img_to_array(img)[:,:,1]
    x = np.expand_dims(x, axis=-1)
    X_train[n] = x
    
    
    mask = img_to_array(load_img(path + '/masks/' + id_))[:,:,1]
    mask = np.expand_dims(mask, axis=-1)
    Y_train[n] = mask
    

print('Done!')

In [None]:
ix = random.randint(0, len(train_ids))
plt.figure(figsize=(10, 5))

plt.subplot(1, 3, 1)
plt.imshow(np.dstack((X_train[ix],X_train[ix],X_train[ix])))
plt.title("Image")

plt.subplot(1, 3, 2)
tmp = np.squeeze(Y_train[ix]).astype(np.float32)
plt.imshow(np.dstack((tmp,tmp,tmp)))
plt.title("Ground truth")

plt.show()

In [None]:
from tensorflow.keras import metrics
from tensorflow.keras.utils import register_keras_serializable

@register_keras_serializable()
class MeanIoUMetric(metrics.Metric):
    
    def __init__(self, num_classes, name='mean_iou', **kwargs):
        super(MeanIoUMetric, self).__init__(name=name, **kwargs)
        self.num_classes = num_classes  # Store num_classes as an attribute
        self.iou_metric = metrics.MeanIoU(num_classes=num_classes) 

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred_ = tf.cast(y_pred > 0.5, tf.int32)  
        self.iou_metric.update_state(y_true, y_pred_)

    def result(self):
        return self.iou_metric.result()

    def reset_state(self):
        self.iou_metric.reset_state()
        
    @classmethod
    def from_config(cls, config):
        return cls(**config) 

    def get_config(self):
        config = super().get_config()
        config.update({"num_classes": self.num_classes})  
        return config
    
    
mean_iou_metric = MeanIoUMetric(num_classes=2)

In [None]:
from tensorflow.keras.layers import Dropout
from tensorflow.keras.regularizers import l1_l2

def simplified_conv_block(X, filters, block, dropout_rate=0.3, l1_reg=1e-6, l2_reg=1e-4):
    b = 'block_' + str(block) + '_'
    f1, f2 = filters
    X_skip = X
    
    
    X = Convolution2D(filters=f1, kernel_size=(3, 3), padding='same', name=b + 'a',
                      kernel_regularizer=l1_l2(l1=l1_reg, l2=l2_reg))(X)
    X = BatchNormalization(name=b + 'batch_norm_a')(X)
    X = LeakyReLU(alpha=0.2, name=b + 'leakyrelu_a')(X)
    X = Dropout(dropout_rate, name=b + 'dropout_a')(X)
    
    X = Convolution2D(filters=f2, kernel_size=(3, 3), padding='same', name=b + 'b',
                      kernel_regularizer=l1_l2(l1=l1_reg, l2=l2_reg))(X)
    X = BatchNormalization(name=b + 'batch_norm_b')(X)
    X = Dropout(dropout_rate, name=b + 'dropout_b')(X)
    
    X_skip = Convolution2D(filters=f2, kernel_size=(1, 1), padding='same', name=b + 'skip_conv',
                           kernel_regularizer=l1_l2(l1=l1_reg, l2=l2_reg))(X_skip)
    X_skip = BatchNormalization(name=b + 'batch_norm_skip_conv')(X_skip)
    
    X = Add(name=b + 'add')([X, X_skip])
    X = ReLU(name=b + 'relu')(X)
    return X

def simplified_base_feature_maps(input_layer, dropout_rate=0.3, l1_reg=1e-6, l2_reg=1e-4):
    
    base = simplified_conv_block(input_layer, [32, 64], '1', dropout_rate, l1_reg, l2_reg)
    base = simplified_conv_block(base, [64, 128], '2', dropout_rate, l1_reg, l2_reg)
    return base

def simplified_pyramid_feature_maps(input_layer, dropout_rate=0.3, l1_reg=1e-6, l2_reg=1e-4):
    base = simplified_base_feature_maps(input_layer, dropout_rate, l1_reg, l2_reg)
    
   
    yellow = AveragePooling2D(pool_size=(2, 2), name='yellow_pool')(base)
    yellow = Convolution2D(filters=32, kernel_size=(1, 1), name='yellow_1_by_1',
                           kernel_regularizer=l1_l2(l1=l1_reg, l2=l2_reg))(yellow)
    yellow = Dropout(dropout_rate, name='yellow_dropout')(yellow)
    yellow = UpSampling2D(size=(2, 2), interpolation='bilinear', name='yellow_upsampling')(yellow)
    
    blue = AveragePooling2D(pool_size=(4, 4), name='blue_pool')(base)
    blue = Convolution2D(filters=32, kernel_size=(1, 1), name='blue_1_by_1',
                         kernel_regularizer=l1_l2(l1=l1_reg, l2=l2_reg))(blue)
    blue = Dropout(dropout_rate, name='blue_dropout')(blue)
    blue = UpSampling2D(size=(4, 4), interpolation='bilinear', name='blue_upsampling')(blue)
    
    return tf.keras.layers.concatenate([base, yellow, blue])

def simplified_last_conv_module(input_layer, dropout_rate=0.3, l1_reg=1e-6, l2_reg=1e-4):
    X = simplified_pyramid_feature_maps(input_layer, dropout_rate, l1_reg, l2_reg)
    X = Convolution2D(filters=1, kernel_size=(3, 3), padding='same', name='last_conv_3_by_3',
                      kernel_regularizer=l1_l2(l1=l1_reg, l2=l2_reg))(X)
    X = BatchNormalization(name='last_conv_3_by_3_batch_norm')(X)
    X = Dropout(dropout_rate, name='last_conv_dropout')(X)
    X = Activation('sigmoid', name='last_conv_sigmoid')(X)
    return X

input_layer = Input(shape=(im_height, im_width, im_chan))
output_layer = simplified_last_conv_module(input_layer)
simplified_pspnet_model = Model(inputs=input_layer, outputs=output_layer)


simplified_pspnet_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[mean_iou_metric])
simplified_pspnet_model.summary()

In [None]:
earlystopper = EarlyStopping(patience=10, verbose=2)
checkpointer = ModelCheckpoint('PSPNET.keras', verbose=1, save_best_only=True)
results = pspnet_model.fit(X_train, Y_train, validation_split=0.2, batch_size=2, epochs=150,
                           callbacks=[earlystopper, checkpointer])

In [None]:
import matplotlib.pyplot as plt

colors = {
    "training_loss": "#0072B2",  # Blue
    "validation_loss": "#D55E00",  # Red
    "training_iou": "#009E73",  # Green
    "validation_iou": "#CC79A7"  # Purple
}


plt.figure(figsize=(12, 6))
fig, ax1 = plt.subplots(figsize=(12, 6))


ax1.plot(results.history['loss'], label='Training Loss', color=colors["training_loss"], linewidth=2)
ax1.plot(results.history['val_loss'], label='Validation Loss', color=colors["validation_loss"], linewidth=2, linestyle='--')
ax1.set_xlabel('Epochs', fontsize=14)
ax1.set_ylabel('Loss', fontsize=14, color=colors["training_loss"])
ax1.tick_params(axis='y', labelcolor=colors["training_loss"])


ax2 = ax1.twinx()
ax2.plot(results.history['mean_iou'], label='Training Mean IoU', color=colors["training_iou"], linewidth=2)
ax2.plot(results.history['val_mean_iou'], label='Validation Mean IoU', color=colors["validation_iou"], linewidth=2, linestyle='--')
ax2.set_ylabel('Mean IoU', fontsize=14, color=colors["training_iou"])
ax2.tick_params(axis='y', labelcolor=colors["training_iou"])

ax1.grid(False)  
ax2.grid(False)


plt.title('Training and Validation Metrics', fontsize=16)
plt.grid(True, linestyle='--', alpha=0.7)

lines, labels = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
plt.legend(lines + lines2, labels + labels2, loc='lower left', ncol=2, fontsize=12) 

ax1.set_ylim(0, 1) 

plt.savefig('PSP_Net.png', dpi=600)
plt.show()

In [None]:
# Get and resize test images and masks
X_test = np.zeros((len(test_ids), im_height, im_width, im_chan), dtype=np.uint8)
Y_test = np.zeros((len(test_ids), im_height, im_width, 1), dtype=bool) 



print('Getting and resizing test images and masks ... ')
sys.stdout.flush()

for n, id_ in tqdm_notebook(enumerate(test_ids), total=len(test_ids)):
    path = path_test
    img = load_img(path + '/images/' + id_)
    x = img_to_array(img)[:,:,1]
    x = np.expand_dims(x, axis=-1)
    X_test[n] = x


    mask = img_to_array(load_img(path + '/masks/' + id_))[:,:,1]
    mask = np.expand_dims(mask, axis=-1)
    Y_test[n] = mask

print('Done!')

In [None]:
model = tf.keras.models.load_model(
    '/kaggle/working/PSPNET.keras',
    custom_objects={'MeanIoUMetric': MeanIoUMetric} ,safe_mode=False
)


preds_test = model.predict(X_test, verbose=1, batch_size=2)
preds_test_t = (preds_test > 0.5).astype(np.uint8)

In [None]:
#Perform a sanity check on some random Test Samples

import matplotlib.pyplot as plt
import numpy as np
import random

 
plt.style.use('seaborn-whitegrid') 

fig, axes = plt.subplots(3, 3, figsize=(15, 12))  
fig.subplots_adjust(hspace=0.3, wspace=-0.5)

for i in range(3):
    if i ==0 :
        axes[i, 0].set_title("Image", fontsize=16) 
        axes[i, 1].set_title("Ground Truth", fontsize=16)
        axes[i, 2].set_title("Predicted", fontsize=16)
        
        
    ix = random.randint(0, len(preds_test_t) - 1) 

    # Image
    axes[i, 0].imshow(np.dstack((X_test[ix], X_test[ix], X_test[ix])))
    
    axes[i, 0].axis('off') 

    
    # Ground Truth
    im_gt = axes[i, 1].imshow(Y_test[ix], cmap='inferno')
    
    axes[i, 1].axis('off')
    

    
    # Prediction
    im_pred = axes[i, 2].imshow(preds_test_t[ix], cmap='inferno')
    
    axes[i, 2].axis('off')
    

plt.savefig("prediction_visualization.png", dpi=600) 
plt.show()