In [1]:
import torch
import torch.nn as nn

In [16]:
def double_conv(no_in, no_out):
    conv = nn.Sequential(nn.Conv2d(no_in, no_out, kernel_size=3), nn.ReLU(inplace=True),
                  nn.Conv2d(no_out, no_out, kernel_size=3), nn.ReLU(inplace=True)
                 )
    return conv

def crop_img(tensor, target_tensor):
    target_size = target_tensor.size()[2]
    tensor_size = target_tensor.size()[2]
    delta = tensor_size - target_size
    delta = delta // 2
    return tensor[:, :, delta:tensor_size-delta, delta:tensor_size-delta]


class UNet(nn.Module):
    
    def __init__(self, no_in=1, no_out=64):
        super(UNet, self).__init__()
        self.no_in=no_in
        self.no_out_1=no_out
        self.no_out_2 = self.no_out_1*2
        self.no_out_3 = self.no_out_2*2
        self.no_out_4 = self.no_out_3*2
        self.no_out_5 = self.no_out_4*2
        self.max_poool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.down_conv_1 = double_conv(self.no_in, self.no_out_1)
        self.down_conv_2 = double_conv(self.no_out_1, self.no_out_2)
        self.down_conv_3 = double_conv(self.no_out_2, self.no_out_3)
        self.down_conv_4 = double_conv(self.no_out_3, self.no_out_4)
        self.down_conv_5 = double_conv(self.no_out_4, self.no_out_5)
        
        self.up_trans_1 = nn.ConvTranspose2d(
            in_channels=self.no_out_5, out_channels=self.no_out_4,
            kernel_size=2, stride=2)
        self.up_conv_1 = double_conv(self.no_out_5, self.no_out_4)
        
        self.up_trans_2 = nn.ConvTranspose2d(
            in_channels=self.no_out_4, out_channels=self.no_out_3,
            kernel_size=2, stride=2)
        
        self.up_conv_2 = double_conv(self.no_out_4, self.no_out_3)
        
        self.up_trans_3 = nn.ConvTranspose2d(
            in_channels=self.no_out_3, out_channels=self.no_out_2,
            kernel_size=2, stride=2)
        
        self.up_conv_3 = double_conv(self.no_out_3, self.no_out_2)
        
        self.up_trans_4 = nn.ConvTranspose2d(
            in_channels=self.no_out_2, out_channels=self.no_out_1,
            kernel_size=2, stride=2)
        
        self.up_conv_4 = double_conv(self.no_out_2, self.no_out_1)
        
        self.out = nn.Conv2d(
        in_channels=64,
        out_channels=2,
        kernel_size=1)

        
    def foward(self, image):
        x1 = self.down_conv_1(image)
        x2 = self.max_poool_2x2(x1)
        x3 = self.down_conv_2(x2)
        x4 = self.max_poool_2x2(x3)
        x5 = self.down_conv_3(x4)
        x6 = self.max_poool_2x2(x5)
        x7 = self.down_conv_4(x6)
        x8 = self.max_poool_2x2(x7)
        x9 = self.down_conv_5(x8)
        x = self.up_trans_1(x9)
        y = crop_img(x7, x)
        x = self.up_conv_1(torch.cat([x, y], 1))
        x = self.up_trans_2(x)
        y = crop_img(x5, x)
        x = self.up_conv_2(torch.cat([x, y], 1))
        x = self.up_trans_3(x)
        y = crop_img(x3, x)
        x = self.up_conv_3(torch.cat([x, y], 1))
        x = self.up_trans_4(x)
        y = crop_img(x1, x)
        x = self.up_conv_4(torch.cat([x, y], 1))
        x = self.out(x)
        print(x.size())
        return x

        
        
if __name__ == "__main__":
    image = torch.rand((1, 1, 572, 572))
    model = UNet()
    print(model.foward(image))
   

torch.Size([1, 2, 388, 388])
tensor([[[[0.0658, 0.0640, 0.0649,  ..., 0.0620, 0.0600, 0.0644],
          [0.0629, 0.0600, 0.0628,  ..., 0.0643, 0.0656, 0.0639],
          [0.0627, 0.0632, 0.0623,  ..., 0.0619, 0.0613, 0.0613],
          ...,
          [0.0602, 0.0581, 0.0554,  ..., 0.0654, 0.0569, 0.0615],
          [0.0565, 0.0599, 0.0588,  ..., 0.0601, 0.0661, 0.0611],
          [0.0579, 0.0613, 0.0605,  ..., 0.0559, 0.0647, 0.0592]],

         [[0.0985, 0.0996, 0.0985,  ..., 0.1012, 0.1013, 0.1004],
          [0.1002, 0.1002, 0.1014,  ..., 0.0955, 0.0957, 0.0976],
          [0.0955, 0.0943, 0.0952,  ..., 0.0975, 0.1000, 0.0969],
          ...,
          [0.1032, 0.1005, 0.1015,  ..., 0.0963, 0.1044, 0.0985],
          [0.0977, 0.0989, 0.0982,  ..., 0.1026, 0.0932, 0.1012],
          [0.0986, 0.1019, 0.1022,  ..., 0.0972, 0.0959, 0.0969]]]],
       grad_fn=<MkldnnConvolutionBackward>)
