<a href="https://colab.research.google.com/github/Sneha-dasgupta/Multi-class-seg/blob/main/Optic_Disk_Segmentation_UNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Data

In [2]:
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
%matplotlib inline
from glob import glob
from tqdm import tqdm
!pip install -U albumentations
from albumentations import HorizontalFlip, VerticalFlip, ElasticTransform, GridDistortion, OpticalDistortion

def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def load_data(path):
    """ X = Images and Y = masks """

    train_x = sorted(glob(os.path.join(path, "Training", "Images", "*.png")))
    train_y = sorted(glob(os.path.join(path, "Training", "GT_OD", "*.png")))

    test_x = sorted(glob(os.path.join(path, "Test", "Images", "*.png")))
    test_y = sorted(glob(os.path.join(path, "Test", "Test_GT_OD", "*.png")))

    return (train_x, train_y), (test_x, test_y)

def augment_data(images, masks, save_path, augment=True):
    H = 512
    W = 512

    for idx, (x, y) in tqdm(enumerate(zip(images, masks)), total=len(images)):
        """ Extracting names """
        name = x.split("/")[-1].split(".")[0]

        """ Reading image and mask """
        x = cv2.imread(x, cv2.IMREAD_COLOR)
        x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
        y = cv2.imread(y, cv2.IMREAD_COLOR)
        print(x.shape, y.shape)

        if augment == True:
            aug = HorizontalFlip(p=1.0)
            augmented = aug(image=x, mask=y)
            x1 = augmented["image"]
            y1 = augmented["mask"]

            aug = VerticalFlip(p=1.0)
            augmented = aug(image=x, mask=y)
            x2 = augmented["image"]
            y2 = augmented["mask"]

            aug = ElasticTransform(p=1, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03)
            augmented = aug(image=x, mask=y)
            x3 = augmented['image']
            y3 = augmented['mask']

            aug = GridDistortion(p=1)
            augmented = aug(image=x, mask=y)
            x4 = augmented['image']
            y4 = augmented['mask']

            aug = OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5)
            augmented = aug(image=x, mask=y)
            x5 = augmented['image']
            y5 = augmented['mask']

            X = [x, x1, x2, x3, x4, x5]
            Y = [y, y1, y2, y3, y4, y5]

        else:
            X = [x]
            Y = [y]

        index = 0
        for i, m in zip(X, Y):
            i = cv2.resize(i, (W, H), interpolation = cv2.INTER_CUBIC)
            #i = cv2.cvtColor(i,cv2.COLOR_BGR2RGB)

            #Histogram equalisation - start
            #img_gray = cv2.cvtColor(i, cv2.COLOR_BGR2GRAY)
            #clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
            #i = clahe.apply(img_gray)
            #i = cv2.cvtColor(i,cv2.COLOR_GRAY2RGB)
            lab = cv2.cvtColor(i, cv2.COLOR_BGR2LAB)

            lab_planes = cv2.split(lab)

            clahe = cv2.createCLAHE(clipLimit=3.0,tileGridSize=(10,10))

            lab_planes[0] = clahe.apply(lab_planes[0])

            lab = cv2.merge(lab_planes)

            i = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
            i = cv2.cvtColor(i,cv2.COLOR_BGR2RGB)

            #histogram equalization - end 

            m = cv2.resize(m, (W, H), interpolation = cv2.INTER_CUBIC)

            if len(X) == 1:
                tmp_image_name = f"{name}.png"
                tmp_mask_name = f"{name}.png"
            else:
                tmp_image_name = f"{name}_{index}.png"
                tmp_mask_name = f"{name}_{index}.png"

            image_path = os.path.join(save_path, "image", tmp_image_name)
            mask_path = os.path.join(save_path, "mask", tmp_mask_name)

            cv2.imwrite(image_path, i)
            cv2.imwrite(mask_path, m)

            index += 1

if __name__ == "__main__":
    """ Seeding """
    np.random.seed(42)

    """ Load the data """
    data_path = "/content/drive/MyDrive/Drishti/"
    (train_x, train_y), (test_x, test_y) = load_data(data_path)

    '''print(f"Train: {len(train_x)} - {len(train_y)}")
    print(f"Test: {len(test_x)} - {len(test_y)}")'''

    """ Creating directories """
    create_dir("new_data/train/image")
    create_dir("new_data/train/mask")
    create_dir("new_data/test/image")
    create_dir("new_data/test/mask")

    augment_data(train_x, train_y, "new_data/train/", augment=False)
    augment_data(test_x, test_y, "new_data/test/", augment=False)



  1%|▏         | 1/80 [00:00<00:13,  6.03it/s]

