In [1]:
# from PyTorch's torchvision

import torch
from torch import nn
from torch.nn import functional as F


class ASPPConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, dilation):
        modules = [
            nn.Conv2d(
                in_channels,
                out_channels,
                3,
                padding=dilation,
                dilation=dilation,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        ]
        super(ASPPConv, self).__init__(*modules)


class ASPPPooling(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(ASPPPooling, self).__init__(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        size = x.shape[-2:]
        x = super(ASPPPooling, self).forward(x)
        return F.interpolate(x, size=size, mode="bilinear", align_corners=False)


class ASPP(nn.Module):
    def __init__(self, in_channels, atrous_rates):
        super(ASPP, self).__init__()
        out_channels = 256
        modules = []
        modules.append(
            nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
            )
        )

        rate1, rate2, rate3 = tuple(atrous_rates)
        modules.append(ASPPConv(in_channels, out_channels, rate1))
        modules.append(ASPPConv(in_channels, out_channels, rate2))
        modules.append(ASPPConv(in_channels, out_channels, rate3))
        modules.append(ASPPPooling(in_channels, out_channels))

        self.convs = nn.ModuleList(modules)

        self.project = nn.Sequential(
            nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
        )

    def forward(self, x):
        res = []
        for conv in self.convs:
            res.append(conv(x))
        res = torch.cat(res, dim=1)
        return self.project(res)


class DeepLabHeadV3Plus(nn.Module):
    def __init__(
        self,
        in_channels=2048,
        low_level_channels=256,
        num_classes=43,
        aspp_dilate=[12, 24, 36],
        output_size=None,
    ):
        super(DeepLabHeadV3Plus, self).__init__()
        self.project = nn.Sequential(
            nn.Conv2d(low_level_channels, 48, 1, bias=False),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True),
        )
        self.output_size = output_size

        self.aspp = ASPP(in_channels, aspp_dilate)

        self.classifier = nn.Sequential(
            nn.Conv2d(304, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, 1),
        )
        self._init_weight()

    def forward(self, feature):
        low_level_feature = self.project(feature["low_level"])
        output_feature = self.aspp(feature["out"])
        output_size = (
            self.output_size if self.output_size else low_level_feature.shape[2:]
        )
        low_level_feature = F.interpolate(
            low_level_feature, size=output_size, mode="bilinear", align_corners=False
        )
        output_feature = F.interpolate(
            output_feature, size=output_size, mode="bilinear", align_corners=False
        )
        return self.classifier(torch.cat([low_level_feature, output_feature], dim=1))

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

In [2]:
# from PyTorch's torchvision

import torch
import torch.nn as nn
from torch.hub import load_state_dict_from_url

# __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
# 		   'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
# 		   'wide_resnet50_2', 'wide_resnet101_2']


model_urls = {
    "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
}


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        groups=groups,
        bias=False,
        dilation=dilation,
    )


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(
        self,
        inplanes,
        planes,
        stride=1,
        downsample=None,
        groups=1,
        base_width=64,
        dilation=1,
        norm_layer=None,
    ):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(
        self,
        inplanes,
        planes,
        stride=1,
        downsample=None,
        groups=1,
        base_width=64,
        dilation=1,
        norm_layer=None,
    ):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.0)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(
        self,
        block,
        layers,
        num_classes=1000,
        zero_init_residual=False,
        groups=1,
        width_per_group=64,
        replace_stride_with_dilation=None,
        norm_layer=None,
    ):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError(
                "replace_stride_with_dilation should be None "
                "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
            )
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(
            3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
        )
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(
            block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
        )
        self.layer3 = self._make_layer(
            block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
        )
        self.layer4 = self._make_layer(
            block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]
        )
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(
            block(
                self.inplanes,
                planes,
                stride,
                downsample,
                self.groups,
                self.base_width,
                previous_dilation,
                norm_layer,
            )
        )
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(
                    self.inplanes,
                    planes,
                    groups=self.groups,
                    base_width=self.base_width,
                    dilation=self.dilation,
                    norm_layer=norm_layer,
                )
            )

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x


def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
        model.load_state_dict(state_dict)
    return model


