Skip to content

Commit

Permalink
Add MobileNetV3
Browse files Browse the repository at this point in the history
  • Loading branch information
ericup committed Jul 7, 2021
1 parent 8d7a8db commit 1346561
Show file tree
Hide file tree
Showing 5 changed files with 268 additions and 15 deletions.
33 changes: 21 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ If your data contains images and [label images](https://scikit-image.org/docs/de
- `models.CpnResNeXt152FPN`
- `models.CpnWideResNet50FPN`
- `models.CpnWideResNet101FPN`
- `models.CpnMobileNetV3SmallFPN`
- `models.CpnMobileNetV3LargeFPN`
- `models.CPN`

###### U-Nets:
Expand All @@ -44,18 +46,6 @@ If your data contains images and [label images](https://scikit-image.org/docs/de
- `models.UNetEncoder`
- `models.UNet`

###### Residual Networks:
- `models.ResNet18`
- `models.ResNet34`
- `models.ResNet50`
- `models.ResNet101`
- `models.ResNet152`
- `models.ResNeXt50_32x4d`
- `models.ResNeXt101_32x8d`
- `models.ResNeXt152_32x8d`
- `models.WideResNet50_2`
- `models.WideResNet101_2`

###### Feature Pyramid Networks:
- `models.ResNet18FPN`
- `models.ResNet34FPN`
Expand All @@ -67,8 +57,27 @@ If your data contains images and [label images](https://scikit-image.org/docs/de
- `models.ResNeXt152FPN`
- `models.WideResNet50FPN`
- `models.WideResNet101FPN`
- `models.MobileNetV3SmallFPN`
- `models.MobileNetV3LargeFPN`
- `models.FPN`

###### Residual Networks:
- `models.ResNet18`
- `models.ResNet34`
- `models.ResNet50`
- `models.ResNet101`
- `models.ResNet152`
- `models.ResNeXt50_32x4d`
- `models.ResNeXt101_32x8d`
- `models.ResNeXt152_32x8d`
- `models.WideResNet50_2`
- `models.WideResNet101_2`

###### Mobile Networks:
- `models.MobileNetV3Small`
- `models.MobileNetV3Large`



## 📝 Citing

Expand Down
5 changes: 3 additions & 2 deletions celldetection/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
ResNeXt152_32x8d, ResNeXt101_32x8d, ResNeXt50_32x4d
from .cpn import CPN, CpnSlimU22, CpnU22, CpnWideU22, CpnResNet18FPN, CpnResNet34FPN, CpnResNet50FPN, CpnResNet101FPN, \
CpnResNet152FPN, CpnResNeXt50FPN, CpnResNeXt101FPN, CpnResNeXt152FPN, CpnWideResNet50FPN, \
CpnWideResNet101FPN
CpnWideResNet101FPN, CpnMobileNetV3LargeFPN, CpnMobileNetV3SmallFPN
from .fpn import FPN, ResNeXt50FPN, ResNeXt101FPN, ResNet18FPN, ResNet34FPN, ResNeXt152FPN, WideResNet50FPN, \
WideResNet101FPN, ResNet50FPN, ResNet101FPN, ResNet152FPN
WideResNet101FPN, ResNet50FPN, ResNet101FPN, ResNet152FPN, MobileNetV3SmallFPN, MobileNetV3LargeFPN
from .inference import Inference
from .mobilenetv3 import *
108 changes: 107 additions & 1 deletion celldetection/models/cpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
order_weighting, resolve_refinement_buckets
from .unet import U22, SlimU22, WideU22
from .fpn import ResNet34FPN, ResNet18FPN, ResNet50FPN, ResNet101FPN, ResNet152FPN, ResNeXt50FPN, \
ResNeXt101FPN, ResNeXt152FPN, WideResNet50FPN, WideResNet101FPN
ResNeXt101FPN, ResNeXt152FPN, WideResNet50FPN, WideResNet101FPN, MobileNetV3LargeFPN, MobileNetV3SmallFPN


class ReadOut(nn.Module):
Expand Down Expand Up @@ -1185,6 +1185,112 @@ def __init__(
)


class CpnMobileNetV3SmallFPN(CPN):
def __init__(
self,
in_channels: int,
order: int = 5,
nms_thresh: float = .2,
score_thresh: float = .5,
samples: int = 32,
classes: int = 2,
refinement: bool = True,
refinement_iterations: int = 4,
refinement_margin: float = 3.,
refinement_buckets: int = 1,
backbone_kwargs: dict = None,
**kwargs
):
""" Contour Proposal Network with small MobileNetV3 FPN backbone.
Args:
in_channels: Number of input channels.
order: Contour order. The higher, the more complex contours can be proposed.
`order=1` restricts the CPN to propose ellipses, `order=3` allows for non-convex rough outlines,
`order=8` allows even finer detail.
nms_thresh: IOU threshold for non-maximum suppression (NMS). NMS considers all objects with
`iou > nms_thresh` to be identical.
score_thresh: Score threshold. For binary classification problems (object vs. background) an object must
have `score > score_thresh` to be proposed as a result.
samples: Number of samples. This sets the number of coordinates with which a contour is defined.
This setting can be changed on the fly, e.g. small for training and large for inference.
Small settings reduces computational costs, while larger settings capture more detail.
classes: Number of classes. Default: 2 (object vs. background).
refinement: Whether to use local refinement or not.
refinement_iterations: Number of refinement iterations.
refinement_margin: Maximum refinement margin (step size) per iteration.
refinement_buckets: Number of refinement buckets.
backbone_kwargs: Optional keyword arguments for backbone.
**kwargs: See docstring of CPN.
"""
super().__init__(
backbone=MobileNetV3SmallFPN(in_channels, **(backbone_kwargs or {})),
order=order,
nms_thresh=nms_thresh,
score_thresh=score_thresh,
samples=samples,
classes=classes,
refinement=refinement,
refinement_iterations=refinement_iterations,
refinement_margin=refinement_margin,
refinement_buckets=refinement_buckets,
**kwargs
)


class CpnMobileNetV3LargeFPN(CPN):
def __init__(
self,
in_channels: int,
order: int = 5,
nms_thresh: float = .2,
score_thresh: float = .5,
samples: int = 32,
classes: int = 2,
refinement: bool = True,
refinement_iterations: int = 4,
refinement_margin: float = 3.,
refinement_buckets: int = 1,
backbone_kwargs: dict = None,
**kwargs
):
""" Contour Proposal Network with large MobileNetV3 FPN backbone.
Args:
in_channels: Number of input channels.
order: Contour order. The higher, the more complex contours can be proposed.
`order=1` restricts the CPN to propose ellipses, `order=3` allows for non-convex rough outlines,
`order=8` allows even finer detail.
nms_thresh: IOU threshold for non-maximum suppression (NMS). NMS considers all objects with
`iou > nms_thresh` to be identical.
score_thresh: Score threshold. For binary classification problems (object vs. background) an object must
have `score > score_thresh` to be proposed as a result.
samples: Number of samples. This sets the number of coordinates with which a contour is defined.
This setting can be changed on the fly, e.g. small for training and large for inference.
Small settings reduces computational costs, while larger settings capture more detail.
classes: Number of classes. Default: 2 (object vs. background).
refinement: Whether to use local refinement or not.
refinement_iterations: Number of refinement iterations.
refinement_margin: Maximum refinement margin (step size) per iteration.
refinement_buckets: Number of refinement buckets.
backbone_kwargs: Optional keyword arguments for backbone.
**kwargs: See docstring of CPN.
"""
super().__init__(
backbone=MobileNetV3LargeFPN(in_channels, **(backbone_kwargs or {})),
order=order,
nms_thresh=nms_thresh,
score_thresh=score_thresh,
samples=samples,
classes=classes,
refinement=refinement,
refinement_iterations=refinement_iterations,
refinement_margin=refinement_margin,
refinement_buckets=refinement_buckets,
**kwargs
)


models_by_name = {
'cpn_u22': 'cpn_u22'
}
Expand Down
49 changes: 49 additions & 0 deletions celldetection/models/fpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torchvision.models.detection.backbone_utils import BackboneWithFPN
from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt50_32x4d, ResNeXt101_32x8d, \
ResNeXt152_32x8d, WideResNet50_2, WideResNet101_2
from .mobilenetv3 import MobileNetV3Large, MobileNetV3Small


class FPN(BackboneWithFPN):
Expand Down Expand Up @@ -87,3 +88,51 @@ def __init__(self, in_channels, fpn_channels=256):
class WideResNet101FPN(FPN):
def __init__(self, in_channels, fpn_channels=256):
super().__init__(WideResNet101_2(in_channels=in_channels), channels=fpn_channels)


class MobileNetV3SmallFPN(FPN):
"""Feature Pyramid Network with MobileNetV3Small.
Examples:
```
>>> import torch
>>> from celldetection import models
>>> model = models.MobileNetV3SmallFPN(in_channels=3)
>>> out: dict = model(torch.rand(1, 3, 256, 256))
>>> for k, v in out.items():
... print(k, v.shape)
0 torch.Size([1, 256, 128, 128])
1 torch.Size([1, 256, 64, 64])
2 torch.Size([1, 256, 32, 32])
3 torch.Size([1, 256, 16, 16])
4 torch.Size([1, 256, 8, 8])
pool torch.Size([1, 256, 4, 4])
```
"""

def __init__(self, in_channels, fpn_channels=256, **kwargs):
super().__init__(MobileNetV3Small(in_channels=in_channels, **kwargs), channels=fpn_channels)


class MobileNetV3LargeFPN(FPN):
"""Feature Pyramid Network with MobileNetV3Large.
Examples:
```
>>> import torch
>>> from celldetection import models
>>> model = models.MobileNetV3LargeFPN(in_channels=3)
>>> out: dict = model(torch.rand(1, 3, 256, 256))
>>> for k, v in out.items():
... print(k, v.shape)
0 torch.Size([1, 256, 128, 128])
1 torch.Size([1, 256, 64, 64])
2 torch.Size([1, 256, 32, 32])
3 torch.Size([1, 256, 16, 16])
4 torch.Size([1, 256, 8, 8])
pool torch.Size([1, 256, 4, 4])
```
"""

def __init__(self, in_channels, fpn_channels=256, **kwargs):
super().__init__(MobileNetV3Large(in_channels=in_channels, **kwargs), channels=fpn_channels)
88 changes: 88 additions & 0 deletions celldetection/models/mobilenetv3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import torch.nn as nn
from torchvision.models.mobilenetv3 import InvertedResidualConfig, InvertedResidual, _mobilenet_v3_conf
from torchvision.models.mobilenetv2 import ConvBNActivation
from typing import Any, Callable, List, Optional, Sequence
from functools import partial

__all__ = ['MobileNetV3Large', 'MobileNetV3Small']


def init_modules_(mod: nn.Module):
for m in mod.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)


class MobileNetV3Base(nn.Sequential):
"""Adaptation of torchvision.models.mobilenetv3.MobileNetV3"""

def __init__(
self,
in_channels,
inverted_residual_setting: List[InvertedResidualConfig],
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any
) -> None:
super().__init__()

if not inverted_residual_setting:
raise ValueError("The inverted_residual_setting should not be empty")
elif not (isinstance(inverted_residual_setting, Sequence) and
all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])):
raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]")

