[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/SaurabhRBhandari/semantic-segmentation/blob/main/sfnet.ipynb)

In [76]:
import torch 
from torch import nn, Tensor
from torch.nn import functional as F
import math
from torchsummary import summary

In [77]:
class BasicBlock(nn.Module):
    """2 Layer No Expansion Block
    """
    expansion: int = 1
    def __init__(self, c1, c2, s=1, downsample= None) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(c1, c2, 3, s, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(c2)
        self.conv2 = nn.Conv2d(c2, c2, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(c2)
        self.downsample = downsample

    def forward(self, x: Tensor) -> Tensor:
        identity = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample is not None: identity = self.downsample(x)
        out += identity
        return F.relu(out)


class Bottleneck(nn.Module):
    """3 Layer 4x Expansion Block
    """
    expansion: int = 4
    def __init__(self, c1, c2, s=1, downsample=None) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(c1, c2, 1, 1, 0, bias=False)
        self.bn1 = nn.BatchNorm2d(c2)
        self.conv2 = nn.Conv2d(c2, c2, 3, s, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(c2)
        self.conv3 = nn.Conv2d(c2, c2 * self.expansion, 1, 1, 0, bias=False)
        self.bn3 = nn.BatchNorm2d(c2 * self.expansion)
        self.downsample = downsample

    def forward(self, x: Tensor) -> Tensor:
        identity = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        if self.downsample is not None: identity = self.downsample(x)
        out += identity
        return F.relu(out)


resnet_settings = {
    '18': [BasicBlock, [2, 2, 2, 2], [64, 128, 256, 512]],
    '34': [BasicBlock, [3, 4, 6, 3], [64, 128, 256, 512]],
    '50': [Bottleneck, [3, 4, 6, 3], [256, 512, 1024, 2048]],
    '101': [Bottleneck, [3, 4, 23, 3], [256, 512, 1024, 2048]],
    '152': [Bottleneck, [3, 8, 36, 3], [256, 512, 1024, 2048]]
}


class ResNet(nn.Module):
    def __init__(self, model_name: str = '50') -> None:
        super().__init__()
        assert model_name in resnet_settings.keys(), f"ResNet model name should be in {list(resnet_settings.keys())}"
        block, depths, channels = resnet_settings[model_name]

        self.inplanes = 64
        self.channels = channels
        self.conv1 = nn.Conv2d(3, self.inplanes, 7, 2, 3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes)
        self.maxpool = nn.MaxPool2d(3, 2, 1)

        self.layer1 = self._make_layer(block, 64, depths[0], s=1)
        self.layer2 = self._make_layer(block, 128, depths[1], s=2)
        self.layer3 = self._make_layer(block, 256, depths[2], s=2)
        self.layer4 = self._make_layer(block, 512, depths[3], s=2)


    def _make_layer(self, block, planes, depth, s=1) -> nn.Sequential:
        downsample = None
        if s != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion, 1, s, bias=False),
                nn.BatchNorm2d(planes * block.expansion)
            )
        layers = nn.Sequential(
            block(self.inplanes, planes, s, downsample),
            *[block(planes * block.expansion, planes) for _ in range(1, depth)]
        )
        self.inplanes = planes * block.expansion
        return layers
    
    def forward(self, x: Tensor) -> Tensor:
        x = self.maxpool(F.relu(self.bn1(self.conv1(x))))   # [1, 64, H/4, W/4]
        x1 = self.layer1(x)  # [1, 64/256, H/4, W/4]   
        x2 = self.layer2(x1)  # [1, 128/512, H/8, W/8]
        x3 = self.layer3(x2)  # [1, 256/1024, H/16, W/16]
        x4 = self.layer4(x3)  # [1, 512/2048, H/32, W/32]
        return x1, x2, x3, x4

In [78]:
class ConvModule(nn.Sequential):
    def __init__(self, c1, c2, k, s=1, p=0, d=1, g=1):
        super().__init__(
            nn.Conv2d(c1, c2, k, s, p, d, g, bias=False),
            nn.BatchNorm2d(c2),
            nn.ReLU(True)
        )
        
class PPM(nn.Module):
    """Pyramid Pooling Module in PSPNet
    """
    def __init__(self, c1, c2=128, scales=(1, 2, 3, 6)):
        super().__init__()
        self.stages = nn.ModuleList([
            nn.Sequential(
                nn.AdaptiveAvgPool2d(scale),
                ConvModule(c1, c2, 1)
            )
        for scale in scales])

        self.bottleneck = ConvModule(c1 + c2 * len(scales), c2, 3, 1, 1)

    def forward(self, x: Tensor) -> Tensor:
        outs = []
        for stage in self.stages:
            outs.append(F.interpolate(stage(x), size=x.shape[-2:], mode='bilinear', align_corners=True))

        outs = [x] + outs[::-1]
        out = self.bottleneck(torch.cat(outs, dim=1))
        return out

In [79]:
class AlignedModule(nn.Module):
    def __init__(self, c1, c2, k=3):
        super().__init__()
        self.down_h = nn.Conv2d(c1, c2, 1, bias=False)
        self.down_l = nn.Conv2d(c1, c2, 1, bias=False)
        self.flow_make = nn.Conv2d(c2 * 2, 2, k, 1, 1, bias=False)

    def forward(self, low_feature: Tensor, high_feature: Tensor) -> Tensor:
        high_feature_origin = high_feature
        H, W = low_feature.shape[-2:]
        low_feature = self.down_l(low_feature)
        high_feature = self.down_h(high_feature)
        high_feature = F.interpolate(high_feature, size=(H, W), mode='bilinear', align_corners=True)
        flow = self.flow_make(torch.cat([high_feature, low_feature], dim=1))
        high_feature = self.flow_warp(high_feature_origin, flow, (H, W))
        return high_feature

    def flow_warp(self, x: Tensor, flow: Tensor, size: tuple) -> Tensor:
        # norm = torch.tensor(size).reshape(1, 1, 1, -1)
        norm = torch.tensor([[[[*size]]]]).type_as(x).to(x.device)
        H = torch.linspace(-1.0, 1.0, size[0]).view(-1, 1).repeat(1, size[1])
        W = torch.linspace(-1.0, 1.0, size[1]).repeat(size[0], 1)
        grid = torch.cat((W.unsqueeze(2), H.unsqueeze(2)), dim=2)
        grid = grid.repeat(x.shape[0], 1, 1, 1).type_as(x).to(x.device)
        grid = grid + flow.permute(0, 2, 3, 1) / norm
        output = F.grid_sample(x, grid, align_corners=False)
        return output


class SFHead(nn.Module):
    def __init__(self, in_channels, channel=256, num_classes=19, scales=(1, 2, 3, 6)):
        super().__init__()
        self.ppm = PPM(in_channels[-1], channel, scales)

        self.fpn_in = nn.ModuleList([])
        self.fpn_out = nn.ModuleList([])
        self.fpn_out_align = nn.ModuleList([])

        for in_ch in in_channels[:-1]:
            self.fpn_in.append(ConvModule(in_ch, channel, 1))
            self.fpn_out.append(ConvModule(channel, channel, 3, 1, 1))
            self.fpn_out_align.append(AlignedModule(channel, channel//2))

        self.bottleneck = ConvModule(len(in_channels) * channel, channel, 3, 1, 1)
        self.dropout = nn.Dropout2d(0.1)
        self.conv_seg = nn.Conv2d(channel, num_classes, 1)

    def forward(self, features: list) -> Tensor:
        f = self.ppm(features[-1])
        fpn_features = [f]

        for i in reversed(range(len(features) - 1)):
            feature = self.fpn_in[i](features[i])
            f = feature + self.fpn_out_align[i](feature, f)
            fpn_features.append(self.fpn_out[i](f))

        fpn_features.reverse()

        for i in range(1, len(fpn_features)):
            fpn_features[i] = F.interpolate(fpn_features[i], size=fpn_features[0].shape[-2:], mode='bilinear', align_corners=True)

        output = self.bottleneck(torch.cat(fpn_features, dim=1))
        output = self.conv_seg(self.dropout(output))
        return output

In [80]:

class SFNet(nn.Module):
    def __init__(self,num_classes=19):
        super().__init__()
        self.backbone=ResNet()
        self.head=SFHead(self.backbone.channels,256,num_classes)
        
    def forward(self,x):
        outs=self.backbone(x)
        out=self.head(outs)
        out = F.interpolate(out, size=x.shape[-2:], mode='bilinear', align_corners=True)
        return out 
        

In [81]:
model=SFNet()
summary(model,(3,256,256),device="cpu")

Layer (type:depth-idx)                   Output Shape              Param #
├─ResNet: 1-1                            [-1, 256, 64, 64]         --
|    └─Conv2d: 2-1                       [-1, 64, 128, 128]        9,408
|    └─BatchNorm2d: 2-2                  [-1, 64, 128, 128]        128
|    └─MaxPool2d: 2-3                    [-1, 64, 64, 64]          --
|    └─Sequential: 2-4                   [-1, 256, 64, 64]         --
|    |    └─Bottleneck: 3-1              [-1, 256, 64, 64]         75,008
|    |    └─Bottleneck: 3-2              [-1, 256, 64, 64]         70,400
|    |    └─Bottleneck: 3-3              [-1, 256, 64, 64]         70,400
|    └─Sequential: 2-5                   [-1, 512, 32, 32]         --
|    |    └─Bottleneck: 3-4              [-1, 512, 32, 32]         379,392
|    |    └─Bottleneck: 3-5              [-1, 512, 32, 32]         280,064
|    |    └─Bottleneck: 3-6              [-1, 512, 32, 32]         280,064
|    |    └─Bottleneck: 3-7              [-1, 512, 32,

Layer (type:depth-idx)                   Output Shape              Param #
├─ResNet: 1-1                            [-1, 256, 64, 64]         --
|    └─Conv2d: 2-1                       [-1, 64, 128, 128]        9,408
|    └─BatchNorm2d: 2-2                  [-1, 64, 128, 128]        128
|    └─MaxPool2d: 2-3                    [-1, 64, 64, 64]          --
|    └─Sequential: 2-4                   [-1, 256, 64, 64]         --
|    |    └─Bottleneck: 3-1              [-1, 256, 64, 64]         75,008
|    |    └─Bottleneck: 3-2              [-1, 256, 64, 64]         70,400
|    |    └─Bottleneck: 3-3              [-1, 256, 64, 64]         70,400
|    └─Sequential: 2-5                   [-1, 512, 32, 32]         --
|    |    └─Bottleneck: 3-4              [-1, 512, 32, 32]         379,392
|    |    └─Bottleneck: 3-5              [-1, 512, 32, 32]         280,064
|    |    └─Bottleneck: 3-6              [-1, 512, 32, 32]         280,064
|    |    └─Bottleneck: 3-7              [-1, 512, 32,