def resnet50(pretrained=False, progress=True, **kwargs):
    r"""ResNet-50 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
            pretrained (bool): If True, returns a model pre-trained on ImageNet
            progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)

In [None]:
import torch
import torch.nn as nn

class BasicConv(nn.Module):  ## debug
    def __init__(
        self,
        in_planes,
        out_planes,
        kernel_size=3,
        stride=1,
        padding=1,
        dilation=1,
        bias=False,
        pooling=None,
        groups=1,
        batchnorm=True,
        act=nn.ReLU(inplace=True),
    ):
        super(BasicConv, self).__init__()
        self.layers = nn.ModuleList(
            [
                nn.Conv2d(
                    in_planes,
                    out_planes,
                    kernel_size,
                    stride,
                    padding,
                    dilation,
                    groups,
                    bias,
                ),  # training acceleration: bias off before BN
            ]
        )

        if pooling:
            self.layers.append(pooling)

        if batchnorm:
            self.layers.append(nn.BatchNorm2d(out_planes))

        if act:
            self.layers.append(act)

        self.sequential = nn.Sequential(*self.layers)

    def forward(self, x):
        return self.sequential(x)


class ResNet_Proj(nn.Module):
    def __init__(self, channels=(1024, 2048)):
        super(ResNet_Proj, self).__init__()
        layer2_channel, layer3_channel = channels
        self.conv_bn_relu1 = BasicConv(
            layer2_channel, layer2_channel, pooling=nn.AdaptiveAvgPool2d(24)
        )
        self.conv_bn_relu2 = BasicConv(
            layer3_channel, layer3_channel, pooling=nn.AdaptiveAvgPool2d(12)
        )

    def forward(self, semantic_feat):
        semantic_feat[2] = self.conv_bn_relu1(semantic_feat[2])
        semantic_feat[3] = self.conv_bn_relu2(semantic_feat[3])
        return semantic_feat


class Res_DeepLabV3P(nn.Module):
    """
    Res backbone for semantic feature extraction
    """

    def __init__(self) -> None:
        super(Res_DeepLabV3P, self).__init__()

        self.resnet = resnet.__dict__["resnet50"](
            pretrained=True, replace_stride_with_dilation=[False, True, True]
        )
        self.conv1 = self.resnet.conv1
        self.bn1 = self.resnet.bn1
        self.relu = self.resnet.relu
        self.max_pool = self.resnet.maxpool
        self.layer1 = self.resnet.layer1
        self.layer2 = self.resnet.layer2
        self.layer3 = self.resnet.layer3
        self.layer4 = self.resnet.layer4

        self.projection = ResNet_Proj()

    def forward(self, x):

        features = []
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.max_pool(x)
        layer0 = x
        x = self.layer1(x)
        features.append(x)  # ; layer1 = x
        x = self.layer2(x)
        features.append(x)
        x = self.layer3(x)
        features.append(x)
        x = self.layer4(x)
        features.append(x)  # ; layer4 = x

        features = self.projection(features)

        return {"backbone": features, "layer0": layer0}

In [None]:
# ---------------------------------------------------------------
# https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# ---------------------------------------------------------------

import torch
import torch.nn as nn
from functools import partial

from timm.models.layers import DropPath, to_2tuple, trunc_normal_
import math
import os


class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1, 2)

        return x


class Mlp(nn.Module):
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
        drop=0.0,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.dwconv = DWConv(hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        x = self.fc1(x)
        x = self.dwconv(x, H, W)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(
        self,
        dim,
        num_heads=8,
        qkv_bias=False,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
        sr_ratio=1,
    ):
        super().__init__()
        assert (
            dim % num_heads == 0
        ), f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim**-0.5

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm = nn.LayerNorm(dim)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        B, N, C = x.shape
        q = (
            self.q(x)
            .reshape(B, N, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
        )

        if self.sr_ratio > 1:
            x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
            x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
            x_ = self.norm(x_)
            kv = (
                self.kv(x_)
                .reshape(B, -1, 2, self.num_heads, C // self.num_heads)
                .permute(2, 0, 3, 1, 4)
            )
        else:
            kv = (
                self.kv(x)
                .reshape(B, -1, 2, self.num_heads, C // self.num_heads)
                .permute(2, 0, 3, 1, 4)
            )
        k, v = kv[0], kv[1]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x


class Block(nn.Module):

    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        sr_ratio=1,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
            sr_ratio=sr_ratio,
        )
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        x = x + self.drop_path(self.attn(self.norm1(x), H, W))
        x = x + self.drop_path(self.mlp(self.norm2(x), H, W))

        return x


class OverlapPatchEmbed(nn.Module):
    """Image to Patch Embedding"""

    def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)

        self.img_size = img_size
        self.patch_size = patch_size
        self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        self.num_patches = self.H * self.W
        self.proj = nn.Conv2d(
            in_chans,
            embed_dim,
            kernel_size=patch_size,
            stride=stride,
            padding=(patch_size[0] // 2, patch_size[1] // 2),
        )
        self.norm = nn.LayerNorm(embed_dim)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x):
        x = self.proj(x)
        _, _, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)

        return x, H, W


class MixVisionTransformer(nn.Module):
    official_ckpts = {}

    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_chans=3,
        num_classes=1000,
        embed_dims=[64, 128, 256, 512],
        num_heads=[1, 2, 4, 8],
        mlp_ratios=[4, 4, 4, 4],
        qkv_bias=False,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        depths=[3, 4, 6, 3],
        sr_ratios=[8, 4, 2, 1],
    ):
        super().__init__()
        self.num_classes = num_classes
        self.depths = depths

        # patch_embed
        self.patch_embed1 = OverlapPatchEmbed(
            img_size=img_size,
            patch_size=7,
            stride=4,
            in_chans=in_chans,
            embed_dim=embed_dims[0],
        )
        self.patch_embed2 = OverlapPatchEmbed(
            img_size=img_size // 4,
            patch_size=3,
            stride=2,
            in_chans=embed_dims[0],
            embed_dim=embed_dims[1],
        )
        self.patch_embed3 = OverlapPatchEmbed(
            img_size=img_size // 8,
            patch_size=3,
            stride=2,
            in_chans=embed_dims[1],
            embed_dim=embed_dims[2],
        )
        self.patch_embed4 = OverlapPatchEmbed(
            img_size=img_size // 16,
            patch_size=3,
            stride=2,
            in_chans=embed_dims[2],
            embed_dim=embed_dims[3],
        )

        # transformer encoder
        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
        ]  # stochastic depth decay rule
        cur = 0
        self.block1 = nn.ModuleList(
            [
                Block(
                    dim=embed_dims[0],
                    num_heads=num_heads[0],
                    mlp_ratio=mlp_ratios[0],
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[cur + i],
                    norm_layer=norm_layer,
                    sr_ratio=sr_ratios[0],
                )
                for i in range(depths[0])
            ]
        )
        self.norm1 = norm_layer(embed_dims[0])

        cur += depths[0]
        self.block2 = nn.ModuleList(
            [
                Block(
                    dim=embed_dims[1],
                    num_heads=num_heads[1],
                    mlp_ratio=mlp_ratios[1],
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[cur + i],
                    norm_layer=norm_layer,
                    sr_ratio=sr_ratios[1],
                )
                for i in range(depths[1])
            ]
        )
        self.norm2 = norm_layer(embed_dims[1])

        cur += depths[1]
        self.block3 = nn.ModuleList(
            [
                Block(
                    dim=embed_dims[2],
                    num_heads=num_heads[2],
                    mlp_ratio=mlp_ratios[2],
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[cur + i],
                    norm_layer=norm_layer,
                    sr_ratio=sr_ratios[2],
                )
                for i in range(depths[2])
            ]
        )
        self.norm3 = norm_layer(embed_dims[2])

        cur += depths[2]
        self.block4 = nn.ModuleList(
            [
                Block(
                    dim=embed_dims[3],
                    num_heads=num_heads[3],
                    mlp_ratio=mlp_ratios[3],
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[cur + i],
                    norm_layer=norm_layer,
                    sr_ratio=sr_ratios[3],
                )
                for i in range(depths[3])
            ]
        )
        self.norm4 = norm_layer(embed_dims[3])

        # classification head
        # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    # def init_weights(self, pretrained=None):
    #     if isinstance(pretrained, str):
    #         logger = get_root_logger()
    #         load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)

    def reset_drop_path(self, drop_path_rate):
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
        cur = 0
        for i in range(self.depths[0]):
            self.block1[i].drop_path.drop_prob = dpr[cur + i]

        cur += self.depths[0]
        for i in range(self.depths[1]):
            self.block2[i].drop_path.drop_prob = dpr[cur + i]

        cur += self.depths[1]
        for i in range(self.depths[2]):
            self.block3[i].drop_path.drop_prob = dpr[cur + i]

        cur += self.depths[2]
        for i in range(self.depths[3]):
            self.block4[i].drop_path.drop_prob = dpr[cur + i]

    def freeze_patch_emb(self):
        self.patch_embed1.requires_grad = False

    @torch.jit.ignore
    def no_weight_decay(self):
        return {
            "pos_embed1",
            "pos_embed2",
            "pos_embed3",
            "pos_embed4",
            "cls_token",
        }  # has pos_embed may be better

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=""):
        self.num_classes = num_classes
        self.head = (
            nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        )

    def forward_features(self, x):
        B = x.shape[0]
        outs = []

        # stage 1
        x, H, W = self.patch_embed1(x)
        for i, blk in enumerate(self.block1):
            x = blk(x, H, W)
        x = self.norm1(x)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

        # stage 2
        x, H, W = self.patch_embed2(x)
        for i, blk in enumerate(self.block2):
            x = blk(x, H, W)
        x = self.norm2(x)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

        # stage 3
        x, H, W = self.patch_embed3(x)
        for i, blk in enumerate(self.block3):
            x = blk(x, H, W)
        x = self.norm3(x)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

        # stage 4
        x, H, W = self.patch_embed4(x)
        for i, blk in enumerate(self.block4):
            x = blk(x, H, W)
        x = self.norm4(x)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

        return outs

    def forward(self, x):
        x = self.forward_features(x)
        # x = self.head(x)

        return x

    def load_official_state_dict(
        self,
        filename: str,
        local_dir: str = None,
        download: bool = True,
        strict: bool = False,
    ):
        """
        Args:
            local_dir: if not None, load from "local_dir/filename"
            strict: note that, the definition is silghtly different from load_state_dict().
                    If set to flase, the weight of final layer will not be loaded, so num_classes can be any.
        """

        assert (
            filename in self.official_ckpts.keys()
        ), f"available checkpoints are {self.official_ckpts.keys()}"

        if local_dir is None:
            local_dir = os.path.join(
                os.path.expanduser("~"), ".cache", "torch", "checkpoints"
            )

        path = os.path.join(local_dir, filename)
        if os.path.isfile(path):
            ckpt = torch.load(path, map_location="cpu")
        elif download:
            os.makedirs(os.path.dirname(path), exist_ok=True)
            import gdown

            url = self.official_ckpts[filename]
            gdown.download(url, path, quiet=False)
            ckpt = torch.load(path, map_location="cpu")
            # ckpt = load_state_dict_from_url(url, progress=True, file_name=filename, model_dir=local_dir)
        else:
            raise ValueError("You neither proivde local path nor set download True!")

        # state_dict = ckpt['state_dict']
        # # exclude_keys = ["decode_head.conv_seg.weight", "decode_head.conv_seg.bias"]
        # # if not strict:
        # #     exclude_keys += ["decode_head.linear_pred.weight", "decode_head.linear_pred.bias"]
        # exclude_keys = []
        # ckpt_to_load = {k:v for k, v in state_dict.items() if k not in exclude_keys}
        self.load_state_dict(ckpt, strict=strict)
        print("loaded pretrained weight from", path)

    def reset_input_channel(self, new_in_chans, pretrained=True):
        """
        this function can be used to change the input channels for a pretrained model.
        the weights of first conv layer are cyclicaly copied.
        see https://stackoverflow.com/questions/62629114/how-to-modify-resnet-50-with-4-channels-as-input-using-pre-trained-weights-in-py
        and https://github.com/qubvel/segmentation_models.pytorch/blob/master/segmentation_models_pytorch/encoders/_utils.py
        """
        # pass
        weight = self.patch_embed1.proj.weight.detach()
        # bias = self.patch_embed1.proj.bias.detach()
        embed_dim = self.patch_embed1.proj.out_channels
        kernel_size = self.patch_embed1.proj.kernel_size
        stride = self.patch_embed1.proj.stride
        padding = self.patch_embed1.proj.padding
        self.patch_embed1.proj = nn.Conv2d(
            new_in_chans,
            embed_dim,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        )

        if pretrained:
            for ch in range(new_in_chans):
                # self.patch_embed1.proj.bias.data[:, ch, :, :] = weight[:, ch%3, :, :]
                self.patch_embed1.proj.weight.data[:, ch, :, :] = weight[
                    :, ch % 3, :, :
                ]


class MiTB5(MixVisionTransformer):
    official_ckpts = {
        "mit_b5.pth": "https://drive.google.com/uc?export=download&id=1d7I50jVjtCddnhpf-lqj8-f13UyCzoW1"
    }

    def __init__(self, **kwargs):
        super(MiTB5, self).__init__(
            patch_size=4,
            embed_dims=[64, 128, 320, 512],
            num_heads=[1, 2, 5, 8],
            mlp_ratios=[4, 4, 4, 4],
            qkv_bias=True,
            norm_layer=partial(nn.LayerNorm, eps=1e-6),
            depths=[3, 6, 40, 3],
            sr_ratios=[8, 4, 2, 1],
            drop_rate=0.0,
            drop_path_rate=0.1,
        )


class SegFormer(nn.Module):

    def __init__(self, checkpoint_path=None, req_grad=False) -> None:
        super(SegFormer, self).__init__()
        self.segformer = MiTB5()
        # if checkpoint_path:
        # 	self.segformer.load_official_state_dict('mit_b5.pth', checkpoint_path)
        # self.turn_grad(req_grad)

    def turn_grad(self, req_grad):
        if not req_grad:
            for param in self.segformer.parameters():
                param.requires_grad = False
            else:
                print("Backbone SegFormer: turned off requires_grad")

    def forward(self, x):
        rgb_list = self.segformer(x)
        return [rgb_list[i] for i in range(4)]

In [None]:
import torch.nn as nn
from einops import rearrange
import math


## qkv prep
class SqueezeAndExcitationSEShared(nn.Module):
    def __init__(self, channel, activation=nn.ReLU(), channel_mid=None):
        super(SqueezeAndExcitationSEShared, self).__init__()
        if channel_mid is None:
            channel_mid = channel // 2

        self.avgpool1d = nn.AdaptiveAvgPool1d(1)
        self.conv1d1 = nn.Conv1d(channel, channel_mid, kernel_size=1)
        self.act = activation
        self.conv1d2 = nn.Conv1d(channel_mid, channel, kernel_size=1)
        self.sigm = nn.Sigmoid()

    def forward(self, rgb, sem_encod):
        weighting = self.conv1d1(self.avgpool1d(rgb))
        if sem_encod is not None:
            weighting = weighting + sem_encod

        weighting = self.sigm(self.conv1d2(self.act(weighting)))

        y = rgb * weighting
        return y


class Semantic_To_KV(nn.Module):
    def __init__(self, embed_dim, semantic_assert):
        super(Semantic_To_KV, self).__init__()
        self.projection = SqueezeAndExcitationSEShared(
            embed_dim, channel_mid=semantic_assert
        )

    def forward(self, semantic_feature, sem_encod):
        sem = rearrange(
            self.projection(rearrange(semantic_feature, "hw b c -> b c hw"), sem_encod),
            "b c hw -> hw b c",
        )
        sem_k, sem_v = sem.chunk(2, dim=-1)

        return sem_k, sem_v


## Attention block
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(self, embed_dim):
        super(Attention, self).__init__()

        self.attn = nn.MultiheadAttention(embed_dim, num_heads=8)
        self.norm = nn.LayerNorm(embed_dim)
        self.ff = PreNorm(embed_dim, FeedForward(embed_dim, 1024))

    def forward(self, spatial_q, semantic_k, semantic_v):
        attention, _ = self.attn(spatial_q, semantic_k, semantic_v, need_weights=False)
        attention = self.norm(attention)
        attention = self.ff(attention)
        attention = attention + spatial_q
        return attention


# Context Correlation Attention (CCA) Module
class CCA(nn.Module):
    def __init__(self, spatial_dim, semantic_dim, transform_dim, semantic_assert):
        super(CCA, self).__init__()

        self.spatial_to_q = nn.Linear(spatial_dim, transform_dim)
        self.semantic_to_kv = Semantic_To_KV(semantic_dim, semantic_assert)
        self.attention = Attention(transform_dim)

    def unflatten(self, x):
        return rearrange(x, "(h w) b c -> b c h w", h=int(math.sqrt(x.shape[0])))

    def forward(self, spatial_feature, semantic_feature, sem_encod):
        spatial = rearrange(spatial_feature, "b c h w -> (h w) b c")
        spatial_q = self.spatial_to_q(spatial)

        semantic = rearrange(semantic_feature, "b c h w -> (h w) b c")
        semantic_k, semantic_v = self.semantic_to_kv(semantic, sem_encod)

        attended = self.attention(spatial_q, semantic_k, semantic_v)

        return self.unflatten(attended)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict, Any, Optional

from model.backbone.ResNet import Res_DeepLabV3P
from model.backbone.SegFormer import SegFormer
from model.backbone.DeepLab import DeepLabHeadV3Plus
from model.UperNet import UPerNet
from model.SAA import SAA
from model.CCA import CCA


class Sem_Enc(nn.Module):
    """
    Semantic Encoding Module.

    This module converts semantic features from the ResNet backbone into encodings.
    It processes the high-level semantic features using a DeepLabV3+ head followed by
    a series of depthwise convolutions and pooling operations to generate encodings
    of shape (B, num_classes, 1).

    Args:
        num_classes (int): Number of target classes for segmentation.
    """

    def __init__(self, num_classes: int) -> None:
        """
        Initializes the Sem_Enc module.

        Args:
            num_classes (int): Number of target classes for segmentation.
        """
        super(Sem_Enc, self).__init__()

        # Initialize DeepLabV3+ head for projecting semantic features
        self.projection: DeepLabHeadV3Plus = DeepLabHeadV3Plus(num_classes=num_classes)

        # Depthwise convolution layers with specified kernel sizes and pooling
        self.conv1: nn.Conv2d = nn.Conv2d(
            in_channels=num_classes,
            out_channels=num_classes,
            kernel_size=7,
            stride=1,
            padding=4,
            groups=num_classes,  # Depthwise convolution
            bias=False,
        )
        self.pool1: nn.AvgPool2d = nn.AvgPool2d(kernel_size=6)

        self.conv2: nn.Conv2d = nn.Conv2d(
            in_channels=num_classes,
            out_channels=num_classes,
            kernel_size=5,
            stride=1,
            padding=2,
            groups=num_classes,  # Depthwise convolution
            bias=False,
        )
        self.pool2: nn.AvgPool2d = nn.AvgPool2d(
            kernel_size=4
        )  # Reduces spatial dimensions from 16 to 4

        self.conv3: nn.Conv2d = nn.Conv2d(
            in_channels=num_classes,
            out_channels=num_classes,
            kernel_size=4,
            stride=1,
            padding=0,
            groups=num_classes,  # Depthwise convolution
            bias=False,
        )  # Reduces spatial dimensions from 4 to 1

        # Batch normalization and ReLU activation
        self.bn: nn.BatchNorm2d = nn.BatchNorm2d(num_classes)
        self.relu: nn.ReLU = nn.ReLU()

    def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
        """
        Forward pass of the Sem_Enc module.

        Args:
            features (List[torch.Tensor]): List of feature maps from the ResNet backbone.
                                           Expected to contain [layer1, layer2, layer3, layer4].

        Returns:
            torch.Tensor: Semantic encodings with shape (B, num_classes).
        """
        # Apply DeepLabV3+ projection to semantic features
        # Input: {'low_level': features[0], 'out': features[3]}
        x: torch.Tensor = self.projection(
            {"low_level": features[0], "out": features[3]}
        )  # Shape: (B, num_classes, H, W)

        # First convolution and pooling layer
        conv: torch.Tensor = self.conv1(x)  # Shape: (B, num_classes, H, W)
        conv = self.pool1(conv)  # Shape: (B, num_classes, H//6, W//6)
        conv = self.bn(conv)  # Shape: (B, num_classes, H//6, W//6)
        conv = self.relu(conv)  # Shape: (B, num_classes, H//6, W//6)

        # Second convolution and pooling layer
        conv = self.conv2(conv)  # Shape: (B, num_classes, H//6, W//6)
        conv = self.pool2(conv)  # Shape: (B, num_classes, H//24, W//24)
        conv = self.bn(conv)  # Shape: (B, num_classes, H//24, W//24)
        conv = self.relu(conv)  # Shape: (B, num_classes, H//24, W//24)

        # Third convolution layer to reduce spatial dimensions to 1x1
        conv = self.conv3(conv)  # Shape: (B, num_classes, 1, 1)
        conv = self.bn(conv)  # Shape: (B, num_classes, 1, 1)
        conv = self.relu(conv)  # Shape: (B, num_classes, 1, 1)

        # Squeeze the last two dimensions to obtain (B, num_classes)
        return conv.squeeze(3)  # Shape: (B, num_classes, 1) -> (B, num_classes)


class GlassSemNet(nn.Module):
    """
    GlassSemNet Model for Semantic Segmentation.

    This model integrates spatial and semantic backbones, semantic encodings,
    Scene Aware Activation (SAA) modules, Context Correlation Attention (CCA) module,
    and a decoder (UPerNet) to produce segmentation outputs.

    Architecture Components:
        - Spatial Backbone: SegFormer
        - Semantic Backbone: Res_DeepLabV3P
        - Semantic Encoding: Sem_Enc
        - SAA Modules: SAA0, SAA1, SAA2
        - CCA Module: CCA3
        - Decoder: UPerNet
        - Auxiliary Outputs: aux1, aux2

    Args:
        None
    """

    def __init__(self) -> None:
        """
        Initializes the GlassSemNet model.
        """
        super(GlassSemNet, self).__init__()

        # Define the number of segmentation classes
        self.num_classes: int = 43

        # Initialize the spatial backbone (SegFormer)
        self.spatial_backbone: SegFormer = SegFormer()

        # Initialize the semantic backbone (ResNet with DeepLabV3+)
        self.semantic_backbone: Res_DeepLabV3P = Res_DeepLabV3P()

        # Initialize the semantic encoding module
        self.sem_enc: Sem_Enc = Sem_Enc(num_classes=self.num_classes)

        # Initialize Scene Aware Activation (SAA) modules for different feature levels
        self.saa0: SAA = SAA(
            spatial_dim=64, semantic_dim=256, semantic_assert=self.num_classes
        )
        self.saa1: SAA = SAA(
            spatial_dim=128, semantic_dim=512, semantic_assert=self.num_classes
        )
        self.saa2: SAA = SAA(
            spatial_dim=320, semantic_dim=1024, semantic_assert=self.num_classes
        )

        # Initialize Context Correlation Attention (CCA) module for the highest feature level
        self.cca3: CCA = CCA(
            spatial_dim=512,
            semantic_dim=2048,
            transform_dim=1024,
            semantic_assert=self.num_classes,
        )

        # Initialize auxiliary convolution layers for intermediate outputs (optional)
        self.aux1: nn.Conv2d = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1)
        self.aux2: nn.Conv2d = nn.Conv2d(
            in_channels=1024, out_channels=1, kernel_size=1
        )

        # Initialize the decoder (UPerNet) for final segmentation output
        self.decoder: UPerNet = UPerNet()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the GlassSemNet model.

        Args:
            x (torch.Tensor): Input tensor with shape (B, 3, H, W).

        Returns:
            torch.Tensor: Segmentation output tensor with shape (B, num_class, H', W'),
                          where (H', W') is determined by the decoder's output resolution.
        """
        # ---------------------------
        # Spatial Backbone Forward Pass
        # ---------------------------
        spatial_feats: List[torch.Tensor] = self.spatial_backbone(x)
        # spatial_feats: List of feature maps from SegFormer, e.g., [P2, P3, P4, P5]

        # ---------------------------
        # Semantic Backbone Forward Pass
        # ---------------------------
        resnet_out: Dict[str, List[torch.Tensor]] = self.semantic_backbone(x)
        semantic_feats: List[torch.Tensor] = resnet_out["backbone"]
        semantic_lowlevel: torch.Tensor = resnet_out["layer0"]
        # semantic_feats: List of semantic feature maps from ResNet, e.g., [C1, C2, C3, C4]

        # ---------------------------
        # Semantic Encodings
        # ---------------------------
        sem_enc: torch.Tensor = self.sem_enc(semantic_feats)
        # sem_enc: Semantic encodings with shape (B, num_classes)

        # ---------------------------
        # Scene Aware Activation (SAA) Modules
        # ---------------------------
        saa0: torch.Tensor = self.saa0(
            spatial_feature=spatial_feats[0],
            semantic_feature=semantic_feats[0],
            sem_encod=sem_enc,
        )
        saa1: torch.Tensor = self.saa1(
            spatial_feature=spatial_feats[1],
            semantic_feature=semantic_feats[1],
            sem_encod=sem_enc,
        )
        saa2: torch.Tensor = self.saa2(
            spatial_feature=spatial_feats[2],
            semantic_feature=semantic_feats[2],
            sem_encod=sem_enc,
        )

        # ---------------------------
        # Context Correlation Attention (CCA) Module
        # ---------------------------
        cca3: torch.Tensor = self.cca3(
            spatial_feature=spatial_feats[3],
            semantic_feature=semantic_feats[3],
            sem_encod=sem_enc,
        )

        # ---------------------------
        # Decoder Preparation
        # ---------------------------
        # Concatenate spatial, semantic, and activated features for each level
        l0: torch.Tensor = torch.cat(
            [spatial_feats[0], semantic_feats[0], saa0], dim=1
        )  # Shape: (B, C0 + C0 + C0, H0, W0)
        l1: torch.Tensor = torch.cat(
            [spatial_feats[1], semantic_feats[1], saa1], dim=1
        )  # Shape: (B, C1 + C1 + C1, H1, W1)
        l2: torch.Tensor = torch.cat(
            [spatial_feats[2], semantic_feats[2], saa2], dim=1
        )  # Shape: (B, C2 + C2 + C2, H2, W2)
        l3: torch.Tensor = torch.cat(
            [spatial_feats[3], semantic_feats[3], cca3], dim=1
        )  # Shape: (B, C3 + C3 + C3, H3, W3)

        # Prepare the list of feature maps for the decoder
        # Typically, decoder expects [low_level, P2, P3, P4, P5]
        decoder_feats: List[torch.Tensor] = [semantic_lowlevel, l0, l1, l2, l3]

        # ---------------------------
        # Decoder Forward Pass
        # ---------------------------
        out: torch.Tensor = self.decoder(decoder_feats)
        # out: Segmentation output from UPerNet, shape depends on UPerNet configuration

        # Optional: Auxiliary outputs (commented out)
        # aux_out1: torch.Tensor = self.aux1(saa1)  # Shape: (B, 1, H1, W1)
        # aux_out2: torch.Tensor = self.aux2(cca3)  # Shape: (B, 1, H3, W3)
        # out = out + aux_out1 + aux_out2  # Combine main and auxiliary outputs

        return out  # Final segmentation output