if block is None:
block = InvertedResidual
if norm_layer is None:
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01)

layers: List[nn.Sequential] = [nn.Sequential()]

# building first layer
firstconv_output_channels = inverted_residual_setting[0].input_channels
self.out_channels = [firstconv_output_channels]
layers[-1].add_module(str(len(layers[-1])),
ConvBNActivation(in_channels, firstconv_output_channels, kernel_size=3, stride=2,
norm_layer=norm_layer, activation_layer=nn.Hardswish))

# building inverted residual blocks
for cnf in inverted_residual_setting:
if cnf.stride > 1:
layers.append(nn.Sequential())
self.out_channels.append(cnf.out_channels)
else:
self.out_channels[-1] = cnf.out_channels
layers[-1].add_module(str(len(layers[-1])), block(cnf, norm_layer))

# building last several layers
lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = 6 * lastconv_input_channels
self.out_channels[-1] = lastconv_output_channels
assert len(self.out_channels) == len(layers)
layers[-1].add_module(str(len(layers[-1])),
ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=nn.Hardswish))

super().__init__(*layers)
init_modules_(self)


class MobileNetV3Large(MobileNetV3Base):
def __init__(self, in_channels, width_mult: float = 1.0, reduced_tail: bool = False, dilated: bool = False):
super().__init__(in_channels=in_channels, inverted_residual_setting=_mobilenet_v3_conf(
'mobilenet_v3_large', width_mult=width_mult, reduced_tail=reduced_tail, dilated=dilated)[0])


class MobileNetV3Small(MobileNetV3Base):
def __init__(self, in_channels, width_mult: float = 1.0, reduced_tail: bool = False, dilated: bool = False):
super().__init__(in_channels=in_channels, inverted_residual_setting=_mobilenet_v3_conf(
'mobilenet_v3_small', width_mult=width_mult, reduced_tail=reduced_tail, dilated=dilated)[0])

0 comments on commit 1346561

Please sign in to comment.