In [None]:
class FCN32(VGG):
    def __init__(self):
        super(FCN32, self).__init__(make_layers(cfg['vgg16']))

        self.numclass = 21

        self.relu = nn.ReLU(True)
        self.dropout = nn.Dropout2d()

        self.conv1 = nn.Conv2d(512, 4096, kernel_size=7)
        self.conv2 = nn.Conv2d(4096, 4096, kernel_size=1)

        self.classifier = nn.Conv2d(4096, self.numclass, kernel_size=1, stride=1, padding=0)
        self.upsampler = nn.ConvTranspose2d(self.numclass, self.numclass, kernel_size=64, stride=32, bias=False)

        self._initialize_weights()

    def load_pretrained(self, pretrained_model):
        
        self.features = pretrained_model.features
        fc6 = pretrained_model.classifier[0]
        fc7 = pretrained_model.classifier[3]

        conv1W = nn.parameter.Parameter(fc6.weight.view(4096,512, 7,7))
        conv2W = nn.parameter.Parameter(fc7.weight.view(4096,4096,1,1))

        # for the pre-trained weights of VGG16
        with torch.no_grad():
          self.conv1.weight = conv1W
          self.conv2.weight = conv2W

    def vgg_layer_forward(self, x, indices):
        output = x
        start_idx, end_idx = indices
        for idx in range(start_idx, end_idx):
            output = self.features[idx](output)
        return output

    def vgg_forward(self, x):
        out = {}
        layer_indices = [0, 5, 10, 17, 24, 31]
        for layer_num in range(len(layer_indices)-1):
            x = self.vgg_layer_forward(x, layer_indices[layer_num:layer_num+2])
            out[f'pool{layer_num+1}'] = x
        return out

    def forward(self, x):
        # Padding for aligning to the input size
        padded_x = F.pad(x, [100, 100, 100, 100], "constant", 0)
        vgg_features = self.vgg_forward(padded_x)
        vgg_pool5 = vgg_features['pool5'].detach()
        vgg_pool4 = vgg_features['pool4'].detach()
        vgg_pool3 = vgg_features['pool3'].detach()

        h = self.conv1(vgg_pool5)
        h = self.relu(h)
        h = self.dropout(h)

        h = self.conv2(h)
        h = self.relu(h)
        h = self.dropout(h)

        classified = self.classifier(h)
        upsampled = self.upsampler(classified)
        out = transforms.functional.crop(upsampled, top=31, left=31, height=x.shape[-2], width=x.shape[-1])

        return out

    # initialize transdeconv layer with bilinear upsampling.
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.ConvTranspose2d):
                assert m.kernel_size[0] == m.kernel_size[1]
                initial_weight = get_upsampling_weight(
                    m.in_channels, m.out_channels, m.kernel_size[0])
                m.weight.data.copy_(initial_weight)


In [None]:
class FCN8(FCN32):
    def __init__(self):
        super(FCN8, self).__init__()

        self.numclass = 21

        self.relu = nn.ReLU(True)
        self.dropout = nn.Dropout2d()

        # inheried
        # self.conv1 = nn.Conv2d(512, 4096, kernel_size=7)
        # self.conv2 = nn.Conv2d(4096, 4096, kernel_size=1)

        self.classifier1 = nn.Conv2d(4096, self.numclass, kernel_size=1)
        self.classifier2 = nn.Conv2d(512, self.numclass, kernel_size=1)
        self.classifier3 = nn.Conv2d(256, self.numclass, kernel_size=1)

        # Learnable upsampling layers in FCN model.
        self.trans1 = nn.ConvTranspose2d(self.numclass, self.numclass, kernel_size=4, stride=2, bias=False)
        self.trans2 = nn.ConvTranspose2d(self.numclass, self.numclass, kernel_size=4, stride=2, bias=False)
        self.trans3 = nn.ConvTranspose2d(self.numclass, self.numclass, kernel_size=16, stride=8, bias=False)

        # initialize deconv layer with bilinear upsampling.
        self._initialize_weights()

    def forward(self, x):
        # Padding for aligning to the input size
        padded_x = F.pad(x, [100, 100, 100, 100], "constant", 0)
        vgg_features = self.vgg_forward(padded_x)
        vgg_pool5 = vgg_features['pool5'].detach()
        vgg_pool4 = vgg_features['pool4'].detach()
        vgg_pool3 = vgg_features['pool3'].detach()

        h = self.conv1(vgg_pool5)
        h = self.relu(h)
        h = self.dropout(h)

        h = self.conv2(h)
        h = self.relu(h)
        h = self.dropout(h)

        h2 = 0.01 * vgg_pool4
        h3 = 0.0001 * vgg_pool3 

        h1 = self.classifier1(h)
        h2 = self.classifier2(h2)
        h3 = self.classifier3(h3)

        h1 = self.trans1(h1)
        h2 = transforms.functional.crop(h2, top=5, left=5, height=h1.shape[-2], width=h1.shape[-1])
        h1 = h1 + h2 

        h1 = self.trans2(h1)
        h3 = transforms.functional.crop(h3, top=9, left=9, height=h1.shape[-2], width=h1.shape[-1])
        h1 = h1 + h3 

        h1 = self.trans3(h1)
        out = transforms.functional.crop(h1, top=31, left=31, height=x.shape[-2], width=x.shape[-1])       
        return out
