In [1]:
import torch
import torch.nn as nn

In [2]:
def conv_2(inc,outc):
    conv = nn.Sequential(
        nn.Conv2d(inc,outc,kernel_size=3),
        nn.ReLU(inplace=True),
        nn.Conv2d(outc,outc,kernel_size=3),
        nn.ReLU(inplace=True)
    )
    return conv

In [3]:
def crop(original, new):
    new_size = new.size()[2]
    original_size = original.size()[2]
    diff = original_size - new_size
    diff = diff // 2
    return original[:,:,diff:original_size-diff,diff:original_size-diff]

In [4]:
class U_net(nn.Module):
    def __init__(self):
        super(U_net,self).__init__()
        self.maxpool2x2 = nn.MaxPool2d(kernel_size=2,stride=2)
        self.conv1down = conv_2(1,64)
        self.conv2down = conv_2(64,128)
        self.conv3down = conv_2(128,256)
        self.conv4down = conv_2(256,512)
        self.conv5down = conv_2(512,1024)
        
        self.conv_Trans1 = nn.ConvTranspose2d(in_channels=1024,out_channels=512,kernel_size=2,stride=2)
        self.conv_up1 = conv_2(1024,512)
        self.conv_Trans2 = nn.ConvTranspose2d(in_channels=512,out_channels=256,kernel_size=2,stride=2)
        self.conv_up2 = conv_2(512,256)
        self.conv_Trans3 = nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=2,stride=2)
        self.conv_up3 = conv_2(256,128)
        self.conv_Trans4 = nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=2,stride=2)
        self.conv_up4 = conv_2(128,64)

        self.output = nn.Conv2d(64,2,kernel_size=1)
    def forward(self,img):
        #encode
        # b c h w
        e1 = self.conv1down(img)
        e2 = self.maxpool2x2(e1)
        print(e1.size())
        e3 = self.conv2down(e2)
        e4 = self.maxpool2x2(e3)
        print(e3.size())
        
        e5 = self.conv3down(e4)
        e6 = self.maxpool2x2(e5)
        print(e5.size())
        
        e7 = self.conv4down(e6)
        e8 = self.maxpool2x2(e7)
        print(e7.size())
    
        e9 = self.conv5down(e8)
        print(e9.size())

        #decode
        x = self.conv_Trans1(e9)
        y =crop(e7,x)
        x = self.conv_up1(torch.cat([x,y],1))

        x = self.conv_Trans2(x)
        y =crop(e5,x)
        x = self.conv_up2(torch.cat([x,y],1))

        x = self.conv_Trans3(x)
        y =crop(e3,x)
        x = self.conv_up3(torch.cat([x,y],1))

        x = self.conv_Trans4(x)
        y =crop(e1,x)
        x = self.conv_up4(torch.cat([x,y],1))
        print(x.size())
        
        y = self.output(x) 
        print(y.size())
        return y




In [6]:
if __name__ == "__main__":
    img = torch.rand((1,1,572,572))
    modl = U_net()
    print(modl(img))

torch.Size([1, 64, 568, 568])
torch.Size([1, 128, 280, 280])
torch.Size([1, 256, 136, 136])
torch.Size([1, 512, 64, 64])
torch.Size([1, 1024, 28, 28])
torch.Size([1, 64, 388, 388])
torch.Size([1, 2, 388, 388])
tensor([[[[ 0.1126,  0.1124,  0.1087,  ...,  0.1092,  0.1077,  0.1113],
          [ 0.1076,  0.1103,  0.1110,  ...,  0.1073,  0.1127,  0.1102],
          [ 0.1156,  0.1064,  0.1093,  ...,  0.1148,  0.1146,  0.1050],
          ...,
          [ 0.1100,  0.1088,  0.1131,  ...,  0.1147,  0.1125,  0.1126],
          [ 0.1063,  0.1117,  0.0994,  ...,  0.1057,  0.1088,  0.1093],
          [ 0.1103,  0.1103,  0.1166,  ...,  0.1118,  0.1171,  0.1114]],

         [[ 0.0006, -0.0011, -0.0045,  ..., -0.0039, -0.0062, -0.0061],
          [-0.0040, -0.0042,  0.0014,  ..., -0.0035, -0.0040,  0.0004],
          [ 0.0046, -0.0112, -0.0010,  ..., -0.0010,  0.0002, -0.0094],
          ...,
          [-0.0057, -0.0037,  0.0011,  ...,  0.0005, -0.0070, -0.0054],
          [ 0.0007, -0.0019, -0.0099, 