# Example usage:
# if __name__ == '__main__':
#     # Create a random input tensor with batch size 2, 3 channels, and 384x384 spatial dimensions
#     x: torch.Tensor = torch.rand(2, 3, 384, 384)

#     # Initialize the GlassSemNet model
#     model: GlassSemNet = GlassSemNet()

#     # Perform a forward pass
#     out: torch.Tensor = model(x)

#     # Print the output shape
#     print(out.shape)  # Expected shape: (2, num_class, H', W')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.backbone.ResNet import BasicConv


# CBAM
class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


class ChannelGate(nn.Module):
    def __init__(self, gate_channels, compressed_channels=None):
        super(ChannelGate, self).__init__()
        if compressed_channels is None:
            compressed_channels = gate_channels // 2
        self.flat = Flatten()
        self.lin1 = nn.Linear(gate_channels, compressed_channels)
        self.relu = nn.ReLU()
        self.lin2 = nn.Linear(compressed_channels, gate_channels)

    def forward(self, x, sem_encod):
        avg_pool = F.avg_pool2d(
            x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))
        )
        channel_att_sum = self.lin1(self.flat(avg_pool))
        if sem_encod is not None:
            channel_att_sum = channel_att_sum + sem_encod

        channel_att_sum = self.lin2(self.relu(channel_att_sum))
        scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
        return x * scale


