In [89]:
from torch import nn
from torch.nn import functional as F
import torch
import torchvision


class ConvBn2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=(3,3), stride=(1,1), padding=(1,1)):
        super(ConvBn2d, self).__init__()

        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        #self.bn = SynchronizedBatchNorm2d(out_channels)


    def forward(self, z):
        x = self.conv(z)
        x = self.bn(x)
        return x





class Decoder(nn.Module):
    def __init__(self, in_channels, channels, out_channels ):
        super(Decoder, self).__init__()
        self.conv1 =  ConvBn2d(in_channels,  channels, kernel_size=3, padding=1)
        self.conv2 =  ConvBn2d(channels, out_channels, kernel_size=3, padding=1)

    def forward(self, x ):
        x = F.upsample(x, scale_factor=2, mode='bilinear')#False
        x = F.relu(self.conv1(x),inplace=True)
        x = F.relu(self.conv2(x),inplace=True)
        return x

In [90]:
resnet = torchvision.models.resnet34(pretrained=False)

In [91]:
conv1 = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
        )# 64

In [92]:
conv1

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
  (2): ReLU(inplace)
)

In [93]:
encoder2 = resnet.layer1  # 64
encoder3 = resnet.layer2  #128
encoder4 = resnet.layer3  #256
encoder5 = resnet.layer4  #512

In [94]:
encoder2

Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (relu): ReLU(inplace)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
  )
  (1): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (relu): ReLU(inplace)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
  )
  (2): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (relu): ReLU(inplace)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), pa

In [150]:
center = nn.Sequential(
            ConvBn2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            ConvBn2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

In [151]:
center

Sequential(
  (0): ConvBn2d(
    (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
  )
  (1): ReLU(inplace)
  (2): ConvBn2d(
    (conv): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
  )
  (3): ReLU(inplace)
)

In [123]:
decoder5 = Decoder(512+256, 512, 64)
decoder4 = Decoder(256+64, 256, 64)
decoder3 = Decoder(128+64, 128,  64)
decoder2 = Decoder(64+ 64, 64, 64)
decoder1 = Decoder(64    , 32,  64)

In [124]:
logit = nn.Sequential(
            nn.Conv2d(320, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64,  1, kernel_size=1, padding=0),
        )

In [169]:
from torch.autograd import Variable
x = Variable(torch.rand(1,1,256,256))

In [170]:
x

Variable containing:
( 0 , 0 ,.,.) = 
  6.8921e-01  9.6380e-01  6.0601e-01  ...   1.2621e-01  7.9491e-01  8.9094e-01
  7.0223e-01  4.0470e-01  7.0927e-01  ...   6.6238e-01  9.0533e-01  9.5152e-01
  2.6702e-01  7.9851e-01  3.2655e-01  ...   1.8036e-01  1.5876e-01  8.1081e-01
                 ...                   ⋱                   ...                
  1.5830e-01  4.9696e-01  2.8194e-01  ...   2.7953e-01  6.6735e-01  6.4035e-01
  4.1907e-02  3.4205e-01  6.6111e-01  ...   4.9074e-01  4.0352e-01  9.1868e-01
  1.9086e-01  8.4836e-01  7.4859e-01  ...   7.7303e-01  7.6596e-01  9.0447e-02
[torch.FloatTensor of size 1x1x256x256]

In [171]:
mean=[0.485, 0.456, 0.406]
std =[0.229, 0.224, 0.225]
x = torch.cat([
            (x-mean[0])/std[0],
            (x-mean[1])/std[1],
            (x-mean[2])/std[2],
        ],1)


In [172]:
x.size()

torch.Size([1, 3, 256, 256])

In [173]:
conv1

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
  (2): ReLU(inplace)
)

In [174]:
x = conv1(x)

In [175]:
x.size()

torch.Size([1, 64, 128, 128])

In [176]:
x = F.max_pool2d(x, kernel_size=2, stride=2)
x.size()

torch.Size([1, 64, 64, 64])

In [177]:
e2 = encoder2( x)  
print('e2',e2.size())

e2 torch.Size([1, 64, 64, 64])


In [178]:
e3 = encoder3(e2)  
print('e3',e3.size())

e3 torch.Size([1, 128, 32, 32])


In [179]:
e4 = encoder4(e3)  
print('e4',e4.size())

e4 torch.Size([1, 256, 16, 16])


In [180]:
e5 = encoder5(e4)  
print('e5',e5.size())

e5 torch.Size([1, 512, 8, 8])


In [181]:
c = center(e5)
print('c',c.size())

c torch.Size([1, 256, 8, 8])


In [182]:
d5 = decoder5(torch.cat([c, e5], 1))  #; print('d5',f.size())

In [183]:
d5.size()

torch.Size([1, 64, 16, 16])

In [184]:
d4 = decoder4(torch.cat([d5, e4], 1))  #; print('d4',f.size())

In [185]:
d4.size()

torch.Size([1, 64, 32, 32])

In [186]:
d3 = decoder3(torch.cat([d4, e3], 1))  #; print('d4',f.size())

In [187]:
d3.size()

torch.Size([1, 64, 64, 64])

In [188]:
d2 = decoder2(torch.cat([d3, e2], 1))  #; print('d4',f.size())

In [189]:
d2.size()

torch.Size([1, 64, 128, 128])

In [190]:
d1 = decoder1(d2)  #; print('d4',f.size())
d1.size()

torch.Size([1, 64, 256, 256])

In [191]:
f = torch.cat((
    d1,
    F.upsample(d2,scale_factor=2,mode='bilinear'),
    F.upsample(d3,scale_factor=4,mode='bilinear'),
    F.upsample(d4,scale_factor=8,mode='bilinear'),
    F.upsample(d5,scale_factor=16,mode='bilinear')),1
)

In [192]:
f.size()

torch.Size([1, 320, 256, 256])

In [193]:
f= F.dropout(f,p=0.5)

In [194]:
logit_f = logit(f)
logit_f.size()

torch.Size([1, 1, 256, 256])