Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add backbones #64

Merged
merged 49 commits into from Jun 13, 2020
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
16b5468
Merge pull request #1 from lgvaz/master
oke-aditya Jun 9, 2020
50ee68c
added backbones
oke-aditya Jun 9, 2020
2d9d49b
reformatted with black
oke-aditya Jun 9, 2020
43b6669
fixed typo
oke-aditya Jun 9, 2020
db22928
reformatted with black again
oke-aditya Jun 9, 2020
f597b4a
added fix for backbones
oke-aditya Jun 10, 2020
0f85784
fixed typo, added test_backbones
oke-aditya Jun 10, 2020
c7faf2a
restructured, applied black
oke-aditya Jun 10, 2020
42cd607
fixes import error
oke-aditya Jun 10, 2020
3c72e4c
trying to fix some fastcore error
oke-aditya Jun 10, 2020
aec8aa8
reformat with black
oke-aditya Jun 10, 2020
00aa287
updated for consistency with resnet fpn
oke-aditya Jun 11, 2020
b41111c
added supported for resnet with fpn
oke-aditya Jun 11, 2020
79828ce
applied black
oke-aditya Jun 11, 2020
5dd749c
fixed test
oke-aditya Jun 11, 2020
8070445
fixed pretrained, added pytest for fasterrcnn
oke-aditya Jun 11, 2020
8151494
added changes, removed unnecessary file, fixing tests
oke-aditya Jun 11, 2020
fdf9bae
reformatted with black
oke-aditya Jun 11, 2020
0aece58
added test for dummy image
oke-aditya Jun 11, 2020
317c1f4
removed fixture, fixed import
oke-aditya Jun 11, 2020
da8f480
reformat black
oke-aditya Jun 11, 2020
ccf4026
another endless test fix (import numpy)
oke-aditya Jun 11, 2020
36e7352
resized to a big image
oke-aditya Jun 11, 2020
48333fe
fixed typo
oke-aditya Jun 11, 2020
a0d79c3
checking no fpn right now
oke-aditya Jun 11, 2020
2b30201
avoiding bigger resnets
oke-aditya Jun 11, 2020
1b4f94e
reversed testing order, re-reading images
oke-aditya Jun 11, 2020
a2e3d1c
no idea what error, but lets see
oke-aditya Jun 11, 2020
7a1fcfb
blacked it
oke-aditya Jun 11, 2020
d652af8
Final Check, double surety, I just added comment
oke-aditya Jun 11, 2020
9cf43da
added custom cnn tests and support, splitted tests
oke-aditya Jun 12, 2020
8e875fd
These are same files, dependency error isn't this codes problem
oke-aditya Jun 12, 2020
e46fced
fixed custom cnn, revamped test
oke-aditya Jun 12, 2020
fc0b46c
black applied
oke-aditya Jun 12, 2020
db2435d
added tests, this is final time
oke-aditya Jun 12, 2020
57b4a86
revamped the tests
oke-aditya Jun 12, 2020
522261b
incorporate the changes suggested, restructured the tests
oke-aditya Jun 12, 2020
75b9696
Ooops I had deleted some code :-( readded it
oke-aditya Jun 12, 2020
c40578e
removed if condition as suggested
oke-aditya Jun 12, 2020
9d41d04
made every pretrained in tests to false
oke-aditya Jun 12, 2020
bbcbeac
set the fasterrcnn to false
oke-aditya Jun 12, 2020
e551103
fixed everything now, these models must work
oke-aditya Jun 12, 2020
4b4b461
final try for nonfpn resnets
oke-aditya Jun 12, 2020
112f4ee
attempts test_faster_rcnn_backbones
lgvaz Jun 13, 2020
1bf5d49
tests fpn, removes models that are too big
lgvaz Jun 13, 2020
b0b6e37
run tests in training mode
lgvaz Jun 13, 2020
a39b696
cleans code
lgvaz Jun 13, 2020
9fe0800
Merge branch 'master' into add_backbones
lgvaz Jun 13, 2020
97648a2
cleans code
lgvaz Jun 13, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions mantisshrimp/backbones/__init__.py
@@ -0,0 +1,2 @@
from .torchvision_backbones import *
from torchvision.models.detection.backbone_utils import *
85 changes: 85 additions & 0 deletions mantisshrimp/backbones/torchvision_backbones.py
@@ -0,0 +1,85 @@
# This imports the torchivsion defined backbones
__all__ = ["create_torchvision_backbone"]

from mantisshrimp.imports import *


def create_torchvision_backbone(backbone: str, pretrained: bool):
# These creates models from torchvision directly, it uses imagent pretrained_weights
if backbone == "mobilenet":
mobile_net = torchvision.models.mobilenet_v2(pretrained=pretrained)
# print(mobile_net.features) # From that I got the output channels for mobilenet
ft_backbone = mobile_net.features
ft_backbone.out_channels = 1280
return ft_backbone

elif backbone == "vgg11":
vgg_net = torchvision.models.vgg11(pretrained=pretrained)
ft_backbone = vgg_net.features
ft_backbone.out_channels = 512
return ft_backbone

elif backbone == "vgg13":
vgg_net = torchvision.models.vgg13(pretrained=pretrained)
ft_backbone = vgg_net.features
ft_backbone.out_channels = 512
return ft_backbone

elif backbone == "vgg16":
vgg_net = torchvision.models.vgg16(pretrained=pretrained)
ft_backbone = vgg_net.features
ft_backbone.out_channels = 512
return ft_backbone

elif backbone == "vgg19":
vgg_net = torchvision.models.vgg19(pretrained=pretrained)
ft_backbone = vgg_net.features
ft_backbone.out_channels = 512
return ft_backbone

elif backbone == "resnet18":
resnet_net = torchvision.models.resnet18(pretrained=pretrained)
modules = list(resnet_net.children())[:-1]
ft_backbone = nn.Sequential(*modules)
ft_backbone.out_channels = 512
return ft_backbone

elif backbone == "resnet34":
resnet_net = torchvision.models.resnet34(pretrained=pretrained)
modules = list(resnet_net.children())[:-1]
ft_backbone = nn.Sequential(*modules)
ft_backbone.out_channels = 512
return ft_backbone

elif backbone == "resnet50":
resnet_net = torchvision.models.resnet50(pretrained=pretrained)
modules = list(resnet_net.children())[:-1]
ft_backbone = nn.Sequential(*modules)
ft_backbone.out_channels = 2048
return ft_backbone

elif backbone == "resnet101":
resnet_net = torchvision.models.resnet101(pretrained=pretrained)
modules = list(resnet_net.children())[:-1]
ft_backbone = nn.Sequential(*modules)
ft_backbone.out_channels = 2048
return ft_backbone

elif backbone == "resnet152":
resnet_net = torchvision.models.resnet152(pretrained=pretrained)
modules = list(resnet_net.children())[:-1]
ft_backbone = nn.Sequential(*modules)
ft_backbone.out_channels = 2048
return ft_backbone

elif backbone == "resnext101_32x8d":
resnet_net = torchvision.models.resnext101_32x8d(pretrained=pretrained)
modules = list(resnet_net.children())[:-1]
ft_backbone = nn.Sequential(*modules)
ft_backbone.out_channels = 2048
return ft_backbone
# print(ft_model)

else:
# print("Error Wrong unsupported Backbone")
raise NotImplementedError("No such backbone implemented in mantisshrimp")
120 changes: 115 additions & 5 deletions mantisshrimp/models/mantis_rcnn/mantis_faster_rcnn.py
Expand Up @@ -4,16 +4,126 @@
from mantisshrimp.core import *
from mantisshrimp.models.mantis_rcnn.rcnn_param_groups import *
from mantisshrimp.models.mantis_rcnn.mantis_rcnn import *
from mantisshrimp.backbones import *


class MantisFasterRCNN(MantisRCNN):
"""
Creates a flexible Faster RCNN implementation based on torchvision library.
Args:
n_class : number of classes. Do not have class_id "0" it is reserved as background. n_class = number of classes to label + 1 for background.
backbone: If none creates a default resnet50_fpn model trained on MS COCO 2017
Supported backones are: "resnet18", "resnet34","resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2",
as resnets with fpn backbones.
Without fpn backbones supported are: "resnet18", "resnet34", "resnet50","resnet101", "resnet152", "resnext101_32x8d", "mobilenet", "vgg11", "vgg13", "vgg16", "vgg19",
pretrained: Creates a pretrained backbone with imagenet weights.
fpn: If True it can use one of the fpn supported backbones else it will create Faster RCNN without FPN with fpn unsupported backbones.
metrics: Specific metrics for the model
out_channels: If defining a custom CNN as backbone, pass the output channels of laster layer
"""

@delegates(FasterRCNN.__init__)
def __init__(self, n_class, h=256, pretrained=True, metrics=None, **kwargs):
def __init__(
self,
n_class,
backbone=None,
pretrained=True,
fpn=True,
metrics=None,
out_channels=None,
**kwargs,
):
super().__init__(metrics=metrics)
self.n_class, self.h, self.pretrained = n_class, h, pretrained
self.m = fasterrcnn_resnet50_fpn(pretrained=self.pretrained, **kwargs)
in_features = self.m.roi_heads.box_predictor.cls_score.in_features
self.m.roi_heads.box_predictor = FastRCNNPredictor(in_features, self.n_class)
self.n_class = n_class
self.backbone = backbone
self.pretrained = pretrained
self.fpn = fpn
self.out_channels = out_channels

self.supported_resnet_fpn_models = [
"resnet18",
"resnet34",
"resnet50",
"resnet101",
"resnet152",
"resnext50_32x4d",
"resnext101_32x8d",
"wide_resnet50_2",
"wide_resnet101_2",
]

self.supported_non_fpn_models = [
"resnet18",
"resnet34",
"resnet50",
"resnet101",
"resnet152",
# "resnext50_32x4d",
"resnext101_32x8d",
# "wide_resnet50_2",
# "wide_resnet101_2",
"mobilenet",
"vgg11",
"vgg13",
"vgg16",
"vgg19",
]

if self.backbone is None:
# Creates the default fasterrcnn as given in pytorch. Trained on COCO dataset
self.m = fasterrcnn_resnet50_fpn(
pretrained=self.pretrained,
num_classes=self.n_class,
pretrained_backbone=True,
**kwargs,
)
in_features = self.m.roi_heads.box_predictor.cls_score.in_features
self.m.roi_heads.box_predictor = FastRCNNPredictor(
in_features, self.n_class
)

elif isinstance(self.backbone, str):
# Giving string as a backbone, which is either supported resnet or backbone
if self.fpn is True:
# Creates a torchvision resnet model with fpn added
# Will need to add support for other models with fpn as well
# Passing pretrained True will initiate backbone which was trained on ImageNet
if self.backbone in self.supported_resnet_fpn_models:
# It returns BackboneWithFPN model
backbone = resnet_fpn_backbone(
backbone_name=self.backbone, pretrained=self.pretrained
)
self.m = FasterRCNN(backbone, self.n_class, **kwargs)
else:
raise NotImplementedError(
"FPN for non resnets is not supported yet"
)

else:
# This does not create fpn backbone, it is supported for all models
if self.backbone in self.supported_non_fpn_models:
backbone = create_torchvision_backbone(
backbone=self.backbone, pretrained=self.pretrained
)
self.m = FasterRCNN(
backbone=backbone, num_classes=self.n_class, **kwargs
)
else:
raise NotImplementedError(
"Non FPN for this model is not supported yet"
)

elif isinstance(self.backbone, torch.nn.Module):
# Trying to create the backbone from CNN passed.
try:
modules = list(self.backbone.children())[:-1]
backbone = nn.Sequential(*modules)
backbone.out_channels = self.out_channels
self.m = FasterRCNN(
backbone=backbone, num_classes=self.n_class, **kwargs
)
except Exception:
raise ("Could not parse your CNN as RCNN backbone")

def forward(self, images, targets=None):
return self.m(images, targets)
Expand Down
27 changes: 27 additions & 0 deletions tests/backbones/test_backbones.py
@@ -0,0 +1,27 @@
from mantisshrimp.backbones import *
import pytest
import torch


def test_torchvision_backbones():
supported_backbones = [
oke-aditya marked this conversation as resolved.
Show resolved Hide resolved
"mobilenet",
"vgg11",
"vgg13",
"vgg16",
"vgg19",
"resnet18",
"resnet34",
"resnet50",
"resnet101",
"resnet152",
"resnext101_32x8d",
]
pretrained_status = [True, False]

for backbone in supported_backbones:
for is_pretrained in pretrained_status:
model = create_torchvision_backbone(
backbone=backbone, pretrained=is_pretrained
)
assert isinstance(model, torch.nn.modules.container.Sequential)
38 changes: 38 additions & 0 deletions tests/models/test_fastercnn_custom_cnn.py
@@ -0,0 +1,38 @@
import pytest
import mantisshrimp
from mantisshrimp.models.mantis_rcnn import *
from torchvision.transforms.functional import to_tensor as im2tensor
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import requests
import numpy as np
import cv2
from PIL import Image

# Passing a custom CNN as a Backbone should be supported


def get_image():
# Get a big image because of these big CNNs
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
img = np.array(Image.open(requests.get(url, stream=True).raw))
# Get a big size image for these big resnets
img = cv2.resize(img, (2048, 2048))
tensor_img = im2tensor(img)
tensor_img = torch.unsqueeze(tensor_img, 0)
return tensor_img


# Just pass a resnet18 as if user wrote it


def test_custom_backbone():
backbone = torchvision.models.resnet18(pretrained=False)
model = MantisFasterRCNN(n_class=10, backbone=backbone, out_channels=512)
model.eval()
print("Testing custom backbone")
image = get_image()
pred = model(image)
assert isinstance(pred, list)