class ChannelPool(nn.Module):
    def forward(self, x):
        GAP = torch.mean(x, dim=1, keepdim=True)
        GMP, _ = torch.max(x, dim=1, keepdim=True)
        return torch.cat([GAP, GMP], dim=1)


class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        self.spatial = BasicConv(
            2,
            1,
            kernel_size=kernel_size,
            stride=1,
            padding=(kernel_size - 1) // 2,
            act=None,
        )

    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = torch.sigmoid(x_out)
        return x * scale


class CBAMResidualShared(nn.Module):

    def __init__(self, spatial_dim, semantic_dim, semantic_assert):
        super(CBAMResidualShared, self).__init__()
        self.ChannelGate = ChannelGate(semantic_dim, semantic_assert)
        self.SpatialGate = SpatialGate()
        self.ConvFusion = nn.Conv2d(spatial_dim + semantic_dim, semantic_dim, 1)
        self.ConvAttentFusion = nn.Conv2d(spatial_dim + semantic_dim, semantic_dim, 1)

    def forward(self, spatial, semantic, sem_encod):
        spatial_ = self.SpatialGate(spatial)
        semantic_ = self.ChannelGate(
            semantic, sem_encod.squeeze(2)
        )  # [b, c] + [b, c, 1] => [b, c] + [b, c]

        feature_sum = spatial_ + semantic_
        feature_prod = spatial_ * semantic_

        x_out = self.ConvFusion(torch.cat([spatial, semantic], 1))
        x_att = self.ConvAttentFusion(torch.cat([feature_sum, feature_prod], 1))

        return x_out + x_att