(1752, 2045, 3) (1752, 2045, 3)


  2%|▎         | 2/80 [00:01<01:17,  1.00it/s]

(1762, 2049, 3) (1762, 2049, 3)


  4%|▍         | 3/80 [00:02<01:22,  1.07s/it]

(1755, 2049, 3) (1755, 2049, 3)


  5%|▌         | 4/80 [00:03<01:04,  1.17it/s]

(1757, 2140, 3) (1757, 2140, 3)


  6%|▋         | 5/80 [00:04<01:01,  1.22it/s]

(1755, 2047, 3) (1755, 2047, 3)


  8%|▊         | 6/80 [00:05<01:17,  1.04s/it]

(1758, 2050, 3) (1758, 2050, 3)


  9%|▉         | 7/80 [00:06<01:09,  1.04it/s]

(1760, 2051, 3) (1760, 2051, 3)


 10%|█         | 8/80 [00:07<01:04,  1.12it/s]

(1750, 2049, 3) (1750, 2049, 3)


 11%|█▏        | 9/80 [00:08<01:05,  1.09it/s]

(1750, 2048, 3) (1750, 2048, 3)


 12%|█▎        | 10/80 [00:09<01:17,  1.11s/it]

(1753, 2048, 3) (1753, 2048, 3)


 14%|█▍        | 11/80 [00:10<01:16,  1.11s/it]

(1757, 2047, 3) (1757, 2047, 3)


 15%|█▌        | 12/80 [00:11<01:15,  1.11s/it]

(1763, 2047, 3) (1763, 2047, 3)


 16%|█▋        | 13/80 [00:13<01:14,  1.11s/it]

(1763, 2466, 3) (1763, 2466, 3)


 18%|█▊        | 14/80 [00:13<01:07,  1.02s/it]

(1759, 2049, 3) (1759, 2049, 3)


 19%|█▉        | 15/80 [00:14<01:01,  1.05it/s]

(1759, 2049, 3) (1759, 2049, 3)


 20%|██        | 16/80 [00:15<00:57,  1.11it/s]

(1845, 2050, 3) (1845, 2050, 3)


 21%|██▏       | 17/80 [00:16<01:02,  1.01it/s]

(1759, 2048, 3) (1759, 2048, 3)


 22%|██▎       | 18/80 [00:17<01:00,  1.02it/s]

(1760, 2048, 3) (1760, 2048, 3)


 24%|██▍       | 19/80 [00:18<01:04,  1.05s/it]

(1762, 2049, 3) (1762, 2049, 3)


 25%|██▌       | 20/80 [00:20<01:09,  1.16s/it]

(1755, 2049, 3) (1755, 2049, 3)


 26%|██▋       | 21/80 [00:21<01:04,  1.09s/it]

(1751, 2052, 3) (1751, 2052, 3)


 28%|██▊       | 22/80 [00:22<01:09,  1.20s/it]

(1750, 2050, 3) (1750, 2050, 3)


 29%|██▉       | 23/80 [00:24<01:12,  1.28s/it]

(1751, 2047, 3) (1751, 2047, 3)


 30%|███       | 24/80 [00:24<01:04,  1.15s/it]

(1751, 2045, 3) (1751, 2045, 3)


 31%|███▏      | 25/80 [00:25<00:55,  1.01s/it]

(1753, 2048, 3) (1753, 2048, 3)


 32%|███▎      | 26/80 [00:26<00:54,  1.00s/it]

(1749, 2049, 3) (1749, 2049, 3)


 34%|███▍      | 27/80 [00:27<00:47,  1.11it/s]

(1749, 2048, 3) (1749, 2048, 3)


 35%|███▌      | 28/80 [00:28<00:45,  1.15it/s]

(1760, 2048, 3) (1760, 2048, 3)


 36%|███▋      | 29/80 [00:29<00:51,  1.01s/it]

(1759, 2049, 3) (1759, 2049, 3)


 38%|███▊      | 30/80 [00:30<00:55,  1.10s/it]

