In [22]:
from torch import nn,cat
from torchsummary import summary
class UNet(nn.Module):
    def __init__(self,num_classes):
        super(UNet,self).__init__()
        self.num_classes=num_classes
        self.c11=self.conv_block(3,64)
        self.c21=self.conv_block(64,128)
        self.c31=self.conv_block(128,256)
        self.c41=self.conv_block(256,512)
        self.m=self.conv_block(512,1024)
        self.e11=self.convT(1024,512)
        self.e12=self.conv_block(1024,512)
        self.e21=self.convT(512,256)
        self.e22=self.conv_block(512,256)
        self.e31=self.convT(256,128)
        self.e32=self.conv_block(256,128)
        self.e41=self.convT(128,64)
        self.e42=self.conv_block(128,64)
        self.o=nn.Conv2d(64,num_classes,3,1,1)
                
        
    def conv_block(self,in_channels,out_channels):
        block= nn.Sequential(
            nn.Conv2d(in_channels,out_channels,3,1,1),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(out_channels,out_channels,3,1,1),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels),
        )
        return block
    
    def convT(self,in_channels,out_channels):
        return nn.ConvTranspose2d(in_channels,out_channels,3,2,1,1)
    
    def pool(self):
        return nn.MaxPool2d(2,2)
    
    def forward(self,X):
        c1=self.c11(X)
        c2=self.c21(self.pool()(c1))
        c3=self.c31(self.pool()(c2))
        c4=self.c41(self.pool()(c3))
        
        middle=self.m(self.pool()(c4))

        e1=self.e11(middle)
        e1=self.e12(cat((e1,c4),1))
        
        e2=self.e21(e1)
        e2=self.e22(cat((e2,c3),1))
        
        e3=self.e31(e2)
        e3=self.e32(cat((e3,c2),1))
        
        e4=self.e41(e3)
        e4=self.e42(cat((e4,c1),1))
        
        
        out=self.o(e4)
        
        return out
        
        

In [23]:
model=UNet(10)

In [27]:
summary(model,[3,256,256])

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 64, 256, 256]        --
|    └─Conv2d: 2-1                       [-1, 64, 256, 256]        1,792
|    └─ReLU: 2-2                         [-1, 64, 256, 256]        --
|    └─BatchNorm2d: 2-3                  [-1, 64, 256, 256]        128
|    └─Conv2d: 2-4                       [-1, 64, 256, 256]        36,928
|    └─ReLU: 2-5                         [-1, 64, 256, 256]        --
|    └─BatchNorm2d: 2-6                  [-1, 64, 256, 256]        128
├─Sequential: 1-2                        [-1, 128, 128, 128]       --
|    └─Conv2d: 2-7                       [-1, 128, 128, 128]       73,856
|    └─ReLU: 2-8                         [-1, 128, 128, 128]       --
|    └─BatchNorm2d: 2-9                  [-1, 128, 128, 128]       256
|    └─Conv2d: 2-10                      [-1, 128, 128, 128]       147,584
|    └─ReLU: 2-11                        [-1, 128, 128, 128]      

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 64, 256, 256]        --
|    └─Conv2d: 2-1                       [-1, 64, 256, 256]        1,792
|    └─ReLU: 2-2                         [-1, 64, 256, 256]        --
|    └─BatchNorm2d: 2-3                  [-1, 64, 256, 256]        128
|    └─Conv2d: 2-4                       [-1, 64, 256, 256]        36,928
|    └─ReLU: 2-5                         [-1, 64, 256, 256]        --
|    └─BatchNorm2d: 2-6                  [-1, 64, 256, 256]        128
├─Sequential: 1-2                        [-1, 128, 128, 128]       --
|    └─Conv2d: 2-7                       [-1, 128, 128, 128]       73,856
|    └─ReLU: 2-8                         [-1, 128, 128, 128]       --
|    └─BatchNorm2d: 2-9                  [-1, 128, 128, 128]       256
|    └─Conv2d: 2-10                      [-1, 128, 128, 128]       147,584
|    └─ReLU: 2-11                        [-1, 128, 128, 128]      