In [1]:
import torch
import torch.nn as nn
from torchsummary import summary

In [2]:
def double_conv(in_c,out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c,out_c,kernel_size=3),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c,out_c,kernel_size=3),
        nn.ReLU(inplace=True),
        )
    return conv


def crop_img(tensor,target_tensor):
    target_size = target_tensor.size()[2]
    tensor_size = tensor.size()[2]
    delta = tensor_size-target_size
    delta = delta // 2
    return tensor[:,:,delta:tensor_size-delta,delta:tensor_size-delta]


class UNet(nn.Module):
    def __init__(self):
        super(UNet,self).__init__()
        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2,stride=2)
        self.down_conv_1 = double_conv(3,64) 
        self.down_conv_2 = double_conv(64,128)
        self.down_conv_3 = double_conv(128,256) 
        self.down_conv_4 = double_conv(256,512) 
        self.down_conv_5 = double_conv(512,1024)
        
        self.up_trans_1 = nn.ConvTranspose2d(in_channels=1024,out_channels=512,kernel_size=2,stride=2)
        self.up_trans_2 = nn.ConvTranspose2d(in_channels=512,out_channels=256,kernel_size=2,stride=2) 
        self.up_trans_3 = nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=2,stride=2) 
        self.up_trans_4 = nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=2,stride=2) 
        
        self.up_conv_1  = double_conv(1024,512)
        self.up_conv_2  = double_conv(512,256)
        self.up_conv_3  = double_conv(256,128)
        self.up_conv_4  = double_conv(128,64)
        
        self.out = nn.Conv2d(in_channels=64,out_channels=2,kernel_size=1)
        
    def forward(self,image):
        x1 = self.down_conv_1(image)
        x2 = self.max_pool_2x2(x1)
        x3 = self.down_conv_2(x2)
        x4 = self.max_pool_2x2(x3)
        x5 = self.down_conv_3(x4)
        x6 = self.max_pool_2x2(x5)
        x7 = self.down_conv_4(x6)
        x8 = self.max_pool_2x2(x7)
        x9 = self.down_conv_5(x8)
        x = self.up_trans_1(x9)
        y = crop_img(x7,x)
        x = self.up_conv_1(torch.cat([x,y],1))
        
        x = self.up_trans_2(x)
        y = crop_img(x5,x)
        x = self.up_conv_2(torch.cat([x,y],1))
        
        x = self.up_trans_3(x)
        y = crop_img(x3,x)
        x = self.up_conv_3(torch.cat([x,y],1))
        
        x = self.up_trans_4(x)
        y = crop_img(x1,x)
        x = self.up_conv_4(torch.cat([x,y],1))
        
        x = self.out(x)
        return x

In [3]:
if __name__ == '__main__':
    image = torch.rand((1,1,572,572))
    unet = UNet()

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet = unet.to(device)

In [5]:
summary(unet,input_size=(3,572,572))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 570, 570]           1,792
              ReLU-2         [-1, 64, 570, 570]               0
            Conv2d-3         [-1, 64, 568, 568]          36,928
              ReLU-4         [-1, 64, 568, 568]               0
         MaxPool2d-5         [-1, 64, 284, 284]               0
            Conv2d-6        [-1, 128, 282, 282]          73,856
              ReLU-7        [-1, 128, 282, 282]               0
            Conv2d-8        [-1, 128, 280, 280]         147,584
              ReLU-9        [-1, 128, 280, 280]               0
        MaxPool2d-10        [-1, 128, 140, 140]               0
           Conv2d-11        [-1, 256, 138, 138]         295,168
             ReLU-12        [-1, 256, 138, 138]               0
           Conv2d-13        [-1, 256, 136, 136]         590,080
             ReLU-14        [-1, 256, 1

In [None]:
%matplotlib inline
import numpy as np
from skimage.io import imshow
from skimage.measure import label
from scipy.ndimage.morphology import distance_transform_edt
import numpy as np
import cv2
def generate_random_circles(n = 100, d = 256):
    circles = np.random.randint(0, d, (n, 3))
    x = np.zeros((d, d), dtype=int)
    f = lambda x, y: ((x - x0)**2 + (y - y0)**2) <= (r/d*10)**2
    for x0, y0, r in circles:
        x += np.fromfunction(f, x.shape)
    x = np.clip(x, 0, 1)
    
    return x

def unet_weight_map(y, wc=None, w0 = 10, sigma = 5):

    """
    Generate weight maps as specified in the U-Net paper
    for boolean mask.

    "U-Net: Convolutional Networks for Biomedical Image Segmentation"
    https://arxiv.org/pdf/1505.04597.pdf

    Parameters
    ----------
    mask: Numpy array
        2D array of shape (image_height, image_width) representing binary mask
        of objects.
    wc: dict
        Dictionary of weight classes.
    w0: int
        Border weight parameter.
    sigma: int
        Border width parameter.

    Returns
    -------
    Numpy array
        Training weights. A 2D array of shape (image_height, image_width).
    """

    labels = label(y)
    no_labels = labels == 0
    label_ids = sorted(np.unique(labels))[1:]

    if len(label_ids) > 1:
        distances = np.zeros((y.shape[0], y.shape[1], len(label_ids)))

        for i, label_id in enumerate(label_ids):
            distances[:,:,i] = distance_transform_edt(labels != label_id)

        distances = np.sort(distances, axis=2)
        d1 = distances[:,:,0]
        d2 = distances[:,:,1]
        w = w0 * np.exp(-1/2*((d1 + d2) / sigma)**2) * no_labels
    else:
        w = np.zeros_like(y)
    if wc:
        class_weights = np.zeros_like(y)
        for k, v in wc.items():
            class_weights[y == k] = v
        w = w + class_weights
    return w

y = generate_random_circles()
print(y.shape)
img = cv2.imread('./25.png',cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img,(512,512))
print(img.shape)

wc = {
    0: 1, # background
    1: 5  # objects
}

w = unet_weight_map(img, wc)

imshow(w)

(256, 256)
(512, 512)
