In [None]:
import torch as t
import torch.nn as nn
import torchvision.models as models
from torchvision.ops import FeaturePyramidNetwork
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
from torchvision.ops.feature_pyramid_network import LastLevelMaxPool
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models.detection import MaskRCNN
from torchvision.ops import MultiScaleRoIAlign


class MobileNetV3_FPN(nn.Module):
    def __init__(self):
        super(MobileNetV3_FPN, self).__init__()
        backbone = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1).features
        layers = ['0', '1', '2', '9']

        for name, parameter in backbone.named_parameters():
            if all([not name.startswith(layer) for layer in layers]):
                parameter.requires_grad_(False)

        in_channels_list = [backbone[int(i)].out_channels for i in layers]
        return_nodes = {layer: f'{index}' for index, layer in enumerate(layers)}

        self.out_channels = 256
        self.extractor = create_feature_extractor(backbone, return_nodes=return_nodes)
        self.fpn = FeaturePyramidNetwork(
            in_channels_list=in_channels_list,
            out_channels=self.out_channels,
            extra_blocks=LastLevelMaxPool(),
        )
    
    def forward(self, x):
        x = self.extractor(x)
        x = self.fpn(x)
        return x


class MaskRCNN_MobileNetV3_FPN(nn.Module):
    def __init__(self):
        super(MaskRCNN_MobileNetV3_FPN, self).__init__()
        backbone = MobileNetV3_FPN()

        feature_maps = ['0', '1', '2', '3']
        anchor_sizes = ((32, 64, 128, 256),) * (len(feature_maps) + 1)
        aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)

        anchor_generator = AnchorGenerator(anchor_sizes,  aspect_ratios)

        box_roi_pool = MultiScaleRoIAlign(featmap_names=feature_maps, output_size=7, sampling_ratio=2)
        mask_roi_pooler = MultiScaleRoIAlign(featmap_names=feature_maps, output_size=14, sampling_ratio=2)

        maskrcnn = MaskRCNN(
            backbone=backbone,
            progress=True,
            max_size=640,
            num_classes=2,
            rpn_anchor_generator=anchor_generator,
            box_roi_pool=box_roi_pool,
            mask_roi_pool=mask_roi_pooler,
            box_detections_per_img=32
        )

        self.maskrcnn = maskrcnn

    def forward(self, x):
        return self.maskrcnn(x)