diff --git a/mantisshrimp/backbones/__init__.py b/mantisshrimp/backbones/__init__.py new file mode 100644 index 000000000..4240fe5d0 --- /dev/null +++ b/mantisshrimp/backbones/__init__.py @@ -0,0 +1,2 @@ +from .torchvision_backbones import * +from torchvision.models.detection.backbone_utils import * diff --git a/mantisshrimp/backbones/torchvision_backbones.py b/mantisshrimp/backbones/torchvision_backbones.py new file mode 100644 index 000000000..4ab5d6dd3 --- /dev/null +++ b/mantisshrimp/backbones/torchvision_backbones.py @@ -0,0 +1,82 @@ +# This imports the torchivsion defined backbones +__all__ = ["create_torchvision_backbone"] + +from mantisshrimp.imports import * + + +def create_torchvision_backbone(backbone: str, pretrained: bool): + # These creates models from torchvision directly, it uses imagent pretrained_weights + if backbone == "mobilenet": + mobile_net = torchvision.models.mobilenet_v2(pretrained=pretrained) + ft_backbone = mobile_net.features + ft_backbone.out_channels = 1280 + return ft_backbone + + elif backbone == "vgg11": + vgg_net = torchvision.models.vgg11(pretrained=pretrained) + ft_backbone = vgg_net.features + ft_backbone.out_channels = 512 + return ft_backbone + + elif backbone == "vgg13": + vgg_net = torchvision.models.vgg13(pretrained=pretrained) + ft_backbone = vgg_net.features + ft_backbone.out_channels = 512 + return ft_backbone + + elif backbone == "vgg16": + vgg_net = torchvision.models.vgg16(pretrained=pretrained) + ft_backbone = vgg_net.features + ft_backbone.out_channels = 512 + return ft_backbone + + elif backbone == "vgg19": + vgg_net = torchvision.models.vgg19(pretrained=pretrained) + ft_backbone = vgg_net.features + ft_backbone.out_channels = 512 + return ft_backbone + + elif backbone == "resnet18": + resnet_net = torchvision.models.resnet18(pretrained=pretrained) + modules = list(resnet_net.children())[:-1] + ft_backbone = nn.Sequential(*modules) + ft_backbone.out_channels = 512 + return ft_backbone + + elif backbone == "resnet34": + resnet_net = torchvision.models.resnet34(pretrained=pretrained) + modules = list(resnet_net.children())[:-1] + ft_backbone = nn.Sequential(*modules) + ft_backbone.out_channels = 512 + return ft_backbone + + elif backbone == "resnet50": + resnet_net = torchvision.models.resnet50(pretrained=pretrained) + modules = list(resnet_net.children())[:-1] + ft_backbone = nn.Sequential(*modules) + ft_backbone.out_channels = 2048 + return ft_backbone + + elif backbone == "resnet101": + resnet_net = torchvision.models.resnet101(pretrained=pretrained) + modules = list(resnet_net.children())[:-1] + ft_backbone = nn.Sequential(*modules) + ft_backbone.out_channels = 2048 + return ft_backbone + + elif backbone == "resnet152": + resnet_net = torchvision.models.resnet152(pretrained=pretrained) + modules = list(resnet_net.children())[:-1] + ft_backbone = nn.Sequential(*modules) + ft_backbone.out_channels = 2048 + return ft_backbone + + elif backbone == "resnext101_32x8d": + resnet_net = torchvision.models.resnext101_32x8d(pretrained=pretrained) + modules = list(resnet_net.children())[:-1] + ft_backbone = nn.Sequential(*modules) + ft_backbone.out_channels = 2048 + return ft_backbone + + else: + raise ValueError("No such backbone implemented in mantisshrimp") diff --git a/mantisshrimp/models/mantis_rcnn/mantis_faster_rcnn.py b/mantisshrimp/models/mantis_rcnn/mantis_faster_rcnn.py index 0067eefb4..e2aa458d6 100644 --- a/mantisshrimp/models/mantis_rcnn/mantis_faster_rcnn.py +++ b/mantisshrimp/models/mantis_rcnn/mantis_faster_rcnn.py @@ -4,16 +4,55 @@ from mantisshrimp.core import * from mantisshrimp.models.mantis_rcnn.rcnn_param_groups import * from mantisshrimp.models.mantis_rcnn.mantis_rcnn import * +from mantisshrimp.backbones import * class MantisFasterRCNN(MantisRCNN): + """ + Creates a flexible Faster RCNN implementation based on torchvision library. + Args: + n_class (int) : number of classes. Do not have class_id "0" it is reserved as background. n_class = number of classes to label + 1 for background. + """ + @delegates(FasterRCNN.__init__) - def __init__(self, n_class, h=256, pretrained=True, metrics=None, **kwargs): + def __init__( + self, n_class: int, backbone: nn.Module = None, metrics=None, **kwargs, + ): super().__init__(metrics=metrics) - self.n_class, self.h, self.pretrained = n_class, h, pretrained - self.m = fasterrcnn_resnet50_fpn(pretrained=self.pretrained, **kwargs) - in_features = self.m.roi_heads.box_predictor.cls_score.in_features - self.m.roi_heads.box_predictor = FastRCNNPredictor(in_features, self.n_class) + self.n_class = n_class + self.backbone = backbone + if backbone is None: + # Creates the default fasterrcnn as given in pytorch. Trained on COCO dataset + self.m = fasterrcnn_resnet50_fpn( + pretrained=False, num_classes=n_class, **kwargs, + ) + in_features = self.m.roi_heads.box_predictor.cls_score.in_features + self.m.roi_heads.box_predictor = FastRCNNPredictor(in_features, n_class) + else: + self.m = FasterRCNN(backbone, num_classes=n_class, **kwargs) + + @staticmethod + def get_backbone_by_name( + name: str, fpn: bool = True, pretrained: bool = True + ) -> nn.Module: + """ + Args: + backbone (str): If none creates a default resnet50_fpn model trained on MS COCO 2017 + Supported backones are: "resnet18", "resnet34","resnet50", "resnet101", "resnet152", + "resnext50_32x4d", "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2", as resnets with fpn backbones. + Without fpn backbones supported are: "resnet18", "resnet34", "resnet50","resnet101", + "resnet152", "resnext101_32x8d", "mobilenet", "vgg11", "vgg13", "vgg16", "vgg19", + pretrained (bool): Creates a pretrained backbone with imagenet weights. + """ + # Giving string as a backbone, which is either supported resnet or backbone + if fpn: + # Creates a torchvision resnet model with fpn added + # It returns BackboneWithFPN model + backbone = resnet_fpn_backbone(name, pretrained=pretrained) + else: + # This does not create fpn backbone, it is supported for all models + backbone = create_torchvision_backbone(name, pretrained=pretrained) + return backbone def forward(self, images, targets=None): return self.m(images, targets) @@ -26,7 +65,7 @@ def build_training_sample( imageid: int, img: np.ndarray, label: List[int], bbox: List[BBox], **kwargs, ): x = im2tensor(img) - # inject values when annotations are empty are disconsidered + # injected values when annotations are empty are disconsidered # because we mark label as 0 (background) _fake_box = [0, 1, 2, 3] y = { diff --git a/tests/backbones/test_backbones.py b/tests/backbones/test_backbones.py new file mode 100644 index 000000000..3721baf31 --- /dev/null +++ b/tests/backbones/test_backbones.py @@ -0,0 +1,27 @@ +from mantisshrimp.backbones import * +import pytest +import torch + + +def test_torchvision_backbones(): + supported_backbones = [ + "mobilenet", + "vgg11", + "vgg13", + "vgg16", + "vgg19", + "resnet18", + "resnet34", + "resnet50", + "resnet101", + "resnet152", + "resnext101_32x8d", + ] + pretrained_status = [True, False] + + for backbone in supported_backbones: + for is_pretrained in pretrained_status: + model = create_torchvision_backbone( + backbone=backbone, pretrained=is_pretrained + ) + assert isinstance(model, torch.nn.modules.container.Sequential) diff --git a/tests/conftest.py b/tests/conftest.py index 6587e4cc6..399e61b12 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ -import pytest +import pytest, requests, PIL from mantisshrimp import * +from mantisshrimp.imports import * @pytest.fixture(scope="module") @@ -16,3 +17,15 @@ def record(records): @pytest.fixture(scope="module") def data_sample(record): return default_prepare_record(record) + + +@pytest.fixture() +def image(): + # Get a big image because of these big CNNs + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + img = np.array(PIL.Image.open(requests.get(url, stream=True).raw)) + # Get a big size image for these big resnets + img = cv2.resize(img, (2048, 2048)) + tensor_img = im2tensor(img) + tensor_img = torch.unsqueeze(tensor_img, 0) + return tensor_img diff --git a/tests/models/test_mantis_fastercnn.py b/tests/models/test_mantis_fastercnn.py new file mode 100644 index 000000000..4cbdf442d --- /dev/null +++ b/tests/models/test_mantis_fastercnn.py @@ -0,0 +1,46 @@ +import pytest, torch +from mantisshrimp import * + + +@pytest.fixture(scope="session") +def batch(): + dataset = test_utils.sample_dataset() + dataloader = MantisFasterRCNN.dataloader(dataset, batch_size=2) + xb, yb = next(iter(dataloader)) + return xb, list(yb) + + +@pytest.mark.slow +@pytest.mark.parametrize("pretrained", [False, True]) +@pytest.mark.parametrize( + "backbone, fpn", + [ + (None, True), + ("mobilenet", False), + ("vgg11", False), + ("vgg13", False), + ("vgg16", False), + ("vgg19", False), + ("resnet18", False), + ("resnet34", False), + ("resnet50", False), + ("resnet18", True), + ("resnet34", True), + ("resnet50", True), + # these models are too big for github runners + # "resnet101", + # "resnet152", + # "resnext101_32x8d", + ], +) +def test_faster_rcnn_nonfpn_backbones(batch, backbone, fpn, pretrained): + if backbone is not None: + backbone = MantisFasterRCNN.get_backbone_by_name( + name=backbone, fpn=fpn, pretrained=pretrained + ) + model = MantisFasterRCNN(n_class=91, backbone=backbone) + with torch.no_grad(): + preds = model.forward(*batch) + assert set(preds.keys()) == set( + ["loss_classifier", "loss_box_reg", "loss_objectness", "loss_rpn_box_reg"] + )