In [1]:
import time
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import cv2
import os
import random
import glob
from PIL import Image
from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage.interpolation import map_coordinates
from torch.utils.data.sampler import SubsetRandomSampler
import torch.utils.data as data
from skimage import io
from skimage.exposure import histogram
from skimage.morphology import binary_dilation,binary_erosion,disk,square
from skimage.filters.rank import mean_bilateral
from skimage.color import rgb2gray
from skimage.filters import threshold_otsu, gaussian

In [2]:
path_dir = '/scratch/netra/Datasets/Drive_Dataset/'

In [4]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [5]:
import torch.utils.data as data

class DataLoaderSegmentation(data.Dataset):
    def __init__(self,folder_path,transform = None):
        super(DataLoaderSegmentation, self).__init__()
        self.img_files = glob.glob(os.path.join(folder_path,'images','*.tif'))
        self.mask_files = glob.glob(os.path.join(folder_path,'new_mask','*.bmp'))
        self.alpha_files = glob.glob(os.path.join(folder_path,'alpha_mask','*gif'))
        self.transforms = transform
        #for img_path in img_files:
         #   self.mask_files.append(os.path.join(folder_path,'masks',os.path.basename(img_path))
         
    def mask_to_class(self,mask):
        target = torch.from_numpy(mask)
        assert target.shape[2] ==3
        h,w = target.shape[0],target.shape[1]
        masks = torch.empty(h, w, dtype=torch.long)
        colors = torch.unique(target.view(-1,target.size(2)),dim=0).numpy()
        target = target.permute(2, 0, 1).contiguous()
        mapping = {tuple(c): t for c, t in zip(colors.tolist(), range(len(colors)))}
        for k in mapping:
            idx = (target==torch.tensor(k, dtype=torch.uint8).unsqueeze(1).unsqueeze(2))
            validx = (idx.sum(0) == 3) 
            masks[validx] = torch.tensor(mapping[k], dtype=torch.long)
        return masks
    
    def elastic_transform_nearest(self,image, alpha=1000, sigma=20, spline_order=0, mode='nearest', random_state=np.random):
        
        image = np.array(image)
       # assert image.ndim == 3
        shape = image.shape[:2]

        dx = gaussian_filter((random_state.rand(*shape) * 2 - 1),
                         sigma, mode="constant", cval=0) * alpha
        dy = gaussian_filter((random_state.rand(*shape) * 2 - 1),
                      sigma, mode="constant", cval=0) * alpha

        x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
        indices = [np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))]
        result = np.empty_like(image)
        for i in range(image.shape[2]):
            result[:, :, i] = map_coordinates(
            image[:, :, i], indices, order=spline_order, mode=mode).reshape(shape)
        result = Image.fromarray(result)
        return result
    
    def elastic_transform_bilinear(self,image, alpha=1000, sigma=20, spline_order=1, mode='nearest', random_state=np.random):
        

        image = np.array(image)
        #assert image.ndim == 3
        shape = image.shape[:2]
        dx = gaussian_filter((random_state.rand(*shape) * 2 - 1),
                         sigma, mode="constant", cval=0) * alpha
        dy = gaussian_filter((random_state.rand(*shape) * 2 - 1),
                         sigma, mode="constant", cval=0) * alpha

        x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
        indices = [np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))]
        result = np.empty_like(image)
        for i in range(image.shape[2]):
            result[:, :, i] = map_coordinates(
            image[:, :, i], indices, order=spline_order, mode=mode).reshape(shape)
        result = Image.fromarray(result)
        return result
    
    def gaussian_blur(self,img_dir,mask_dir):
        img = io.imread(img_dir,plugin = 'pil')
        mask = io.imread(mask_dir,plugin = 'pil')
        a = np.pad(img, ((100,100), (100,100), (0,0)), mode = "constant")
        img = a
        grayscale = rgb2gray(a)
        global_thresh = threshold_otsu(grayscale)
        binary_global1 = grayscale > global_thresh
        
        num_px_to_expand = 100
        # process each channel (RGB) separately
        for channel in range(a.shape[2]):

    # select a single channel
            one_channel = a[:, :, channel]

    # reset binary_global for the each channel
            binary_global = binary_global1.copy()

    # erode by 5 px to get rid of unusual edges from original image
            binary_global = binary_erosion(binary_global, disk(5))

    # turn everything less than the threshold to 0
            one_channel = one_channel * binary_global

    # update pixels one at a time
            for jj in range(num_px_to_expand):

        # get 1 px ring of to update
                px_to_update = np.logical_xor(binary_dilation(binary_global, disk(1)), 
                                      binary_global)

        # update those pixels with the average of their neighborhood
                x, y = np.where(px_to_update == 1)

                for x, y in zip(x,y):
            # make 3 x 3 px slices
                    slices = np.s_[(x-1):(x+2), (y-1):(y+2)]

            # update a single pixel
                    one_channel[x, y] = (np.sum(one_channel[slices]*
                                             binary_global[slices]) / 
                                       np.sum(binary_global[slices]))      


        # update original image
                a[:,:, channel] = one_channel

        # increase binary_global by 1 px dilation
                binary_global = binary_dilation(binary_global, disk(1))
            
            
            image_blur = cv2.GaussianBlur(a,(65,65),60)
            new_image = cv2.subtract(img,image_blur, dtype=cv2.CV_32F)
            out = cv2.normalize(new_image, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
            out = out[100:684,100:665,:]
            mask_bool = mask>0
            crop_img = a[100:684,100:665,:]
            gray = rgb2gray(crop_img)
            global_thresh = threshold_otsu(grayscale)
            binary_global_crop = gray > global_thresh
            px_to_update = np.logical_not(np.logical_and(binary_global_crop,mask_bool))
            x, y = np.where(px_to_update == 1)
            for x, y in zip(x,y):
                out[x,y,:] = 0
            
            out = Image.fromarray(out)
            return out

In [6]:
    def transform(self,image,mask):
        i, j, h, w = transforms.RandomCrop.get_params(
        image, output_size=(512, 512))
        image = TF.crop(image, i, j, h, w)
        mask = TF.crop(mask, i, j, h, w)
        
        #image = TF.Lambda(gaussian_blur),
       # mask = 
        #image = TF.Lambda(elastic_transform)
        # Random horizontal flipping
        #image = transforms.transforms.Lambda(gaussian_blur)
        if random.random() > 0.5:
            image = TF.hflip(image)
            mask = TF.hflip(mask)
        
        image = TF.rotate(image,90)
        mask = TF.rotate(mask,90)
        image = TF.rotate(image,180)
        mask = TF.rotate(mask,180)
        image = TF.rotate(image,270)
        mask = TF.rotate(mask,270)

        # Transform to tensor
        #image = TF.to_tensor(image)
#         mask = TF.to_tensor(mask)
        return image, mask
     
    
    def __getitem__(self, index):
        img_path = self.img_files[index]
        mask_path = self.mask_files[index]
        alpha_path = self.alpha_files[index]
        #data = Image.open(img_path)
        label = Image.open(mask_path)
       # label = np.array(label)
        data = self.gaussian_blur(img_path,alpha_path)
        data = self.elastic_transform_bilinear(data)
        label = self.elastic_transform_nearest(label)
        data,label = self.transform(data,label)
        label = np.array(label)
        data = np.array(data)
        #label = np.transpose(label,(2,0,1))
        mask = self.mask_to_class(label)
        if transforms is not None:
             data = self.transforms(data)
        return data,mask
       # return data, torch.from_numpy(label).long()
           
    def __len__(self):
        return len(self.img_files)


In [7]:
from skimage.segmentation import find_boundaries

w0 = 10
sigma = 5

def make_weight_map(masks):
    """
    Generate the weight maps as specified in the UNet paper
    for a set of binary masks.
    
    Parameters
    ----------
    masks: array-like
        A 3D array of shape (n_masks, image_height, image_width),
        where each slice of the matrix along the 0th axis represents one binary mask.

    Returns
    -------
    array-like
        A 2D array of shape (image_height, image_width)
    
    """
    masks = masks.numpy()
    nrows, ncols = masks.shape[1:]
    masks = (masks > 0).astype(int)
    distMap = np.zeros((nrows * ncols, masks.shape[0]))
    X1, Y1 = np.meshgrid(np.arange(nrows), np.arange(ncols))
    X1, Y1 = np.c_[X1.ravel(), Y1.ravel()].T
    for i, mask in enumerate(masks):
        # find the boundary of each mask,
        # compute the distance of each pixel from this boundary
        bounds = find_boundaries(mask, mode='inner')
        X2, Y2 = np.nonzero(bounds)
        xSum = (X2.reshape(-1, 1) - X1.reshape(1, -1)) ** 2
        ySum = (Y2.reshape(-1, 1) - Y1.reshape(1, -1)) ** 2
        distMap[:, i] = np.sqrt(xSum + ySum).min(axis=0)
    ix = np.arange(distMap.shape[0])
    if distMap.shape[1] == 1:
        d1 = distMap.ravel()
        border_loss_map = w0 * np.exp((-1 * (d1) ** 2) / (2 * (sigma ** 2)))
    else:
        if distMap.shape[1] == 2:
            d1_ix, d2_ix = np.argpartition(distMap, 1, axis=1)[:, :2].T
        else:
            d1_ix, d2_ix = np.argpartition(distMap, 2, axis=1)[:, :2].T
        d1 = distMap[ix, d1_ix]
        d2 = distMap[ix, d2_ix]
        border_loss_map = w0 * np.exp((-1 * (d1 + d2) ** 2) / (2 * (sigma ** 2)))
    xBLoss = np.zeros((nrows, ncols))
    xBLoss[X1, Y1] = border_loss_map
    # class weight map
    loss = np.zeros((nrows, ncols))
    w_1 = 1 - masks.sum() / loss.size
    w_0 = 1 - w_1
    loss[masks.sum(0) == 1] = w_1
    loss[masks.sum(0) == 0] = w_0
    ZZ = xBLoss + loss
    ZZ = torch.from_numpy(ZZ)
    ZZ = ZZ.type(torch.float)
    ZZ = ZZ.cuda()
    return ZZ


In [8]:

def elastic_transform_bilinear(image, alpha=1000, sigma=20, spline_order=1, mode='constant', random_state=np.random):
    """Elastic deformation of image as described in [Simard2003]_.
    .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
       Convolutional Neural Networks applied to Visual Document Analysis", in
       Proc. of the International Conference on Document Analysis and
       Recognition, 2003.
    """
#     assert image.ndim == 3
    image = Image.open(image)
    image = np.array(image)
 #   assert image.ndim == 3
    shape = image.shape[:2]

    dx = gaussian_filter((random_state.rand(*shape) * 2 - 1),
                         sigma, mode="constant", cval=0) * alpha
    dy = gaussian_filter((random_state.rand(*shape) * 2 - 1),
                         sigma, mode="constant", cval=0) * alpha

    x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
    indices = [np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))]
    result = np.empty_like(image)
    for i in range(image.shape[2]):
        result[:, :, i] = map_coordinates(
            image[:, :, i], indices, order=spline_order, mode=mode).reshape(shape)
    result = Image.fromarray(result)
    return result



In [9]:

def elastic_transform_nearest(image, alpha=1000, sigma=20, spline_order=0, mode='nearest', random_state=np.random):
    """Elastic deformation of image as described in [Simard2003]_.
    .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
       Convolutional Neural Networks applied to Visual Document Analysis", in
       Proc. of the International Conference on Document Analysis and
       Recognition, 2003.
    """
#     assert image.ndim == 3
    image = np.array(image)
   # assert image.ndim == 3
    shape = image.shape[:2]

    dx = gaussian_filter((random_state.rand(*shape) * 2 - 1),
                         sigma, mode="constant", cval=0) * alpha
    dy = gaussian_filter((random_state.rand(*shape) * 2 - 1),
                         sigma, mode="constant", cval=0) * alpha

    x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
    indices = [np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))]
    result = np.empty_like(image)
    for i in range(image.shape[2]):
        result[:, :, i] = map_coordinates(
            image[:, :, i], indices, order=spline_order, mode=mode).reshape(shape)
    result = Image.fromarray(result)
    return result
