Skip to content

Commit

Permalink
Delete mmset and change docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
thinksoso committed Feb 8, 2022
1 parent 53c9529 commit d6b961c
Showing 1 changed file with 114 additions and 161 deletions.
275 changes: 114 additions & 161 deletions flowvision/models/poolformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,6 @@
from .helpers import to_2tuple


try:
from mmseg.models.builder import BACKBONES as seg_BACKBONES
from mmseg.utils import get_root_logger
from mmcv.runner import _load_checkpoint
has_mmseg = True
except ImportError:
print("If for semantic segmentation, please install mmsegmentation first")
has_mmseg = False

try:
from mmdet.models.builder import BACKBONES as det_BACKBONES
from mmdet.utils import get_root_logger
from mmcv.runner import _load_checkpoint
has_mmdet = True
except ImportError:
print("If for detection, please install mmdetection first")
has_mmdet = False


model_urls = {
"poolformer_s12": "https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s12.pth.tar",
"poolformer_s24": "https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s24.pth.tar",
Expand All @@ -45,28 +26,19 @@
}


def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .95, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'classifier': 'head',
**kwargs
}


default_cfgs = {
'poolformer_s': _cfg(crop_pct=0.9),
'poolformer_m': _cfg(crop_pct=0.95),
}


class PatchEmbed(nn.Module):
"""
Patch Embedding that is implemented by a layer of conv.
Input: tensor in shape [B, C, H, W]
Output: tensor in shape [B, C, H/stride, W/stride]
Args:
patch_size (int): kernel size. Default: 16
stride (int): stride in conv. Default: 16
padding (int): controls the amount of padding applied to the input. Default: 0
in_chans (int): nums of input channels. Default: 3
embed_dim (int): nums of out channels. Default: 768
norm_layer : Default: 768
"""
def __init__(self, patch_size=16, stride=16, padding=0,
in_chans=3, embed_dim=768, norm_layer=None):
Expand All @@ -88,17 +60,21 @@ class LayerNormChannel(nn.Module):
"""
LayerNorm only for Channel Dimension.
Input: tensor in shape [B, C, H, W]
Args:
num_channels (int): Number of input channels.
eps (float): Default: 1e-05.
"""
def __init__(self, num_channels, eps=1e-05):
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.weight = nn.Parameter(flow.ones(num_channels))
self.bias = nn.Parameter(flow.zeros(num_channels))
self.eps = eps

def forward(self, x):
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = (x - u) / flow.sqrt(s + self.eps)
x = self.weight.unsqueeze(-1).unsqueeze(-1) * x \
+ self.bias.unsqueeze(-1).unsqueeze(-1)
return x
Expand Down Expand Up @@ -192,9 +168,9 @@ def __init__(self, dim, pool_size=3, mlp_ratio=4.,
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
layer_scale_init_value * flow.ones((dim)), requires_grad=True)
self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
layer_scale_init_value * flow.ones((dim)), requires_grad=True)

