# Libraries and setup

In [None]:
import pandas as pd
import numpy as np
import os
import cv2
import gc
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
from tqdm.notebook import tqdm
from datetime import datetime
import json,itertools
from typing import Optional
from glob import glob
import warnings
warnings.filterwarnings("ignore")
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
import matplotlib as mpl
from sklearn.model_selection import StratifiedKFold, KFold, StratifiedGroupKFold
import random

from tensorflow import keras
import tensorflow as tf
import keras
from keras import backend as K
from keras.models import Model
from keras.losses import binary_crossentropy
from keras.callbacks import Callback, ModelCheckpoint, EarlyStopping
from keras.models import load_model, save_model
from matplotlib.patches import Rectangle
from tensorflow.keras import layers


**Reproducibility**

In [None]:
# Set random seeds
def set_seed(seed=0):
    np.random.seed(seed)
    random.seed(seed)
    tf.random.set_seed(seed)
set_seed(seed=42)

In [None]:
print(f"\n... ACCELERATOR SETUP STARTING ...\n")

try:
    TPU = tf.distribute.cluster_resolver.TPUClusterResolver()  
except ValueError:
    TPU = None

if TPU:
    print(f"\n... RUNNING ON TPU - {TPU.master()}...")
    tf.config.experimental_connect_to_cluster(TPU)
    tf.tpu.experimental.initialize_tpu_system(TPU)
    strategy = tf.distribute.experimental.TPUStrategy(TPU)
else:
    print(f"\n... RUNNING ON CPU/GPU ...")
    strategy = tf.distribute.get_strategy() 

N_REPLICAS = strategy.num_replicas_in_sync
print(f"... # OF REPLICAS: {N_REPLICAS} ...\n")
print(f"\n... ACCELERATOR SETUP COMPLTED ...\n")

**Config**

In [None]:
BATCH_SIZE = 16
EPOCHS = 30
n_splits = 5
fold_selected = 2   
IMAGE_WIDTH = 256
IMAGE_HEIGHT = 256

# Data

To import dataset in kaggle click on + Add data in top right corner -> go to competitions and choose the uw-madison-gi-tract-image-segmentation datset

In [None]:
print("\n... DATA ACCESS SETUP STARTED ...\n")

DATA_DIR = "/kaggle/input/uw-madison-gi-tract-image-segmentation"
save_locally = None
load_locally = None

# if TPU:
#     # Google Cloud Dataset path to training and validation images
#     DATA_DIR = KaggleDatasets().get_gcs_path('uw-madison-gi-tract-image-segmentation')
#     save_locally = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
#     load_locally = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
# else:
#     DATA_DIR = "/kaggle/input/uw-madison-gi-tract-image-segmentation"
#     save_locally = None
#     load_locally = None

**Train set**

In [None]:
TRAIN_DIR = os.path.join(DATA_DIR, "train")
TRAIN_CSV = os.path.join(DATA_DIR, "train.csv")
train_df = pd.read_csv(TRAIN_CSV)
print(train_df.shape)
train_df.head()

**Test set**

In [None]:
TEST_CSV = os.path.join(DATA_DIR, 'sample_submission.csv')
test_df = pd.read_csv(TEST_CSV)

if len(test_df)==0:
    print("Debug is true")
    DEBUG=True
    test_df = train_df[115488-300:]
    #test_df["segmentation"]=''
    #test_df=test_df.rename(columns={"segmentation":"prediction"})
else:
    DEBUG=False


submission=test_df.copy()
submission["segmentation"]=''
submission=submission.rename(columns={"segmentation":"prediction"})
print(test_df.head())
print(submission.head())

# Preprocessing

In [None]:
def preprocessing(df, subset="train"):
    df["case"] = df["id"].apply(lambda x: int(x.split("_")[0].replace("case", "")))
    df["day"] = df["id"].apply(lambda x: int(x.split("_")[1].replace("day", "")))
    df["slice"] = df["id"].apply(lambda x: x.split("_")[3])
    if (subset=="train") or (DEBUG):
        DIR="../input/uw-madison-gi-tract-image-segmentation/train"
    else:
        DIR="../input/uw-madison-gi-tract-image-segmentation/test"
    
    all_images = glob(os.path.join(DIR, "**", "*.png"), recursive=True)
    x = all_images[0].rsplit("/", 4)[0] 

    path_partial_list = []
    for i in range(0, df.shape[0]):
        path_partial_list.append(os.path.join(x,
                              "case"+str(df["case"].values[i]),
                              "case"+str(df["case"].values[i])+"_"+ "day"+str(df["day"].values[i]),
                              "scans",
                              "slice_"+str(df["slice"].values[i])))
    df["path_partial"] = path_partial_list
    path_partial_list = []
    for i in range(0, len(all_images)):
        path_partial_list.append(str(all_images[i].rsplit("_",4)[0]))

    tmp_df = pd.DataFrame()
    tmp_df['path_partial'] = path_partial_list
    tmp_df['path'] = all_images

    df = df.merge(tmp_df, on="path_partial").drop(columns=["path_partial"])
    df["width"] = df["path"].apply(lambda x: int(x[:-4].rsplit("_",4)[1]))
    df["height"] = df["path"].apply(lambda x: int(x[:-4].rsplit("_",4)[2]))
    del x, path_partial_list, tmp_df
    
    return df

In [None]:
def restructure(df, subset="train"):
    # RESTRUCTURE  DATAFRAME
    df_out = pd.DataFrame({'id': df['id'][::3]})
    if True:
        df_out['large_bowel'] = df['segmentation'][::3].values
        df_out['small_bowel'] = df['segmentation'][1::3].values
        df_out['stomach'] = df['segmentation'][2::3].values


    if subset=="train":
        df_out['large_bowel'] = df['segmentation'][::3].values
        df_out['small_bowel'] = df['segmentation'][1::3].values
        df_out['stomach'] = df['segmentation'][2::3].values

    df_out['path'] = df['path'][::3].values
    df_out['case'] = df['case'][::3].values
    df_out['day'] = df['day'][::3].values
    df_out['slice'] = df['slice'][::3].values
    df_out['width'] = df['width'][::3].values
    df_out['height'] = df['height'][::3].values

    df_out=df_out.reset_index(drop=True)
    df_out=df_out.fillna('')
    if subset=="train":
        df_out['count'] = np.sum(df_out.iloc[:,1:4]!='',axis=1).values
    
    return df_out

In [None]:
print(train_df.shape)
train_df = train_df[:115488-300]

In [None]:
train_df = preprocessing(train_df, subset="train")
print(train_df.shape)
train_df.head()

In [None]:
test_df=preprocessing(test_df, subset="test")
print(test_df.shape)
test_df.head()

In [None]:
train_df=restructure(train_df, subset="train")
train_df.head()

In [None]:
test_df=restructure(test_df, subset="test")
test_df.head()

In [None]:
# Remove mislabeled training data
train_df = train_df[(train_df['case']!=7)|(train_df['day']!=0)].reset_index(drop=True)
train_df = train_df[(train_df['case']!=81)|(train_df['day']!=30)].reset_index(drop=True)

In [None]:
# Garbage collection
gc.collect()

In [None]:
print(train_df.shape, test_df.shape)

# Helper functions

In [None]:
def rle_decode(mask_rle, shape, color=1):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background
    '''
    s = np.array(mask_rle.split(), dtype=int)   
    starts = s[0::2] - 1 # Every even value is the start, every odd value is the "run" length
    lengths = s[1::2]
    ends = starts + lengths
    if len(shape)==3:
        h, w, d = shape
        img = np.zeros((h * w, d), dtype=np.float32)
    else:
        h, w = shape
        img = np.zeros((h * w,), dtype=np.float32)
    for lo, hi in zip(starts, ends):
        img[lo : hi] = color    
    return img.reshape(shape)    

In [None]:
#run length encoding
def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [None]:
# Metrics
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def iou_coef(y_true, y_pred, smooth=1):
    intersection = K.sum(K.abs(y_true * y_pred), axis=[1,2,3])
    union = K.sum(y_true,[1,2,3])+K.sum(y_pred,[1,2,3])-intersection
    iou = K.mean((intersection + smooth) / (union + smooth), axis=0)
    return iou

def dice_loss(y_true, y_pred):
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = y_true_f * y_pred_f
    score = (2. * K.sum(intersection) + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    return 1. - score

def bce_dice_loss(y_true, y_pred):
    return binary_crossentropy(tf.cast(y_true, tf.float32), y_pred) + dice_loss(tf.cast(y_true, tf.float32), y_pred)

def tversky(y_true, y_pred):
    smooth = 1
    y_true_pos = K.flatten(y_true)
    y_pred_pos = K.flatten(y_pred)
    true_pos = K.sum(y_true_pos * y_pred_pos)
    false_neg = K.sum(y_true_pos * (1-y_pred_pos))
    false_pos = K.sum((1-y_true_pos)*y_pred_pos)
    alpha = 0.7
    return (true_pos + smooth)/(true_pos + alpha*false_neg + (1-alpha)*false_pos + smooth)

def focal_tversky(y_true,y_pred):
    pt_1 = tversky(y_true, y_pred)
    gamma = 0.75
    return K.pow((1-pt_1), gamma)

def tversky_loss(y_true, y_pred):
    return 1 - tversky(y_true,y_pred)


In [None]:
class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, df, batch_size = BATCH_SIZE, subset="train", shuffle=False, width=IMAGE_WIDTH, height=IMAGE_HEIGHT):
        super().__init__()
        self.df = df
        self.shuffle = shuffle
        self.subset = subset
        self.batch_size = batch_size
        self.indexes = np.arange(len(df))
        self.width=width
        self.height=height
        self.on_epoch_end()
        
    def __augment_image(self, X, y):
        updown_rand = np.random.rand()
        leftright_rand = np.random.rand()
        if (updown_rand > 0.5):
            X = tf.image.flip_up_down(X)
            y = tf.image.flip_up_down(y)
        if (leftright_rand > 0.5):
            X = tf.image.flip_left_right(X)
            y = tf.image.flip_left_right(y)
        return X, y

    def __len__(self):
        return int(np.floor(len(self.df) / self.batch_size))
    
    def on_epoch_end(self):
        if self.shuffle == True:
            np.random.shuffle(self.indexes)
    
    def __getitem__(self, index):
        X = np.empty((self.batch_size,self.width,self.height,3))
        y = np.empty((self.batch_size,self.width,self.height,3))
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        for i, img_path in enumerate(self.df['path'].iloc[indexes]):
            w=self.df['width'].iloc[indexes[i]]
            h=self.df['height'].iloc[indexes[i]]
            img = self.__load_grayscale(img_path)
            X[i,] =img  
            if self.subset == 'train':
                for k,j in enumerate(["large_bowel","small_bowel","stomach"]):
                    rles = self.df[j].iloc[indexes[i]]
                    mask = rle_decode(rles, shape=(h, w, 1))
                    mask = cv2.resize(mask, (self.width, self.height))
                    y[i,:,:,k] = mask
        if self.subset == 'train':
            #return self.__augment_image(X, y)
            return X,y
        else: 
            return X
        
    def __load_grayscale(self, img_path):
        img = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH)
        dsize = (self.width, self.height)
        img = cv2.resize(img, dsize)
        #img = img.astype(np.float32) / 255.
        img = ((img-img.min())/(img.max()-img.min())).astype(np.float32)
        img = np.expand_dims(img, axis=-1)
        return img

# EDA

In [None]:
def open_gray16(_path, normalize=True, to_rgb=False):
    if normalize:
        if to_rgb:
            return np.tile(np.expand_dims(cv2.imread(_path, cv2.IMREAD_ANYDEPTH)/255., axis=-1), 3)
        else:
            return cv2.imread(_path, cv2.IMREAD_ANYDEPTH)/255.
    else:
        if to_rgb:
            return np.tile(np.expand_dims(cv2.imread(_path, cv2.IMREAD_ANYDEPTH), axis=-1), 3)
        else:
            return cv2.imread(_path, cv2.IMREAD_ANYDEPTH)

In [None]:
def examine_id(DEMO_ID, seg_masks=False):
  demo_ex = train_df[train_df.id==DEMO_ID].squeeze()
  display(demo_ex.to_frame())

  print(f"\n\n...IMAGE ...\n")
  plt.figure(figsize=(6,6))
  plt.imshow(open_gray16(demo_ex.path), cmap="gray")
  plt.title(f"Original Grayscale Image For ID: {demo_ex.id}", fontweight="bold")
  plt.axis(False)
  plt.show()

  if(seg_masks):
    print(f"\n\n... 3 SEGMENTATION MASKS ...\n")

    plt.figure(figsize=(14,7))
    for i, _seg_type in enumerate(["large_bowel", "small_bowel", "stomach"]):
        if pd.isna(demo_ex[f"{_seg_type}"]): continue
        plt.subplot(1,3,i+1)
        s = demo_ex[f"{_seg_type}"]
        plt.imshow(rle_decode(demo_ex[f"{_seg_type}"], shape=(demo_ex.height, demo_ex.width), color=1))
        plt.title(f"RLE Encoding For {_seg_type} Segmentation", fontweight="bold")
        plt.axis(False)
    plt.tight_layout()
    plt.show()

  print(f"\n\n...IMAGE WITH AN RGB SEGMENTATION MASK OVERLAY ...\n")

  _img = open_gray16(demo_ex.path, to_rgb=True)
  _img = ((_img-_img.min())/(_img.max()-_img.min())).astype(np.float32)
  _seg_rgb = np.stack([rle_decode(demo_ex[f"{_seg_type}"], shape=(demo_ex.height, demo_ex.width), color=1) if not pd.isna(demo_ex[f"{_seg_type}"]) else np.zeros((demo_ex.height, demo_ex.width)) for _seg_type in ["large_bowel", "small_bowel", "stomach"]], axis=-1).astype(np.float32)
  seg_overlay = cv2.addWeighted(src1=_img, alpha=0.99, src2=_seg_rgb, beta=0.33, gamma=0.0)

  plt.figure(figsize=(6,6))
  plt.imshow(seg_overlay)
  plt.title(f"Segmentation Overlay For ID: {demo_ex.id}", fontweight="bold")
  handles = [Rectangle((0,0),1,1, color=_c) for _c in [(0.667,0.0,0.0), (0.0,0.667,0.0), (0.0,0.0,0.667)]]
  labels = ["Large Bowel Segmentation Map", "Small Bowel Segmentation Map", "Stomach Segmentation Map"]
  plt.legend(handles,labels)
  plt.axis(False)
  plt.show()


In [None]:
print("\n... SINGLE ID EXPLORATION ...\n\n")
DEMO_ID = "case123_day20_slice_0082"
examine_id(DEMO_ID, seg_masks=True)

# Cross-validation

In [None]:
# Group by case id
skf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
for fold, (_, val_idx) in enumerate(skf.split(X=train_df, y=train_df['count'], groups=train_df['case']), 1):
    train_df.loc[val_idx, 'fold'] = fold

train_df['fold'] = train_df['fold'].astype(np.uint8)

train_ids = train_df[train_df["fold"]!=fold_selected].index
valid_ids = train_df[train_df["fold"]==fold_selected].index

X_train = train_df[train_df.index.isin(train_ids)]
X_valid = train_df[train_df.index.isin(valid_ids)]

train_df.groupby('fold').size()

In [None]:
# Fold sizes
train_df.groupby(['fold','count'])['id'].count()

# Model

In [None]:
#Data generators 
train_generator = DataGenerator(X_train, shuffle=True)
val_generator = DataGenerator(X_valid)

In [None]:
def convolution_block(
    block_input,
    num_filters=256,
    kernel_size=3,
    dilation_rate=1,
    padding="same",
    use_bias=False,
):
    x = layers.Conv2D(
        num_filters,
        kernel_size=kernel_size,
        dilation_rate=dilation_rate,
        padding="same",
        use_bias=use_bias,
        kernel_initializer=keras.initializers.HeNormal(),
    )(block_input)
    x = layers.BatchNormalization()(x)
    return tf.nn.relu(x)


def DilatedSpatialPyramidPooling(dspp_input):
    dims = dspp_input.shape
    x = layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
    x = convolution_block(x, kernel_size=1, use_bias=True)
    out_pool = layers.UpSampling2D(
        size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), interpolation="bilinear",
    )(x)

    out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
    out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
    out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
    out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)

    x = layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
    output = convolution_block(x, kernel_size=1)
    return output


In [None]:
def DeeplabV3Plus(image_size, num_classes,dropout = 0.2):
    model_input = keras.Input(shape=(image_size, image_size, 3))
    #model_input1 = tf.keras.layers.RandomFlip("horizontal_and_vertical")(model_input)
    # model_input1 = tf.keras.layers.RandomRotation(0.2)(model_input1)
    #model_input = data_augmentation(model_input)
    resnet50 = tf.keras.applications.ResNet50(
        weights="imagenet", include_top=False, input_tensor=model_input
    )
    '''print("LAYERS OF RESNET are")
    for layer in resnet50.layers:
        print(layer.name)'''
    x = resnet50.get_layer("conv4_block6_2_relu").output
    x = tf.keras.layers.Dropout(dropout)(x)
    
    x = DilatedSpatialPyramidPooling(x)

    input_a = layers.UpSampling2D(
        size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]),
        interpolation="bilinear",
    )(x)
    input_b = resnet50.get_layer("conv2_block3_2_relu").output
    input_b = tf.keras.layers.Dropout(dropout)(input_b)
    input_b = convolution_block(input_b, num_filters=48, kernel_size=1)

    x = layers.Concatenate(axis=-1)([input_a, input_b])
    x = convolution_block(x)
    x = convolution_block(x)
    x = layers.UpSampling2D(
        size=(image_size // x.shape[1], image_size // x.shape[2]),
        interpolation="bilinear",
    )(x)
    x = tf.keras.layers.Dropout(dropout/2)(x)
    model_output = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
    model_output = layers.Activation(tf.keras.activations.sigmoid)(model_output)

    return keras.Model(inputs=model_input, outputs=model_output)


model = DeeplabV3Plus(image_size=IMAGE_HEIGHT, num_classes=3)
model.summary()


In [None]:
import tensorflow_addons as tfa; print(f"\t\t– TENSORFLOW ADDONS VERSION: {tfa.__version__}");

OPTIMIZER = tf.keras.optimizers.Adam(learning_rate=0.00075)
LOSS = focal_tversky
#LOSS = bce_dice_loss
#tfa.losses.SigmoidFocalCrossEntropy(from_logits=True)
#tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) #bce_dice_loss
#METRICS =[dice_coef,iou_coef,compute_hausdorff_monai]
METRICS =[dice_coef,iou_coef]

EPOCHS = 25

In [None]:
model.compile(optimizer=OPTIMIZER, loss=LOSS, metrics=METRICS)

checkpoint = ModelCheckpoint(
    'deep_labv3_model',
    monitor='val_loss',
    verbose=1,
    save_best_only=True,
    save_weights_only=False,
    mode='auto',
)

early_stopping = EarlyStopping(
    patience=5,
    min_delta=0.0001,
    restore_best_weights=True,
)

def scheduler(epoch, lr):
    if epoch < 18:
        return lr
    else:
        return lr * tf.math.exp(-0.1)

lr_scheduler = tf.keras.callbacks.LearningRateScheduler(scheduler)

#loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
''''model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss=loss,
    metrics=["accuracy"],
)
'''
history = model.fit(
    train_generator,
    validation_data=val_generator,
    callbacks=[checkpoint,lr_scheduler],
    use_multiprocessing=False,
    workers=4,
    epochs=EPOCHS
)
#history = model.fit(train_dataset, validation_data=val_dataset, epochs=25)


In [None]:
# History
hist_df = pd.DataFrame(history.history)
hist_df.to_csv('history.csv')

# PLOT TRAINING
plt.figure(figsize=(15,5))
plt.subplot(1,3,1)
plt.plot(range(history.epoch[-1]+1),history.history['loss'],label='Train_Loss')
plt.plot(range(history.epoch[-1]+1),history.history['val_loss'],label='Val_loss')
plt.title('LOSS'); plt.xlabel('Epoch'); plt.ylabel('loss');plt.legend();

plt.subplot(1,3,2)
plt.plot(range(history.epoch[-1]+1),history.history['dice_coef'],label='Train_dice_coef')
plt.plot(range(history.epoch[-1]+1),history.history['val_dice_coef'],label='Val_dice_coef')
plt.title('DICE'); plt.xlabel('Epoch'); plt.ylabel('dice_coef');plt.legend(); 

plt.subplot(1,3,3)
plt.plot(range(history.epoch[-1]+1),history.history['iou_coef'],label='Train_iou_coef')
plt.plot(range(history.epoch[-1]+1),history.history['val_iou_coef'],label='Val_iou_coef')
plt.title('IOU'); plt.xlabel('Epoch'); plt.ylabel('iou_coef');plt.legend();
plt.show()

# Evaluation

In [None]:
custom_objects = custom_objects={
    'dice_coef': dice_coef,
    'iou_coef': iou_coef,
    'focal_tversky': focal_tversky
}
model1 = load_model('./deep_labv3_model', custom_objects=custom_objects)



**Preview predictions and report test scores**

In [None]:
pred_batches = DataGenerator(test_df, batch_size = 1, subset="train", shuffle=False)
preds = model1.predict_generator(pred_batches,verbose=1)
model1.evaluate(pred_batches,verbose=1)

In [None]:

Threshold = 0.5
# Visualizing
fig = plt.figure(figsize=(10, 25))
gs = gridspec.GridSpec(nrows=8, ncols=3)
colors = ['yellow','green','red']
labels = ["Large Bowel", "Small Bowel", "Stomach"]
patches = [ mpatches.Patch(color=colors[i], label=f"{labels[i]}") for i in range(len(labels))]

cmap1 = mpl.colors.ListedColormap(colors[0])
cmap2 = mpl.colors.ListedColormap(colors[1])
cmap3= mpl.colors.ListedColormap(colors[2])

for i in range(8):
    images, mask = pred_batches[i+72]
    sample_img=images[0,:,:,0]
    mask1=mask[0,:,:,0]
    mask2=mask[0,:,:,1]
    mask3=mask[0,:,:,2]
    
    pre=preds[i+72]
    predict1=pre[:,:,0]
    predict2=pre[:,:,1]
    predict3=pre[:,:,2]
    
    predict1= (predict1 > Threshold).astype(np.float32)
    predict2= (predict2 > Threshold).astype(np.float32)
    predict3= (predict3 > Threshold).astype(np.float32)
    
    ax0 = fig.add_subplot(gs[i, 0])
    im = ax0.imshow(sample_img, cmap='bone')
    ax0.set_title("Image", fontsize=12, y=1.01)
    #--------------------------
    ax1 = fig.add_subplot(gs[i, 1])
    ax1.set_title("Mask", fontsize=12,  y=1.01)
    l0 = ax1.imshow(sample_img, cmap='bone')
    l1 = ax1.imshow(np.ma.masked_where(mask1== False,  mask1),cmap=cmap1, alpha=1)
    l2 = ax1.imshow(np.ma.masked_where(mask2== False,  mask2),cmap=cmap2, alpha=1)
    l3 = ax1.imshow(np.ma.masked_where(mask3== False,  mask3),cmap=cmap3, alpha=1)
    #--------------------------
    ax2 = fig.add_subplot(gs[i, 2])
    ax2.set_title("Predict", fontsize=12, y=1.01)
    l0 = ax2.imshow(sample_img, cmap='bone')
    l1 = ax2.imshow(np.ma.masked_where(predict1== False,  predict1),cmap=cmap1, alpha=1)
    l2 = ax2.imshow(np.ma.masked_where(predict2== False,  predict2),cmap=cmap2, alpha=1)
    l3 = ax2.imshow(np.ma.masked_where(predict3== False,  predict3),cmap=cmap3, alpha=1)
   

    _ = [ax.set_axis_off() for ax in [ax0,ax1,ax2]]
    colors = [im.cmap(im.norm(1)) for im in [l1,l2, l3]]
    plt.legend(handles=patches, bbox_to_anchor=(1.1, 0.65), loc=2, borderaxespad=0.4,fontsize = 12,title='Mask Labels', title_fontsize=12, edgecolor="black",  facecolor='#c5c6c7')