(1764, 2050, 3) (1764, 2050, 3)


 39%|███▉      | 31/80 [00:32<00:58,  1.20s/it]

(1762, 2048, 3) (1762, 2048, 3)


 40%|████      | 32/80 [00:32<00:50,  1.06s/it]

(1758, 2047, 3) (1758, 2047, 3)


 41%|████▏     | 33/80 [00:33<00:45,  1.03it/s]

(1760, 2049, 3) (1760, 2049, 3)


 42%|████▎     | 34/80 [00:34<00:40,  1.15it/s]

(1756, 2049, 3) (1756, 2049, 3)


 44%|████▍     | 35/80 [00:35<00:45,  1.01s/it]

(1841, 2289, 3) (1841, 2289, 3)


 45%|████▌     | 36/80 [00:36<00:41,  1.06it/s]

(1759, 2049, 3) (1759, 2049, 3)


 46%|████▋     | 37/80 [00:37<00:45,  1.05s/it]

(1761, 2048, 3) (1761, 2048, 3)


 48%|████▊     | 38/80 [00:38<00:42,  1.02s/it]

(1763, 2049, 3) (1763, 2049, 3)


 49%|████▉     | 39/80 [00:39<00:44,  1.07s/it]

(1757, 2048, 3) (1757, 2048, 3)


 50%|█████     | 40/80 [00:40<00:41,  1.03s/it]

(1760, 2049, 3) (1760, 2049, 3)


 51%|█████▏    | 41/80 [00:42<00:43,  1.12s/it]

(1758, 2463, 3) (1758, 2463, 3)


 52%|█████▎    | 42/80 [00:42<00:40,  1.06s/it]

(1755, 2047, 3) (1755, 2047, 3)


 54%|█████▍    | 43/80 [00:44<00:41,  1.11s/it]

(1753, 2046, 3) (1753, 2046, 3)


 55%|█████▌    | 44/80 [00:45<00:39,  1.08s/it]

(1752, 2047, 3) (1752, 2047, 3)


 56%|█████▋    | 45/80 [00:46<00:36,  1.04s/it]

(1752, 2045, 3) (1752, 2045, 3)


 57%|█████▊    | 46/80 [00:47<00:40,  1.19s/it]

(1760, 2048, 3) (1760, 2048, 3)


 59%|█████▉    | 47/80 [00:48<00:38,  1.17s/it]

(1753, 2049, 3) (1753, 2049, 3)


 60%|██████    | 48/80 [00:50<00:41,  1.29s/it]

(1753, 2046, 3) (1753, 2046, 3)


 61%|██████▏   | 49/80 [00:51<00:36,  1.19s/it]

(1754, 2046, 3) (1754, 2046, 3)


 62%|██████▎   | 50/80 [00:52<00:33,  1.11s/it]

(1761, 2049, 3) (1761, 2049, 3)


 64%|██████▍   | 51/80 [00:53<00:32,  1.12s/it]

(1759, 2049, 3) (1759, 2049, 3)


 65%|██████▌   | 52/80 [00:54<00:34,  1.23s/it]

(1760, 2047, 3) (1760, 2047, 3)


 66%|██████▋   | 53/80 [00:55<00:31,  1.15s/it]

(1761, 2140, 3) (1761, 2140, 3)


 68%|██████▊   | 54/80 [00:56<00:27,  1.04s/it]

(1759, 2048, 3) (1759, 2048, 3)


 69%|██████▉   | 55/80 [00:58<00:29,  1.19s/it]

(1762, 2468, 3) (1762, 2468, 3)


 70%|███████   | 56/80 [00:59<00:31,  1.33s/it]

(1757, 2047, 3) (1757, 2047, 3)


 71%|███████▏  | 57/80 [01:00<00:28,  1.26s/it]

(1760, 2049, 3) (1760, 2049, 3)


 72%|███████▎  | 58/80 [01:01<00:26,  1.19s/it]

(1762, 2048, 3) (1762, 2048, 3)


 74%|███████▍  | 59/80 [01:03<00:24,  1.15s/it]

(1755, 2048, 3) (1755, 2048, 3)


 75%|███████▌  | 60/80 [01:03<00:21,  1.07s/it]

(1754, 2046, 3) (1754, 2046, 3)


 76%|███████▋  | 61/80 [01:04<00:17,  1.09it/s]

