In [6]:
import os
import numpy as np
import cv2
from glob import glob
from tqdm import tqdm
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import CustomObjectScope
import numpy as np
from tensorflow.keras.utils import Sequence

In [12]:
smooth = 1.
def dice_coef(y_true, y_pred):
    y_true_f = tf.keras.layers.Flatten()(y_true)
    y_pred_f = tf.keras.layers.Flatten()(y_pred)
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)


def dice_loss(y_true, y_pred):
    return 1.0 - dice_coef(y_true, y_pred)

In [2]:
def mask_to_3d(mask):
    mask = np.squeeze(mask)
    mask = [mask, mask, mask]
    mask = np.transpose(mask, (1, 2, 0))
    return mask

In [3]:
def parse_image(img_path, image_size):
    image_rgb = cv2.imread(img_path, 1)
    h, w, _ = image_rgb.shape
    if (h == image_size) and (w == image_size):
        pass
    else:
        image_rgb = cv2.resize(image_rgb, (image_size, image_size))
    image_rgb = image_rgb/255.0
    return image_rgb

def parse_mask(mask_path, image_size):
    mask = cv2.imread(mask_path, -1)
    h, w = mask.shape
    if (h == image_size) and (w == image_size):
        pass
    else:
        mask = cv2.resize(mask, (image_size, image_size))
    mask = np.expand_dims(mask, -1)
    mask = mask/255.0

    return mask

In [7]:
class DataGen(Sequence):
    def __init__(self, image_size, images_path, masks_path, batch_size=8):
        self.image_size = image_size
        self.images_path = images_path
        self.masks_path = masks_path
        self.batch_size = batch_size
        self.on_epoch_end()

    def __getitem__(self, index):
        if(index+1)*self.batch_size > len(self.images_path):
            self.batch_size = len(self.images_path) - index*self.batch_size

        images_path = self.images_path[index*self.batch_size : (index+1)*self.batch_size]
        masks_path = self.masks_path[index*self.batch_size : (index+1)*self.batch_size]

        images_batch = []
        masks_batch = []

        for i in range(len(images_path)):
            ## Read image and mask
            image = parse_image(images_path[i], self.image_size)
            mask = parse_mask(masks_path[i], self.image_size)

            images_batch.append(image)
            masks_batch.append(mask)

        return np.array(images_batch), np.array(masks_batch)

    def on_epoch_end(self):
        pass

    def __len__(self):
        return int(np.ceil(len(self.images_path)/float(self.batch_size)))

In [13]:
def test_model(model_path, save_path):
    model_path = model_path
    save_path = save_path
    test_path = "../input/kvasit-seg-train-test-valid/data/content/new_data/Kvasir-SEG/test"
    
    image_size = 256
    batch_size = 1
    
    test_image_paths = glob(os.path.join(test_path, "images", "*"))
    test_mask_paths = glob(os.path.join(test_path, "masks", "*"))
    test_image_paths.sort()
    test_mask_paths.sort()
    
    ## Create result folder
    try:
        os.mkdir(save_path)
    except:
        pass
    
    ## Model
    with CustomObjectScope({'dice_loss': dice_loss, 'dice_coef': dice_coef}):
        model = load_model(model_path)

    ## Test
    print("Test Result: ")
    test_steps = len(test_image_paths)//batch_size
    test_gen = DataGen(image_size, test_image_paths, test_mask_paths, batch_size=batch_size)
    model.evaluate_generator(test_gen, steps=test_steps, verbose=1)

    ## Generating the result
    for i, path in tqdm(enumerate(test_image_paths), total=len(test_image_paths)):
        image = parse_image(test_image_paths[i], image_size)
        mask = parse_mask(test_mask_paths[i], image_size)

        predict_mask = model.predict(np.expand_dims(image, axis=0))[0]
        predict_mask = (predict_mask > 0.5) * 255.0

        sep_line = np.ones((image_size, 10, 3)) * 255

        mask = mask_to_3d(mask)
        predict_mask = mask_to_3d(predict_mask)

        all_images = [image * 255, sep_line, mask * 255, sep_line, predict_mask]
        cv2.imwrite(f"{save_path}/{i}.png", np.concatenate(all_images, axis=1))

    print("Test image generation complete")

In [20]:
model_path = "../input/kvasit-seg-train-test-valid/model_comparison_files/model_comparision_files/resunetplusplus.h5"
save_path = "./resunetplusplus"
test_model(model_path, save_path)

Test Result: 


100%|██████████| 100/100 [00:07<00:00, 13.05it/s]

Test image generation complete





