Skip to content

Commit

Permalink
Add backbones (#64)
Browse files Browse the repository at this point in the history
Co-authored-by: lgvaz <lucasgouvaz@gmail.com>
  • Loading branch information
oke-aditya and lgvaz committed Jun 13, 2020
1 parent 864ffe3 commit 48f6f5d
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 7 deletions.
2 changes: 2 additions & 0 deletions mantisshrimp/backbones/__init__.py
@@ -0,0 +1,2 @@
from .torchvision_backbones import *
from torchvision.models.detection.backbone_utils import *
82 changes: 82 additions & 0 deletions mantisshrimp/backbones/torchvision_backbones.py
@@ -0,0 +1,82 @@
# This imports the torchivsion defined backbones
__all__ = ["create_torchvision_backbone"]

from mantisshrimp.imports import *


def create_torchvision_backbone(backbone: str, pretrained: bool):
# These creates models from torchvision directly, it uses imagent pretrained_weights
if backbone == "mobilenet":
mobile_net = torchvision.models.mobilenet_v2(pretrained=pretrained)
ft_backbone = mobile_net.features
ft_backbone.out_channels = 1280
return ft_backbone

elif backbone == "vgg11":
vgg_net = torchvision.models.vgg11(pretrained=pretrained)
ft_backbone = vgg_net.features
ft_backbone.out_channels = 512
return ft_backbone

elif backbone == "vgg13":
vgg_net = torchvision.models.vgg13(pretrained=pretrained)
ft_backbone = vgg_net.features
ft_backbone.out_channels = 512
return ft_backbone

elif backbone == "vgg16":
vgg_net = torchvision.models.vgg16(pretrained=pretrained)
ft_backbone = vgg_net.features
ft_backbone.out_channels = 512
return ft_backbone

elif backbone == "vgg19":
vgg_net = torchvision.models.vgg19(pretrained=pretrained)
ft_backbone = vgg_net.features
ft_backbone.out_channels = 512
return ft_backbone

elif backbone == "resnet18":
resnet_net = torchvision.models.resnet18(pretrained=pretrained)
modules = list(resnet_net.children())[:-1]
ft_backbone = nn.Sequential(*modules)
ft_backbone.out_channels = 512
return ft_backbone

elif backbone == "resnet34":
resnet_net = torchvision.models.resnet34(pretrained=pretrained)
modules = list(resnet_net.children())[:-1]
ft_backbone = nn.Sequential(*modules)
ft_backbone.out_channels = 512
return ft_backbone

elif backbone == "resnet50":
resnet_net = torchvision.models.resnet50(pretrained=pretrained)
modules = list(resnet_net.children())[:-1]
ft_backbone = nn.Sequential(*modules)
ft_backbone.out_channels = 2048
return ft_backbone

elif backbone == "resnet101":
resnet_net = torchvision.models.resnet101(pretrained=pretrained)
modules = list(resnet_net.children())[:-1]
ft_backbone = nn.Sequential(*modules)
ft_backbone.out_channels = 2048
return ft_backbone

elif backbone == "resnet152":
resnet_net = torchvision.models.resnet152(pretrained=pretrained)
modules = list(resnet_net.children())[:-1]
ft_backbone = nn.Sequential(*modules)
ft_backbone.out_channels = 2048
return ft_backbone

elif backbone == "resnext101_32x8d":
resnet_net = torchvision.models.resnext101_32x8d(pretrained=pretrained)
modules = list(resnet_net.children())[:-1]
ft_backbone = nn.Sequential(*modules)
ft_backbone.out_channels = 2048
return ft_backbone

else:
raise ValueError("No such backbone implemented in mantisshrimp")
51 changes: 45 additions & 6 deletions mantisshrimp/models/mantis_rcnn/mantis_faster_rcnn.py
Expand Up @@ -4,16 +4,55 @@
from mantisshrimp.core import *
from mantisshrimp.models.mantis_rcnn.rcnn_param_groups import *
from mantisshrimp.models.mantis_rcnn.mantis_rcnn import *
from mantisshrimp.backbones import *


class MantisFasterRCNN(MantisRCNN):
"""
Creates a flexible Faster RCNN implementation based on torchvision library.
Args:
n_class (int) : number of classes. Do not have class_id "0" it is reserved as background. n_class = number of classes to label + 1 for background.
"""

@delegates(FasterRCNN.__init__)
def __init__(self, n_class, h=256, pretrained=True, metrics=None, **kwargs):
def __init__(
self, n_class: int, backbone: nn.Module = None, metrics=None, **kwargs,
):
super().__init__(metrics=metrics)
self.n_class, self.h, self.pretrained = n_class, h, pretrained
self.m = fasterrcnn_resnet50_fpn(pretrained=self.pretrained, **kwargs)
in_features = self.m.roi_heads.box_predictor.cls_score.in_features
self.m.roi_heads.box_predictor = FastRCNNPredictor(in_features, self.n_class)
self.n_class = n_class
self.backbone = backbone
if backbone is None:
# Creates the default fasterrcnn as given in pytorch. Trained on COCO dataset
self.m = fasterrcnn_resnet50_fpn(
pretrained=False, num_classes=n_class, **kwargs,
)
in_features = self.m.roi_heads.box_predictor.cls_score.in_features
self.m.roi_heads.box_predictor = FastRCNNPredictor(in_features, n_class)
else:
self.m = FasterRCNN(backbone, num_classes=n_class, **kwargs)

@staticmethod
def get_backbone_by_name(
name: str, fpn: bool = True, pretrained: bool = True
) -> nn.Module:
"""
Args:
backbone (str): If none creates a default resnet50_fpn model trained on MS COCO 2017
Supported backones are: "resnet18", "resnet34","resnet50", "resnet101", "resnet152",
"resnext50_32x4d", "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2", as resnets with fpn backbones.
Without fpn backbones supported are: "resnet18", "resnet34", "resnet50","resnet101",
"resnet152", "resnext101_32x8d", "mobilenet", "vgg11", "vgg13", "vgg16", "vgg19",
pretrained (bool): Creates a pretrained backbone with imagenet weights.
"""
# Giving string as a backbone, which is either supported resnet or backbone
if fpn:
# Creates a torchvision resnet model with fpn added
# It returns BackboneWithFPN model
backbone = resnet_fpn_backbone(name, pretrained=pretrained)
else:
# This does not create fpn backbone, it is supported for all models
backbone = create_torchvision_backbone(name, pretrained=pretrained)
return backbone

def forward(self, images, targets=None):
return self.m(images, targets)
Expand All @@ -26,7 +65,7 @@ def build_training_sample(
imageid: int, img: np.ndarray, label: List[int], bbox: List[BBox], **kwargs,
):
x = im2tensor(img)
# inject values when annotations are empty are disconsidered
# injected values when annotations are empty are disconsidered
# because we mark label as 0 (background)
_fake_box = [0, 1, 2, 3]
y = {
Expand Down
27 changes: 27 additions & 0 deletions tests/backbones/test_backbones.py
@@ -0,0 +1,27 @@
from mantisshrimp.backbones import *
import pytest
import torch


def test_torchvision_backbones():
supported_backbones = [
"mobilenet",
"vgg11",
"vgg13",
"vgg16",
"vgg19",
"resnet18",
"resnet34",
"resnet50",
"resnet101",
"resnet152",
"resnext101_32x8d",
]
pretrained_status = [True, False]

for backbone in supported_backbones:
for is_pretrained in pretrained_status:
model = create_torchvision_backbone(
backbone=backbone, pretrained=is_pretrained
)
assert isinstance(model, torch.nn.modules.container.Sequential)
15 changes: 14 additions & 1 deletion tests/conftest.py
@@ -1,5 +1,6 @@
import pytest
import pytest, requests, PIL
from mantisshrimp import *
from mantisshrimp.imports import *


@pytest.fixture(scope="module")
Expand All @@ -16,3 +17,15 @@ def record(records):
@pytest.fixture(scope="module")
def data_sample(record):
return default_prepare_record(record)


@pytest.fixture()
def image():
# Get a big image because of these big CNNs
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
img = np.array(PIL.Image.open(requests.get(url, stream=True).raw))
# Get a big size image for these big resnets
img = cv2.resize(img, (2048, 2048))
tensor_img = im2tensor(img)
tensor_img = torch.unsqueeze(tensor_img, 0)
return tensor_img
46 changes: 46 additions & 0 deletions tests/models/test_mantis_fastercnn.py
@@ -0,0 +1,46 @@
import pytest, torch
from mantisshrimp import *


@pytest.fixture(scope="session")
def batch():
dataset = test_utils.sample_dataset()
dataloader = MantisFasterRCNN.dataloader(dataset, batch_size=2)
xb, yb = next(iter(dataloader))
return xb, list(yb)


@pytest.mark.slow
@pytest.mark.parametrize("pretrained", [False, True])
@pytest.mark.parametrize(
"backbone, fpn",
[
(None, True),
("mobilenet", False),
("vgg11", False),
("vgg13", False),
("vgg16", False),
("vgg19", False),
("resnet18", False),
("resnet34", False),
("resnet50", False),
("resnet18", True),
("resnet34", True),
("resnet50", True),
# these models are too big for github runners
# "resnet101",
# "resnet152",
# "resnext101_32x8d",
],
)
def test_faster_rcnn_nonfpn_backbones(batch, backbone, fpn, pretrained):
if backbone is not None:
backbone = MantisFasterRCNN.get_backbone_by_name(
name=backbone, fpn=fpn, pretrained=pretrained
)
model = MantisFasterRCNN(n_class=91, backbone=backbone)
with torch.no_grad():
preds = model.forward(*batch)
assert set(preds.keys()) == set(
["loss_classifier", "loss_box_reg", "loss_objectness", "loss_rpn_box_reg"]
)

0 comments on commit 48f6f5d

Please sign in to comment.