In [1]:
import torch
import torch.nn as nn

In [2]:
def double_conv(in_channels, out_channels):
    conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3),
        nn.ReLU(inplace=True)
    )
    return conv

In [3]:
def crop_image(tensor, target_size):
    tensor_size = tensor.size()[2]
    target_size = target_size.size()[2]
    delta = tensor_size - target_size
    delta = delta // 2
    return tensor[:, :, delta:tensor_size-delta, delta:tensor_size-delta] #batch size, channel, height, width

In [6]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.down_conv_1 = double_conv(1,64)
        self.down_conv_2 = double_conv(64,128)
        self.down_conv_3 = double_conv(128,256)
        self.down_conv_4 = double_conv(256,512)
        self.down_conv_5 = double_conv(512,1024)

        self.up_trans_1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, stride=2, kernel_size=2)
        self.up_conv_1 = double_conv(1024, 512)
        self.up_trans_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, stride=2, kernel_size=2)
        self.up_conv_2 = double_conv(512, 256)
        self.up_trans_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, stride=2, kernel_size=2)
        self.up_conv_3 = double_conv(256, 128)
        self.up_trans_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, stride=2, kernel_size=2)
        self.up_conv_4 = double_conv(128, 64)


        self.out = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)
        
    def forward(self, image): #batch size, channel, height, width
        # encoder 
        x1 = self.down_conv_1(image);           print("x1", x1.size())
        x2 = self.max_pool_2x2(x1);           
        x3 = self.down_conv_2(x2);              print("x3", x3.size())
        x4 = self.max_pool_2x2(x3);           
        x5 = self.down_conv_3(x4);              print("x5", x5.size())
        x6 = self.max_pool_2x2(x5);           
        x7 = self.down_conv_4(x6);              print("x7", x7.size())
        x8 = self.max_pool_2x2(x7);           
        x9 = self.down_conv_5(x8);              print("x9", x9.size())


        # decoder
        x = self.up_trans_1(x9);                
        y = crop_image(x7, x);                     
        x = self.up_conv_1(torch.cat([x,y],1)); print("x", x.size()) 

        x = self.up_trans_2(x);                 
        y = crop_image(x5, x);                    
        x = self.up_conv_2(torch.cat([x,y],1)); print("x", x.size()) 

        x = self.up_trans_3(x);                 
        y = crop_image(x3, x);                    
        x = self.up_conv_3(torch.cat([x,y],1)); print("x", x.size()) 

        x = self.up_trans_4(x);                 
        y = crop_image(x1, x);                     
        x = self.up_conv_4(torch.cat([x,y],1)); print("x", x.size()) 

        x = self.out(x);                        print("x", x.size()) 
        return x


In [7]:
image=torch.rand((1,1,572,572))
model = UNet()
print(model(image))

x1 torch.Size([1, 64, 568, 568])
x3 torch.Size([1, 128, 280, 280])
x5 torch.Size([1, 256, 136, 136])
x7 torch.Size([1, 512, 64, 64])
x9 torch.Size([1, 1024, 28, 28])
x torch.Size([1, 512, 52, 52])
x torch.Size([1, 256, 100, 100])
x torch.Size([1, 128, 196, 196])
x torch.Size([1, 64, 388, 388])
x torch.Size([1, 2, 388, 388])
tensor([[[[ 0.0237,  0.0233,  0.0158,  ...,  0.0246,  0.0224,  0.0176],
          [ 0.0228,  0.0208,  0.0201,  ...,  0.0229,  0.0198,  0.0223],
          [ 0.0197,  0.0222,  0.0244,  ...,  0.0219,  0.0231,  0.0216],
          ...,
          [ 0.0192,  0.0225,  0.0179,  ...,  0.0258,  0.0200,  0.0215],
          [ 0.0263,  0.0220,  0.0241,  ...,  0.0173,  0.0156,  0.0264],
          [ 0.0215,  0.0209,  0.0209,  ...,  0.0186,  0.0218,  0.0251]],

         [[-0.0510, -0.0535, -0.0521,  ..., -0.0547, -0.0517, -0.0505],
          [-0.0514, -0.0510, -0.0481,  ..., -0.0533, -0.0469, -0.0516],
          [-0.0511, -0.0486, -0.0520,  ..., -0.0524, -0.0517, -0.0476],
         

In [3]:
# https://www.mathworks.com/help/vision/ug/getting-started-with-semantic-segmentation-using-deep-learning.html
# https://medium.com/@arthur_ouaknine/review-of-deep-learning-algorithms-for-image-semantic-segmentation-509a600f7b57
# differenece between classificaiton/object detection/image segmentaion - # https://developer.qualcomm.com/software/qualcomm-neural-processing-sdk/learning-resources/image-segmentation-deeplab-neural-processing-sdk/classification-object-detection-segmentation
