In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

`num_classes` (int): number of classes to segment

`in_features` (int): number of input features in the first convolution

`drop_rate` (float): dropout rate of the last two encoders

`filter` (list of 5 ints): number of output features at each level

In [2]:
class Unet(nn.Module):
  def __init__(self,num_classes,in_features=1,drop_rate=0.5,filters=(64,128,256,512,1024)):
    super(Unet,self).__init__()

    self.encoder1= Encoder(in_features, filters[0])
    self.encoder2= Encoder(filters[0], filters[1])
    self.encoder3= Encoder(filters[1], filters[2])
    self.encoder4= Encoder(filters[2], filters[3], drop_rate)
    self.encoder5= Encoder(filters[3], filters[4], drop_rate)

    self.decoder1= Decoder(filters[4], filters[3])
    self.decoder2= Decoder(filters[3],filters[2])
    self.decoder3= Decoder(filters[2], filters[1])
    self.decoder4= Decoder(filters[1],filters[0])

    # final classifier

    self.classifier= nn.Conv2d(filters[0], num_classes,1)

    for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal(m.weight)

  def forward(self,x):
    encoder_1= self.encoder1(x)
    encoder_2= self.encoder2(F.max_pool2d(encoder_1,2))
    encoder_3= self.encoder3(F.max_pool2d(encoder_2,2))
    encoder_4= self.encoder4(F.max_pool2d(encoder_3,2))
    encoder_5= self.encoder5(F.max_pool2d(encoder_4,2))

    f_decoder= self.decoder1(encoder_5,encoder_4)
    f_decoder= self.decoder2(f_decoder,encoder_3)
    f_decoder= self.decoder3(f_decoder,encoder_2)
    f_decoder= self.decoder4(f_decoder,encoder_1)

    return self.calssifier(f_decoder)



Encoder layer encodes the features along the contracting path (left side),drop_rate parameter is used with respect to the paper

`e_in_feature` (int): number of input features

`e_out_feature` (int): number of output features

`drop_rate` (float): dropout rate at the end of the block


In [3]:
class Encoder(nn.Module):
  def __init__(self,e_in_feature,e_out_feature,drop_rate=0):
    super(Encoder,self).__init__()

    layers=[nn.Conv2d(e_in_feature,e_out_feature,3),
            nn.Relu(inplace=True),
            nn.Conv2d(e_out_feature,e_out_feature,3),
            nn.Relu(inplace=True)]

    if drop_rate>0:
      layers += [nn.Dropout(drop_rate)]

    self.features= nn.Sequential(*layers)

  def forward(self,x):
    return self.features(x)

Decoder layer decodes the features by performing deconvolutions and concatenating the resulting features with cropped features from the corresponding encoder (skip-connections).

`d_in_feature` (int): number of input features

`d_out_feature` (int): number of output features

In [4]:
class Decoder(nn.Module):
  def __init__(self,d_in_feature,d_out_feature):
    super(Decoder,self).__init__()

    self.encoderr= Encoder(d_in_feature,d_out_feature)
    self.decoderr= nn.ConvTranspose2d(d_in_feature,d_out_feature,2,2)

  def forward(self,x,f_encoder):
    f_decoder= F.relu(self.decoderr(x),True)

    crop_size = f_decoder.size(-1)
    offset = (f_encoder.size(-1) - crop_size) // 2
    crop = f_encoder[:, :, offset:offset + crop_size,
                            offset:offset + crop_size]
    return self.encoder(torch.cat([f_decoder, crop], 1))