(1759, 2048, 3) (1759, 2048, 3)


 78%|███████▊  | 62/80 [01:05<00:17,  1.03it/s]

(1760, 2048, 3) (1760, 2048, 3)


 79%|███████▉  | 63/80 [01:06<00:16,  1.05it/s]

(1758, 2049, 3) (1758, 2049, 3)


 80%|████████  | 64/80 [01:07<00:16,  1.04s/it]

(1750, 2048, 3) (1750, 2048, 3)


 81%|████████▏ | 65/80 [01:08<00:14,  1.05it/s]

(1754, 2049, 3) (1754, 2049, 3)


 82%|████████▎ | 66/80 [01:09<00:15,  1.12s/it]

(1835, 2049, 3) (1835, 2049, 3)


 84%|████████▍ | 67/80 [01:10<00:13,  1.02s/it]

(1761, 2048, 3) (1761, 2048, 3)


 85%|████████▌ | 68/80 [01:12<00:13,  1.12s/it]

(1760, 2047, 3) (1760, 2047, 3)


 86%|████████▋ | 69/80 [01:12<00:11,  1.04s/it]

(1749, 2049, 3) (1749, 2049, 3)


 88%|████████▊ | 70/80 [01:14<00:11,  1.10s/it]

(1749, 2049, 3) (1749, 2049, 3)


 89%|████████▉ | 71/80 [01:15<00:09,  1.06s/it]

(1749, 2048, 3) (1749, 2048, 3)


 90%|█████████ | 72/80 [01:16<00:08,  1.07s/it]

(1749, 2050, 3) (1749, 2050, 3)


 91%|█████████▏| 73/80 [01:17<00:07,  1.02s/it]

(1748, 2048, 3) (1748, 2048, 3)


 92%|█████████▎| 74/80 [01:19<00:08,  1.39s/it]

(1750, 2049, 3) (1750, 2049, 3)


 94%|█████████▍| 75/80 [01:20<00:06,  1.24s/it]

(1750, 2048, 3) (1750, 2048, 3)


 95%|█████████▌| 76/80 [01:21<00:04,  1.20s/it]

(1749, 2049, 3) (1749, 2049, 3)


 96%|█████████▋| 77/80 [01:22<00:03,  1.27s/it]

(1749, 2049, 3) (1749, 2049, 3)


 98%|█████████▊| 78/80 [01:23<00:02,  1.11s/it]

(1749, 2048, 3) (1749, 2048, 3)


 99%|█████████▉| 79/80 [01:24<00:01,  1.12s/it]

(1748, 2047, 3) (1748, 2047, 3)


100%|██████████| 80/80 [01:25<00:00,  1.07s/it]


(1749, 2048, 3) (1749, 2048, 3)


  6%|▋         | 1/16 [00:00<00:11,  1.27it/s]

(1751, 2049, 3) (1751, 2049, 3)


 12%|█▎        | 2/16 [00:01<00:11,  1.18it/s]

(1759, 2047, 3) (1759, 2047, 3)


 19%|█▉        | 3/16 [00:02<00:11,  1.16it/s]

(1759, 2047, 3) (1759, 2047, 3)


 25%|██▌       | 4/16 [00:04<00:16,  1.37s/it]

(1755, 2048, 3) (1755, 2048, 3)


 31%|███▏      | 5/16 [00:05<00:13,  1.20s/it]

(1757, 2049, 3) (1757, 2049, 3)


 38%|███▊      | 6/16 [00:06<00:11,  1.19s/it]

(1754, 2048, 3) (1754, 2048, 3)


 44%|████▍     | 7/16 [00:07<00:10,  1.11s/it]

(1752, 2049, 3) (1752, 2049, 3)


 50%|█████     | 8/16 [00:08<00:09,  1.13s/it]

(1759, 2463, 3) (1759, 2463, 3)


 56%|█████▋    | 9/16 [00:09<00:07,  1.10s/it]

(1760, 2048, 3) (1760, 2048, 3)


 62%|██████▎   | 10/16 [00:10<00:06,  1.08s/it]

(1760, 2049, 3) (1760, 2049, 3)


 69%|██████▉   | 11/16 [00:12<00:05,  1.17s/it]

(1757, 2048, 3) (1757, 2048, 3)


 75%|███████▌  | 12/16 [00:13<00:04,  1.19s/it]

