# UroSegNet

### create segmentation image of stoma

In [None]:
!pip install tf_slim

In [None]:
# coding: utf-8

import os
from io import BytesIO
import tarfile
import tempfile
from six.moves import urllib

from matplotlib import gridspec
from matplotlib import pyplot as plt
import numpy as np
from PIL import Image
from glob import glob

import tensorflow as tf


class DeepLabModel(object):
    INPUT_TENSOR_NAME = 'ImageTensor:0'
    OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
    INPUT_SIZE = 512
    FROZEN_GRAPH_NAME = 'frozen_inference_graph'

    def __init__(self, frozen_path):
        self.graph = tf.Graph()
        graph_def = None
        with open(frozen_path, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
        if graph_def is None:
            raise RuntimeError('Cannot find inference graph in tar archive.')
        with self.graph.as_default():
            tf.import_graph_def(graph_def, name='')
        self.sess = tf.Session(graph=self.graph)

    def run(self, image):
        width, height = image.size
        resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)
        target_size = (int(resize_ratio * width), int(resize_ratio * height))
        resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)
        batch_seg_map = self.sess.run(
            self.OUTPUT_TENSOR_NAME,
            feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]})
        seg_map = batch_seg_map[0]
        return resized_image, seg_map

    def returnSize(self,image):
        width, height = image.size
        resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)
        target_size = (int(resize_ratio * width), int(resize_ratio * height))
        return target_size


def create_pascal_label_colormap():
    """Creates a label colormap used in PASCAL VOC segmentation benchmark.

    Returns:
      A Colormap for visualizing segmentation results.
    """
    colormap = np.zeros((256, 3), dtype=int)
    ind = np.arange(256, dtype=int)

    for shift in reversed(range(8)):
        for channel in range(3):
            colormap[:, channel] |= ((ind >> channel) & 1) << shift
        ind >>= 3

    return colormap


def label_to_color_image(label):
    """Adds color defined by the dataset colormap to the label.

    Args:
      label: A 2D array with integer type, storing the segmentation label.

    Returns:
      result: A 2D array with floating type. The element of the array
        is the color indexed by the corresponding element in the input label
        to the PASCAL color map.

    Raises:
      ValueError: If label is not of rank 2 or its value is larger than color
        map maximum entry.
    """
    if label.ndim != 2:
        raise ValueError('Expect 2-D input label')

    colormap = create_pascal_label_colormap()

    if np.max(label) >= len(colormap):
        raise ValueError('label value too large.')

    return colormap[label]


def vis_segmentation(image, seg_map):
    """Visualizes input image, segmentation map and overlay view."""
    plt.figure(figsize=(15, 5))
    grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])

    plt.subplot(grid_spec[0])
    plt.imshow(image)
    plt.axis('off')
    plt.title('input image')

    plt.subplot(grid_spec[1])
    seg_image = label_to_color_image(seg_map).astype(np.uint8)
    plt.imshow(seg_image)
    plt.axis('off')
    plt.title('segmentation map')

    plt.subplot(grid_spec[2])
    plt.imshow(image)
    plt.imshow(seg_image, alpha=0.7)
    plt.axis('off')
    plt.title('segmentation overlay')

    unique_labels = np.unique(seg_map)
    ax = plt.subplot(grid_spec[3])
    plt.imshow(
        FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest')
    ax.yaxis.tick_right()
    plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
    plt.xticks([], [])
    ax.tick_params(width=0.0)
    plt.grid('off')
    plt.show()


# label setting
LABEL_NAMES = np.asarray([
    'background', 'Stromal'
])

FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)


model_path = "./models/research/deeplab/datasets/pascal_voc_seg/exp/train_on_trainval_set/export/frozen_inference_graph.pb"
save_path = "./1_stromal_seg_img/"
original_path = "./0_pre_seg_img/all/"
if not os.path.exists(save_path):
    os.makedirs(save_path)

files = glob(original_path+"*png")

for i in range(1000,len(files)):
  img_path = files[i]


  # load model
  model = DeepLabModel(model_path)

  # read image
  original_im = Image.open(img_path)

  # inferences DeepLab model
  resized_im, seg_map = model.run(original_im)
  seg_map = seg_map * 255

  pilImg = Image.fromarray(np.uint8(seg_map))
  pilImg.save(save_path + os.path.basename(img_path))
  print(str(i) + "/"+ str(len(files)))

### Binalization, Fill holes

In [None]:
out_path_3 = "./1_stromal_seg_img_epithelium/"
files = glob(original_img_path + "*png")

In [None]:
import cv2
import numpy as np
import scipy.ndimage as ndimage
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt

seg_img_path = "./1_stromal_seg_img/"
original_img_path = "./0_pre_seg_img/all/"
out_path_1 = "./1_stromal_seg_img_binarization/"
out_path_2 = "./1_stromal_seg_img_back/"
out_path_3 = "./1_stromal_seg_img_epithelium/"
if not os.path.exists(out_path_1):
    os.makedirs(out_path_1)
    os.makedirs(out_path_2)
    os.makedirs(out_path_3)
    
files = glob(original_img_path + "*png")

for i in range(len(files)):
  print(str(i) + "/" + str(len(files)))
  im_path = files[i]
  im = np.array(Image.open(im_path).convert('L'))
  seg_im = np.array(Image.open(seg_img_path + os.path.basename(im_path)) )

  # white mask of stromal
  seg_im = np.where(seg_im == 0, 1, 0)
  im = im * seg_im

  pilImg = Image.fromarray(np.uint8(im))
  pilImg.save(out_path_1 + os.path.basename(im_path))

  # binarization----------------------------------------------------
  img_raw = cv2.imread(out_path_1 + os.path.basename(im_path), 1)
  img = cv2.cvtColor(img_raw, cv2.COLOR_BGR2GRAY)
  h, w = img.shape
  #GaussianBlur
  img_blur = cv2.GaussianBlur(img,(5,5),0)

  ret,th = cv2.threshold(img,180,255,cv2.THRESH_BINARY)

  kernel = np.ones((3,3),np.uint8)
  th = cv2.dilate(th,kernel,iterations = 1)

  cv2.imwrite(out_path_1 + os.path.basename(im_path), th)

  # Fill Holes----------------------------------------------------
  th_fill = ndimage.binary_fill_holes(th).astype(int) * 255

  th_fill_2 = np.copy(th_fill)
  th_fill_2 = np.where(th_fill_2 == 0, 255, 0)

  th_fill_2  = ndimage.binary_fill_holes(th_fill_2).astype(int) * 255
  cv2.imwrite(out_path_2 + os.path.basename(im_path), th_fill_2)

  # mask stromal
  th_fill_3 = th_fill_2 * seg_im
  cv2.imwrite(out_path_3 + os.path.basename(im_path), th_fill_3)