def forward(self, x):
if self.use_layer_scale:
Expand Down Expand Up @@ -236,21 +212,22 @@ def basic_blocks(dim, index, layers,


class PoolFormer(nn.Module):
"""
PoolFormer, the main class of our model
--layers: [x,x,x,x], number of blocks for the 4 stages
--embed_dims, --mlp_ratios, --pool_size: the embedding dims, mlp ratios and
pooling size for the 4 stages
--downsamples: flags to apply downsampling or not
--norm_layer, --act_layer: define the types of normalizaiotn and activation
--num_classes: number of classes for the image classification
--in_patch_size, --in_stride, --in_pad: specify the patch embedding
for the input image
--down_patch_size --down_stride --down_pad:
specify the downsample (patch embed.)
--fork_faat: whetehr output features of the 4 stages, for dense prediction
--init_cfg,--pretrained:
for mmdetection and mmsegmentation to load pretrianfed weights
""" PoolFormer
The OneFlow impl of : `A ConvNet for the 2020s` -
https://arxiv.org/abs/2111.11418
Args:
--layers: [x,x,x,x], number of blocks for the 4 stages
--embed_dims: The embedding dims
--mlp_ratios: The mlp ratios
--pool_size: Pooling size
--downsamples: Flags to apply downsampling or not
--norm_layer, --act_layer: Define the types of normalizaiotn and activation
--num_classes: Number of classes for the image classification
--in_patch_size, --in_stride, --in_pad: Specify the patch embedding for the input image
--down_patch_size --down_stride --down_pad: Specify the downsample (patch embed.)
--fork_faat: Whetehr output features of the 4 stages, for dense prediction
--init_cfg,--pretrained: For mmdetection and mmsegmentation to load pretrianfed weights
"""
def __init__(self, layers, embed_dims=None,
mlp_ratios=None, downsamples=None,
Expand Down Expand Up @@ -308,7 +285,7 @@ def __init__(self, layers, embed_dims=None,
if i_emb == 0 and os.environ.get('FORK_LAST3', None):
# TODO: more elegant way
"""For RetinaNet, `start_level=1`. The first norm layer will not used.
cmd: `FORK_LAST3=1 python -m torch.distributed.launch ...`
cmd: `FORK_LAST3=1 python -m flow.distributed.launch ...`
"""
layer = nn.Identity()
else:
Expand Down Expand Up @@ -379,8 +356,6 @@ def forward(self, x):

def _create_poolformer(arch, pretrained=False, progress=True, **model_kwargs):
model = PoolFormer(**model_kwargs)
default_cfgs_name = arch.split("_")[0] + arch.split("_")[1][1]
model.default_cfg = default_cfgs[default_cfgs_name]
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
Expand All @@ -390,11 +365,22 @@ def _create_poolformer(arch, pretrained=False, progress=True, **model_kwargs):
@ModelCreator.register_model
def poolformer_s12(pretrained=False, progress=True, **kwargs):
"""
PoolFormer-S12 model, Params: 12M
--layers: [x,x,x,x], numbers of layers for the four stages
--embed_dims, --mlp_ratios:
embedding dims and mlp ratios for the four stages
--downsamples: flags to apply downsampling or not in four blocks
Constructs the PoolFormer model.
.. note::
PoolFormer-S12 model, Params: 12M. From `"MetaFormer is Actually What You Need for Vision" <https://arxiv.org/abs/2111.11418>` _.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderrt. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> poolformer_s12 = flowvision.models.poolformer_s12(pretrained=False, progress=True)
"""
model_kwargs = dict(
layers = [2, 2, 6, 2],
Expand All @@ -410,7 +396,22 @@ def poolformer_s12(pretrained=False, progress=True, **kwargs):
@ModelCreator.register_model
def poolformer_s24(pretrained=False, progress=True, **kwargs):
"""
PoolFormer-S24 model, Params: 21M
Constructs the PoolFormer model.
.. note::
PoolFormer-S24 model, Params: 21M. From `"MetaFormer is Actually What You Need for Vision" <https://arxiv.org/abs/2111.11418>` _.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderrt. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> poolformer_s24 = flowvision.models.poolformer_s24(pretrained=False, progress=True)
"""
model_kwargs = dict(
layers = [4, 4, 12, 4],
Expand All @@ -426,7 +427,22 @@ def poolformer_s24(pretrained=False, progress=True, **kwargs):
@ModelCreator.register_model
def poolformer_s36(pretrained=False, progress=True, **kwargs):
"""
PoolFormer-S36 model, Params: 31M
Constructs the PoolFormer model.
.. note::
PoolFormer-S36 model, Params: 31M. From `"MetaFormer is Actually What You Need for Vision" <https://arxiv.org/abs/2111.11418>` _.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderrt. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> poolformer_s36 = flowvision.models.poolformer_s36(pretrained=False, progress=True)
"""
model_kwargs = dict(
layers = [6, 6, 18, 6],
Expand All @@ -441,7 +457,22 @@ def poolformer_s36(pretrained=False, progress=True, **kwargs):
@ModelCreator.register_model
def poolformer_m36(pretrained=False, progress=True, **kwargs):
"""
PoolFormer-M36 model, Params: 56M
Constructs the PoolFormer model.
.. note::
PoolFormer-m36 model, Params: 56M. From `"MetaFormer is Actually What You Need for Vision" <https://arxiv.org/abs/2111.11418>` _.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderrt. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> poolformer_m36 = flowvision.models.poolformer_m36(pretrained=False, progress=True)
"""
model_kwargs = dict(
layers = [6, 6, 18, 6],
Expand All @@ -457,7 +488,22 @@ def poolformer_m36(pretrained=False, progress=True, **kwargs):
@ModelCreator.register_model
def poolformer_m48(pretrained=False, progress=True, **kwargs):
"""
PoolFormer-M48 model, Params: 73M
Constructs the PoolFormer model.
.. note::
PoolFormer-m48 model, Params: 73M. From `"MetaFormer is Actually What You Need for Vision" <https://arxiv.org/abs/2111.11418>` _.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderrt. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> poolformer_m48 = flowvision.models.poolformer_m48(pretrained=False, progress=True)
"""
model_kwargs = dict(
layers = [8, 8, 24, 8],
Expand All @@ -468,96 +514,3 @@ def poolformer_m48(pretrained=False, progress=True, **kwargs):
return _create_poolformer(
"poolformer_m48", pretrained=pretrained, progress=progress, **model_kwargs
)

if has_mmseg and has_mmdet:
"""
The following models are for dense prediction based on
mmdetection and mmsegmentation
"""
@seg_BACKBONES.register_module()
@det_BACKBONES.register_module()
class poolformer_s12_feat(PoolFormer):
"""
PoolFormer-S12 model, Params: 12M
"""
def __init__(self, **kwargs):
layers = [2, 2, 6, 2]
embed_dims = [64, 128, 320, 512]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
super().__init__(
layers, embed_dims=embed_dims,
mlp_ratios=mlp_ratios, downsamples=downsamples,
fork_feat=True,
**kwargs)

@seg_BACKBONES.register_module()
@det_BACKBONES.register_module()
class poolformer_s24_feat(PoolFormer):
"""
PoolFormer-S24 model, Params: 21M
"""
def __init__(self, **kwargs):
layers = [4, 4, 12, 4]
embed_dims = [64, 128, 320, 512]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
super().__init__(
layers, embed_dims=embed_dims,
mlp_ratios=mlp_ratios, downsamples=downsamples,
fork_feat=True,
**kwargs)

@seg_BACKBONES.register_module()
@det_BACKBONES.register_module()
class poolformer_s36_feat(PoolFormer):
"""
PoolFormer-S36 model, Params: 31M
"""
def __init__(self, **kwargs):
layers = [6, 6, 18, 6]
embed_dims = [64, 128, 320, 512]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
super().__init__(
layers, embed_dims=embed_dims,
mlp_ratios=mlp_ratios, downsamples=downsamples,
layer_scale_init_value=1e-6,
fork_feat=True,
**kwargs)

@seg_BACKBONES.register_module()
@det_BACKBONES.register_module()
class poolformer_m36_feat(PoolFormer):
"""
PoolFormer-S36 model, Params: 56M
"""
def __init__(self, **kwargs):
layers = [6, 6, 18, 6]
embed_dims = [96, 192, 384, 768]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
super().__init__(
layers, embed_dims=embed_dims,
mlp_ratios=mlp_ratios, downsamples=downsamples,
layer_scale_init_value=1e-6,
fork_feat=True,
**kwargs)

@seg_BACKBONES.register_module()
@det_BACKBONES.register_module()
class poolformer_m48_feat(PoolFormer):
"""
PoolFormer-M48 model, Params: 73M
"""
def __init__(self, **kwargs):
layers = [8, 8, 24, 8]
embed_dims = [96, 192, 384, 768]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
super().__init__(
layers, embed_dims=embed_dims,
mlp_ratios=mlp_ratios, downsamples=downsamples,
layer_scale_init_value=1e-6,
fork_feat=True,
**kwargs)

0 comments on commit d6b961c

Please sign in to comment.