# Scene Aware Activation (SAA) Module
class SAA(nn.Module):
    def __init__(self, spatial_dim, semantic_dim, semantic_assert):
        super(SAA, self).__init__()
        self.projection = nn.Conv2d(spatial_dim, semantic_dim, 1)
        self.cbam = CBAMResidualShared(semantic_dim, semantic_dim, semantic_assert)

    def forward(self, spatial_feature, semantic_feature, sem_encod):
        spatial_feature_proj = self.projection(spatial_feature)
        activated = self.cbam(spatial_feature_proj, semantic_feature, sem_encod)

        return activated

In [None]:
import torch
import torch.nn as nn
from model.backbone.ResNet import BasicConv


class UPerNet(nn.Module):
    def __init__(
        self,
        num_class=1,
        use_softmax=False,
        pool_scales=(1, 2, 3, 6),
        fpn_inplanes=(64, 576, 1152, 2368, 3584),
        fpn_dim=512,
    ):
        super(UPerNet, self).__init__()
        self.use_softmax = use_softmax

        # PPM Module
        self.ppm_pooling = []
        self.ppm_conv = []

        for scale in pool_scales:
            self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale))
            self.ppm_conv.append(
                BasicConv(fpn_inplanes[-1], 512, kernel_size=1, padding=0)
            )
        self.ppm_pooling = nn.ModuleList(self.ppm_pooling)
        self.ppm_conv = nn.ModuleList(self.ppm_conv)
        self.ppm_last_conv = BasicConv(
            fpn_inplanes[-1] + len(pool_scales) * 512, fpn_dim, padding=0
        )

        # FPN Module
        self.fpn_in = []
        for fpn_inplane in fpn_inplanes[:-1]:  # skip the top layer
            self.fpn_in.append(BasicConv(fpn_inplane, fpn_dim, kernel_size=1))
        self.fpn_in = nn.ModuleList(self.fpn_in)

        self.fpn_out = []
        for i in range(len(fpn_inplanes) - 1):  # skip the top layer
            self.fpn_out.append(
                nn.Sequential(
                    BasicConv(fpn_dim, fpn_dim),
                )
            )
        self.fpn_out = nn.ModuleList(self.fpn_out)

        self.conv_last = nn.Sequential(
            BasicConv(len(fpn_inplanes) * fpn_dim, fpn_dim),
            nn.Conv2d(fpn_dim, num_class, kernel_size=1),
        )

    def forward(self, conv_out):
        conv5 = conv_out[-1]

        input_size = conv5.size()
        ppm_out = [conv5]
        for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv):
            ppm_out.append(
                pool_conv(
                    nn.functional.interpolate(
                        pool_scale(conv5),
                        (input_size[2], input_size[3]),
                        mode="bilinear",
                        align_corners=False,
                    )
                )
            )
        ppm_out = torch.cat(ppm_out, 1)
        f = self.ppm_last_conv(ppm_out)

        fpn_feature_list = [f]
        for i in reversed(range(len(conv_out) - 1)):
            conv_x = conv_out[i]
            conv_x = self.fpn_in[i](conv_x)  # lateral branch

            f = nn.functional.interpolate(
                f, size=conv_x.size()[2:], mode="bilinear", align_corners=False
            )  # top-down branch
            f = conv_x + f

            fpn_feature_list.append(self.fpn_out[i](f))

        fpn_feature_list.reverse()  # [P2 - P5]
        output_size = fpn_feature_list[0].size()[2:]
        fusion_list = [fpn_feature_list[0]]
        for i in range(1, len(fpn_feature_list)):
            fusion_list.append(
                nn.functional.interpolate(
                    fpn_feature_list[i],
                    output_size,
                    mode="bilinear",
                    align_corners=False,
                )
            )
        fusion_out = torch.cat(fusion_list, 1)
        x = self.conv_last(fusion_out)

        return x

