In [12]:
import torch
import torch.nn as nn
from torchsummary import summary

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [16]:
device

device(type='cuda')

In [92]:
def crop_tensor(input,target):
    diff =input.size()[2] - target.size()[2] # difference in width
    diff = diff // 2
    print("Difffffffffffffff: ",diff,input.size()[2],target.size()[2])
    return(input[:,diff:input.size()[2]-diff,diff:input.size()[2]-diff])

In [99]:
def conv_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=3),
        nn.ReLU()
    )

In [107]:
class UNet(nn.Module):
    def __init__(self,IN_CHANNELS=3):
        super().__init__() 
        self.encode_conv1 = conv_block(IN_CHANNELS, 64)
        self.encode_conv2 = conv_block(64, 128)
        self.encode_conv3 = conv_block(128, 256)
        self.encode_conv4 = conv_block(256, 512)
        self.encode_conv5 = conv_block(512, 1024)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.decode_conv1 = nn.ConvTranspose2d(in_channels=1024,out_channels=512,kernel_size=2,stride=2)
        #self.decode_conv1 = nn.ConvTranspose2d(in_channels=1024,out_channels=512,kernel_size=2,stride=2)

    def forward(self, x):
        # Encoder
        x1 = self.encode_conv1(x)
        x2 = self.maxpool(x1)
        x3 = self.encode_conv2(x2)
        x4 = self.maxpool(x3)
        x5 = self.encode_conv3(x4)
        x6 = self.maxpool(x5)
        x7 = self.encode_conv4(x6)
        x8 = self.maxpool(x7)
        x9 = self.encode_conv5(x8)

        
        # Decoder x1,x3,x5,x7 will be used as input 
        x10 = self.decode_conv1(x9)
        x7_cropped = crop_tensor(x7,x10)
        print(" x1 : ",x1.size())
        print("\nx2 : ",x2.size())
        print("\nx3 : ",x3.size())
        print("\nx4 : ",x4.size())
        print("\nx5 : ",x5.size())
        print("\nx6 : ",x6.size())
        print("\nx7 : ",x7.size())
        print("\nx8 : ",x8.size())
        print("\nx9 : ",x9.size())
        print("\nx10 : ",x10.size())
        print("\nx7_cropped : ",x7_cropped.size())

        



        # self.conv2 = nn.Conv2d(20, 20, 5)
        # self.max_pool = nn.MaxPool(kerner_size=2, stride=2)

In [111]:
model = UNet(3).to(device)

In [112]:
image = torch.rand((3,300,300)).to(device)
print(model(image))

Difffffffffffffff:  4 30 22
 x1 :  torch.Size([64, 296, 296])

x2 :  torch.Size([64, 148, 148])

x3 :  torch.Size([128, 144, 144])

x4 :  torch.Size([128, 72, 72])

x5 :  torch.Size([256, 68, 68])

x6 :  torch.Size([256, 34, 34])

x7 :  torch.Size([512, 30, 30])

x8 :  torch.Size([512, 15, 15])

x9 :  torch.Size([1024, 11, 11])

x10 :  torch.Size([512, 22, 22])

x11 :  torch.Size([512, 22, 22])
None


In [46]:
# batchsize, channels,height,width
summary(model, input_size = (3,300, 300),batch_size=200)

x10.size:  torch.Size([2, 512, 36, 36])
x7.size:  torch.Size([2, 512, 37, 37])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1        [200, 64, 300, 300]           1,792
              ReLU-2        [200, 64, 300, 300]               0
            Conv2d-3        [200, 64, 300, 300]          36,928
              ReLU-4        [200, 64, 300, 300]               0
         MaxPool2d-5        [200, 64, 150, 150]               0
            Conv2d-6       [200, 128, 150, 150]          73,856
              ReLU-7       [200, 128, 150, 150]               0
            Conv2d-8       [200, 128, 150, 150]         147,584
              ReLU-9       [200, 128, 150, 150]               0
        MaxPool2d-10         [200, 128, 75, 75]               0
           Conv2d-11         [200, 256, 75, 75]         295,168
             ReLU-12         [200, 256, 75, 75]               0
           Conv2d-13    