Skip to content

Commit

Permalink
Resnest soft import (#592)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgvaz committed Dec 15, 2020
1 parent a16822c commit 6cbd6d1
Show file tree
Hide file tree
Showing 5 changed files with 2,646 additions and 8 deletions.
19 changes: 11 additions & 8 deletions icevision/backbones/resnest_fpn.py
Expand Up @@ -13,28 +13,32 @@
resnet_fpn_backbone,
BackboneWithFPN,
)
import resnest.torch

from icevision.soft_dependencies import SoftDependencies

if SoftDependencies.resnest:
import resnest.torch


# ResNeSt Backbones
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
super().__init__()

def forward(self, x):
return x


# TODO: resnest_name?
def resnest_fpn_backbone(
resnest_name,
backbone_fn,
pretrained,
# norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=3,
returned_layers=None,
extra_blocks=None,
):
"""
Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.
Constructs a specified ResNest backbone with FPN on top. Freezes the specified number of layers in the backbone.
Examples::
Expand All @@ -61,8 +65,7 @@ def resnest_fpn_backbone(
trainable_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
backbone = resnest_name(pretrained=pretrained)
# backbone = resnest50(pretrained=pretrained)
backbone = backbone_fn(pretrained=pretrained)

backbone.fc = Identity()
backbone.avgpool = Identity()
Expand Down Expand Up @@ -98,7 +101,7 @@ def resnest_fpn_backbone(


def _resnest_fpn(name, pretrained: bool = True, **kwargs):
model = resnest_fpn_backbone(resnest_name=name, pretrained=pretrained, **kwargs)
model = resnest_fpn_backbone(backbone_fn=name, pretrained=pretrained, **kwargs)
patch_param_groups(model)

return model
Expand Down
1 change: 1 addition & 0 deletions icevision/soft_dependencies.py
Expand Up @@ -21,6 +21,7 @@ def __init__(self):
self.albumentations = soft_import("albumentations")
self.effdet = soft_import("effdet")
self.wandb = soft_import("wandb")
self.resnest = soft_import("resnest")

def check(self) -> Dict[str, bool]:
return self.__dict__.copy()
Expand Down

0 comments on commit 6cbd6d1

Please sign in to comment.