In [19]:
import torch
import torch.nn as nn
import torchvision

<img src="https://amaarora.github.io/images/unet.png">

In [20]:
class Block(nn.Module):
    def __init__(self,in_channel,out_channel):
        super().__init__()
        self.conv1=nn.Conv2d(in_channel,out_channel,3)
        self.relu=nn.ReLU()
        self.conv2=nn.Conv2d(out_channel,out_channel,3)
    def forward(self,x):
        #return self.relu(self.conv2(self.relu(self.conv1(x))))
        return self.relu(self.conv2(self.relu(self.conv1(x))))

In [21]:
img=torch.rand([1,3,572,572])

In [22]:
blk=Block(3,10)
blk(img).shape


torch.Size([1, 10, 568, 568])

In [29]:
class Encoder(nn.Module):
    def __init__(self, chs=(3,64,128,256,512,1024)):
        super().__init__()
        self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        self.pool       = nn.MaxPool2d(2)
    
    def forward(self, x):
        ftrs = []
        for block in self.enc_blocks:
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
        return ftrs

In [30]:
encc=Encoder()
encoder_out=encc(img)

In [31]:
for o in encoder_out:
    print(o.shape)

torch.Size([1, 64, 568, 568])
torch.Size([1, 128, 280, 280])
torch.Size([1, 256, 136, 136])
torch.Size([1, 512, 64, 64])
torch.Size([1, 1024, 28, 28])


In [32]:
class Decoder(nn.Module):
    def __init__(self, chs=(1024, 512, 256, 128, 64)):
        super().__init__()
        self.chs= chs
        self.upsample=nn.ModuleList([nn.ConvTranspose2d(chs[n_channel],chs[n_channel+1],2,2) for n_channel in range(len(chs)-1)])
        self.dec_block=nn.ModuleList([Block(chs[i],chs[i+1]) for i in range(len(chs)-1)])
        
    def forward(self,x,encoder_features):
        for i in range(len(self.chs)-1):
            x=self.upsample[i](x)
            skip_input=self.crop(encoder_features[i],x)
            x=torch.cat([x,skip_input],dim=1)
            x=self.dec_block[i](x)
        return x
    
    def crop(self,enc_inp, x):
        _,_,H,w=x.shape
        enc_inp=torchvision.transforms.CenterCrop([H,w])(enc_inp)
        return enc_inp

In [38]:
decoder = Decoder()
x = torch.randn(1, 1024, 28, 28)
decoder(x,encoder_features=encoder_out[::-1][1:]).shape

torch.Size([1, 64, 388, 388])

In [39]:
class UNet(nn.Module):
    def __init__(self, enc_chs=(3,64,128,256,512,1024), dec_chs=(1024, 512, 256, 128, 64), num_class=1, retain_dim=False, out_sz=(572,572)):
        super().__init__()
        self.encoder=Encoder(enc_chs)
        self.decoder=Decoder(dec_chs)
        self.head=nn.Conv2d(dec_chs[-1],num_class,1)
        self.retain_dim=retain_dim
    
    def forward(self,x):
        enc_out=self.encoder(x)
        out  = self.decoder(enc_out[::-1][0], enc_out[::-1][1:])
        out  = self.head(out)
        if self.retain_dim:
            out = torch.functional.interpolate(out, enc_out)
        return out


In [41]:
unet = UNet()
x    = torch.randn(1, 3, 572, 572)
unet(x).shape


torch.Size([1, 1, 388, 388])