In [2]:
import torch
import torch.nn as nn

class DoubleConv(nn.Module):
  def __init__(self, input_channel, output_channel):
    super(DoubleConv, self).__init__()
    self.input_channel = input_channel
    self.output_channel = output_channel
    self.db_conv = nn.Sequential(
      nn.Conv2d(self.input_channel, self.output_channel, 3, 1, 0),
      nn.ReLU(),
      nn.Conv2d(self.output_channel, self.output_channel, 3, 1, 0),
      nn.ReLU(),
    )
    
  def forward(self, x):
    return self.db_conv(x)

In [3]:
class UNet(nn.Module):
    def __init__(self, 
                input_channel=3, 
                output_channel=2, 
                hidden_channels=[64, 128, 256, 512]):
        super(UNet, self).__init__()
        self.down_sample = nn.Sequential()
        self.middle = None
        self.up_sample = nn.Sequential()
        self.final_layer = None
        self.max_pooling = nn.MaxPool2d(2, 2)
    
        # Define down_sample
        in_c = input_channel
        for channel in hidden_channels:
            self.down_sample.append(DoubleConv(in_c, channel))
            in_c = channel
        
        # Define middle
        self.middle = DoubleConv(hidden_channels[-1], hidden_channels[-1] * 2)
    
        # Define up_sample
        in_c = hidden_channels[-1] * 2
        for channel in reversed(hidden_channels):
            self.up_sample.append(nn.ConvTranspose2d(in_c, in_c // 2, 2, 2))
            self.up_sample.append(DoubleConv(in_c, channel))
            in_c = channel
    
        # Define final layer
        self.final_layer = nn.Conv2d(in_c, output_channel, 1, 1, 0)
    
    def crop_center_fm(self, feature_map, out_shape):
        w, h = feature_map.shape[2], feature_map.shape[3]
        diff_w, diff_h = (w - out_shape[2]) // 2, (h - out_shape[3]) // 2
        return feature_map[:, :, diff_w:(w-diff_w), diff_h:(h - diff_h)]
    
    
    def forward(self, x):
        skip_connections = []
        print("Down sample shape")
        for down in self.down_sample:
            x = down(x)
            skip_connections = [x] + skip_connections
            x = self.max_pooling(x)
      
        x = self.middle(x)
        print("Mid sample shape: ", x.shape)
    
        print("Up sample shape")
        for i in range(0, len(self.up_sample), 2):
            upper = self.up_sample[i]
            conv = self.up_sample[i + 1]
            x = upper(x)
            cropped = self.crop_center_fm(skip_connections[i // 2], x.shape)
            skip_connection = torch.concat((cropped, x), dim=1) 
            x = conv(skip_connection)
            print(x.shape)
        x = self.final_layer(x)

        return x
  