In [10]:
# https://github.com/open-mmlab/mmsegmentation/blob/v0.21.0/mmseg/models/decode_heads/segmenter_mask_head.py#L15

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import ModuleList

from torch.nn.modules.transformer import TransformerEncoder, TransformerEncoderLayer


class SimpleSegmenterMaskTransformerHead(nn.Module):
    def __init__(self, in_channels, num_layers=4, num_heads=8, embed_dims=256, **kwargs):
        super(SimpleSegmenterMaskTransformerHead, self).__init__(**kwargs)

        # Fixed parameters for simplicity
        mlp_ratio = 4
        # norm_cfg = dict(type='LN')
        # act_cfg = dict(type='GELU')
        self.init_std = 0.02
        self.num_classes = 2
        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            layer = nn.TransformerEncoderLayer(
                d_model=embed_dims,
                nhead=num_heads,
                dim_feedforward=mlp_ratio * embed_dims,
                dropout=0.1,
                activation=F.gelu,
                layer_norm_eps=1e-05,
                batch_first=True,
                norm_first=False,
                bias=True
            )
            self.layers.append(layer)

        self.dec_proj = nn.Linear(in_channels, embed_dims)
        self.cls_emb = nn.Parameter(torch.randn(1, self.num_classes, embed_dims))
        self.patch_proj = nn.Linear(embed_dims, embed_dims, bias=False)
        self.classes_proj = nn.Linear(embed_dims, embed_dims, bias=False)
        self.decoder_norm = nn.LayerNorm(embed_dims)
        self.mask_norm = nn.LayerNorm(self.num_classes)

    # def init_weights(self):
    #     nn.init.trunc_normal_(self.cls_emb, std=self.init_std)
    #     nn.init.trunc_normal_(self.patch_proj.weight, std=self.init_std)
    #     nn.init.trunc_normal_(self.classes_proj.weight, std=init_std)
    #     for layer in self.layers:
    #         layer.init_weights()

    def forward(self, inputs):
        x = inputs.permute(0, 2, 1) # b h c
        b, c, h = x.shape
        x = x.view(b, c, -1).permute(0, 2, 1)

        x = self.dec_proj(x)
        cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
        x = torch.cat((x, cls_emb), 1)
        for layer in self.layers:
            x = layer(x)
        x = self.decoder_norm(x)

        patches = self.patch_proj(x[:, :-self.num_classes])
        cls_seg_feat = self.classes_proj(x[:, -self.num_classes:])

        masks = patches @ cls_seg_feat.transpose(1, 2)
        masks = self.mask_norm(masks).view(b, h, -1)

        return masks


In [21]:
s = SimpleSegmenterMaskTransformerHead(6)

In [22]:
inp = torch.randn(2, 100, 6)
s(inp).shape

torch.Size([2, 100, 2])