(1761, 2047, 3) (1761, 2047, 3)


 81%|████████▏ | 13/16 [00:14<00:03,  1.15s/it]

(1757, 2468, 3) (1757, 2468, 3)


 88%|████████▊ | 14/16 [00:15<00:02,  1.07s/it]

(1761, 2048, 3) (1761, 2048, 3)


 94%|█████████▍| 15/16 [00:16<00:01,  1.05s/it]

(1760, 2048, 3) (1760, 2048, 3)


100%|██████████| 16/16 [00:17<00:00,  1.09s/it]

(1758, 2050, 3) (1758, 2050, 3)





# Model

In [3]:
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
from tensorflow.keras.models import Model

def conv_block(inputs, num_filters):
    x = Conv2D(num_filters, 3, padding="same")(inputs)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x

def encoder_block(inputs, num_filters):
    x = conv_block(inputs, num_filters)
    p = MaxPool2D((2, 2))(x)
    return x, p

def decoder_block(inputs, skip_features, num_filters):
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(inputs)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

def build_unet(input_shape):
    inputs = Input(input_shape)

    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4, p4 = encoder_block(p3, 512)

    b1 = conv_block(p4, 1024)

    d1 = decoder_block(b1, s4, 512)
    d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(d2, s2, 128)
    d4 = decoder_block(d3, s1, 64)

    outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4)

    model = Model(inputs, outputs, name="UNET")
    return model

if __name__ == "__main__":
    input_shape = (512, 512, 3)
    model = build_unet(input_shape)
    model.summary()

Model: "UNET"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 512, 512, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 512, 512, 64) 1792        input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 512, 512, 64) 256         conv2d[0][0]                     
__________________________________________________________________________________________________
activation (Activation)         (None, 512, 512, 64) 0           batch_normalization[0][0]        
_______________________________________________________________________________________________

# Metrices

In [4]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K

def iou(y_true, y_pred):
    def f(y_true, y_pred):
        intersection = (y_true * y_pred).sum()
        union = y_true.sum() + y_pred.sum() - intersection
        x = (intersection + 1e-15) / (union + 1e-15)
        x = x.astype(np.float32)
        return x
    return tf.numpy_function(f, [y_true, y_pred], tf.float32)

smooth = 1e-15
def dice_coef(y_true, y_pred):
    y_true = tf.keras.layers.Flatten()(y_true)
    y_pred = tf.keras.layers.Flatten()(y_pred)
    intersection = tf.reduce_sum(y_true * y_pred)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + smooth)

def dice_loss(y_true, y_pred):
    return 1.0 - dice_coef(y_true, y_pred)

# Train

In [5]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import numpy as np
import cv2
from glob import glob
from sklearn.utils import shuffle
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, ReduceLROnPlateau, EarlyStopping, TensorBoard
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import Recall, Precision
#from Model_UNet_Simple import build_unet
#from Metrices import dice_loss, dice_coef, iou

H = 512
W = 512

def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def load_data(path):
    x = sorted(glob(os.path.join(path, "image", "*.png")))
    y = sorted(glob(os.path.join(path, "mask", "*.png")))
    return x, y

def shuffling(x, y):
    x, y = shuffle(x, y, random_state=42)
    return x, y

def read_image(path):
    path = path.decode()
    x = cv2.imread(path, cv2.IMREAD_COLOR)
    # x = cv2.resize(x, (W, H))
    x = x/255.0
    x = x.astype(np.float32)
    return x

def read_mask(path):
    path = path.decode()
    x = cv2.imread(path, cv2.IMREAD_GRAYSCALE)  ## (512, 512)
    # x = cv2.resize(x, (W, H))
    x = x/255.0
    x = x.astype(np.float32)
    x = np.expand_dims(x, axis=-1)              ## (512, 512, 1)
    return x

def tf_parse(x, y):
    def _parse(x, y):
        x = read_image(x)
        y = read_mask(y)
        return x, y

    x, y = tf.numpy_function(_parse, [x, y], [tf.float32, tf.float32])
    x.set_shape([H, W, 3])
    y.set_shape([H, W, 1])
    return x, y

def tf_dataset(X, Y, batch_size=2):
    dataset = tf.data.Dataset.from_tensor_slices((X, Y))
    dataset = dataset.map(tf_parse)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(4)
    return dataset

