Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add support for more backbones(mobilnet, vgg, densenet, resnext) & re…
Browse files Browse the repository at this point in the history
…factor (#45)

* add support for more backbones & refactor

* fix imports

* fix import paths

* fix densenet model name

* add comment for creating resnet backbone

* remove model zoo

* Update flash/vision/classification/backbones.py

* fix tests to raise the right exception

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
kaushikb11 and carmocca committed Feb 2, 2021
1 parent 58d3a84 commit bc00c05
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 18 deletions.
35 changes: 35 additions & 0 deletions flash/vision/classification/backbones.py
@@ -0,0 +1,35 @@
from typing import Tuple

import torch.nn as nn
import torchvision
from pytorch_lightning.utilities.exceptions import MisconfigurationException


def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:

model = getattr(torchvision.models, model_name, None)
if model is None:
raise MisconfigurationException(f"{model_name} is not supported by torchvision")

if model_name in ["mobilenet_v2", "vgg11", "vgg13", "vgg16", "vgg19"]:
model = model(pretrained=pretrained)
backbone = model.features
num_features = model.classifier[-1].in_features
return backbone, num_features

elif model_name in [
"resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d"
]:
model = model(pretrained=pretrained)
# remove the last two layers & turn it into a Sequential model
backbone = nn.Sequential(*list(model.children())[:-2])
num_features = model.fc.in_features
return backbone, num_features

elif model_name in ["densenet121", "densenet169", "densenet161", "densenet161"]:
model = model(pretrained=pretrained)
backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True))
num_features = model.classifier.in_features
return backbone, num_features

raise ValueError(f"{model_name} is not supported yet.")
20 changes: 2 additions & 18 deletions flash/vision/classification/model.py
Expand Up @@ -21,19 +21,9 @@
from torch.nn import functional as F

from flash.core.classification import ClassificationTask
from flash.vision.classification.backbones import torchvision_backbone_and_num_features
from flash.vision.classification.data import ImageClassificationData, ImageClassificationDataPipeline

_resnet_backbone = lambda model: nn.Sequential(*list(model.children())[:-2]) # noqa: E731
_resnet_feats = lambda model: model.fc.in_features # noqa: E731

_backbones = {
"resnet18": (torchvision.models.resnet18, _resnet_backbone, _resnet_feats),
"resnet34": (torchvision.models.resnet34, _resnet_backbone, _resnet_feats),
"resnet50": (torchvision.models.resnet50, _resnet_backbone, _resnet_feats),
"resnet101": (torchvision.models.resnet101, _resnet_backbone, _resnet_feats),
"resnet152": (torchvision.models.resnet152, _resnet_backbone, _resnet_feats),
}


class ImageClassifier(ClassificationTask):
"""Task that classifies images.
Expand Down Expand Up @@ -69,13 +59,7 @@ def __init__(

self.save_hyperparameters()

if backbone not in _backbones:
raise MisconfigurationException(f"Backbone {backbone} is not yet supported")

backbone_fn, split, num_feats = _backbones[backbone]
backbone = backbone_fn(pretrained=pretrained)
self.backbone = split(backbone)
num_features = num_feats(backbone)
self.backbone, num_features = torchvision_backbone_and_num_features(backbone, pretrained)

self.head = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
Expand Down

0 comments on commit bc00c05

Please sign in to comment.