In [2]:
import torch
import torch.nn as nn

In [5]:
def double_conv(in_channels, out_channels):
    conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3),
        nn.ReLU(inplace=True)
    )
    return conv

In [28]:
def crop_image(tensor, target_size):
    tensor_size = tensor.size()[2]
    target_size = target_size.size()[2]
    delta = tensor_size - target_size
    delta = delta // 2
    return tensor[:, :, delta:tensor_size-delta, delta:tensor_size-delta]

In [41]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.down_conv_1 = double_conv(1,64)
        self.down_conv_2 = double_conv(64,128)
        self.down_conv_3 = double_conv(128,256)
        self.down_conv_4 = double_conv(256,512)
        self.down_conv_5 = double_conv(512,1024)

        self.up_trans_1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, stride=2, kernel_size=2)
        self.up_conv_1 = double_conv(1024, 512)
        self.up_trans_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, stride=2, kernel_size=2)
        self.up_conv_2 = double_conv(512, 256)
        self.up_trans_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, stride=2, kernel_size=2)
        self.up_conv_3 = double_conv(256, 128)
        self.up_trans_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, stride=2, kernel_size=2)
        self.up_conv_4 = double_conv(128, 64)


        self.out = nn.Conv2d(in_channels=64, out_channels=2, stride=2, kernel_size=1)
        
    def forward(self, image): #batch size, channel, height, width
        # encoder 
        x1 = self.down_conv_1(image);           print("x1", x1.size())
        x2 = self.max_pool_2x2(x1);           
        x3 = self.down_conv_2(x2);              print("x3", x3.size())
        x4 = self.max_pool_2x2(x3);           
        x5 = self.down_conv_3(x4);              print("x5", x5.size())
        x6 = self.max_pool_2x2(x5);           
        x7 = self.down_conv_4(x6);              print("x7", x7.size())
        x8 = self.max_pool_2x2(x7);           
        x9 = self.down_conv_5(x8);              print("x9", x9.size())


        # decoder
        x = self.up_trans_1(x9);                print("x", x.size())
        y = crop_image(x7, x);                  print("y", y.size())   
        x = self.up_conv_1(torch.cat([x,y],1)); print("x", x.size()) 

        x = self.up_trans_2(x);                 print("x", x.size())
        y = crop_image(x5, x);                  print("y", y.size())   
        x = self.up_conv_2(torch.cat([x,y],1)); print("x", x.size()) 

        x = self.up_trans_3(x);                 print("x", x.size())
        y = crop_image(x3, x);                  print("y", y.size())   
        x = self.up_conv_3(torch.cat([x,y],1)); print("x", x.size()) 

        x = self.up_trans_4(x);                 print("x", x.size())
        y = crop_image(x1, x);                  print("y", y.size())   
        x = self.up_conv_4(torch.cat([x,y],1)); print("x", x.size()) 

        x = self.out(x);                        print("x", x.size()) 
        return x


In [42]:
image=torch.rand((1,1,572,572))
model = UNet()
print(model(image))

x1 torch.Size([1, 64, 568, 568])
x3 torch.Size([1, 128, 280, 280])
x5 torch.Size([1, 256, 136, 136])
x7 torch.Size([1, 512, 64, 64])
x9 torch.Size([1, 1024, 28, 28])
x torch.Size([1, 512, 56, 56])
y torch.Size([1, 512, 56, 56])
x torch.Size([1, 512, 52, 52])
x torch.Size([1, 256, 104, 104])
y torch.Size([1, 256, 104, 104])
x torch.Size([1, 256, 100, 100])
x torch.Size([1, 128, 200, 200])
y torch.Size([1, 128, 200, 200])
x torch.Size([1, 128, 196, 196])
x torch.Size([1, 64, 392, 392])
y torch.Size([1, 64, 392, 392])
x torch.Size([1, 64, 388, 388])
x torch.Size([1, 2, 194, 194])
tensor([[[[0.0641, 0.0668, 0.0605,  ..., 0.0616, 0.0639, 0.0606],
          [0.0697, 0.0661, 0.0627,  ..., 0.0616, 0.0625, 0.0629],
          [0.0645, 0.0609, 0.0643,  ..., 0.0622, 0.0660, 0.0617],
          ...,
          [0.0617, 0.0684, 0.0721,  ..., 0.0654, 0.0618, 0.0669],
          [0.0638, 0.0607, 0.0615,  ..., 0.0587, 0.0626, 0.0638],
          [0.0599, 0.0586, 0.0658,  ..., 0.0650, 0.0633, 0.0586]],

   

In [3]:
# https://www.mathworks.com/help/vision/ug/getting-started-with-semantic-segmentation-using-deep-learning.html
# https://medium.com/@arthur_ouaknine/review-of-deep-learning-algorithms-for-image-semantic-segmentation-509a600f7b57
# differenece between classificaiton/object detection/image segmentaion - # https://developer.qualcomm.com/software/qualcomm-neural-processing-sdk/learning-resources/image-segmentation-deeplab-neural-processing-sdk/classification-object-detection-segmentation
