In [23]:
import torch.nn as nn
import torch
import tensorflow as tf 
from tensorflow.keras.layers import Input,Conv3D,MaxPooling3D,UpSampling3D,concatenate,Conv3DTranspose,BatchNormalization


2023-02-22 21:59:51.894129: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1


In [183]:
class conv_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size = 3, padding = 1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size = 3, padding = 1)
        self.bn2 = nn.BatchNorm3d(out_channels)
        
        self.relu = nn.ReLU()
        
        
    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        
        # print(x.shape)
        
        return x
        
        
        
class encoder_block(nn.Module):
    
    def __init__(self,in_channels, out_channels):
        super().__init__()
        
        self.conv = conv_block( in_channels, out_channels)
        self.pool = nn.MaxPool3d((2,2,2))
        
    def forward(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)
        
        return x,p
        
        
        
class decoder_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.up = nn.ConvTranspose3d(in_channels, out_channels, kernel_size = 2, stride = 2,  padding = 0)
        self.conv = conv_block( out_channels + out_channels , out_channels)
        
    
    def forward(self, inputs, skip_connection):
        x = self.up(inputs)
        x= torch.cat([x, skip_connection], axis = 1)
        x = self.conv(x)
        
        return x
        
        
        
        
        
        
class U_Net(nn.Module):
    def __init__(self):
        
        super().__init__()
        
        #Encoder
        self.e1 = encoder_block(1,4)
        self.e2 = encoder_block(4,8)
        self.e3 = encoder_block(8,16)
        self.e4 = encoder_block(16,32)
        self.e5 = encoder_block(32,64)
        
        #Bottleneck
        self.b = conv_block(64,128)   
        
        #Decoder
        self.d1 = decoder_block(128,64)
        self.d2 = decoder_block(64,32)
        self.d3 = decoder_block(32,16)
        self.d4 = decoder_block(16,8)
        self.d5 = decoder_block(8,4)
        
        #Classifier
        
        self.outputs = nn.Conv3d(4,2, kernel_size= 1 ,padding = 0)
        
        
    def forward(self, inputs):
        
        s1, p1 = self.e1(inputs)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)
        s5, p5 = self.e5(p4)
        
        
        print(inputs.shape)
        print(s1.shape, p1.shape)
        print(s2.shape, p2.shape)
        print(s3.shape, p3.shape)
        print(s4.shape, p4.shape)
        print(s5.shape, p5.shape)
        
        b = self.b(p5)
        
        print("b1.shape",b.shape)
        
        d1 = self.d1(b,s5)
        
        print("d1.shape",d1.shape)
        
        d2 = self.d1(d1,s4)
        
        print("d2.shape", d2.shape)
        d3 = self.d1(d2,s3)
        d4 = self.d1(d3,s2)
        d5 = self.d1(d4,s1)
        
        outputs = self.outputs(d5)
        
        return outputs
        

In [193]:
import torch
import torch.nn as nn

class conv_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_c)

        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_c)

        self.relu = nn.ReLU()

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        return x

class encoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv = conv_block(in_c, out_c)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)

        return x, p

class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        self.conv = conv_block(out_c+out_c, out_c)

    def forward(self, inputs, skip):
        x = self.up(inputs)
        x = torch.cat([x, skip], axis=1)
        x = self.conv(x)
        return x

class build_unet(nn.Module):
    def __init__(self):
        super().__init__()

        """ Encoder """
        self.e1 = encoder_block(1, 64)
        self.e2 = encoder_block(64, 128)
        self.e3 = encoder_block(128, 256)
        self.e4 = encoder_block(256, 512)

        """ Bottleneck """
        self.b = conv_block(512, 1024)

        """ Decoder """
        self.d1 = decoder_block(1024, 512)
        self.d2 = decoder_block(512, 256)
        self.d3 = decoder_block(256, 128)
        self.d4 = decoder_block(128, 64)

        """ Classifier """
        self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)

    def forward(self, inputs):
        """ Encoder """
        s1, p1 = self.e1(inputs)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)

        """ Bottleneck """
        b = self.b(p4)

        """ Decoder """
        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)

        outputs = self.outputs(d4)

        return outputs



In [194]:
x = torch.randn((2,1,512,512))

f = build_unet()

y = f(x)

print(y.shape)



torch.Size([2, 1, 512, 512])


