In [None]:
import torch
import torch.nn as nn

def double_conv(in_channel,out_channel):
    conv=nn.Sequential(
        nn.Conv2d(in_channel,out_channel,kernel_size=3),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channel,out_channel,kernel_size=3),
        nn.ReLU(inplace=True),
    )
    return conv



def crop_img(tensor,target_tensor): #change img_size to target_tensor
    target_size=target_tensor.size()[2]  #[2] in torch.Size([1, 512, 56, 56]) 
    tensor_size=tensor.size()[2]    #[2] in torch.Size([1, 512, 56, 56])
    delta=tensor_size-target_size
    delta=delta//2
    return tensor[:,:,delta:tensor_size-delta,delta:tensor_size-delta]

class Unet(nn.Module):
    def __init__(self):
        super(Unet,self).__init__()
    
        self.max_pool_2x2=nn.MaxPool2d(kernel_size=2,stride=2)
        self.down_conv_1=double_conv(1,64)
        self.down_conv_2=double_conv(64,128)
        self.down_conv_3=double_conv(128,256)
        self.down_conv_4=double_conv(256,512)
        self.down_conv_5=double_conv(512,1024)

        self.up_trans_1=nn.ConvTranspose2d(in_channels=1024,out_channels=512,kernel_size=2,stride=2)
        self.up_conv_1=double_conv(1024,512)

        self.up_trans_2=nn.ConvTranspose2d(in_channels=512,out_channels=256,kernel_size=2,stride=2)
        self.up_conv_2=double_conv(512,256)

        self.up_trans_3=nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=2,stride=2)
        self.up_conv_3=double_conv(256,128)

        self.up_trans_4=nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=2,stride=2)
        self.up_conv_4=double_conv(128,64)

        self.out=nn.Conv2d(in_channels=64,out_channels=2,kernel_size=2)
        
    def forward(self,image):
        #encoder
        x1=self.down_conv_1(image)
        x2=self.max_pool_2x2(x1)
        x3=self.down_conv_2(x2)
        x4=self.max_pool_2x2(x3)
        x5=self.down_conv_3(x4)
        x6=self.max_pool_2x2(x5)
        x7=self.down_conv_4(x6)
        x8=self.max_pool_2x2(x7)
        x9=self.down_conv_5(x8)
        #print(x9.size()) #channel=1 ,image_size=28x28 ,filter_size=1024

        #decoder
        x=self.up_trans_1(x9)
        y=crop_img(x7,x)
        x=self.up_conv_1(torch.cat([x,y],1))

        x=self.up_trans_2(x)
        y=crop_img(x5,x)
        x=self.up_conv_2(torch.cat([x,y],1))

        x=self.up_trans_3(x)
        y=crop_img(x3,x)
        x=self.up_conv_3(torch.cat([x,y],1))

        x=self.up_trans_4(x)
        y=crop_img(x1,x)
        x=self.up_conv_4(torch.cat([x,y],1))

        x=self.out(x)
        print(x.size())
        return x
        
        

if __name__ == "__main__":
    image=torch.rand((1,1,572,572))
    model=Unet()
    print(model(image))

torch.Size([1, 2, 387, 387])
tensor([[[[ 0.0182,  0.0133,  0.0200,  ...,  0.0208,  0.0231,  0.0165],
          [ 0.0180,  0.0178,  0.0178,  ...,  0.0162,  0.0201,  0.0189],
          [ 0.0195,  0.0185,  0.0203,  ...,  0.0224,  0.0198,  0.0206],
          ...,
          [ 0.0152,  0.0166,  0.0152,  ...,  0.0222,  0.0175,  0.0189],
          [ 0.0175,  0.0211,  0.0163,  ...,  0.0192,  0.0117,  0.0212],
          [ 0.0172,  0.0185,  0.0171,  ...,  0.0192,  0.0194,  0.0252]],

         [[ 0.0031,  0.0057,  0.0121,  ...,  0.0045,  0.0058,  0.0102],
          [-0.0013,  0.0045,  0.0005,  ...,  0.0080,  0.0023,  0.0048],
          [ 0.0046,  0.0024,  0.0039,  ...,  0.0066,  0.0030,  0.0053],
          ...,
          [-0.0004,  0.0059,  0.0053,  ...,  0.0053,  0.0087,  0.0048],
          [-0.0005,  0.0014,  0.0031,  ...,  0.0029,  0.0026,  0.0019],
          [ 0.0022,  0.0061,  0.0036,  ...,  0.0057,  0.0086,  0.0024]]]],
       grad_fn=<ConvolutionBackward0>)
