Skip to content
Merged

Cfg #62

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
461 changes: 416 additions & 45 deletions Nbs/00_ModelConstructor.ipynb

Large diffs are not rendered by default.

412 changes: 310 additions & 102 deletions Nbs/index.ipynb

Large diffs are not rendered by default.

290 changes: 169 additions & 121 deletions README.md

Large diffs are not rendered by default.

312 changes: 302 additions & 10 deletions docs/00_ModelConstructor.md

Large diffs are not rendered by default.

195 changes: 145 additions & 50 deletions docs/index.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/model_constructor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from model_constructor.convmixer import ConvMixer # noqa F401
from model_constructor.model_constructor import ModelConstructor, ResBlock # noqa F401
from model_constructor.model_constructor import ModelConstructor, ResBlock, CfgMC # noqa F401

from model_constructor.version import __version__ # noqa F401
254 changes: 113 additions & 141 deletions src/model_constructor/model_constructor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass, field, asdict

from collections import OrderedDict
from functools import partial
from typing import Callable, List, Type, Union
# from functools import partial
from typing import Callable, List, Optional, Type, Union

import torch.nn as nn

Expand All @@ -12,24 +14,14 @@
"act_fn",
"ResBlock",
"ModelConstructor",
"xresnet34",
"xresnet50",
# "xresnet34",
# "xresnet50",
]


act_fn = nn.ReLU(inplace=True)


def init_cnn(module: nn.Module):
"Init module - kaiming_normal for Conv2d and 0 for biases."
if getattr(module, "bias", None) is not None:
nn.init.constant_(module.bias, 0) # type: ignore
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(module.weight)
for layer in module.children():
init_cnn(layer)


class ResBlock(nn.Module):
"""Universal Resnet block. Basic block if expansion is 1, otherwise is Bottleneck."""

Expand Down Expand Up @@ -130,10 +122,55 @@ def forward(self, x):
return self.act_fn(self.convs(x) + identity)


