In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [3]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = (DoubleConv(n_channels, 64))
        self.down1 = (Down(64, 128))
        self.down2 = (Down(128, 256))
        self.down3 = (Down(256, 512))
        factor = 2 if bilinear else 1
        self.down4 = (Down(512, 1024 // factor))
        self.up1 = (Up(1024, 512 // factor, bilinear))
        self.up2 = (Up(512, 256 // factor, bilinear))
        self.up3 = (Up(256, 128 // factor, bilinear))
        self.up4 = (Up(128, 64, bilinear))
        self.outc = (OutConv(64, n_classes))

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

In [3]:
class Block(nn.Module):
    """(convolution => [BN] => ReLU) * 3"""
    # Note: CSN Bottleneck doesn't use ReLU for the last conv.

    def __init__(self, in_channels, out_channels, mid_channels=None, conv2_stride=1):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.triple_conv = nn.Sequential(
            nn.Conv3d(in_channels, mid_channels, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm3d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(mid_channels, mid_channels, kernel_size=3, stride=conv2_stride, padding=1,
                      bias=False, groups=mid_channels),
            nn.BatchNorm3d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(mid_channels, out_channels, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.triple_conv(x)

In [41]:
class Layer(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels, n_blocks, conv2_stride=2):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        
        modules = []
        modules.append(Block(in_channels, in_channels, mid_channels))
        
        for i in range(n_blocks-2):
            modules.append(Block(in_channels, in_channels, mid_channels))
        
        modules.append(Block(in_channels, out_channels, mid_channels, conv2_stride=conv2_stride))
        
        self.block = nn.Sequential(*modules)
        
    def forward(self, x):
        return self.block(x)

In [42]:
class Up(nn.Module):
    """Upscaling then triple conv"""

    def __init__(self, in_channels, out_channels, mid_channels, n_blocks, conv2_stride=2):
        super().__init__()
        self.up = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = Layer(in_channels, out_channels, mid_channels, n_blocks, conv2_stride)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

#         print(x2.shape)
#         print(x1.shape)
        
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

In [38]:
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=(3,7,7), stride=(1,2,2), padding=(1,3,3), bias=False)

    def forward(self, x):
        return self.conv(x)

In [28]:
# x1 = torch.rand((1,64,32,112,112))
# x2 = torch.rand((1,256,32,56,56))
# x3 = torch.rand((1,512,16,28,28))
x4 = torch.rand((1,1024,8,14,14))
x5 = torch.rand((1,2048,4,7,7))

In [43]:
up1 = Up(in_channels=2048, out_channels=1024, mid_channels=512, n_blocks=3)
# up2 = Up(1024, 512, 6) 
# up3 = Up(512, 256, 4)
# up4 = Up(256, 128, 3) 
# outc = OutConv(128, 3)

In [46]:
up1

Up(
  (up): ConvTranspose3d(2048, 1024, kernel_size=(2, 2, 2), stride=(2, 2, 2))
  (conv): Layer(
    (block): Sequential(
      (0): Block(
        (triple_conv): Sequential(
          (0): Conv3d(2048, 512, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
          (1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), groups=512, bias=False)
          (4): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace=True)
          (6): Conv3d(512, 2048, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
          (7): BatchNorm3d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (8): ReLU(inplace=True)
        )
      )
      (1): Block(
        (triple_conv): Sequential(
          (0): Conv3d(2048, 512, kernel_size=(1, 1, 1), stride=(1, 1, 1), 

In [45]:
x = up1(x5,x4)
# x = up2(x, x3)
# x = up3(x, x2)
# x = up4(x, x1)
# x = outc(x)
x.shape

torch.Size([1, 1024, 4, 7, 7])

In [148]:
x1, x2 = x, x1

In [149]:
x1.shape, x2.shape

(torch.Size([1, 256, 32, 56, 56]), torch.Size([1, 64, 32, 112, 112]))

In [154]:
up = nn.ConvTranspose3d(256, 256 // 2, kernel_size=2, stride=4)

In [155]:
x1_upped = up(x1)
x1_upped.shape

torch.Size([1, 128, 126, 222, 222])

In [None]:
class resAutoencoder(nn.Module):
    def __init__(self):
        super(resAutoencoder, self).__init__()
        self.inputchannel = 3
        self.block1channel  = 256
        self.block2channel = 512
        self.block3channel = 1024
        self.block4channel = 2048
        
        
        # 128 x 128 x 3
        self.conv = nn.Sequential(
                        nn.Conv2d(self.inputchannel, self.block1channel, kernel_size = 7, stride = 2, padding = 3),
                        nn.BatchNorm2d(self.block1channel),
                        nn.ReLU())
        # 64 x 64 x 64
        self.encoder1 = nn.Sequential(
            ResidualBlock(self.block1channel, self.block1channel, 1),
            ResidualBlock(self.block1channel, self.block1channel, 1),
            ResidualBlock(self.block1channel, self.block1channel, 1))
    
            # 64 x 64 x 128
        self.encoder2 = nn.Sequential(
            ResidualBlock(self.block1channel, self.block2channel, 2),
            ResidualBlock(self.block2channel, self.block2channel, 1),
            ResidualBlock(self.block2channel, self.block2channel, 1),
            ResidualBlock(self.block2channel, self.block2channel, 1))
            # 32 x 32 x 256
        self.encoder3 = nn.Sequential(    
            ResidualBlock(self.block2channel, self.block3channel, 2),
            ResidualBlock(self.block3channel, self.block3channel, 1),
            ResidualBlock(self.block3channel, self.block3channel, 1),
            ResidualBlock(self.block3channel, self.block3channel, 1),
            ResidualBlock(self.block3channel, self.block3channel, 1),
            ResidualBlock(self.block3channel, self.block3channel, 1))
            # 16 x 16 x 512
        self.encoder4 =nn.Sequential(
            ResidualBlock(self.block3channel, self.block4channel, 2),
            ResidualBlock(self.block4channel, self.block4channel, 1),
            ResidualBlock(self.block4channel, self.block4channel, 1))
            # 8 x 8 x 512
            
        
        self.decoder1 =nn.Sequential(
            # 8 x 8 x 512
            ResidualBlock(self.block4channel, self.block4channel, 1),
            ResidualBlock(self.block4channel, self.block4channel, 1),
            upsampleBlock(self.block4channel, self.block3channel))
            # 16 x 16 x 512
        self.decoder2 =nn.Sequential(
            ResidualBlock(self.block3channel*2, self.block3channel, 1),
            ResidualBlock(self.block3channel, self.block3channel, 1),
            ResidualBlock(self.block3channel, self.block3channel, 1),
            ResidualBlock(self.block3channel, self.block3channel, 1),
            ResidualBlock(self.block3channel, self.block3channel, 1),
            upsampleBlock(self.block3channel, self.block2channel))
        self.decoder3 =nn.Sequential(
            # 32 x 32 x 256        
            ResidualBlock(self.block2channel*2, self.block2channel, 1),
            ResidualBlock(self.block2channel, self.block2channel, 1),
            ResidualBlock(self.block2channel, self.block2channel, 1),
            upsampleBlock(self.block2channel, self.block1channel))
            # 64 x 64 x 128
        self.decoder4 =nn.Sequential(
            ResidualBlock(self.block1channel*2, self.block1channel, 1),
            ResidualBlock(self.block1channel, self.block1channel, 1),
            ResidualBlock(self.block1channel, self.block1channel, 1))
            # 128 x 128 x 3
        self.convout = nn.Sequential(
            nn.ConvTranspose2d(self.block1channel*2, self.inputchannel, 4, 2, 1, bias=False),
            nn.Tanh())
​
        
    def forward(self, x):
        x1 = self.conv(x)
        x2 = self.encoder1(x1)
        x3 = self.encoder2(x2)
        x4 = self.encoder3(x3)
        x5 = self.encoder4(x4)
        
        x = self.decoder1(x5)
        x = torch.cat((x, x4), dim=1)
        x = self.decoder2(x)
        x = torch.cat((x, x3), dim=1)
        x = self.decoder3(x)
        x = torch.cat((x, x2), dim=1)
        x = self.decoder4(x)
        x = torch.cat((x, x1), dim=1)
        x = self.convout(x)