if __name__ == "__main__":
    """ Seeding """
    np.random.seed(42)
    tf.random.set_seed(42)

    """ Directory to save files """
    create_dir("files")

    """ Hyperparameters """
    batch_size = 2
    lr = 1e-4
    num_epochs = 100
    model_path = os.path.join("files", "model.h5")
    csv_path = os.path.join("files", "data.csv")

    """ Dataset """
    dataset_path = "new_data"
    train_path = os.path.join(dataset_path, "train")
    valid_path = os.path.join(dataset_path, "test")

    train_x, train_y = load_data(train_path)
    train_x, train_y = shuffling(train_x, train_y)
    valid_x, valid_y = load_data(valid_path)

    print(f"Train: {len(train_x)} - {len(train_y)}")
    print(f"Valid: {len(valid_x)} - {len(valid_y)}")

    train_dataset = tf_dataset(train_x, train_y, batch_size=batch_size)
    valid_dataset = tf_dataset(valid_x, valid_y, batch_size=batch_size)

    train_steps = len(train_x)//batch_size
    valid_setps = len(valid_x)//batch_size

    if len(train_x) % batch_size != 0:
        train_steps += 1
    if len(valid_x) % batch_size != 0:
        valid_setps += 1

    """ Model """
    model = build_unet((H, W, 3))
    model.compile(loss=dice_loss, optimizer=Adam(lr), metrics=[dice_coef, iou, Recall(), Precision()])
    #model.summary()

    callbacks = [
        ModelCheckpoint(model_path, verbose=1, save_best_only=True),
        #ReduceLROnPlateau(monitor="val_loss", factor=0.1, patience=5, min_lr=1e-6, verbose=1),
        CSVLogger(csv_path),
        TensorBoard(),
        #EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=False)
    ]

    model.fit(
        train_dataset,
        epochs=num_epochs,
        validation_data=valid_dataset,
        steps_per_epoch=train_steps,
        validation_steps=valid_setps,
        callbacks=callbacks
    )

Train: 80 - 80
Valid: 16 - 16
Epoch 1/100

Epoch 00001: val_loss improved from inf to 0.93307, saving model to files/model.h5
Epoch 2/100

Epoch 00002: val_loss improved from 0.93307 to 0.91718, saving model to files/model.h5
Epoch 3/100

Epoch 00003: val_loss did not improve from 0.91718
Epoch 4/100

Epoch 00004: val_loss improved from 0.91718 to 0.90158, saving model to files/model.h5
Epoch 5/100

Epoch 00005: val_loss did not improve from 0.90158
Epoch 6/100

Epoch 00006: val_loss did not improve from 0.90158
Epoch 7/100

Epoch 00007: val_loss did not improve from 0.90158
Epoch 8/100

Epoch 00008: val_loss did not improve from 0.90158
Epoch 9/100

Epoch 00009: val_loss did not improve from 0.90158
Epoch 10/100

Epoch 00010: val_loss did not improve from 0.90158
Epoch 11/100

Epoch 00011: val_loss improved from 0.90158 to 0.77840, saving model to files/model.h5
Epoch 12/100

Epoch 00012: val_loss improved from 0.77840 to 0.76865, saving model to files/model.h5
Epoch 13/100

Epoch 000

# Eval

In [6]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import numpy as np
import pandas as pd
import cv2
from glob import glob
from tqdm import tqdm
import tensorflow as tf
from tensorflow.keras.utils import CustomObjectScope
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score
#from metrics import dice_loss, dice_coef, iou

H = 512
W = 512

def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def read_image(path):
    x = cv2.imread(path, cv2.IMREAD_COLOR)
    # x = cv2.resize(x, (W, H))
    ori_x = x
    x = x/255.0
    x = x.astype(np.float32)
    return ori_x, x

def read_mask(path):
    x = cv2.imread(path, cv2.IMREAD_GRAYSCALE)  ## (512, 512)
    # x = cv2.resize(x, (W, H))
    ori_x = x
    x = x/255.0
    x = x.astype(np.int32)
    return ori_x, x

def load_data(path):
    x = sorted(glob(os.path.join(path, "image", "*.png")))
    y = sorted(glob(os.path.join(path, "mask", "*.png")))
    return x, y

