New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add backbones #64
Add backbones #64
Conversation
Codecov Report
@@ Coverage Diff @@
## master #64 +/- ##
==========================================
+ Coverage 53.51% 56.00% +2.49%
==========================================
Files 53 55 +2
Lines 1637 1714 +77
==========================================
+ Hits 876 960 +84
+ Misses 761 754 -7
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
There are some suggestions I would like to give before merging
Currently, how can a user specify a custom backbone that is not defined in create_torchvision_backbone
? The way is implemented right now he would need to change the source code
What do you think of having something like
backbone = get_torchvision_backbone('mobile_net')
model = MantisFasterRCNN(backbone=backbone, ...)
Also, keep in mind that we want to support pytorch hub in the future, how can we write something that we can easily integrate with that?
Another tip, we can mention the issue here so it gets automatically closed when this merges =) Closes #61 |
For Pytorch hub, we could simply take CNN features from there. Maybe create a file in backbone folder as For user specified backbones, I'm unsure, user would have to write their CNN architecture and get features from their CNN architecture. This needs to be passed to our our code. |
I have made changes, to solve other stuff but we would need to think to accommodate those two as well. |
Also, I just checked the research paper of FasterRCNN. Originally the backbones were trained with ImageNet and the other layers were fine-tuned on COCO. So the torchvision layers with pretrained ImageNet weights should be fine. |
As suggested in the paper Anchors are not optimized in FasterRCNN it should be taken care as per the dataset needs. I guess users are aware of this. |
For tests, we use the pytest framework, it's really simple, you just need to write test functions, take a look at So basically, just create a file in the tests directory following the same folder structure of the implemented files, and start test very simple stuff. For the backbones, I suggested that you test the model output shapes, so create a dummy input (it can be a tensor full of zeros), feed to the network, and check if the shape is what you expected. For |
Do not create a new PR with the tests, instead keep adding to this one. When everything is done I'll merge it =) |
Backbones tests done. I will add the Faster RCNN tests. |
This is kinda of what I had in mind: class MantisFasterRCNN(MantisRCNN):
@delegates(FasterRCNN.__init__)
def __init__(self, n_class, backbone=None, metrics=None, **kwargs):
super().__init__(metrics=metrics)
self.n_class = n_class
if backbone is None:
# Creates the default fasterrcnn as given in pytorch. Trained on COCO dataset
self.m = fasterrcnn_resnet50_fpn(pretrained=True, **kwargs)
in_features = self.m.roi_heads.box_predictor.cls_score.in_features
self.m.roi_heads.box_predictor = FastRCNNPredictor(in_features, n_class)
else:
self.m = FasterRCNN(backbone=backbone, num_classes=n_class, **kwargs) |
Actually, looking at torchvisions source code, I see they already have a custom function for getting backbones resnet_fpn_backbone, it also add FPN on top which is nice. Any disadvantages of using this instead of You can take a look here for a guide on how to correctly set it up =) |
It has support for all ResNet models. Suppose we want to extend for other models such as VGG 16 (originally in paper) and mobilenet. Or in future from hub. We would need a generalized function. Hence I thought of this. This is provided in the same util which you gave. I guess we need to add create_fpn util which will create fpn for any backbone. As of now. Let us choice for ResNets with FPNs or without and and ResNet + other nets or own CNNs without fpn. Maybe below blocks will help us extend this further. |
As suggested we could pass the backbones as well, but it would create without fpn. I guess we can have an extra user argument to specify if they are using resnet, they need fpn or not.
This code taken from the same util might help us to add fpn to any backbone, |
I just checked the FPN paper. It was in 2017. Faster RCNN was done in 2014. So thats why torchvision has these 2 features. It wasn't part of orignal paper. From the FPN paper
FPN is completely optionally, Faster RCNN can work without it. So let us give users the choice. |
I have made the tests ultra granular now, non resnets non fpns, resnets not fpns, resnets fpns. But this is best testing I can try |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright! Initially I was a bit hesitant of making these following suggestions because what we have here is already very reasonable, but then I though "this should be a learning experience, so it's worth to spend some of my time showing how we can implement this in a even better and clearer way". So these are my thoughts:
Whenever you start seeing multiple indented blocks, it's a bad sign. Always try to stay away from if
statements, of course, an if
here and there is okay, but when they start to get nested caos starts to emerge (you can find multiple articles online saying why they are a bad).
Good news is that we always can refactor then, and this is how to do it:
Instead of allowing the user to pass multiple types to backbone
at the constructor, restrict it to nn.Module
, this way, our new __init__
looks like so:
def __init__(
self, n_class: int, backbone: nn.Module = None, metrics=None, **kwargs,
):
super().__init__(metrics=metrics)
self.n_class = n_class
self.backbone = backbone
if backbone is None:
# Creates the default fasterrcnn as given in pytorch. Trained on COCO dataset
self.m = fasterrcnn_resnet50_fpn(
pretrained=True, num_classes=n_class, **kwargs,
)
in_features = self.m.roi_heads.box_predictor.cls_score.in_features
self.m.roi_heads.box_predictor = FastRCNNPredictor(in_features, n_class)
else:
self.m = FasterRCNN(backbone, num_classes=n_class, **kwargs)
Much simpler right? And note that the amount of parameters we need to specify at __init__
also dropped significantly.
So now the question is, how to allow the user to specify the backbone by passing a str
? That's is certainly very handy...
Well, create a new function for that!
@staticmethod
def get_backbone_by_name(
name: str, fpn: bool = True, pretrained: bool = True
) -> nn.Module:
# Giving string as a backbone, which is either supported resnet or backbone
if fpn:
# Creates a torchvision resnet model with fpn added
# Will need to add support for other models with fpn as well
# Passing pretrained True will initiate backbone which was trained on ImageNet
if name in MantisFasterRCNN.supported_resnet_fpn_models:
# It returns BackboneWithFPN model
backbone = resnet_fpn_backbone(name, pretrained=pretrained)
else:
raise ValueError(f"{name} with FPN not supported")
else:
# This does not create fpn backbone, it is supported for all models
if name in MantisFasterRCNN.supported_non_fpn_models:
backbone = create_torchvision_backbone(name, pretrained=pretrained)
else:
raise ValueError(f"{name} not supported")
return backbone
Again you see the types of the arguments are very clear, this makes is so we have a very intuitive and easy to use API.
So, if the user wants a resnet101
with fpn he just have to do:
backbone = MantisFasterRCNN.get_backbone_by_name("resnet101")
model = MantisFasterRCNN(2, backbone=backbone)
You can even remove the if
s checking if the model name is in the supported_list
and let the errors be naturally thrown (because both resnet_fpn_backbone
and create_torchvision_backbone
will already throw errors if you pass incorrect names). Then the only thing you need to do is specify the supported names in the function docstring (as you done already for the constructor). And that's it, easy for the user, easy for the devs, easy to maintain 😉. Added benefit that if resnet_fpn_backbone
starts supporting new backbones we don't have to keep updating our supported_list
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I'll pull the PR locally and take a look at the tests.
I'll then modify what it's needed to fix then and merge, cool?
Absolutely fine 😀 |
This is the problem. Torchvision does not support this maybe. |
Does this covers majority of the uses cases? @pytest.mark.slow
@pytest.mark.parametrize("pretrained", [False, True])
@pytest.mark.parametrize(
"backbone, fpn",
[
("mobilenet", False),
("vgg11", False),
("vgg13", False),
("vgg16", False),
("vgg19", False),
("resnet18", False),
("resnet34", False),
("resnet50", False),
("resnet18", True),
("resnet34", True),
("resnet50", True),
# these models are too big for github runners
# "resnet101",
# "resnet152",
# "resnext101_32x8d",
],
)
def test_faster_rcnn_nonfpn_backbones(image, backbone, fpn, pretrained):
backbone = MantisFasterRCNN.get_backbone_by_name(
name=backbone, fpn=fpn, pretrained=pretrained
)
model = MantisFasterRCNN(n_class=3, backbone=backbone)
model.eval()
pred = model(image)
assert isinstance(pred, list)
assert set(["boxes", "labels", "scores"]) == set(pred[0].keys()) |
Yes it does. I will do the testing for other cases as well. |
The problems on the tests were being caused when the model was on Fixed it by testing on training mode, with the additional benefit that it also tests |
Please review the proposed changes, if you approve and the tests pass, I'll go ahead and merge this. Many thanks!!! |
Great. All tests are passing now. 👍 Can't wait for this feature to be merged. |
I have unit tested the
mantisshrimp/backbones/torchvision_backbones.py
file.How do I test the entire fasterRCNN code I am unsure. Please let me know so that I can test and create PRs next time.