Skip to content

Commit

Permalink
Added nn.Module support for FasterRCNN backbone (#661)
Browse files Browse the repository at this point in the history
  • Loading branch information
abhayraw1 committed Nov 15, 2021
1 parent 307722a commit d58a9b9
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 10 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Expand Up @@ -10,6 +10,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added YOLO model ([#552](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/552))


- Added nn.Module support for FasterRCNN backbone ([#661](https://github.com/PyTorchLightning/lightning-bolts/pull/661))


### Changed

- VAE now uses deterministic KL divergence during training, previously estimated KL divergence by random sampling ([#760](https://github.com/PyTorchLightning/lightning-bolts/pull/760))
Expand Down
30 changes: 20 additions & 10 deletions pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py
@@ -1,5 +1,5 @@
from argparse import ArgumentParser
from typing import Any, Optional
from typing import Any, Optional, Union

import torch
from pytorch_lightning import LightningModule, Trainer, seed_everything
Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(
self,
learning_rate: float = 0.0001,
num_classes: int = 91,
backbone: Optional[str] = None,
backbone: Optional[Union[str, torch.nn.Module]] = None,
fpn: bool = True,
pretrained: bool = False,
pretrained_backbone: bool = True,
Expand All @@ -61,7 +61,7 @@ def __init__(
Args:
learning_rate: the learning rate
num_classes: number of detection classes (including background)
backbone: Pretained backbone CNN architecture.
backbone: Pretained backbone CNN architecture or torch.nn.Module instance.
fpn: If True, creates a Feature Pyramind Network on top of Resnet based CNNs.
pretrained: if true, returns a model pre-trained on COCO train2017
pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet
Expand All @@ -86,13 +86,23 @@ def __init__(
self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, self.num_classes)

else:
backbone_model = create_fasterrcnn_backbone(
self.backbone,
fpn,
pretrained_backbone,
trainable_backbone_layers,
**kwargs,
)
if isinstance(self.backbone, torch.nn.Module):
backbone_model = self.backbone
if pretrained_backbone:
import warnings

warnings.warn(
"You would need to load the pretrained state_dict yourself if you are "
"providing backbone of type torch.nn.Module / pl.LightningModule."
)
else:
backbone_model = create_fasterrcnn_backbone(
self.backbone,
fpn,
pretrained_backbone,
trainable_backbone_layers,
**kwargs,
)
self.model = torchvision_FasterRCNN(backbone_model, num_classes=num_classes, **kwargs)

def forward(self, x):
Expand Down
11 changes: 11 additions & 0 deletions tests/models/test_detection.py
Expand Up @@ -7,6 +7,7 @@

from pl_bolts.datasets import DummyDetectionDataset
from pl_bolts.models.detection import YOLO, FasterRCNN, YOLOConfiguration
from pl_bolts.models.detection.faster_rcnn import create_fasterrcnn_backbone
from pl_bolts.models.detection.yolo.yolo_layers import _aligned_iou
from tests import TEST_ROOT

Expand Down Expand Up @@ -42,6 +43,16 @@ def test_fasterrcnn_bbone_train(tmpdir):
trainer.fit(model, train_dl, valid_dl)


def test_fasterrcnn_pyt_module_bbone_train(tmpdir):
backbone = create_fasterrcnn_backbone(backbone="resnet18")
model = FasterRCNN(backbone=backbone, fpn=True, pretrained_backbone=False, pretrained=False)
train_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn)
valid_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn)

trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir)
trainer.fit(model, train_dl, valid_dl)


def test_yolo(tmpdir):
config_path = Path(TEST_ROOT) / "data" / "yolo.cfg"
config = YOLOConfiguration(config_path)
Expand Down

0 comments on commit d58a9b9

Please sign in to comment.