Skip to content

Commit

Permalink
Merge pull request #94 from aaronjohnsabu1999/patch-1
Browse files Browse the repository at this point in the history
Added option for mobile-based models
  • Loading branch information
alankbi committed Aug 22, 2021
2 parents 70be159 + 43f2ac3 commit bd58a1b
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions detecto/core.py
Expand Up @@ -219,7 +219,7 @@ def __getitem__(self, idx):

class Model:

def __init__(self, classes=None, device=None, pretrained=True):
def __init__(self, classes=None, device=None, pretrained=True, modelname='fasterrcnn_resnet50_fpn'):
"""Initializes a machine learning model for object detection.
Models are built on top of PyTorch's `pre-trained models
<https://pytorch.org/docs/stable/torchvision/models.html>`_,
Expand All @@ -241,6 +241,8 @@ def __init__(self, classes=None, device=None, pretrained=True):
:param pretrained: (Optional) Whether to load pretrained weights or not.
Defaults to True.
:type pretrained: bool
:param modelname: (Optional) Name of the pretrained model
:type modelname: str
**Example**::
Expand All @@ -252,7 +254,14 @@ def __init__(self, classes=None, device=None, pretrained=True):
self._device = device if device else config['default_device']

# Load a model pre-trained on COCO
self._model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=pretrained)
if modelname == 'fasterrcnn_resnet50_fpn':
self._model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=pretrained)
elif modelname == 'fasterrcnn_mobilenet_v3_large_fpn':
self._model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=pretrained)
elif modelname == 'fasterrcnn_mobilenet_v3_large_320_fpn':
self._model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=pretrained)
else:
return ValueError('Unknown Pretrained Model')

if classes:
# Get the number of input features for the classifier
Expand Down

0 comments on commit bd58a1b

Please sign in to comment.