In [None]:
import numpy as np
import pydensecrf.densecrf as dcrf


def _sigmoid(x):
    return 1 / (1 + np.exp(-x))


def crf_refine(img, annos):
    assert img.dtype == np.uint8
    assert annos.dtype == np.uint8
    assert img.shape[:2] == annos.shape

    # {img, annos}: {np.array(uint8)}

    EPSILON = 1e-8

    M = 2
    tau = 1.05
    # CRF model setup
    d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], M)

    anno_norm = annos / 255.0

    n_energy = -np.log((1.0 - anno_norm + EPSILON)) / (tau * _sigmoid(1 - anno_norm))
    p_energy = -np.log(anno_norm + EPSILON) / (tau * _sigmoid(anno_norm))

    U = np.zeros((M, img.shape[0] * img.shape[1]), dtype="float32")
    U[0, :] = n_energy.flatten()
    U[1, :] = p_energy.flatten()

    d.setUnaryEnergy(U)

    d.addPairwiseGaussian(sxy=3, compat=3)
    d.addPairwiseBilateral(sxy=60, srgb=5, rgbim=img, compat=5)

    # Inference
    infer = np.array(d.inference(1)).astype("float32")
    res = infer[1, :]

    res = res * 255
    res = res.reshape(img.shape[:2])
    return res.astype("uint8")