From e444c683cfddb5567813f6f6286e7c4e6388ed31 Mon Sep 17 00:00:00 2001 From: ayasyrev Date: Sat, 22 Oct 2022 20:09:56 +0300 Subject: [PATCH 1/6] layers typing --- model_constructor/layers.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/model_constructor/layers.py b/model_constructor/layers.py index 0a0f885..ddbb4fe 100644 --- a/model_constructor/layers.py +++ b/model_constructor/layers.py @@ -1,6 +1,7 @@ +from typing import List, Optional import torch.nn as nn import torch -from torch.nn.utils import spectral_norm +from torch.nn.utils.spectral_norm import spectral_norm from collections import OrderedDict @@ -39,16 +40,28 @@ class ConvBnAct(nn.Sequential): convolution_module = nn.Conv2d # can be changed in models like twist. batchnorm_module = nn.BatchNorm2d - def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, - padding=None, bias=False, groups=1, - act_fn=act_fn, pre_act=False, - bn_layer=True, bn_1st=True, zero_bn=False, - ): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: Optional[int] = None, + bias: bool = False, + groups: int = 1, + act_fn: Optional[nn.Module] = act_fn, + pre_act: bool = False, + bn_layer: bool = True, + bn_1st: bool = True, + zero_bn: bool = False, + ): if padding is None: padding = kernel_size // 2 - layers = [('conv', self.convolution_module(in_channels, out_channels, kernel_size, stride=stride, - padding=padding, bias=bias, groups=groups))] # if no bn - bias True? + layers: List[tuple[str, nn.Module]] = [ + ('conv', self.convolution_module( + in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias, groups=groups)) + ] # if no bn - bias True? if bn_layer: bn = self.batchnorm_module(out_channels) nn.init.constant_(bn.weight, 0. if zero_bn else 1.) @@ -133,7 +146,7 @@ def forward(self, x): return o.view(*size).contiguous() -class SEBlock(nn.Module): # todo: deprecation worning. +class SEBlock(nn.Module): # todo: deprecation warning. "se block" se_layer = nn.Linear act_fn = nn.ReLU(inplace=True) @@ -157,7 +170,7 @@ def forward(self, x): return x * y.expand_as(x) -class SEBlockConv(nn.Module): # todo: deprecation worning. +class SEBlockConv(nn.Module): # todo: deprecation warning. "se block with conv on excitation" se_layer = nn.Conv2d act_fn = nn.ReLU(inplace=True) From d6918726790263471c9b818070ac7ed32c2d70cb Mon Sep 17 00:00:00 2001 From: ayasyrev Date: Sat, 22 Oct 2022 23:44:11 +0300 Subject: [PATCH 2/6] typing --- model_constructor/layers.py | 10 ++-- model_constructor/model_constructor.py | 83 ++++++++++++++++---------- 2 files changed, 57 insertions(+), 36 deletions(-) diff --git a/model_constructor/layers.py b/model_constructor/layers.py index ddbb4fe..2eaa7d9 100644 --- a/model_constructor/layers.py +++ b/model_constructor/layers.py @@ -1,9 +1,9 @@ -from typing import List, Optional -import torch.nn as nn -import torch -from torch.nn.utils.spectral_norm import spectral_norm from collections import OrderedDict +from typing import List, Optional, Union +import torch +import torch.nn as nn +from torch.nn.utils.spectral_norm import spectral_norm __all__ = ['Flatten', 'noop', 'Noop', 'ConvLayer', 'act_fn', 'conv1d', 'SimpleSelfAttention', 'SEBlock', 'SEBlockConv'] @@ -49,7 +49,7 @@ def __init__( padding: Optional[int] = None, bias: bool = False, groups: int = 1, - act_fn: Optional[nn.Module] = act_fn, + act_fn: Union[nn.Module, bool] = act_fn, pre_act: bool = False, bn_layer: bool = True, bn_1st: bool = True, diff --git a/model_constructor/model_constructor.py b/model_constructor/model_constructor.py index 29e7705..70d264f 100644 --- a/model_constructor/model_constructor.py +++ b/model_constructor/model_constructor.py @@ -1,6 +1,6 @@ from collections import OrderedDict from functools import partial -from typing import Callable, Union +from typing import Callable, List, Sequence, Union import torch.nn as nn @@ -24,14 +24,25 @@ def init_cnn(module: nn.Module): class ResBlock(nn.Module): - '''Resnet block''' - - def __init__(self, expansion, in_channels, mid_channels, stride=1, - conv_layer=ConvBnAct, act_fn=act_fn, zero_bn=True, bn_1st=True, - groups=1, dw=False, div_groups=None, - pool=None, - se=None, sa=None - ): + '''Universal Resnet block. Basic block if expansion is 1, otherwise is Bottleneck.''' + + def __init__( + self, + expansion: int, + in_channels: int, + mid_channels: int, + stride: int = 1, + conv_layer: Union[nn.Module, nn.Sequential] = ConvBnAct, + act_fn: nn.Module = act_fn, + zero_bn: bool = True, + bn_1st: bool = True, + groups: int = 1, + dw: bool = False, + div_groups: Union[None, int] = None, + pool: Union[nn.Module, None] = None, + se: Union[nn.Module, None] = None, + sa: Union[nn.Module, None] = None, + ): super().__init__() # pool defined at ModelConstructor. out_channels, in_channels = mid_channels * expansion, in_channels * expansion @@ -124,28 +135,38 @@ def _make_head(self): class ModelConstructor(): """Model constructor. As default - xresnet18""" - def __init__(self, name='MC', in_chans=3, num_classes=1000, - block=ResBlock, conv_layer=ConvBnAct, - block_sizes=[64, 128, 256, 512], layers=[2, 2, 2, 2], - norm=nn.BatchNorm2d, - act_fn=nn.ReLU(inplace=True), - pool=nn.AvgPool2d(2, ceil_mode=True), - expansion=1, groups=1, dw=False, div_groups=None, - sa: Union[bool, int, Callable] = False, - se: Union[bool, int, Callable] = False, - se_module=None, se_reduction=None, - bn_1st=True, - zero_bn=True, - stem_stride_on=0, - stem_sizes=[32, 32, 64], - stem_pool=nn.MaxPool2d(kernel_size=3, stride=2, padding=1), - stem_bn_end=False, - _init_cnn=init_cnn, - _make_stem=_make_stem, - _make_layer=_make_layer, - _make_body=_make_body, - _make_head=_make_head, - ): + def __init__( + self, + name: str = 'MC', + in_chans: int = 3, + num_classes: int = 1000, + block=ResBlock, + conv_layer=ConvBnAct, + block_sizes: List[int] = [64, 128, 256, 512], + layers: List[int] = [2, 2, 2, 2], + norm: nn.Module = nn.BatchNorm2d, + act_fn: nn.Module = nn.ReLU(inplace=True), + pool: nn.Module = nn.AvgPool2d(2, ceil_mode=True), + expansion: int = 1, + groups: int = 1, + dw: bool = False, + div_groups=None, + sa: Union[bool, int, Callable] = False, + se: Union[bool, int, Callable] = False, + se_module=None, + se_reduction=None, + bn_1st=True, + zero_bn=True, + stem_stride_on=0, + stem_sizes=[32, 32, 64], + stem_pool=nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + stem_bn_end=False, + _init_cnn=init_cnn, + _make_stem=_make_stem, + _make_layer=_make_layer, + _make_body=_make_body, + _make_head=_make_head, + ): super().__init__() # se can be bool, int (0, 1) or nn.Module # se_module - deprecated. Leaved for warning and checks. From 8d975953fd2ec6e814f2f862ff6914884cf9a323 Mon Sep 17 00:00:00 2001 From: ayasyrev Date: Mon, 24 Oct 2022 10:21:34 +0300 Subject: [PATCH 3/6] typing MC --- model_constructor/model_constructor.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/model_constructor/model_constructor.py b/model_constructor/model_constructor.py index 70d264f..375521f 100644 --- a/model_constructor/model_constructor.py +++ b/model_constructor/model_constructor.py @@ -150,22 +150,22 @@ def __init__( expansion: int = 1, groups: int = 1, dw: bool = False, - div_groups=None, + div_groups: Union[int, None]=None, sa: Union[bool, int, Callable] = False, se: Union[bool, int, Callable] = False, se_module=None, se_reduction=None, - bn_1st=True, - zero_bn=True, - stem_stride_on=0, - stem_sizes=[32, 32, 64], - stem_pool=nn.MaxPool2d(kernel_size=3, stride=2, padding=1), - stem_bn_end=False, - _init_cnn=init_cnn, - _make_stem=_make_stem, - _make_layer=_make_layer, - _make_body=_make_body, - _make_head=_make_head, + bn_1st: bool = True, + zero_bn: bool = True, + stem_stride_on: int = 0, + stem_sizes: List[int] = [32, 32, 64], + stem_pool: Union[nn.Module, None] =nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + stem_bn_end: bool = False, + _init_cnn: Callable = init_cnn, + _make_stem: Callable = _make_stem, + _make_layer: Callable = _make_layer, + _make_body: Callable = _make_body, + _make_head: Callable = _make_head, ): super().__init__() # se can be bool, int (0, 1) or nn.Module From 8a8eb35436e97ae60025814f94206b8cbd34fbf1 Mon Sep 17 00:00:00 2001 From: ayasyrev Date: Mon, 24 Oct 2022 12:45:56 +0300 Subject: [PATCH 4/6] mc init --- model_constructor/layers.py | 6 +-- model_constructor/model_constructor.py | 63 +++++++++++++++++++------- 2 files changed, 50 insertions(+), 19 deletions(-) diff --git a/model_constructor/layers.py b/model_constructor/layers.py index 2eaa7d9..88ea54b 100644 --- a/model_constructor/layers.py +++ b/model_constructor/layers.py @@ -66,7 +66,7 @@ def __init__( bn = self.batchnorm_module(out_channels) nn.init.constant_(bn.weight, 0. if zero_bn else 1.) layers.append(('bn', bn)) - if act_fn: + if isinstance(act_fn, nn.Module): # act_fn either nn.Module or False if pre_act: act_position = 0 elif not bn_1st: @@ -111,7 +111,7 @@ def conv1d(ni: int, no: int, ks: int = 1, stride: int = 1, padding: int = 0, bia conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias) nn.init.kaiming_normal_(conv.weight) if bias: - conv.bias.data.zero_() + conv.bias.data.zero_() # type: ignore return spectral_norm(conv) @@ -125,7 +125,7 @@ class SimpleSelfAttention(nn.Module): def __init__(self, n_in: int, ks=1, sym=False, use_bias=False): super().__init__() self.conv = conv1d(n_in, n_in, ks, padding=ks // 2, bias=use_bias) - self.gamma = nn.Parameter(torch.tensor([0.])) + self.gamma = torch.nn.Parameter(torch.tensor([0.])) # type: ignore self.sym = sym self.n_in = n_in diff --git a/model_constructor/model_constructor.py b/model_constructor/model_constructor.py index 375521f..8d76e55 100644 --- a/model_constructor/model_constructor.py +++ b/model_constructor/model_constructor.py @@ -1,6 +1,6 @@ from collections import OrderedDict from functools import partial -from typing import Callable, List, Sequence, Union +from typing import Callable, List, Type, Union import torch.nn as nn @@ -16,7 +16,7 @@ def init_cnn(module: nn.Module): "Init module - kaiming_normal for Conv2d and 0 for biases." if getattr(module, 'bias', None) is not None: - nn.init.constant_(module.bias, 0) + nn.init.constant_(module.bias, 0) # type: ignore if isinstance(module, (nn.Conv2d, nn.Linear)): nn.init.kaiming_normal_(module.weight) for layer in module.children(): @@ -32,7 +32,7 @@ def __init__( in_channels: int, mid_channels: int, stride: int = 1, - conv_layer: Union[nn.Module, nn.Sequential] = ConvBnAct, + conv_layer=ConvBnAct, act_fn: nn.Module = act_fn, zero_bn: bool = True, bn_1st: bool = True, @@ -49,7 +49,7 @@ def __init__( if div_groups is not None: # check if groups != 1 and div_groups groups = int(mid_channels / div_groups) if expansion == 1: - layers = [("conv_0", conv_layer(in_channels, mid_channels, 3, stride=stride, + layers = [("conv_0", conv_layer(in_channels, mid_channels, 3, stride=stride, # type: ignore act_fn=act_fn, bn_1st=bn_1st, groups=in_channels if dw else groups)), ("conv_1", conv_layer(mid_channels, out_channels, 3, zero_bn=zero_bn, act_fn=False, bn_1st=bn_1st, groups=mid_channels if dw else groups)) @@ -99,7 +99,8 @@ def _make_stem(self): def _make_layer(self, layer_num: int) -> nn.Module: # expansion, in_channels, out_channels, blocks, stride, sa): - stride = 1 if self.stem_pool and layer_num == 0 else 2 # if no pool on stem - stride = 2 for first layer block in body + # if no pool on stem - stride = 2 for first layer block in body + stride = 1 if self.stem_pool and layer_num == 0 else 2 num_blocks = self.layers[layer_num] return nn.Sequential(OrderedDict([ (f"bl_{block_num}", self.block( @@ -144,22 +145,22 @@ def __init__( conv_layer=ConvBnAct, block_sizes: List[int] = [64, 128, 256, 512], layers: List[int] = [2, 2, 2, 2], - norm: nn.Module = nn.BatchNorm2d, + norm: Type[nn.Module] = nn.BatchNorm2d, act_fn: nn.Module = nn.ReLU(inplace=True), pool: nn.Module = nn.AvgPool2d(2, ceil_mode=True), expansion: int = 1, groups: int = 1, dw: bool = False, - div_groups: Union[int, None]=None, - sa: Union[bool, int, Callable] = False, - se: Union[bool, int, Callable] = False, + div_groups: Union[int, None] = None, + sa: Union[bool, int, Type[nn.Module]] = False, + se: Union[bool, int, Type[nn.Module]] = False, se_module=None, se_reduction=None, bn_1st: bool = True, zero_bn: bool = True, stem_stride_on: int = 0, stem_sizes: List[int] = [32, 32, 64], - stem_pool: Union[nn.Module, None] =nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + stem_pool: Union[Type[nn.Module], None] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1), # type: ignore stem_bn_end: bool = False, _init_cnn: Callable = init_cnn, _make_stem: Callable = _make_stem, @@ -172,24 +173,54 @@ def __init__( # se_module - deprecated. Leaved for warning and checks. # if stem_pool is False - no pool at stem - params = locals() - del params['self'] - self.__dict__ = params - - self._block_sizes = params['block_sizes'] + self.name = name + self.in_chans = in_chans + self.num_classes = num_classes + self.block = block + self.conv_layer = conv_layer + self._block_sizes = block_sizes + self.layers = layers + self.norm = norm + self.act_fn = act_fn + self.pool = pool + self.expansion = expansion + self.groups = groups + self.dw = dw + self.div_groups = div_groups + # se_module + # se_reduction + self.bn_1st = bn_1st + self.zero_bn = zero_bn + self.stem_stride_on = stem_stride_on + self.stem_pool = stem_pool + self.stem_bn_end = stem_bn_end + self._init_cnn = _init_cnn + self._make_stem = _make_stem + self._make_layer = _make_layer + self._make_body = _make_body + self._make_head = _make_head + + # params = locals() + # del params['self'] + # self.__dict__ = params + + # self._block_sizes = params['block_sizes'] + self.stem_sizes = stem_sizes if self.stem_sizes[0] != self.in_chans: self.stem_sizes = [self.in_chans] + self.stem_sizes + self.se = se if self.se: if type(self.se) in (bool, int): # if se=1 or se=True self.se = SEModule else: self.se = se # TODO add check issubclass or isinstance of nn.Module + self.sa = sa if self.sa: # if sa=1 or sa=True if type(self.sa) in (bool, int): self.sa = SimpleSelfAttention # default: ks=1, sym=sym else: self.sa = sa - if self.se_module or se_reduction: # pragma: no cover + if se_module or se_reduction: # pragma: no cover print("Deprecated. Pass se_module as se argument, se_reduction as arg to se.") # add deprecation warning. @property From 76ae20dcdff2f47e32b761aecc5b69cb447bb826 Mon Sep 17 00:00:00 2001 From: ayasyrev Date: Mon, 24 Oct 2022 13:11:06 +0300 Subject: [PATCH 5/6] black mc --- model_constructor/model_constructor.py | 215 +++++++++++++++++-------- 1 file changed, 150 insertions(+), 65 deletions(-) diff --git a/model_constructor/model_constructor.py b/model_constructor/model_constructor.py index 8d76e55..09922e5 100644 --- a/model_constructor/model_constructor.py +++ b/model_constructor/model_constructor.py @@ -7,7 +7,14 @@ from .layers import ConvBnAct, SEModule, SimpleSelfAttention -__all__ = ['init_cnn', 'act_fn', 'ResBlock', 'ModelConstructor', 'xresnet34', 'xresnet50'] +__all__ = [ + "init_cnn", + "act_fn", + "ResBlock", + "ModelConstructor", + "xresnet34", + "xresnet50", +] act_fn = nn.ReLU(inplace=True) @@ -15,7 +22,7 @@ def init_cnn(module: nn.Module): "Init module - kaiming_normal for Conv2d and 0 for biases." - if getattr(module, 'bias', None) is not None: + if getattr(module, "bias", None) is not None: nn.init.constant_(module.bias, 0) # type: ignore if isinstance(module, (nn.Conv2d, nn.Linear)): nn.init.kaiming_normal_(module.weight) @@ -24,7 +31,7 @@ def init_cnn(module: nn.Module): class ResBlock(nn.Module): - '''Universal Resnet block. Basic block if expansion is 1, otherwise is Bottleneck.''' + """Universal Resnet block. Basic block if expansion is 1, otherwise is Bottleneck.""" def __init__( self, @@ -49,21 +56,57 @@ def __init__( if div_groups is not None: # check if groups != 1 and div_groups groups = int(mid_channels / div_groups) if expansion == 1: - layers = [("conv_0", conv_layer(in_channels, mid_channels, 3, stride=stride, # type: ignore - act_fn=act_fn, bn_1st=bn_1st, groups=in_channels if dw else groups)), - ("conv_1", conv_layer(mid_channels, out_channels, 3, zero_bn=zero_bn, - act_fn=False, bn_1st=bn_1st, groups=mid_channels if dw else groups)) - ] + layers = [ + ("conv_0", conv_layer( + in_channels, + mid_channels, + 3, + stride=stride, # type: ignore + act_fn=act_fn, + bn_1st=bn_1st, + groups=in_channels if dw else groups, + ),), + ("conv_1", conv_layer( + mid_channels, + out_channels, + 3, + zero_bn=zero_bn, + act_fn=False, + bn_1st=bn_1st, + groups=mid_channels if dw else groups, + ),), + ] else: - layers = [("conv_0", conv_layer(in_channels, mid_channels, 1, act_fn=act_fn, bn_1st=bn_1st)), - ("conv_1", conv_layer(mid_channels, mid_channels, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st, - groups=mid_channels if dw else groups)), - ("conv_2", conv_layer(mid_channels, out_channels, 1, zero_bn=zero_bn, act_fn=False, bn_1st=bn_1st)) # noqa E501 - ] + layers = [ + ("conv_0", conv_layer( + in_channels, + mid_channels, + 1, + act_fn=act_fn, + bn_1st=bn_1st, + ),), + ("conv_1", conv_layer( + mid_channels, + mid_channels, + 3, + stride=stride, + act_fn=act_fn, + bn_1st=bn_1st, + groups=mid_channels if dw else groups, + ),), + ("conv_2", conv_layer( + mid_channels, + out_channels, + 1, + zero_bn=zero_bn, + act_fn=False, + bn_1st=bn_1st, + ),), # noqa E501 + ] if se: - layers.append(('se', se(out_channels))) + layers.append(("se", se(out_channels))) if sa: - layers.append(('sa', sa(out_channels))) + layers.append(("sa", sa(out_channels))) self.convs = nn.Sequential(OrderedDict(layers)) if stride != 1 or in_channels != out_channels: id_layers = [] @@ -71,9 +114,12 @@ def __init__( id_layers.append(("pool", pool)) if in_channels != out_channels or (stride != 1 and pool is None): id_layers += [("id_conv", conv_layer( - in_channels, out_channels, 1, + in_channels, + out_channels, + 1, stride=1 if pool else stride, - act_fn=False))] + act_fn=False, + ),)] self.id_conv = nn.Sequential(OrderedDict(id_layers)) else: self.id_conv = None @@ -85,15 +131,23 @@ def forward(self, x): def _make_stem(self): - stem = [(f"conv_{i}", self.conv_layer(self.stem_sizes[i], self.stem_sizes[i + 1], - stride=2 if i == self.stem_stride_on else 1, - bn_layer=(not self.stem_bn_end) if i == (len(self.stem_sizes) - 2) else True, - act_fn=self.act_fn, bn_1st=self.bn_1st)) - for i in range(len(self.stem_sizes) - 1)] + stem = [ + (f"conv_{i}", self.conv_layer( + self.stem_sizes[i], + self.stem_sizes[i + 1], + stride=2 if i == self.stem_stride_on else 1, + bn_layer=(not self.stem_bn_end) + if i == (len(self.stem_sizes) - 2) + else True, + act_fn=self.act_fn, + bn_1st=self.bn_1st, + ),) + for i in range(len(self.stem_sizes) - 1) + ] if self.stem_pool: - stem.append(('stem_pool', self.stem_pool)) + stem.append(("stem_pool", self.stem_pool)) if self.stem_bn_end: - stem.append(('norm', self.norm(self.stem_sizes[-1]))) + stem.append(("norm", self.norm(self.stem_sizes[-1]))) return nn.Sequential(OrderedDict(stem)) @@ -102,43 +156,67 @@ def _make_layer(self, layer_num: int) -> nn.Module: # if no pool on stem - stride = 2 for first layer block in body stride = 1 if self.stem_pool and layer_num == 0 else 2 num_blocks = self.layers[layer_num] - return nn.Sequential(OrderedDict([ - (f"bl_{block_num}", self.block( - self.expansion, - self.block_sizes[layer_num] if block_num == 0 else self.block_sizes[layer_num + 1], - self.block_sizes[layer_num + 1], - stride if block_num == 0 else 1, - sa=self.sa if (block_num == num_blocks - 1) and layer_num == 0 else None, - conv_layer=self.conv_layer, - act_fn=self.act_fn, - pool=self.pool, - zero_bn=self.zero_bn, bn_1st=self.bn_1st, - groups=self.groups, div_groups=self.div_groups, - dw=self.dw, se=self.se - )) - for block_num in range(num_blocks) - ])) + return nn.Sequential( + OrderedDict( + [ + ( + f"bl_{block_num}", + self.block( + self.expansion, + self.block_sizes[layer_num] + if block_num == 0 + else self.block_sizes[layer_num + 1], + self.block_sizes[layer_num + 1], + stride if block_num == 0 else 1, + sa=self.sa + if (block_num == num_blocks - 1) and layer_num == 0 + else None, + conv_layer=self.conv_layer, + act_fn=self.act_fn, + pool=self.pool, + zero_bn=self.zero_bn, + bn_1st=self.bn_1st, + groups=self.groups, + div_groups=self.div_groups, + dw=self.dw, + se=self.se, + ), + ) + for block_num in range(num_blocks) + ] + ) + ) def _make_body(self): - return nn.Sequential(OrderedDict([ - (f"l_{layer_num}", self._make_layer(self, layer_num)) - for layer_num in range(len(self.layers)) - ])) + return nn.Sequential( + OrderedDict( + [ + ( + f"l_{layer_num}", + self._make_layer(self, layer_num) + ) + for layer_num in range(len(self.layers)) + ] + ) + ) def _make_head(self): - head = [('pool', nn.AdaptiveAvgPool2d(1)), - ('flat', nn.Flatten()), - ('fc', nn.Linear(self.block_sizes[-1] * self.expansion, self.num_classes))] + head = [ + ("pool", nn.AdaptiveAvgPool2d(1)), + ("flat", nn.Flatten()), + ("fc", nn.Linear(self.block_sizes[-1] * self.expansion, self.num_classes)), + ] return nn.Sequential(OrderedDict(head)) -class ModelConstructor(): +class ModelConstructor: """Model constructor. As default - xresnet18""" + def __init__( self, - name: str = 'MC', + name: str = "MC", in_chans: int = 3, num_classes: int = 1000, block=ResBlock, @@ -221,7 +299,9 @@ def __init__( else: self.sa = sa if se_module or se_reduction: # pragma: no cover - print("Deprecated. Pass se_module as se argument, se_reduction as arg to se.") # add deprecation warning. + print( + "Deprecated. Pass se_module as se argument, se_reduction as arg to se." + ) # add deprecation warning. @property def block_sizes(self): @@ -240,23 +320,28 @@ def body(self): return self._make_body(self) def __call__(self): - model = nn.Sequential(OrderedDict([ - ('stem', self.stem), - ('body', self.body), - ('head', self.head)])) + model = nn.Sequential( + OrderedDict([("stem", self.stem), ("body", self.body), ("head", self.head)]) + ) self._init_cnn(model) model.extra_repr = lambda: f"{self.name}" return model def __repr__(self): - return (f"{self.name} constructor\n" - f" in_chans: {self.in_chans}, num_classes: {self.num_classes}\n" - f" expansion: {self.expansion}, groups: {self.groups}, dw: {self.dw}, div_groups: {self.div_groups}\n" - f" sa: {self.sa}, se: {self.se}\n" - f" stem sizes: {self.stem_sizes}, stride on {self.stem_stride_on}\n" - f" body sizes {self._block_sizes}\n" - f" layers: {self.layers}") - - -xresnet34 = partial(ModelConstructor, name='xresnet34', expansion=1, layers=[3, 4, 6, 3]) -xresnet50 = partial(ModelConstructor, name='xresnet34', expansion=4, layers=[3, 4, 6, 3]) + return ( + f"{self.name} constructor\n" + f" in_chans: {self.in_chans}, num_classes: {self.num_classes}\n" + f" expansion: {self.expansion}, groups: {self.groups}, dw: {self.dw}, div_groups: {self.div_groups}\n" + f" sa: {self.sa}, se: {self.se}\n" + f" stem sizes: {self.stem_sizes}, stride on {self.stem_stride_on}\n" + f" body sizes {self._block_sizes}\n" + f" layers: {self.layers}" + ) + + +xresnet34 = partial( + ModelConstructor, name="xresnet34", expansion=1, layers=[3, 4, 6, 3] +) +xresnet50 = partial( + ModelConstructor, name="xresnet34", expansion=4, layers=[3, 4, 6, 3] +) From d8af738d98787b7699870a968bc6316287ab278c Mon Sep 17 00:00:00 2001 From: ayasyrev Date: Mon, 24 Oct 2022 13:13:07 +0300 Subject: [PATCH 6/6] black layers --- model_constructor/layers.py | 210 +++++++++++++++++++++++------------- 1 file changed, 136 insertions(+), 74 deletions(-) diff --git a/model_constructor/layers.py b/model_constructor/layers.py index 88ea54b..314f782 100644 --- a/model_constructor/layers.py +++ b/model_constructor/layers.py @@ -5,12 +5,22 @@ import torch.nn as nn from torch.nn.utils.spectral_norm import spectral_norm -__all__ = ['Flatten', 'noop', 'Noop', 'ConvLayer', 'act_fn', - 'conv1d', 'SimpleSelfAttention', 'SEBlock', 'SEBlockConv'] +__all__ = [ + "Flatten", + "noop", + "Noop", + "ConvLayer", + "act_fn", + "conv1d", + "SimpleSelfAttention", + "SEBlock", + "SEBlockConv", +] class Flatten(nn.Module): - '''flat x to vector''' + """flat x to vector""" + def __init__(self): super().__init__() @@ -19,12 +29,13 @@ def forward(self, x): def noop(x): - '''Dummy func. Return input''' + """Dummy func. Return input""" return x class Noop(nn.Module): - '''Dummy module''' + """Dummy module""" + def __init__(self): super().__init__() @@ -37,6 +48,7 @@ def forward(self, x): class ConvBnAct(nn.Sequential): """Basic Conv + Bn + Act block""" + convolution_module = nn.Conv2d # can be changed in models like twist. batchnorm_module = nn.BatchNorm2d @@ -59,13 +71,23 @@ def __init__( if padding is None: padding = kernel_size // 2 layers: List[tuple[str, nn.Module]] = [ - ('conv', self.convolution_module( - in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias, groups=groups)) + ( + "conv", + self.convolution_module( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + groups=groups, + ), + ) ] # if no bn - bias True? if bn_layer: bn = self.batchnorm_module(out_channels) - nn.init.constant_(bn.weight, 0. if zero_bn else 1.) - layers.append(('bn', bn)) + nn.init.constant_(bn.weight, 0.0 if zero_bn else 1.0) + layers.append(("bn", bn)) if isinstance(act_fn, nn.Module): # act_fn either nn.Module or False if pre_act: act_position = 0 @@ -73,40 +95,62 @@ def __init__( act_position = 1 else: act_position = len(layers) - layers.insert(act_position, ('act_fn', act_fn)) + layers.insert(act_position, ("act_fn", act_fn)) super().__init__(OrderedDict(layers)) # NOTE First version. Leaved for backwards compatibility with old blocks, models. class ConvLayer(nn.Sequential): """Basic conv layers block""" + Conv2d = nn.Conv2d - def __init__(self, ni, nf, ks=3, stride=1, - act=True, act_fn=act_fn, - bn_layer=True, bn_1st=True, zero_bn=False, - padding=None, bias=False, groups=1, **kwargs): + def __init__( + self, + ni, + nf, + ks=3, + stride=1, + act=True, + act_fn=act_fn, + bn_layer=True, + bn_1st=True, + zero_bn=False, + padding=None, + bias=False, + groups=1, + **kwargs + ): if padding is None: padding = ks // 2 - layers = [('conv', self.Conv2d(ni, nf, ks, stride=stride, - padding=padding, bias=bias, groups=groups))] - act_bn = [('act_fn', act_fn)] if act else [] + layers = [ + ( + "conv", + self.Conv2d( + ni, nf, ks, stride=stride, padding=padding, bias=bias, groups=groups + ), + ) + ] + act_bn = [("act_fn", act_fn)] if act else [] if bn_layer: bn = nn.BatchNorm2d(nf) - nn.init.constant_(bn.weight, 0. if zero_bn else 1.) - act_bn += [('bn', bn)] + nn.init.constant_(bn.weight, 0.0 if zero_bn else 1.0) + act_bn += [("bn", bn)] if bn_1st: act_bn.reverse() layers += act_bn super().__init__(OrderedDict(layers)) + # Cell # SA module from mxresnet at fastai. todo - add persons!!! # Unmodified from https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py -def conv1d(ni: int, no: int, ks: int = 1, stride: int = 1, padding: int = 0, bias: bool = False): +def conv1d( + ni: int, no: int, ks: int = 1, stride: int = 1, padding: int = 0, bias: bool = False +): "Create and initialize a `nn.Conv1d` layer with spectral normalization." conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias) nn.init.kaiming_normal_(conv.weight) @@ -116,16 +160,16 @@ def conv1d(ni: int, no: int, ks: int = 1, stride: int = 1, padding: int = 0, bia class SimpleSelfAttention(nn.Module): - '''SimpleSelfAttention module. # noqa W291 - Adapted from SelfAttention layer at - https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py - Inspired by https://arxiv.org/pdf/1805.08318.pdf - ''' + """SimpleSelfAttention module. # noqa W291 + Adapted from SelfAttention layer at + https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py + Inspired by https://arxiv.org/pdf/1805.08318.pdf + """ def __init__(self, n_in: int, ks=1, sym=False, use_bias=False): super().__init__() self.conv = conv1d(n_in, n_in, ks, padding=ks // 2, bias=use_bias) - self.gamma = torch.nn.Parameter(torch.tensor([0.])) # type: ignore + self.gamma = torch.nn.Parameter(torch.tensor([0.0])) # type: ignore self.sym = sym self.n_in = n_in @@ -136,12 +180,14 @@ def forward(self, x): c = (c + c.t()) / 2 self.conv.weight = c.view(self.n_in, self.n_in, 1) size = x.size() - x = x.view(*size[:2], -1) # (C,N) + x = x.view(*size[:2], -1) # (C,N) # changed the order of multiplication to avoid O(N^2) complexity # (x*xT)*(W*x) instead of (x*(xT*(W*x))) - convx = self.conv(x) # (C,C) * (C,N) = (C,N) => O(NC^2) - xxT = torch.bmm(x, x.permute(0, 2, 1).contiguous()) # (C,N) * (N,C) = (C,C) => O(NC^2) - o = torch.bmm(xxT, convx) # (C,C) * (C,N) = (C,N) => O(NC^2) + convx = self.conv(x) # (C,C) * (C,N) = (C,N) => O(NC^2) + xxT = torch.bmm( + x, x.permute(0, 2, 1).contiguous() + ) # (C,N) * (N,C) = (C,C) => O(NC^2) + o = torch.bmm(xxT, convx) # (C,C) * (C,N) = (C,N) => O(NC^2) o = self.gamma * o + x return o.view(*size).contiguous() @@ -157,11 +203,15 @@ def __init__(self, c, r=16): ch = max(c // r, 1) self.squeeze = nn.AdaptiveAvgPool2d(1) self.excitation = nn.Sequential( - OrderedDict([('fc_reduce', self.se_layer(c, ch, bias=self.use_bias)), - ('se_act', self.act_fn), - ('fc_expand', self.se_layer(ch, c, bias=self.use_bias)), - ('sigmoid', nn.Sigmoid()) - ])) + OrderedDict( + [ + ("fc_reduce", self.se_layer(c, ch, bias=self.use_bias)), + ("se_act", self.act_fn), + ("fc_expand", self.se_layer(ch, c, bias=self.use_bias)), + ("sigmoid", nn.Sigmoid()), + ] + ) + ) def forward(self, x): bs, c, _, _ = x.shape @@ -178,16 +228,19 @@ class SEBlockConv(nn.Module): # todo: deprecation warning. def __init__(self, c, r=16): super().__init__() -# c_in = math.ceil(c//r/8)*8 + # c_in = math.ceil(c//r/8)*8 c_in = max(c // r, 1) self.squeeze = nn.AdaptiveAvgPool2d(1) self.excitation = nn.Sequential( - OrderedDict([ - ('conv_reduce', self.se_layer(c, c_in, 1, bias=self.use_bias)), - ('se_act', self.act_fn), - ('conv_expand', self.se_layer(c_in, c, 1, bias=self.use_bias)), - ('sigmoid', nn.Sigmoid()) - ])) + OrderedDict( + [ + ("conv_reduce", self.se_layer(c, c_in, 1, bias=self.use_bias)), + ("se_act", self.act_fn), + ("conv_expand", self.se_layer(c_in, c, 1, bias=self.use_bias)), + ("sigmoid", nn.Sigmoid()), + ] + ) + ) def forward(self, x): y = self.squeeze(x) @@ -198,16 +251,17 @@ def forward(self, x): class SEModule(nn.Module): "se block" - def __init__(self, - channels, - reduction=16, - rd_channels=None, - rd_max=False, - se_layer=nn.Linear, - act_fn=nn.ReLU(inplace=True), - use_bias=True, - gate=nn.Sigmoid - ): + def __init__( + self, + channels, + reduction=16, + rd_channels=None, + rd_max=False, + se_layer=nn.Linear, + act_fn=nn.ReLU(inplace=True), + use_bias=True, + gate=nn.Sigmoid, + ): super().__init__() reducted = max(channels // reduction, 1) # preserve zero-element tensors if rd_channels is None: @@ -217,11 +271,15 @@ def __init__(self, rd_channels = max(rd_channels, reducted) self.squeeze = nn.AdaptiveAvgPool2d(1) self.excitation = nn.Sequential( - OrderedDict([('reduce', se_layer(channels, rd_channels, bias=use_bias)), - ('se_act', act_fn), - ('expand', se_layer(rd_channels, channels, bias=use_bias)), - ('se_gate', gate()) - ])) + OrderedDict( + [ + ("reduce", se_layer(channels, rd_channels, bias=use_bias)), + ("se_act", act_fn), + ("expand", se_layer(rd_channels, channels, bias=use_bias)), + ("se_gate", gate()), + ] + ) + ) def forward(self, x): bs, c, _, _ = x.shape @@ -233,18 +291,19 @@ def forward(self, x): class SEModuleConv(nn.Module): "se block with conv on excitation" - def __init__(self, - channels, - reduction=16, - rd_channels=None, - rd_max=False, - se_layer=nn.Conv2d, - act_fn=nn.ReLU(inplace=True), - use_bias=True, - gate=nn.Sigmoid - ): + def __init__( + self, + channels, + reduction=16, + rd_channels=None, + rd_max=False, + se_layer=nn.Conv2d, + act_fn=nn.ReLU(inplace=True), + use_bias=True, + gate=nn.Sigmoid, + ): super().__init__() -# rd_channels = math.ceil(channels//reduction/8)*8 + # rd_channels = math.ceil(channels//reduction/8)*8 reducted = max(channels // reduction, 1) # preserve zero-element tensors if rd_channels is None: rd_channels = reducted @@ -253,12 +312,15 @@ def __init__(self, rd_channels = max(rd_channels, reducted) self.squeeze = nn.AdaptiveAvgPool2d(1) self.excitation = nn.Sequential( - OrderedDict([ - ('reduce', se_layer(channels, rd_channels, 1, bias=use_bias)), - ('se_act', act_fn), - ('expand', se_layer(rd_channels, channels, 1, bias=use_bias)), - ('gate', gate()) - ])) + OrderedDict( + [ + ("reduce", se_layer(channels, rd_channels, 1, bias=use_bias)), + ("se_act", act_fn), + ("expand", se_layer(rd_channels, channels, 1, bias=use_bias)), + ("gate", gate()), + ] + ) + ) def forward(self, x): y = self.squeeze(x)