<a href="https://colab.research.google.com/github/WilliamAshbee/corticalmeshgen/blob/main/another3dunet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#https://github.com/davidiommi/Pytorch-Unet3D-single_channel/blob/master/UNet.py

In [2]:
!pip install torchsummaryX



In [6]:
a = None
model = None

In [3]:
from torch.nn import Module, Sequential
from torch.nn import Conv3d, ConvTranspose3d, BatchNorm3d, MaxPool3d, AvgPool1d, Dropout3d
from torch.nn import ReLU, Sigmoid
import torch


class UNet(Module):
    # __                            __
    #  1|__   ________________   __|1
    #     2|__  ____________  __|2
    #        3|__  ______  __|3
    #           4|__ __ __|4
    # The convolution operations on either side are residual subject to 1*1 Convolution for channel homogeneity

    def __init__(self, num_channels=1, feat_channels=[64, 256, 256, 512, 1024], residual='conv'):
        # residual: conv for residual input x through 1*1 conv across every layer for downsampling, None for removal of residuals

        super(UNet, self).__init__()

        # Encoder downsamplers
        self.pool1 = MaxPool3d((2, 2, 2))
        self.pool2 = MaxPool3d((2, 2, 2))
        self.pool3 = MaxPool3d((2, 2, 2))
        self.pool4 = MaxPool3d((2, 2, 2))

        # Encoder convolutions
        self.conv_blk1 = Conv3D_Block(num_channels, feat_channels[0], residual=residual)
        self.conv_blk2 = Conv3D_Block(feat_channels[0], feat_channels[1], residual=residual)
        self.conv_blk3 = Conv3D_Block(feat_channels[1], feat_channels[2], residual=residual)
        self.conv_blk4 = Conv3D_Block(feat_channels[2], feat_channels[3], residual=residual)
        self.conv_blk5 = Conv3D_Block(feat_channels[3], feat_channels[4], residual=residual)

        # Decoder convolutions
        self.dec_conv_blk4 = Conv3D_Block(2 * feat_channels[3], feat_channels[3], residual=residual)
        self.dec_conv_blk3 = Conv3D_Block(2 * feat_channels[2], feat_channels[2], residual=residual)
        self.dec_conv_blk2 = Conv3D_Block(2 * feat_channels[1], feat_channels[1], residual=residual)
        self.dec_conv_blk1 = Conv3D_Block(2 * feat_channels[0], feat_channels[0], residual=residual)

        # Decoder upsamplers
        self.deconv_blk4 = Deconv3D_Block(feat_channels[4], feat_channels[3])
        self.deconv_blk3 = Deconv3D_Block(feat_channels[3], feat_channels[2])
        self.deconv_blk2 = Deconv3D_Block(feat_channels[2], feat_channels[1])
        self.deconv_blk1 = Deconv3D_Block(feat_channels[1], feat_channels[0])

        # Final 1*1 Conv Segmentation map
        self.one_conv = Conv3d(feat_channels[0], num_channels, kernel_size=1, stride=1, padding=0, bias=True)

        # Activation function
        self.sigmoid = Sigmoid()

    def forward(self, x):
        # Encoder part
        print('x',x.shape)
        x1 = self.conv_blk1(x)


        print('x1',x1.shape)
        x_low1 = self.pool1(x1)
        x2 = self.conv_blk2(x_low1)
        
        print('x2',x2.shape)
        
        x_low2 = self.pool2(x2)
        x3 = self.conv_blk3(x_low2)

        print('x3',x3.shape)
        x_low3 = self.pool3(x3)
        x4 = self.conv_blk4(x_low3)

        print('x4',x4.shape)
        x_low4 = self.pool4(x4)
        base = self.conv_blk5(x_low4)

        # Decoder part
        
        d4 = torch.cat([self.deconv_blk4(base), x4], dim=1)
        d_high4 = self.dec_conv_blk4(d4)
        
        print('d_high4',d_high4.shape)
        d3 = torch.cat([self.deconv_blk3(d_high4), x3], dim=1)
        d_high3 = self.dec_conv_blk3(d3)
        d_high3 = Dropout3d(p=0.5)(d_high3)

        print('d_high3',d_high3.shape)
        
        d2 = torch.cat([self.deconv_blk2(d_high3), x2], dim=1)
        d_high2 = self.dec_conv_blk2(d2)
        d_high2 = Dropout3d(p=0.5)(d_high2)

        print('d_high2',d_high2.shape)
        
        d1 = torch.cat([self.deconv_blk1(d_high2), x1], dim=1)
        d_high1 = self.dec_conv_blk1(d1)

        print('d_high1',d_high1.shape)
        
        seg = self.sigmoid(self.one_conv(d_high1))
        print('seg',seg.shape)
        
        return seg


class Conv3D_Block(Module):

    def __init__(self, inp_feat, out_feat, kernel=3, stride=1, padding=1, residual=None):

        super(Conv3D_Block, self).__init__()

        self.conv1 = Sequential(
            Conv3d(inp_feat, out_feat, kernel_size=kernel,
                   stride=stride, padding=padding, bias=True),
            BatchNorm3d(out_feat),
            ReLU())

        self.conv2 = Sequential(
            Conv3d(out_feat, out_feat, kernel_size=kernel,
                   stride=stride, padding=padding, bias=True),
            BatchNorm3d(out_feat),
            ReLU())

        self.residual = residual

        if self.residual is not None:
            self.residual_upsampler = Conv3d(inp_feat, out_feat, kernel_size=1, bias=False)

    def forward(self, x):

        res = x

        if not self.residual:
            return self.conv2(self.conv1(x))
        else:
            return self.conv2(self.conv1(x)) + self.residual_upsampler(res)


