In [1]:
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import cv2

In [None]:
class UNET(nn.Module) :
    def __init__(self) :
        super(Unet, self).__init__()
        self.encoder_1 = self.conv(1, 64)
        self.encoder_2 = self.conv(64, 128)
        self.encoder_3 = self.conv(128, 256)
        self.encoder_4 = self.conv(256, 512)
        
        self.connect = self.conv(512, 1024)
        
        self.decoder_1 = self.conv(512 + 1024, 512)
        self.decoder_2 = self.conv(512 + 256, 256)
        self.decoder_3 = self.conv(256 + 128, 128)
        self.decoder_4 = self.conv(128 + 64, 64)
        
        self.fc = nn.Conv2d(64, 1, kernel_size=1)
    
    def conv(self, in_c, out_c) :
        lay = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size = 3, padding = 1), 
            nn.ReLU(inplace = True), 
            nn.Conv2d(out_c, out_c, kernel_size = 3, padding = 1), 
            nn.ReLU(inplace = True)
        )
        return lay

    def forward(self, x) :
        encoder_1 = self.encoder_1(x)
        encoder_2 = self.encoder_2(F.max_pool2d(encoder_1, kernel_size = 2))
        encoder_3 = self.encoder_3(F.max_pool2d(encoder_2, kernel_size = 2))
        encoder_4 = self.encoder_4(F.max_pool2d(encoder_3, kernel_size = 2))
        
        connect = self.connect(F.max_pool2d(encoder_4, kernel_size = 2))
        
        decoder_1 = self.decoder_1(torch.cat([F.interpolate(connect, scale_factor = 2, mode = 'bilinear', 
                                                            align_corners=True), encoder_4], dim = 1))
        decoder_2 = self.decoder_2(torch.cat([F.interpolate(decoder_1, scale_factor = 2, mode = 'bilinear', 
                                                            align_corners=True), encoder_3], dim = 1)) 
        decoder_3 = self.decoder_3(torch.cat([F.interpolate(decoder_2, scale_factor = 2, mode = 'bilinear', 
                                                            align_corners=True), encoder_2], dim = 1))
        decoder_4 = self.decoder_4(torch.cat([F.interpolate(decoder_3, scale_factor = 2, mode = 'bilinear', 
                                                            align_corners=True), encoder_1], dim = 1))        
        
        fc = self.fc(decoder_4)
        
        return torch.sigmoid(fc)

def boundary_mask(mask, dilation_iterations) :
    kernel = np.ones((3, 3), np.unit8)
    dilate = cv2.dilate(mask.numpy(), kernel, interations = dilation_iterations)
    boundary = dilated - mask.numpy()
    
    return torch.tensor(boundary, dtype = torch.float32)

def boundary_loss(y_true, y_pred) :
    mask = boundary_mask(y_true, 2)
    bce = nn.BCELoss()
    loss = bce(y_pred, y_true)
    boundary_loss_value = loss * mask
    
    return torch.mean(boundary_loss_value)