def _make_stem(self):
stem = [
@dataclass
class CfgMC:
"""Model constructor Config. As default - xresnet18"""

name: str = "MC"
in_chans: int = 3
num_classes: int = 1000
block: Type[nn.Module] = ResBlock
conv_layer: Type[nn.Module] = ConvBnAct
block_sizes: List[int] = field(default_factory=lambda: [64, 128, 256, 512])
layers: List[int] = field(default_factory=lambda: [2, 2, 2, 2])
norm: Type[nn.Module] = nn.BatchNorm2d
act_fn: nn.Module = nn.ReLU(inplace=True)
pool: nn.Module = nn.AvgPool2d(2, ceil_mode=True)
expansion: int = 1
groups: int = 1
dw: bool = False
div_groups: Union[int, None] = None
sa: Union[bool, int, Type[nn.Module]] = False
se: Union[bool, int, Type[nn.Module]] = False
se_module = None
se_reduction = None
bn_1st: bool = True
zero_bn: bool = True
stem_stride_on: int = 0
stem_sizes: List[int] = field(default_factory=lambda: [32, 32, 64])
stem_pool: Union[nn.Module, None] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # type: ignore
stem_bn_end: bool = False
_init_cnn: Optional[Callable[[nn.Module], None]] = field(repr=False, default=None)
_make_stem: Optional[Callable] = field(repr=False, default=None)
_make_layer: Optional[Callable] = field(repr=False, default=None)
_make_body: Optional[Callable] = field(repr=False, default=None)
_make_head: Optional[Callable] = field(repr=False, default=None)


def init_cnn(module: nn.Module):
"Init module - kaiming_normal for Conv2d and 0 for biases."
if getattr(module, "bias", None) is not None:
nn.init.constant_(module.bias, 0) # type: ignore
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(module.weight)
for layer in module.children():
init_cnn(layer)


def make_stem(self: CfgMC) -> nn.Sequential:
stem: List[tuple[str, nn.Module]] = [
(f"conv_{i}", self.conv_layer(
self.stem_sizes[i],
self.stem_sizes[i], # type: ignore
self.stem_sizes[i + 1],
stride=2 if i == self.stem_stride_on else 1,
bn_layer=(not self.stem_bn_end)
Expand All @@ -147,39 +184,38 @@ def _make_stem(self):
if self.stem_pool:
stem.append(("stem_pool", self.stem_pool))
if self.stem_bn_end:
stem.append(("norm", self.norm(self.stem_sizes[-1])))
stem.append(("norm", self.norm(self.stem_sizes[-1]))) # type: ignore
return nn.Sequential(OrderedDict(stem))


def _make_layer(self, layer_num: int) -> nn.Module:
def make_layer(cfg: CfgMC, layer_num: int) -> nn.Sequential:
# expansion, in_channels, out_channels, blocks, stride, sa):
# if no pool on stem - stride = 2 for first layer block in body
stride = 1 if self.stem_pool and layer_num == 0 else 2
num_blocks = self.layers[layer_num]
stride = 1 if cfg.stem_pool and layer_num == 0 else 2
num_blocks = cfg.layers[layer_num]
block_chs = [cfg.stem_sizes[-1] // cfg.expansion] + cfg.block_sizes
return nn.Sequential(
OrderedDict(
[
(
f"bl_{block_num}",
self.block(
self.expansion,
self.block_sizes[layer_num]
if block_num == 0
else self.block_sizes[layer_num + 1],
self.block_sizes[layer_num + 1],
cfg.block(
cfg.expansion, # type: ignore
block_chs[layer_num] if block_num == 0 else block_chs[layer_num + 1],
block_chs[layer_num + 1],
stride if block_num == 0 else 1,
sa=self.sa
sa=cfg.sa
if (block_num == num_blocks - 1) and layer_num == 0
else None,
conv_layer=self.conv_layer,
act_fn=self.act_fn,
pool=self.pool,
zero_bn=self.zero_bn,
bn_1st=self.bn_1st,
groups=self.groups,
div_groups=self.div_groups,
dw=self.dw,
se=self.se,
conv_layer=cfg.conv_layer,
act_fn=cfg.act_fn,
pool=cfg.pool,
zero_bn=cfg.zero_bn,
bn_1st=cfg.bn_1st,
groups=cfg.groups,
div_groups=cfg.div_groups,
dw=cfg.dw,
se=cfg.se,
),
)
for block_num in range(num_blocks)
Expand All @@ -188,160 +224,96 @@ def _make_layer(self, layer_num: int) -> nn.Module:
)


def _make_body(self):
def make_body(cfg: CfgMC) -> nn.Sequential:
return nn.Sequential(
OrderedDict(
[
(
f"l_{layer_num}",
self._make_layer(self, layer_num)
cfg._make_layer(cfg, layer_num) # type: ignore
)
for layer_num in range(len(self.layers))
for layer_num in range(len(cfg.layers))
]
)
)


def _make_head(self):
def make_head(cfg: CfgMC) -> nn.Sequential:
head = [
("pool", nn.AdaptiveAvgPool2d(1)),
("flat", nn.Flatten()),
("fc", nn.Linear(self.block_sizes[-1] * self.expansion, self.num_classes)),
("fc", nn.Linear(cfg.block_sizes[-1] * cfg.expansion, cfg.num_classes)),
]
return nn.Sequential(OrderedDict(head))


class ModelConstructor:
@dataclass
class ModelConstructor(CfgMC):
"""Model constructor. As default - xresnet18"""

def __init__(
self,
name: str = "MC",
in_chans: int = 3,
num_classes: int = 1000,
block=ResBlock,
conv_layer=ConvBnAct,
block_sizes: List[int] = [64, 128, 256, 512],
layers: List[int] = [2, 2, 2, 2],
norm: Type[nn.Module] = nn.BatchNorm2d,
act_fn: nn.Module = nn.ReLU(inplace=True),
pool: nn.Module = nn.AvgPool2d(2, ceil_mode=True),
expansion: int = 1,
groups: int = 1,
dw: bool = False,
div_groups: Union[int, None] = None,
sa: Union[bool, int, Type[nn.Module]] = False,
se: Union[bool, int, Type[nn.Module]] = False,
se_module=None,
se_reduction=None,
bn_1st: bool = True,
zero_bn: bool = True,
stem_stride_on: int = 0,
stem_sizes: List[int] = [32, 32, 64],
stem_pool: Union[Type[nn.Module], None] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1), # type: ignore
stem_bn_end: bool = False,
_init_cnn: Callable = init_cnn,
_make_stem: Callable = _make_stem,
_make_layer: Callable = _make_layer,
_make_body: Callable = _make_body,
_make_head: Callable = _make_head,
):
super().__init__()
# se can be bool, int (0, 1) or nn.Module
# se_module - deprecated. Leaved for warning and checks.
# if stem_pool is False - no pool at stem

self.name = name
self.in_chans = in_chans
self.num_classes = num_classes
self.block = block
self.conv_layer = conv_layer
self._block_sizes = block_sizes
self.layers = layers
self.norm = norm
self.act_fn = act_fn
self.pool = pool
self.expansion = expansion
self.groups = groups
self.dw = dw
self.div_groups = div_groups
# se_module
# se_reduction
self.bn_1st = bn_1st
self.zero_bn = zero_bn
self.stem_stride_on = stem_stride_on
self.stem_pool = stem_pool
self.stem_bn_end = stem_bn_end
self._init_cnn = _init_cnn
self._make_stem = _make_stem
self._make_layer = _make_layer
self._make_body = _make_body
self._make_head = _make_head

# params = locals()
# del params['self']
# self.__dict__ = params

# self._block_sizes = params['block_sizes']
self.stem_sizes = stem_sizes
def __post_init__(self):
if self._init_cnn is None:
self._init_cnn = init_cnn
if self._make_stem is None:
self._make_stem = make_stem
if self._make_layer is None:
self._make_layer = make_layer
if self._make_body is None:
self._make_body = make_body
if self._make_head is None:
self._make_head = make_head

if self.stem_sizes[0] != self.in_chans:
self.stem_sizes = [self.in_chans] + self.stem_sizes
self.se = se
if self.se:
if type(self.se) in (bool, int): # if se=1 or se=True
self.se = SEModule
else:
self.se = se # TODO add check issubclass or isinstance of nn.Module
self.sa = sa
if self.sa: # if sa=1 or sa=True
if type(self.sa) in (bool, int):
self.sa = SimpleSelfAttention # default: ks=1, sym=sym
else:
self.sa = sa
if se_module or se_reduction: # pragma: no cover
if self.se and isinstance(self.se, (bool, int)): # if se=1 or se=True
self.se = SEModule
if self.sa and isinstance(self.sa, (bool, int)): # if sa=1 or sa=True
self.sa = SimpleSelfAttention # default: ks=1, sym=sym
if self.se_module or self.se_reduction: # pragma: no cover
print(
"Deprecated. Pass se_module as se argument, se_reduction as arg to se."
) # add deprecation warning.

@property
def block_sizes(self):
return [self.stem_sizes[-1] // self.expansion] + self._block_sizes

@property
def stem(self):
return self._make_stem(self)
return self._make_stem(self) # type: ignore

@property
def head(self):
return self._make_head(self)
return self._make_head(self) # type: ignore

@property
def body(self):
return self._make_body(self)
return self._make_body(self) # type: ignore

@classmethod
def from_cfg(cls, cfg: CfgMC):
return cls(**asdict(cfg))

def __call__(self):
model = nn.Sequential(
OrderedDict([("stem", self.stem), ("body", self.body), ("head", self.head)])
)
self._init_cnn(model)
self._init_cnn(model) # type: ignore
model.extra_repr = lambda: f"{self.name}"
return model

def __repr__(self):
return (
def print_cfg(self):
print(
f"{self.name} constructor\n"
f" in_chans: {self.in_chans}, num_classes: {self.num_classes}\n"
f" expansion: {self.expansion}, groups: {self.groups}, dw: {self.dw}, div_groups: {self.div_groups}\n"
f" sa: {self.sa}, se: {self.se}\n"
f" stem sizes: {self.stem_sizes}, stride on {self.stem_stride_on}\n"
f" body sizes {self._block_sizes}\n"
f" body sizes {self.block_sizes}\n"
f" layers: {self.layers}"
)


xresnet34 = partial(
ModelConstructor, name="xresnet34", expansion=1, layers=[3, 4, 6, 3]
xresnet34 = ModelConstructor.from_cfg(
CfgMC(name="xresnet34", expansion=1, layers=[3, 4, 6, 3])
)
xresnet50 = partial(
ModelConstructor, name="xresnet34", expansion=4, layers=[3, 4, 6, 3]

xresnet50 = ModelConstructor.from_cfg(
CfgMC(name="xresnet34", expansion=4, layers=[3, 4, 6, 3])
)
6 changes: 5 additions & 1 deletion tests/test_mc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@ def test_MC():
"""test ModelConstructor"""
img_size = 16
mc = ModelConstructor()
assert "MC constructor" in str(mc)
assert "name='MC'" in str(mc)
model = mc()
xb = torch.randn(bs_test, 3, img_size, img_size)
pred = model(xb)
assert pred.shape == torch.Size([bs_test, 1000])
mc.expansion = 2
model = mc()
pred = model(xb)
assert pred.shape == torch.Size([bs_test, 1000])
num_classes = 10
mc.num_classes = num_classes
mc.se = SEModule
Expand Down