# Segmentation of the whole dataset using previously built U-Net

In [1]:
!nvidia-smi -L

GPU 0: Quadro P1000 (UUID: GPU-337a71bd-41de-e1d4-25dd-4ef2bc95668d)


In [2]:
!pip install -q scikit-image
!pip install -q pandas

You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.[0m


In [48]:
import tensorflow as tf
import skimage as ski
import numpy as np
import matplotlib.pyplot as plt
import shutil
from tensorflow.keras import backend as K
import cv2
import os
os.chdir("/tf")

In [4]:
def crop(img, h, w):
    """Crops the image to match with specified height and width values

    Args:
        img (ndarray): image name
        h (int): wanted number of pixels for height
        w (int): wanted number of pixels for width

    Returns:
        img (ndarray): cropped image
    """
    diff_h = img.shape[0] - h
    diff_w = img.shape[1] - w
    if diff_h > 0:
        if diff_h%2 == 0:        
            img = img[diff_h//2:img.shape[0] - diff_h//2, :]
        else:
            img = img[diff_h//2 + 1:img.shape[0] - diff_h//2, :]
    if diff_w > 0:
        if diff_w%2 == 0:
            img = img[:, diff_w//2:img.shape[1] - diff_w//2]
        else:
            img = img[:, diff_w//2 + 1:img.shape[1] - diff_w//2]
    return img

In [5]:
def resize_padding(image, RGB, h_target, w_target):
    """Adds some padding to an image to correspond to specified height and width

    Args:
        image (ndarray): image name
        RGB (boolean): True if RGB (3 channels), False if greyscale
        h_target (int): wanted number of pixels for height
        w_target (int): wanted number of pixels for width
    
    Returns:
        result (ndarray): padded image
    """

    h, w = image.shape[0:2]
    if h_target < h or w_target < w:
        print("Specified dimensions smaller than the input image")
        return None
        
    delta_h = h_target - h
    delta_w = w_target - w

    pad_up = delta_h//2
    pad_down = h_target - h - pad_up
    pad_left = delta_w//2
    pad_right = w_target - w - pad_left

    if RGB:
        return cv2.copyMakeBorder(image, pad_up, pad_down, pad_left, pad_right, 
                                  cv2.BORDER_CONSTANT, value = [255,255,255])
    return cv2.copyMakeBorder(image, pad_up, pad_down, pad_left, pad_right, 
                              cv2.BORDER_CONSTANT, value = 0)

In [16]:
def merge_vertical(top_img, bot_img, overlap):
    """Merges 2 patches of 512*512 pixels each

    Args:
        top_img (ndarray): patch corresponding to the top of the image
        bot_img (ndarray): patch corresponding to the bottom of the image
        overlap (int): number of pixels on which the 2 patches overlap

    Returns:
        full_img (ndarray): image resulting from the coupling of the 2 patches
    """
    top = top_img[0:top_img.shape[0] - overlap, :]
    bot = bot_img[bot_img.shape[0] - top.shape[0]:, :]
    mid_top = top_img[top_img.shape[0] - overlap:top_img.shape[0], :]       
    mid_bot = bot_img[0:overlap, :]

    full_img = cv2.vconcat([top, mid_top])
    full_img = cv2.vconcat([full_img, bot])

    for i in range(1, 10):
      full_img[512 - (i-1),:] = i*.1*top_img[512 - i,:] + (1 - i*.1)*bot_img[overlap - (i-1),:]
    
    return full_img

In [10]:
def jaccard_coef(y_true, y_pred):
    """Computes Jaccard score between ground truth and segmented image

    Args:
        y_true (ndarray): ground truth
        y_pred (ndarray): segmented image

    Returns:
        float: Jaccard score
    """
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    
    return (intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + 1.0)

The jaccard_coef() function is taken from: <br>
https://github.com/bnsreenu/python_for_microscopists/blob/2c2b120fec17d8686572719916920bc05e3288f8/207-simple_unet_model_with_jacard.py

In [11]:
loss = "categorical_crossentropy" 

### Model importation

In [15]:
model = tf.keras.models.load_model('notebooks/unet3.hdf5',
                                  custom_objects = {"categorical_crossentropy":loss, "jaccard_coef":jaccard_coef})

### Model prediction on dataset

In [53]:
def segment_unet(img):
    """Uses our U-Net network to segment a given image

    Args:
        img (ndarray): image to be segmented

    Returns:
        img_seg (ndarray): probability map resulting from segmentation
    """
    or_shape = img.shape
    if or_shape[0] > 768:
        img = crop(img, 768, img.shape[1])
    elif or_shape[1] > 512:
        img = crop(img, img.shape[0], 512)
        
    img_padded = np.array(resize_padding(img, True, 768, 512))
    patches = np.array([img_padded[256:, :]] +  [img_padded[0:512, :]])
    predicted = [np.squeeze(model.predict(np.expand_dims(im, 0))) for im in patches]
    img_seg = merge_vertical(predicted[1], predicted[0], abs(img_padded.shape[0] - 2*512))
    img_seg = crop(img_seg, or_shape[0], or_shape[1])
    return img_seg[:,:,2]

In [None]:
folders = ski.io.imread_collection('datasets/Dossiers_Mildiou_Maj2023/*').files

for folder in folders:
    if os.path.exists(folder + '/Segment_Unet'): 
        shutil.rmtree(folder + '/Segment_Unet')
    Images = ski.io.imread_collection(folder +  '/Recalibrated_CPD/*').files
    Visible = []
    for im in Images:
        if not any([x in im for x in ['Fv', 'Fo', 'Fm', 'FvFm', 'error']]):   
            Visible.append(im)
    Visible = sorted(Visible, key=str.casefold, reverse=True)  
    
    os.mkdir(folder + '/Segment_Unet')  
    for im_path in Visible:
        im = ski.io.imread(im_path)
        res = segment_unet(im)
        
        #Post-processing step:
        
        if len(Visible) == 4 and "_J5_" not in im_path or len(Visible) == 3 and "_J4_" not in im_path:  
            for i in range(res.shape[0]):
                for j in range(res.shape[1]):
                    if res[i, j] > res_upcoming_day[i, j]:
                        res[i, j] = res_upcoming_day[i, j]
        res_upcoming_day = res
            
        ski.io.imsave(folder + '/Segment_Unet' + '/P_map_' + im_path.rsplit('/', 1)[-1], res)