This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for more backbones(mobilnet, vgg, densenet, resnext) & re…
…factor (#45) * add support for more backbones & refactor * fix imports * fix import paths * fix densenet model name * add comment for creating resnet backbone * remove model zoo * Update flash/vision/classification/backbones.py * fix tests to raise the right exception Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
- Loading branch information
1 parent
58d3a84
commit bc00c05
Showing
2 changed files
with
37 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from typing import Tuple | ||
|
||
import torch.nn as nn | ||
import torchvision | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
|
||
|
||
def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: | ||
|
||
model = getattr(torchvision.models, model_name, None) | ||
if model is None: | ||
raise MisconfigurationException(f"{model_name} is not supported by torchvision") | ||
|
||
if model_name in ["mobilenet_v2", "vgg11", "vgg13", "vgg16", "vgg19"]: | ||
model = model(pretrained=pretrained) | ||
backbone = model.features | ||
num_features = model.classifier[-1].in_features | ||
return backbone, num_features | ||
|
||
elif model_name in [ | ||
"resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d" | ||
]: | ||
model = model(pretrained=pretrained) | ||
# remove the last two layers & turn it into a Sequential model | ||
backbone = nn.Sequential(*list(model.children())[:-2]) | ||
num_features = model.fc.in_features | ||
return backbone, num_features | ||
|
||
elif model_name in ["densenet121", "densenet169", "densenet161", "densenet161"]: | ||
model = model(pretrained=pretrained) | ||
backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True)) | ||
num_features = model.classifier.in_features | ||
return backbone, num_features | ||
|
||
raise ValueError(f"{model_name} is not supported yet.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters