diff --git a/src/model_constructor/model_constructor.py b/src/model_constructor/model_constructor.py index 52064af..eee2ff5 100644 --- a/src/model_constructor/model_constructor.py +++ b/src/model_constructor/model_constructor.py @@ -310,10 +310,14 @@ def print_cfg(self): ) -xresnet34 = ModelConstructor.from_cfg( - CfgMC(name="xresnet34", expansion=1, layers=[3, 4, 6, 3]) -) +@dataclass +class XResNet34(ModelConstructor): + name: str = "xresnet34" + layers: list[int] = field(default_factory=lambda: [3, 4, 6, 3]) + -xresnet50 = ModelConstructor.from_cfg( - CfgMC(name="xresnet34", expansion=4, layers=[3, 4, 6, 3]) -) +@dataclass +class XResNet50(ModelConstructor): + name: str = "xresnet50" + expansion: int = 4 + layers: list[int] = field(default_factory=lambda: [3, 4, 6, 3]) diff --git a/src/model_constructor/yaresnet.py b/src/model_constructor/yaresnet.py index c33cca5..4ce026a 100644 --- a/src/model_constructor/yaresnet.py +++ b/src/model_constructor/yaresnet.py @@ -1,15 +1,20 @@ # YaResBlock - former NewResBlock. # Yet another ResNet. -import torch.nn as nn -from functools import partial from collections import OrderedDict -from .layers import ConvBnAct -from .net import Net +from typing import Union + +import torch.nn as nn from torch.nn import Mish +from .layers import ConvBnAct +from .model_constructor import CfgMC, ModelConstructor -__all__ = ['YaResBlock', 'yaresnet_parameters', 'yaresnet34', 'yaresnet50'] +__all__ = [ + 'YaResBlock', + 'yaresnet34', + 'yaresnet50', +] act_fn = nn.ReLU(inplace=True) @@ -18,16 +23,29 @@ class YaResBlock(nn.Module): '''YaResBlock. Reduce by pool instead of stride 2''' - def __init__(self, expansion, in_channels, mid_channels, stride=1, - conv_layer=ConvBnAct, act_fn=act_fn, zero_bn=True, bn_1st=True, - groups=1, dw=False, div_groups=None, - pool=None, - se=None, sa=None, - ): + def __init__( + self, + expansion: int, + in_channels: int, + mid_channels: int, + stride: int = 1, + conv_layer=ConvBnAct, + act_fn: nn.Module = act_fn, + zero_bn: bool = True, + bn_1st: bool = True, + groups: int = 1, + dw: bool = False, + div_groups: Union[None, int] = None, + pool: Union[nn.Module, None] = None, + se: Union[nn.Module, None] = None, + sa: Union[nn.Module, None] = None, + ): super().__init__() + # pool defined at ModelConstructor. out_channels, in_channels = mid_channels * expansion, in_channels * expansion if div_groups is not None: # check if groups != 1 and div_groups groups = int(mid_channels / div_groups) + if stride != 1: if pool is None: self.reduce = conv_layer(in_channels, in_channels, 1, stride=2) @@ -36,23 +54,69 @@ def __init__(self, expansion, in_channels, mid_channels, stride=1, self.reduce = pool else: self.reduce = None - layers = [("conv_0", conv_layer(in_channels, mid_channels, 3, stride=1, - act_fn=act_fn, bn_1st=bn_1st, groups=in_channels if dw else groups)), - ("conv_1", conv_layer(mid_channels, out_channels, 3, zero_bn=zero_bn, - act_fn=False, bn_1st=bn_1st, groups=mid_channels if dw else groups)) - ] if expansion == 1 else [ - ("conv_0", conv_layer(in_channels, mid_channels, 1, act_fn=act_fn, bn_1st=bn_1st)), - ("conv_1", conv_layer(mid_channels, mid_channels, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st, - groups=mid_channels if dw else groups)), - ("conv_2", conv_layer( - mid_channels, out_channels, 1, zero_bn=zero_bn, act_fn=False, bn_1st=bn_1st)) - ] + if expansion == 1: + layers = [ + ("conv_0", conv_layer( + in_channels, + mid_channels, + 3, + stride=1, + act_fn=act_fn, + bn_1st=bn_1st, + groups=in_channels if dw else groups, + ),), + ("conv_1", conv_layer( + mid_channels, + out_channels, + 3, + zero_bn=zero_bn, + act_fn=False, + bn_1st=bn_1st, + groups=mid_channels if dw else groups, + ),), + ] + else: + layers = [ + ("conv_0", conv_layer( + in_channels, + mid_channels, + 1, + act_fn=act_fn, + bn_1st=bn_1st, + ),), + ("conv_1", conv_layer( + mid_channels, + mid_channels, + 3, + stride=1, + act_fn=act_fn, + bn_1st=bn_1st, + groups=mid_channels if dw else groups, + ),), + ("conv_2", conv_layer( + mid_channels, + out_channels, + 1, + zero_bn=zero_bn, + act_fn=False, + bn_1st=bn_1st, + ),), # noqa E501 + ] if se: - layers.append(('se', se(out_channels))) + layers.append(("se", se(out_channels))) if sa: - layers.append(('sa', sa(out_channels))) + layers.append(("sa", sa(out_channels))) self.convs = nn.Sequential(OrderedDict(layers)) - self.id_conv = None if in_channels == out_channels else conv_layer(in_channels, out_channels, 1, act_fn=False) + if in_channels != out_channels: + self.id_conv = conv_layer( + in_channels, + out_channels, + 1, + stride=1, + act_fn=False, + ) + else: + self.id_conv = None self.merge = act_fn def forward(self, x): @@ -62,6 +126,21 @@ def forward(self, x): return self.merge(self.convs(x) + identity) -yaresnet_parameters = {'block': YaResBlock, 'stem_sizes': [3, 32, 64, 64], 'act_fn': Mish(), 'stem_stride_on': 1} -yaresnet34 = partial(Net, name='YaResnet34', expansion=1, layers=[3, 4, 6, 3], **yaresnet_parameters) -yaresnet50 = partial(Net, name='YaResnet50', expansion=4, layers=[3, 4, 6, 3], **yaresnet_parameters) +yaresnet34 = ModelConstructor.from_cfg( + CfgMC( + name='YaResnet34', + block=YaResBlock, + expansion=1, + layers=[3, 4, 6, 3], + act_fn=Mish(), + ) +) +yaresnet50 = ModelConstructor.from_cfg( + CfgMC( + name='YaResnet50', + block=YaResBlock, + act_fn=Mish(), + expansion=4, + layers=[3, 4, 6, 3], + ) +)