def save_results(ori_x, ori_y, y_pred, save_image_path):
    line = np.ones((H, 10, 3)) * 255

    ori_y = np.expand_dims(ori_y, axis=-1)
    ori_y = np.concatenate([ori_y, ori_y, ori_y], axis=-1)

    y_pred = np.expand_dims(y_pred, axis=-1)
    y_pred = np.concatenate([y_pred, y_pred, y_pred], axis=-1) * 255

    cat_images = np.concatenate([ori_x, line, ori_y, line, y_pred], axis=1)
    cv2.imwrite(save_image_path, cat_images)

if __name__ == "__main__":
    """ Save the results in this folder """
    create_dir("results")

    """ Load the model """
    with CustomObjectScope({'iou': iou, 'dice_coef': dice_coef, 'dice_loss': dice_loss}):
        model = tf.keras.models.load_model("files/model.h5")

    """ Load the dataset """
    dataset_path = os.path.join("new_data", "test")
    test_x, test_y = load_data(dataset_path)

    """ Make the prediction and calculate the metrics values """
    SCORE = []
    for x, y in tqdm(zip(test_x, test_y), total=len(test_x)):
        """ Extracting name """
        name = x.split("/")[-1].split(".")[0]

        """ Read the image and mask """
        ori_x, x = read_image(x)
        ori_y, y = read_mask(y)

        """ Prediction """
        y_pred = model.predict(np.expand_dims(x, axis=0))[0]
        y_pred = y_pred > 0.5
        y_pred = y_pred.astype(np.int32)
        y_pred = np.squeeze(y_pred, axis=-1)

        """ Saving the images """
        save_image_path = f"results/{name}.png"
        save_results(ori_x, ori_y, y_pred, save_image_path)

        """ Flatten the array """
        y = y.flatten()
        y_pred = y_pred.flatten()

        """ Calculate the metrics """
        acc_value = accuracy_score(y, y_pred)
        f1_value = f1_score(y, y_pred, labels=[0, 1], average="binary")
        jac_value = jaccard_score(y, y_pred, labels=[0, 1], average="binary")
        recall_value = recall_score(y, y_pred, labels=[0, 1], average="binary")
        precision_value = precision_score(y, y_pred, labels=[0, 1], average="binary")
        SCORE.append([name, acc_value, f1_value, jac_value, recall_value, precision_value])

    score = [s[1:] for s in SCORE]
    score = np.mean(score, axis=0)
    print(f"Accuracy: {score[0]:0.5f}")
    print(f"F1: {score[1]:0.5f}")
    print(f"Jaccard: {score[2]:0.5f}")
    print(f"Recall: {score[3]:0.5f}")
    print(f"Precision: {score[4]:0.5f}")

    """ Saving """
    df = pd.DataFrame(SCORE, columns=["Image", "Acc", "F1", "Jaccard", "Recall", "Precision"])
    df.to_csv("files/score.csv")

100%|██████████| 16/16 [00:14<00:00,  1.08it/s]

Accuracy: 0.99697
F1: 0.95229
Jaccard: 0.90973
Recall: 0.98103
Precision: 0.92654





# Downloading Results

In [7]:
!zip -r /content/results.zip /content/results

  adding: content/results/ (stored 0%)
  adding: content/results/drishtiGS_007.png (deflated 3%)
  adding: content/results/drishtiGS_011.png (deflated 3%)
  adding: content/results/drishtiGS_001.png (deflated 3%)
  adding: content/results/drishtiGS_019.png (deflated 3%)
  adding: content/results/drishtiGS_034.png (deflated 3%)
  adding: content/results/drishtiGS_028.png (deflated 3%)
  adding: content/results/drishtiGS_025.png (deflated 5%)
  adding: content/results/drishtiGS_014.png (deflated 3%)
  adding: content/results/drishtiGS_009.png (deflated 4%)
  adding: content/results/drishtiGS_006.png (deflated 3%)
  adding: content/results/drishtiGS_021.png (deflated 3%)
  adding: content/results/drishtiGS_030.png (deflated 4%)
  adding: content/results/drishtiGS_003.png (deflated 3%)
  adding: content/results/drishtiGS_020.png (deflated 3%)
  adding: content/results/drishtiGS_013.png (deflated 6%)
  adding: content/results/drishtiGS_005.png (deflated 3%)


In [8]:
from google.colab import files
files.download("/content/results.zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>