class Deconv3D_Block(Module):

    def __init__(self, inp_feat, out_feat, kernel=3, stride=2, padding=1):
        super(Deconv3D_Block, self).__init__()

        self.deconv = Sequential(
            ConvTranspose3d(inp_feat, out_feat, kernel_size=(kernel, kernel, kernel),
                            stride=(stride, stride, stride), padding=(padding, padding, padding), output_padding=1, bias=True),
            ReLU())

    def forward(self, x):
        return self.deconv(x)


class ChannelPool3d(AvgPool1d):

    def __init__(self, kernel_size, stride, padding):
        super(ChannelPool3d, self).__init__(kernel_size, stride, padding)
        self.pool_1d = AvgPool1d(self.kernel_size, self.stride, self.padding, self.ceil_mode)

    def forward(self, inp):
        n, c, d, w, h = inp.size()
        inp = inp.view(n, c, d * w * h).permute(0, 2, 1)
        pooled = self.pool_1d(inp)
        c = int(c / self.kernel_size[0])
        return inp.view(n, c, d, w, h)


if __name__ == '__main__':
    import time
    import torch
    from torch.autograd import Variable
    from torchsummaryX import summary

    torch.cuda.set_device(0)
    net =UNet(residual='pool').cuda().eval()

    data = Variable(torch.randn(1, 1, 128, 128, 32)).cuda()

    out = net(data)

    summary(net,data)
    print("out size: {}".format(out.size()))

x torch.Size([1, 1, 128, 128, 32])
x1 torch.Size([1, 64, 128, 128, 32])
x2 torch.Size([1, 256, 64, 64, 16])
x3 torch.Size([1, 256, 32, 32, 8])
x4 torch.Size([1, 512, 16, 16, 4])
d_high4 torch.Size([1, 512, 16, 16, 4])
d_high3 torch.Size([1, 256, 32, 32, 8])
d_high2 torch.Size([1, 256, 64, 64, 16])
d_high1 torch.Size([1, 64, 128, 128, 32])
seg torch.Size([1, 1, 128, 128, 32])
x torch.Size([1, 1, 128, 128, 32])
x1 torch.Size([1, 64, 128, 128, 32])
x2 torch.Size([1, 256, 64, 64, 16])
x3 torch.Size([1, 256, 32, 32, 8])
x4 torch.Size([1, 512, 16, 16, 4])
d_high4 torch.Size([1, 512, 16, 16, 4])
d_high3 torch.Size([1, 256, 32, 32, 8])
d_high2 torch.Size([1, 256, 64, 64, 16])
d_high1 torch.Size([1, 64, 128, 128, 32])
seg torch.Size([1, 1, 128, 128, 32])
                                                     Kernel Shape  \
Layer                                                               
0_conv_blk1.conv1.Conv3d_0                       [1, 64, 3, 3, 3]   
1_conv_blk1.conv1.BatchNorm3d_1      

  df_sum = df.sum()


In [11]:
if a is not None:
  del a
if model is not None:
  del model
torch.cuda.empty_cache() #doesn't seem to work in interpreter

In [8]:
model = UNet().cuda()

In [9]:
a = torch.randn(32,1,32,32,32).cuda()

In [10]:
model(a)

x torch.Size([32, 1, 32, 32, 32])
x1 torch.Size([32, 64, 32, 32, 32])
x2 torch.Size([32, 256, 16, 16, 16])
x3 torch.Size([32, 256, 8, 8, 8])
x4 torch.Size([32, 512, 4, 4, 4])
d_high4 torch.Size([32, 512, 4, 4, 4])
d_high3 torch.Size([32, 256, 8, 8, 8])
d_high2 torch.Size([32, 256, 16, 16, 16])
d_high1 torch.Size([32, 64, 32, 32, 32])
seg torch.Size([32, 1, 32, 32, 32])


tensor([[[[[0.4755, 0.4546, 0.3778,  ..., 0.3932, 0.4387, 0.4577],
           [0.4175, 0.4698, 0.5316,  ..., 0.6231, 0.5316, 0.4928],
           [0.4419, 0.4782, 0.4976,  ..., 0.3770, 0.4959, 0.4932],
           ...,
           [0.4735, 0.4194, 0.4269,  ..., 0.4671, 0.4305, 0.4789],
           [0.3931, 0.4931, 0.5381,  ..., 0.4995, 0.4135, 0.4868],
           [0.4188, 0.4978, 0.4611,  ..., 0.4716, 0.5130, 0.4871]],

          [[0.4069, 0.4548, 0.4032,  ..., 0.4490, 0.5219, 0.4814],
           [0.3981, 0.4024, 0.5689,  ..., 0.4023, 0.4858, 0.4706],
           [0.4727, 0.5536, 0.4549,  ..., 0.4810, 0.5515, 0.4312],
           ...,
           [0.4343, 0.4596, 0.6131,  ..., 0.4779, 0.5491, 0.4269],
           [0.3957, 0.5134, 0.4575,  ..., 0.5630, 0.4941, 0.4278],
           [0.5001, 0.3849, 0.4856,  ..., 0.3006, 0.5770, 0.4633]],

          [[0.4506, 0.5048, 0.4504,  ..., 0.4382, 0.4448, 0.3808],
           [0.5202, 0.4204, 0.4707,  ..., 0.5994, 0.5227, 0.4234],
           [0.4913, 0.4208