In [18]:
model_path = "../input/kvasit-seg-train-test-valid/model_comparison_files/model_comparision_files/unet.h5"
save_path = "./unet"
test_model(model_path, save_path)

Test Result: 


100%|██████████| 100/100 [00:06<00:00, 14.83it/s]

Test image generation complete





In [19]:
model_path = "../input/kvasit-seg-train-test-valid/model_comparison_files/model_comparision_files/resunet.h5"
save_path = "./resunet"
test_model(model_path, save_path)

Test Result: 


100%|██████████| 100/100 [00:07<00:00, 13.98it/s]

Test image generation complete





In [21]:
!zip -r ./resunetplusplus.zip ./resunetplusplus

  adding: resunetplusplus/ (stored 0%)
  adding: resunetplusplus/23.png (deflated 4%)
  adding: resunetplusplus/13.png (deflated 2%)
  adding: resunetplusplus/68.png (deflated 3%)
  adding: resunetplusplus/31.png (deflated 3%)
  adding: resunetplusplus/5.png (deflated 3%)
  adding: resunetplusplus/62.png (deflated 2%)
  adding: resunetplusplus/59.png (deflated 3%)
  adding: resunetplusplus/19.png (deflated 3%)
  adding: resunetplusplus/60.png (deflated 4%)
  adding: resunetplusplus/93.png (deflated 3%)
  adding: resunetplusplus/48.png (deflated 3%)
  adding: resunetplusplus/37.png (deflated 3%)
  adding: resunetplusplus/35.png (deflated 2%)
  adding: resunetplusplus/73.png (deflated 3%)
  adding: resunetplusplus/11.png (deflated 3%)
  adding: resunetplusplus/77.png (deflated 2%)
  adding: resunetplusplus/45.png (deflated 3%)
  adding: resunetplusplus/55.png (deflated 2%)
  adding: resunetplusplus/83.png (deflated 4%)
  adding: resunetplusplus/33.png (deflated 3%)
  adding: resunetplusp

In [22]:
!zip -r ./resunet.zip ./resunet

  adding: resunet/ (stored 0%)
  adding: resunet/23.png (deflated 4%)
  adding: resunet/13.png (deflated 2%)
  adding: resunet/68.png (deflated 2%)
  adding: resunet/31.png (deflated 3%)
  adding: resunet/5.png (deflated 3%)
  adding: resunet/62.png (deflated 2%)
  adding: resunet/59.png (deflated 3%)
  adding: resunet/19.png (deflated 2%)
  adding: resunet/60.png (deflated 4%)
  adding: resunet/93.png (deflated 3%)
  adding: resunet/48.png (deflated 3%)
  adding: resunet/37.png (deflated 3%)
  adding: resunet/35.png (deflated 2%)
  adding: resunet/73.png (deflated 3%)
  adding: resunet/11.png (deflated 3%)
  adding: resunet/77.png (deflated 2%)
  adding: resunet/45.png (deflated 3%)
  adding: resunet/55.png (deflated 2%)
  adding: resunet/83.png (deflated 4%)
  adding: resunet/33.png (deflated 3%)
  adding: resunet/34.png (deflated 3%)
  adding: resunet/86.png (deflated 4%)
  adding: resunet/80.png (deflated 2%)
  adding: resunet/50.png (deflated 4%)
  adding: resunet/30.png (deflated

In [23]:
!zip -r ./unet.zip ./unet

  adding: unet/ (stored 0%)
  adding: unet/23.png (deflated 4%)
  adding: unet/13.png (deflated 3%)
  adding: unet/68.png (deflated 2%)
  adding: unet/31.png (deflated 3%)
  adding: unet/5.png (deflated 3%)
  adding: unet/62.png (deflated 2%)
  adding: unet/59.png (deflated 3%)
  adding: unet/19.png (deflated 2%)
  adding: unet/60.png (deflated 4%)
  adding: unet/93.png (deflated 3%)
  adding: unet/48.png (deflated 3%)
  adding: unet/37.png (deflated 3%)
  adding: unet/35.png (deflated 2%)
  adding: unet/73.png (deflated 3%)
  adding: unet/11.png (deflated 3%)
  adding: unet/77.png (deflated 2%)
  adding: unet/45.png (deflated 2%)
  adding: unet/55.png (deflated 2%)
  adding: unet/83.png (deflated 4%)
  adding: unet/33.png (deflated 3%)
  adding: unet/34.png (deflated 3%)
  adding: unet/86.png (deflated 3%)
  adding: unet/80.png (deflated 2%)
  adding: unet/50.png (deflated 4%)
  adding: unet/30.png (deflated 4%)
  adding: unet/94.png (deflated 4%)
  adding: unet/95.png (deflated 2%)
 