diff --git a/Nbs/00_ModelConstructor.ipynb b/Nbs/00_ModelConstructor.ipynb index 5a7914f..22477ba 100644 --- a/Nbs/00_ModelConstructor.ipynb +++ b/Nbs/00_ModelConstructor.ipynb @@ -31,18 +31,6 @@ "import torch.nn as nn" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# hide\n", - "def print_doc(func_name):\n", - " doc = show_doc(func_name, title_level=4, disp=False)\n", - " display(Markdown(doc))" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -60,16 +48,6 @@ "from model_constructor.model_constructor import ResBlock" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#hide_input\n", - "# print_doc(ResBlock)" - ] - }, { "cell_type": "code", "execution_count": null, @@ -217,6 +195,7 @@ } ], "source": [ + "# collapse_output\n", "block = ResBlock(2,64,64,act_fn=nn.LeakyReLU(), bn_1st=False)\n", "block" ] @@ -262,7 +241,8 @@ } ], "source": [ - "block = ResBlock(2, 32, 64, dw=True)\n", + "# collapse_output\n", + "lock = ResBlock(2, 32, 64, dw=True)\n", "block" ] }, @@ -317,6 +297,7 @@ } ], "source": [ + "# collapse_output\n", "block = ResBlock(2, 32, 64, stride=2, dw=True, pool=pool)\n", "block" ] @@ -357,9 +338,9 @@ " (se): SEModule(\n", " (squeeze): AdaptiveAvgPool2d(output_size=1)\n", " (excitation): Sequential(\n", - " (fc_reduce): Linear(in_features=128, out_features=8, bias=True)\n", + " (reduce): Linear(in_features=128, out_features=8, bias=True)\n", " (se_act): ReLU(inplace=True)\n", - " (fc_expand): Linear(in_features=8, out_features=128, bias=True)\n", + " (expand): Linear(in_features=8, out_features=128, bias=True)\n", " (se_gate): Sigmoid()\n", " )\n", " )\n", @@ -413,9 +394,9 @@ " (se): SEModule(\n", " (squeeze): AdaptiveAvgPool2d(output_size=1)\n", " (excitation): Sequential(\n", - " (fc_reduce): Linear(in_features=128, out_features=8, bias=True)\n", + " (reduce): Linear(in_features=128, out_features=8, bias=True)\n", " (se_act): ReLU(inplace=True)\n", - " (fc_expand): Linear(in_features=8, out_features=128, bias=True)\n", + " (expand): Linear(in_features=8, out_features=128, bias=True)\n", " (se_gate): Sigmoid()\n", " )\n", " )\n", @@ -449,7 +430,15 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Model Constructor." + "## Stem, Body, Layer, Head" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Helper functions to create stem, body and head of model from config. \n", + "Returns `nn.Sequential`. " ] }, { @@ -458,18 +447,85 @@ "metadata": {}, "outputs": [], "source": [ - "#hide\n", - "from model_constructor import ModelConstructor" + "from model_constructor.model_constructor import CfgMC, make_stem, make_body, make_layer, make_head\n", + "from rich import print" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
CfgMC(\n",
+       "    name='MC',\n",
+       "    in_chans=3,\n",
+       "    num_classes=1000,\n",
+       "    block=<class 'model_constructor.model_constructor.ResBlock'>,\n",
+       "    conv_layer=<class 'model_constructor.layers.ConvBnAct'>,\n",
+       "    block_sizes=[64, 128, 256, 512],\n",
+       "    layers=[2, 2, 2, 2],\n",
+       "    norm=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>,\n",
+       "    act_fn=ReLU(inplace=True),\n",
+       "    pool=AvgPool2d(kernel_size=2, stride=2, padding=0),\n",
+       "    expansion=1,\n",
+       "    groups=1,\n",
+       "    dw=False,\n",
+       "    div_groups=None,\n",
+       "    sa=False,\n",
+       "    se=False,\n",
+       "    bn_1st=True,\n",
+       "    zero_bn=True,\n",
+       "    stem_stride_on=0,\n",
+       "    stem_sizes=[32, 32, 64],\n",
+       "    stem_pool=MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),\n",
+       "    stem_bn_end=False\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mCfgMC\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mname\u001b[0m=\u001b[32m'MC'\u001b[0m,\n", + " \u001b[33min_chans\u001b[0m=\u001b[1;36m3\u001b[0m,\n", + " \u001b[33mnum_classes\u001b[0m=\u001b[1;36m1000\u001b[0m,\n", + " \u001b[33mblock\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mclass\u001b[0m\u001b[39m \u001b[0m\u001b[32m'model_constructor.model_constructor.ResBlock'\u001b[0m\u001b[1m>\u001b[0m,\n", + " \u001b[33mconv_layer\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mclass\u001b[0m\u001b[39m \u001b[0m\u001b[32m'model_constructor.layers.ConvBnAct'\u001b[0m\u001b[1m>\u001b[0m,\n", + " \u001b[33mblock_sizes\u001b[0m=\u001b[1m[\u001b[0m\u001b[1;36m64\u001b[0m, \u001b[1;36m128\u001b[0m, \u001b[1;36m256\u001b[0m, \u001b[1;36m512\u001b[0m\u001b[1m]\u001b[0m,\n", + " \u001b[33mlayers\u001b[0m=\u001b[1m[\u001b[0m\u001b[1;36m2\u001b[0m, \u001b[1;36m2\u001b[0m, \u001b[1;36m2\u001b[0m, \u001b[1;36m2\u001b[0m\u001b[1m]\u001b[0m,\n", + " \u001b[33mnorm\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mclass\u001b[0m\u001b[39m \u001b[0m\u001b[32m'torch.nn.modules.batchnorm.BatchNorm2d'\u001b[0m\u001b[1m>\u001b[0m,\n", + " \u001b[33mact_fn\u001b[0m=\u001b[1;35mReLU\u001b[0m\u001b[1m(\u001b[0m\u001b[33minplace\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", + " \u001b[33mpool\u001b[0m=\u001b[1;35mAvgPool2d\u001b[0m\u001b[1m(\u001b[0m\u001b[33mkernel_size\u001b[0m=\u001b[1;36m2\u001b[0m, \u001b[33mstride\u001b[0m=\u001b[1;36m2\u001b[0m, \u001b[33mpadding\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m,\n", + " \u001b[33mexpansion\u001b[0m=\u001b[1;36m1\u001b[0m,\n", + " \u001b[33mgroups\u001b[0m=\u001b[1;36m1\u001b[0m,\n", + " \u001b[33mdw\u001b[0m=\u001b[3;91mFalse\u001b[0m,\n", + " \u001b[33mdiv_groups\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33msa\u001b[0m=\u001b[3;91mFalse\u001b[0m,\n", + " \u001b[33mse\u001b[0m=\u001b[3;91mFalse\u001b[0m,\n", + " \u001b[33mbn_1st\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33mzero_bn\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33mstem_stride_on\u001b[0m=\u001b[1;36m0\u001b[0m,\n", + " \u001b[33mstem_sizes\u001b[0m=\u001b[1m[\u001b[0m\u001b[1;36m32\u001b[0m, \u001b[1;36m32\u001b[0m, \u001b[1;36m64\u001b[0m\u001b[1m]\u001b[0m,\n", + " \u001b[33mstem_pool\u001b[0m=\u001b[1;35mMaxPool2d\u001b[0m\u001b[1m(\u001b[0m\u001b[33mkernel_size\u001b[0m=\u001b[1;36m3\u001b[0m, \u001b[33mstride\u001b[0m=\u001b[1;36m2\u001b[0m, \u001b[33mpadding\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mdilation\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mceil_mode\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m,\n", + " \u001b[33mstem_bn_end\u001b[0m=\u001b[3;91mFalse\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "cfg = CfgMC()\n", + "print(cfg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, "source": [ - "#hide_input\n", - "# print_doc(ModelConstructor)" + "### Stem" ] }, { @@ -480,13 +536,19 @@ { "data": { "text/plain": [ - "MC constructor\n", - " in_chans: 3, num_classes: 1000\n", - " expansion: 1, groups: 1, dw: False, div_groups: None\n", - " sa: False, se: False\n", - " stem sizes: [3, 32, 32, 64], stride on 0\n", - " body sizes [64, 128, 256, 512]\n", - " layers: [2, 2, 2, 2]" + "Sequential(\n", + " (conv_0): ConvBnAct(\n", + " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (act_fn): ReLU(inplace=True)\n", + " )\n", + " (conv_1): ConvBnAct(\n", + " (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (act_fn): ReLU(inplace=True)\n", + " )\n", + " (stem_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", + ")" ] }, "execution_count": null, @@ -495,8 +557,293 @@ } ], "source": [ - "mc = ModelConstructor()\n", - "mc" + "# collapse_output\n", + "stem = make_stem(cfg)\n", + "stem" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Layer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`make_layer` need `layer_num` argument - number of layer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Sequential(\n", + " (bl_0): ResBlock(\n", + " (convs): Sequential(\n", + " (conv_0): ConvBnAct(\n", + " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (act_fn): ReLU(inplace=True)\n", + " )\n", + " (conv_1): ConvBnAct(\n", + " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (act_fn): ReLU(inplace=True)\n", + " )\n", + " (bl_1): ResBlock(\n", + " (convs): Sequential(\n", + " (conv_0): ConvBnAct(\n", + " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (act_fn): ReLU(inplace=True)\n", + " )\n", + " (conv_1): ConvBnAct(\n", + " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (act_fn): ReLU(inplace=True)\n", + " )\n", + ")" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# collapse_output\n", + "layer = make_layer(cfg, layer_num=0)\n", + "layer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Body" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`make_body` needs `cfg._make_layer` initialized. As default - `make_layer`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Sequential(\n", + " (l_0): Sequential(\n", + " (bl_0): ResBlock(\n", + " (convs): Sequential(\n", + " (conv_0): ConvBnAct(\n", + " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (act_fn): ReLU(inplace=True)\n", + " )\n", + " (conv_1): ConvBnAct(\n", + " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (act_fn): ReLU(inplace=True)\n", + " )\n", + " (bl_1): ResBlock(\n", + " (convs): Sequential(\n", + " (conv_0): ConvBnAct(\n", + " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (act_fn): ReLU(inplace=True)\n", + " )\n", + " (conv_1): ConvBnAct(\n", + " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (act_fn): ReLU(inplace=True)\n", + " )\n", + " )\n", + " (l_1): Sequential(\n", + " (bl_0): ResBlock(\n", + " (convs): Sequential(\n", + " (conv_0): ConvBnAct(\n", + " (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (act_fn): ReLU(inplace=True)\n", + " )\n", + " (conv_1): ConvBnAct(\n", + " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (id_conv): Sequential(\n", + " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", + " (id_conv): ConvBnAct(\n", + " (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (act_fn): ReLU(inplace=True)\n", + " )\n", + " (bl_1): ResBlock(\n", + " (convs): Sequential(\n", + " (conv_0): ConvBnAct(\n", + " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (act_fn): ReLU(inplace=True)\n", + " )\n", + " (conv_1): ConvBnAct(\n", + " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (act_fn): ReLU(inplace=True)\n", + " )\n", + " )\n", + " (l_2): Sequential(\n", + " (bl_0): ResBlock(\n", + " (convs): Sequential(\n", + " (conv_0): ConvBnAct(\n", + " (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (act_fn): ReLU(inplace=True)\n", + " )\n", + " (conv_1): ConvBnAct(\n", + " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (id_conv): Sequential(\n", + " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", + " (id_conv): ConvBnAct(\n", + " (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (act_fn): ReLU(inplace=True)\n", + " )\n", + " (bl_1): ResBlock(\n", + " (convs): Sequential(\n", + " (conv_0): ConvBnAct(\n", + " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (act_fn): ReLU(inplace=True)\n", + " )\n", + " (conv_1): ConvBnAct(\n", + " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (act_fn): ReLU(inplace=True)\n", + " )\n", + " )\n", + " (l_3): Sequential(\n", + " (bl_0): ResBlock(\n", + " (convs): Sequential(\n", + " (conv_0): ConvBnAct(\n", + " (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (act_fn): ReLU(inplace=True)\n", + " )\n", + " (conv_1): ConvBnAct(\n", + " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (id_conv): Sequential(\n", + " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", + " (id_conv): ConvBnAct(\n", + " (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (act_fn): ReLU(inplace=True)\n", + " )\n", + " (bl_1): ResBlock(\n", + " (convs): Sequential(\n", + " (conv_0): ConvBnAct(\n", + " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (act_fn): ReLU(inplace=True)\n", + " )\n", + " (conv_1): ConvBnAct(\n", + " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (act_fn): ReLU(inplace=True)\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# collapse_output\n", + "cfg._make_layer = make_layer\n", + "body = make_body(cfg)\n", + "body" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Head" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Sequential(\n", + " (pool): AdaptiveAvgPool2d(output_size=1)\n", + " (flat): Flatten(start_dim=1, end_dim=-1)\n", + " (fc): Linear(in_features=512, out_features=1000, bias=True)\n", + ")" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# collapse_output\n", + "head = make_head(cfg)\n", + "head" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model Constructor." ] }, { @@ -506,7 +853,31 @@ "outputs": [], "source": [ "#hide\n", - "model = mc()" + "from model_constructor import ModelConstructor" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MC constructor\n", + " in_chans: 3, num_classes: 1000\n", + " expansion: 1, groups: 1, dw: False, div_groups: None\n", + " sa: False, se: False\n", + " stem sizes: [3, 32, 32, 64], stride on 0\n", + " body sizes [64, 128, 256, 512]\n", + " layers: [2, 2, 2, 2]\n" + ] + } + ], + "source": [ + "mc = ModelConstructor()\n", + "mc.print_cfg()" ] }, { @@ -543,7 +914,7 @@ } ], "source": [ - "#collapse_output\n", + "# collapse_output\n", "mc.stem" ] }, @@ -581,7 +952,7 @@ } ], "source": [ - "#collapse_output\n", + "# collapse_output\n", "mc.stem_stride_on = 1\n", "mc.stem" ] diff --git a/Nbs/index.ipynb b/Nbs/index.ipynb index b12ba14..d9b0095 100644 --- a/Nbs/index.ipynb +++ b/Nbs/index.ipynb @@ -92,6 +92,44 @@ "mc = ModelConstructor()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Check base parameters with `print_cfg` method:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MC constructor\n", + " in_chans: 3, num_classes: 1000\n", + " expansion: 1, groups: 1, dw: False, div_groups: None\n", + " sa: False, se: False\n", + " stem sizes: [3, 32, 32, 64], stride on 0\n", + " body sizes [64, 128, 256, 512]\n", + " layers: [2, 2, 2, 2]\n" + ] + } + ], + "source": [ + "mc.print_cfg()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "ModelConstructor based on dataclass. Repr will show all parameters. \n", + "Better look at it with `rich.print` " + ] + }, { "cell_type": "code", "execution_count": null, @@ -99,23 +137,67 @@ "outputs": [ { "data": { + "text/html": [ + "
ModelConstructor(\n",
+       "    name='MC',\n",
+       "    in_chans=3,\n",
+       "    num_classes=1000,\n",
+       "    block=<class 'model_constructor.model_constructor.ResBlock'>,\n",
+       "    conv_layer=<class 'model_constructor.layers.ConvBnAct'>,\n",
+       "    block_sizes=[64, 128, 256, 512],\n",
+       "    layers=[2, 2, 2, 2],\n",
+       "    norm=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>,\n",
+       "    act_fn=ReLU(inplace=True),\n",
+       "    pool=AvgPool2d(kernel_size=2, stride=2, padding=0),\n",
+       "    expansion=1,\n",
+       "    groups=1,\n",
+       "    dw=False,\n",
+       "    div_groups=None,\n",
+       "    sa=False,\n",
+       "    se=False,\n",
+       "    bn_1st=True,\n",
+       "    zero_bn=True,\n",
+       "    stem_stride_on=0,\n",
+       "    stem_sizes=[3, 32, 32, 64],\n",
+       "    stem_pool=MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),\n",
+       "    stem_bn_end=False\n",
+       ")\n",
+       "
\n" + ], "text/plain": [ - "MC constructor\n", - " in_chans: 3, num_classes: 1000\n", - " expansion: 1, groups: 1, dw: False, div_groups: None\n", - " sa: False, se: False\n", - " stem sizes: [3, 32, 32, 64], stride on 0\n", - " body sizes [64, 128, 256, 512]\n", - " layers: [2, 2, 2, 2]" + "\u001b[1;35mModelConstructor\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mname\u001b[0m=\u001b[32m'MC'\u001b[0m,\n", + " \u001b[33min_chans\u001b[0m=\u001b[1;36m3\u001b[0m,\n", + " \u001b[33mnum_classes\u001b[0m=\u001b[1;36m1000\u001b[0m,\n", + " \u001b[33mblock\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mclass\u001b[0m\u001b[39m \u001b[0m\u001b[32m'model_constructor.model_constructor.ResBlock'\u001b[0m\u001b[1m>\u001b[0m,\n", + " \u001b[33mconv_layer\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mclass\u001b[0m\u001b[39m \u001b[0m\u001b[32m'model_constructor.layers.ConvBnAct'\u001b[0m\u001b[1m>\u001b[0m,\n", + " \u001b[33mblock_sizes\u001b[0m=\u001b[1m[\u001b[0m\u001b[1;36m64\u001b[0m, \u001b[1;36m128\u001b[0m, \u001b[1;36m256\u001b[0m, \u001b[1;36m512\u001b[0m\u001b[1m]\u001b[0m,\n", + " \u001b[33mlayers\u001b[0m=\u001b[1m[\u001b[0m\u001b[1;36m2\u001b[0m, \u001b[1;36m2\u001b[0m, \u001b[1;36m2\u001b[0m, \u001b[1;36m2\u001b[0m\u001b[1m]\u001b[0m,\n", + " \u001b[33mnorm\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mclass\u001b[0m\u001b[39m \u001b[0m\u001b[32m'torch.nn.modules.batchnorm.BatchNorm2d'\u001b[0m\u001b[1m>\u001b[0m,\n", + " \u001b[33mact_fn\u001b[0m=\u001b[1;35mReLU\u001b[0m\u001b[1m(\u001b[0m\u001b[33minplace\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", + " \u001b[33mpool\u001b[0m=\u001b[1;35mAvgPool2d\u001b[0m\u001b[1m(\u001b[0m\u001b[33mkernel_size\u001b[0m=\u001b[1;36m2\u001b[0m, \u001b[33mstride\u001b[0m=\u001b[1;36m2\u001b[0m, \u001b[33mpadding\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m,\n", + " \u001b[33mexpansion\u001b[0m=\u001b[1;36m1\u001b[0m,\n", + " \u001b[33mgroups\u001b[0m=\u001b[1;36m1\u001b[0m,\n", + " \u001b[33mdw\u001b[0m=\u001b[3;91mFalse\u001b[0m,\n", + " \u001b[33mdiv_groups\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33msa\u001b[0m=\u001b[3;91mFalse\u001b[0m,\n", + " \u001b[33mse\u001b[0m=\u001b[3;91mFalse\u001b[0m,\n", + " \u001b[33mbn_1st\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33mzero_bn\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33mstem_stride_on\u001b[0m=\u001b[1;36m0\u001b[0m,\n", + " \u001b[33mstem_sizes\u001b[0m=\u001b[1m[\u001b[0m\u001b[1;36m3\u001b[0m, \u001b[1;36m32\u001b[0m, \u001b[1;36m32\u001b[0m, \u001b[1;36m64\u001b[0m\u001b[1m]\u001b[0m,\n", + " \u001b[33mstem_pool\u001b[0m=\u001b[1;35mMaxPool2d\u001b[0m\u001b[1m(\u001b[0m\u001b[33mkernel_size\u001b[0m=\u001b[1;36m3\u001b[0m, \u001b[33mstride\u001b[0m=\u001b[1;36m2\u001b[0m, \u001b[33mpadding\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mdilation\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mceil_mode\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m,\n", + " \u001b[33mstem_bn_end\u001b[0m=\u001b[3;91mFalse\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" ] }, - "execution_count": null, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "mc" + "from rich import print\n", + "print(mc)" ] }, { @@ -337,7 +419,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Now we can look at model body and if we call constructor - we have pytorch model!" + "Now we can look at model parts - stem, body, head. " ] }, { @@ -701,6 +783,130 @@ "mc.body" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create constructor from config." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Alternative we can create config first and than create constructor from it. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from model_constructor import CfgMC" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
CfgMC(\n",
+       "    name='MC',\n",
+       "    in_chans=3,\n",
+       "    num_classes=1000,\n",
+       "    block=<class 'model_constructor.model_constructor.ResBlock'>,\n",
+       "    conv_layer=<class 'model_constructor.layers.ConvBnAct'>,\n",
+       "    block_sizes=[64, 128, 256, 512],\n",
+       "    layers=[2, 2, 2, 2],\n",
+       "    norm=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>,\n",
+       "    act_fn=ReLU(inplace=True),\n",
+       "    pool=AvgPool2d(kernel_size=2, stride=2, padding=0),\n",
+       "    expansion=1,\n",
+       "    groups=1,\n",
+       "    dw=False,\n",
+       "    div_groups=None,\n",
+       "    sa=False,\n",
+       "    se=False,\n",
+       "    bn_1st=True,\n",
+       "    zero_bn=True,\n",
+       "    stem_stride_on=0,\n",
+       "    stem_sizes=[32, 32, 64],\n",
+       "    stem_pool=MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),\n",
+       "    stem_bn_end=False\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mCfgMC\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mname\u001b[0m=\u001b[32m'MC'\u001b[0m,\n", + " \u001b[33min_chans\u001b[0m=\u001b[1;36m3\u001b[0m,\n", + " \u001b[33mnum_classes\u001b[0m=\u001b[1;36m1000\u001b[0m,\n", + " \u001b[33mblock\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mclass\u001b[0m\u001b[39m \u001b[0m\u001b[32m'model_constructor.model_constructor.ResBlock'\u001b[0m\u001b[1m>\u001b[0m,\n", + " \u001b[33mconv_layer\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mclass\u001b[0m\u001b[39m \u001b[0m\u001b[32m'model_constructor.layers.ConvBnAct'\u001b[0m\u001b[1m>\u001b[0m,\n", + " \u001b[33mblock_sizes\u001b[0m=\u001b[1m[\u001b[0m\u001b[1;36m64\u001b[0m, \u001b[1;36m128\u001b[0m, \u001b[1;36m256\u001b[0m, \u001b[1;36m512\u001b[0m\u001b[1m]\u001b[0m,\n", + " \u001b[33mlayers\u001b[0m=\u001b[1m[\u001b[0m\u001b[1;36m2\u001b[0m, \u001b[1;36m2\u001b[0m, \u001b[1;36m2\u001b[0m, \u001b[1;36m2\u001b[0m\u001b[1m]\u001b[0m,\n", + " \u001b[33mnorm\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mclass\u001b[0m\u001b[39m \u001b[0m\u001b[32m'torch.nn.modules.batchnorm.BatchNorm2d'\u001b[0m\u001b[1m>\u001b[0m,\n", + " \u001b[33mact_fn\u001b[0m=\u001b[1;35mReLU\u001b[0m\u001b[1m(\u001b[0m\u001b[33minplace\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", + " \u001b[33mpool\u001b[0m=\u001b[1;35mAvgPool2d\u001b[0m\u001b[1m(\u001b[0m\u001b[33mkernel_size\u001b[0m=\u001b[1;36m2\u001b[0m, \u001b[33mstride\u001b[0m=\u001b[1;36m2\u001b[0m, \u001b[33mpadding\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m,\n", + " \u001b[33mexpansion\u001b[0m=\u001b[1;36m1\u001b[0m,\n", + " \u001b[33mgroups\u001b[0m=\u001b[1;36m1\u001b[0m,\n", + " \u001b[33mdw\u001b[0m=\u001b[3;91mFalse\u001b[0m,\n", + " \u001b[33mdiv_groups\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33msa\u001b[0m=\u001b[3;91mFalse\u001b[0m,\n", + " \u001b[33mse\u001b[0m=\u001b[3;91mFalse\u001b[0m,\n", + " \u001b[33mbn_1st\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33mzero_bn\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33mstem_stride_on\u001b[0m=\u001b[1;36m0\u001b[0m,\n", + " \u001b[33mstem_sizes\u001b[0m=\u001b[1m[\u001b[0m\u001b[1;36m32\u001b[0m, \u001b[1;36m32\u001b[0m, \u001b[1;36m64\u001b[0m\u001b[1m]\u001b[0m,\n", + " \u001b[33mstem_pool\u001b[0m=\u001b[1;35mMaxPool2d\u001b[0m\u001b[1m(\u001b[0m\u001b[33mkernel_size\u001b[0m=\u001b[1;36m3\u001b[0m, \u001b[33mstride\u001b[0m=\u001b[1;36m2\u001b[0m, \u001b[33mpadding\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mdilation\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mceil_mode\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m,\n", + " \u001b[33mstem_bn_end\u001b[0m=\u001b[3;91mFalse\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "cfg = CfgMC()\n", + "print(cfg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can create constructor from config:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MC constructor\n", + " in_chans: 3, num_classes: 1000\n", + " expansion: 1, groups: 1, dw: False, div_groups: None\n", + " sa: False, se: False\n", + " stem sizes: [3, 32, 32, 64], stride on 0\n", + " body sizes [64, 128, 256, 512]\n", + " layers: [2, 2, 2, 2]\n" + ] + } + ], + "source": [ + "mc = ModelConstructor.from_cfg(cfg)\n", + "mc.print_cfg()" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -786,13 +992,7 @@ { "data": { "text/plain": [ - "MxResNet constructor\n", - " in_chans: 3, num_classes: 1000\n", - " expansion: 1, groups: 1, dw: False, div_groups: None\n", - " sa: False, se: False\n", - " stem sizes: [3, 32, 64, 64], stride on 0\n", - " body sizes [64, 128, 256, 512]\n", - " layers: [2, 2, 2, 2]" + "ModelConstructor(name='MxResNet', in_chans=3, num_classes=1000, block=, conv_layer=, block_sizes=[64, 128, 256, 512], layers=[2, 2, 2, 2], norm=, act_fn=Mish(), pool=AvgPool2d(kernel_size=2, stride=2, padding=0), expansion=1, groups=1, dw=False, div_groups=None, sa=False, se=False, bn_1st=True, zero_bn=True, stem_stride_on=0, stem_sizes=[3, 32, 64, 64], stem_pool=MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False), stem_bn_end=False)" ] }, "execution_count": null, @@ -804,6 +1004,13 @@ "mc" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here is model: " + ] + }, { "cell_type": "code", "execution_count": null, @@ -997,7 +1204,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### MXResNet50" + "## MXResNet50" ] }, { @@ -1034,23 +1241,66 @@ "outputs": [ { "data": { + "text/html": [ + "
ModelConstructor(\n",
+       "    name='mxresnet50',\n",
+       "    in_chans=3,\n",
+       "    num_classes=1000,\n",
+       "    block=<class 'model_constructor.model_constructor.ResBlock'>,\n",
+       "    conv_layer=<class 'model_constructor.layers.ConvBnAct'>,\n",
+       "    block_sizes=[64, 128, 256, 512],\n",
+       "    layers=[3, 4, 6, 3],\n",
+       "    norm=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>,\n",
+       "    act_fn=Mish(),\n",
+       "    pool=AvgPool2d(kernel_size=2, stride=2, padding=0),\n",
+       "    expansion=4,\n",
+       "    groups=1,\n",
+       "    dw=False,\n",
+       "    div_groups=None,\n",
+       "    sa=False,\n",
+       "    se=False,\n",
+       "    bn_1st=True,\n",
+       "    zero_bn=True,\n",
+       "    stem_stride_on=0,\n",
+       "    stem_sizes=[3, 32, 64, 64],\n",
+       "    stem_pool=MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),\n",
+       "    stem_bn_end=False\n",
+       ")\n",
+       "
\n" + ], "text/plain": [ - "mxresnet50 constructor\n", - " in_chans: 3, num_classes: 1000\n", - " expansion: 4, groups: 1, dw: False, div_groups: None\n", - " sa: False, se: False\n", - " stem sizes: [3, 32, 64, 64], stride on 0\n", - " body sizes [64, 128, 256, 512]\n", - " layers: [3, 4, 6, 3]" + "\u001b[1;35mModelConstructor\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mname\u001b[0m=\u001b[32m'mxresnet50'\u001b[0m,\n", + " \u001b[33min_chans\u001b[0m=\u001b[1;36m3\u001b[0m,\n", + " \u001b[33mnum_classes\u001b[0m=\u001b[1;36m1000\u001b[0m,\n", + " \u001b[33mblock\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mclass\u001b[0m\u001b[39m \u001b[0m\u001b[32m'model_constructor.model_constructor.ResBlock'\u001b[0m\u001b[1m>\u001b[0m,\n", + " \u001b[33mconv_layer\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mclass\u001b[0m\u001b[39m \u001b[0m\u001b[32m'model_constructor.layers.ConvBnAct'\u001b[0m\u001b[1m>\u001b[0m,\n", + " \u001b[33mblock_sizes\u001b[0m=\u001b[1m[\u001b[0m\u001b[1;36m64\u001b[0m, \u001b[1;36m128\u001b[0m, \u001b[1;36m256\u001b[0m, \u001b[1;36m512\u001b[0m\u001b[1m]\u001b[0m,\n", + " \u001b[33mlayers\u001b[0m=\u001b[1m[\u001b[0m\u001b[1;36m3\u001b[0m, \u001b[1;36m4\u001b[0m, \u001b[1;36m6\u001b[0m, \u001b[1;36m3\u001b[0m\u001b[1m]\u001b[0m,\n", + " \u001b[33mnorm\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mclass\u001b[0m\u001b[39m \u001b[0m\u001b[32m'torch.nn.modules.batchnorm.BatchNorm2d'\u001b[0m\u001b[1m>\u001b[0m,\n", + " \u001b[33mact_fn\u001b[0m=\u001b[1;35mMish\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m,\n", + " \u001b[33mpool\u001b[0m=\u001b[1;35mAvgPool2d\u001b[0m\u001b[1m(\u001b[0m\u001b[33mkernel_size\u001b[0m=\u001b[1;36m2\u001b[0m, \u001b[33mstride\u001b[0m=\u001b[1;36m2\u001b[0m, \u001b[33mpadding\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m,\n", + " \u001b[33mexpansion\u001b[0m=\u001b[1;36m4\u001b[0m,\n", + " \u001b[33mgroups\u001b[0m=\u001b[1;36m1\u001b[0m,\n", + " \u001b[33mdw\u001b[0m=\u001b[3;91mFalse\u001b[0m,\n", + " \u001b[33mdiv_groups\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33msa\u001b[0m=\u001b[3;91mFalse\u001b[0m,\n", + " \u001b[33mse\u001b[0m=\u001b[3;91mFalse\u001b[0m,\n", + " \u001b[33mbn_1st\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33mzero_bn\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33mstem_stride_on\u001b[0m=\u001b[1;36m0\u001b[0m,\n", + " \u001b[33mstem_sizes\u001b[0m=\u001b[1m[\u001b[0m\u001b[1;36m3\u001b[0m, \u001b[1;36m32\u001b[0m, \u001b[1;36m64\u001b[0m, \u001b[1;36m64\u001b[0m\u001b[1m]\u001b[0m,\n", + " \u001b[33mstem_pool\u001b[0m=\u001b[1;35mMaxPool2d\u001b[0m\u001b[1m(\u001b[0m\u001b[33mkernel_size\u001b[0m=\u001b[1;36m3\u001b[0m, \u001b[33mstride\u001b[0m=\u001b[1;36m2\u001b[0m, \u001b[33mpadding\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mdilation\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mceil_mode\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m,\n", + " \u001b[33mstem_bn_end\u001b[0m=\u001b[3;91mFalse\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" ] }, - "execution_count": null, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "mc" + "print(mc)" ] }, { @@ -1140,6 +1390,25 @@ "model = mc()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Or create with config:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mc = ModelConstructor.from_cfg(\n", + " CfgMC(name=\"MxResNet\", act_fn=Mish(), layers=[3,4,6,3], expansion=4, stem_sizes=[32,64,64])\n", + ")\n", + "model = mc()" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -1185,26 +1454,23 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "YaResNet constructor\n", - " in_chans: 3, num_classes: 1000\n", - " expansion: 4, groups: 1, dw: False, div_groups: None\n", - " sa: False, se: False\n", - " stem sizes: [3, 32, 64, 64], stride on 0\n", - " body sizes [64, 128, 256, 512]\n", - " layers: [3, 4, 6, 3]" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "YaResNet constructor\n", + " in_chans: 3, num_classes: 1000\n", + " expansion: 4, groups: 1, dw: False, div_groups: None\n", + " sa: False, se: False\n", + " stem sizes: [3, 32, 64, 64], stride on 0\n", + " body sizes [64, 128, 256, 512]\n", + " layers: [3, 4, 6, 3]\n" + ] } ], "source": [ "#collapse_output\n", "mc.name = 'YaResNet'\n", - "mc" + "mc.print_cfg()" ] }, { @@ -1257,64 +1523,6 @@ "#collapse_output\n", "mc.body.l_1.bl_0" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## First version" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "First version, it deprecated, but still here for compatibility." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from model_constructor.net import Net" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mc = Net()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Net constructor\n", - " c_in: 3, c_out: 1000\n", - " expansion: 1, groups: 1, dw: False, div_groups: None\n", - " sa: False, se: False\n", - " stem sizes: [3, 32, 32, 64], stride on 0\n", - " body sizes [64, 128, 256, 512]\n", - " layers: [2, 2, 2, 2]" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "mc" - ] } ], "metadata": { diff --git a/README.md b/README.md index 0838606..971afa4 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # model_constructor -> Constructor to create pytorch model. +> Constructor to create pytorch model. ## Install @@ -16,21 +16,22 @@ First import constructor class, then create model constructor object. Now you can change every part of model. + ```python from model_constructor import ModelConstructor ``` -```python -mc = ModelConstructor() -``` ```python -mc +mc = ModelConstructor() ``` +Check base parameters with `print_cfg` method: - +```python +mc.print_cfg() +``` MC constructor in_chans: 3, num_classes: 1000 expansion: 1, groups: 1, dw: False, div_groups: None @@ -40,20 +41,52 @@ mc layers: [2, 2, 2, 2] +ModelConstructor based on dataclass. Repr will show all parameters. +Better look at it with `rich.print` -Now we have model constructor, default setting as xresnet18. And we can get model after call it. ```python -#collapse_output -model = mc() -model +from rich import print +print(mc) ``` -
- Output details ... - +
ModelConstructor(
+    name='MC',
+    in_chans=3,
+    num_classes=1000,
+    block=<class 'model_constructor.model_constructor.ResBlock'>,
+    conv_layer=<class 'model_constructor.layers.ConvBnAct'>,
+    block_sizes=[64, 128, 256, 512],
+    layers=[2, 2, 2, 2],
+    norm=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>,
+    act_fn=ReLU(inplace=True),
+    pool=AvgPool2d(kernel_size=2, stride=2, padding=0),
+    expansion=1,
+    groups=1,
+    dw=False,
+    div_groups=None,
+    sa=False,
+    se=False,
+    bn_1st=True,
+    zero_bn=True,
+    stem_stride_on=0,
+    stem_sizes=[3, 32, 32, 64],
+    stem_pool=MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
+    stem_bn_end=False
+)
+
+ + + +Now we have model constructor, default setting as xresnet18. And we can get model after call it. + +```python + +model = mc() +model +``` Sequential( MC (stem): Sequential( @@ -226,28 +259,22 @@ model -
- If you want to change model, just change constructor parameters. Lets create xresnet50. + ```python mc.expansion = 4 mc.layers = [3,4,6,3] ``` -Now we can look at model body and if we call constructor - we have pytorch model! +Now we can look at model parts - stem, body, head. + ```python -#collapse_output + mc.body ``` -
- Output details ... - - - - Sequential( (l_0): Sequential( (bl_0): ResBlock( @@ -592,68 +619,118 @@ mc.body -
+## Create constructor from config. + +Alternative we can create config first and than create constructor from it. + + +```python +from model_constructor import CfgMC +``` + + +```python +cfg = CfgMC() +print(cfg) +``` + + +
CfgMC(
+    name='MC',
+    in_chans=3,
+    num_classes=1000,
+    block=<class 'model_constructor.model_constructor.ResBlock'>,
+    conv_layer=<class 'model_constructor.layers.ConvBnAct'>,
+    block_sizes=[64, 128, 256, 512],
+    layers=[2, 2, 2, 2],
+    norm=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>,
+    act_fn=ReLU(inplace=True),
+    pool=AvgPool2d(kernel_size=2, stride=2, padding=0),
+    expansion=1,
+    groups=1,
+    dw=False,
+    div_groups=None,
+    sa=False,
+    se=False,
+    bn_1st=True,
+    zero_bn=True,
+    stem_stride_on=0,
+    stem_sizes=[32, 32, 64],
+    stem_pool=MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
+    stem_bn_end=False
+)
+
+ + + +Now we can create constructor from config: + + +```python +mc = ModelConstructor.from_cfg(cfg) +mc.print_cfg() +``` + MC constructor + in_chans: 3, num_classes: 1000 + expansion: 1, groups: 1, dw: False, div_groups: None + sa: False, se: False + stem sizes: [3, 32, 32, 64], stride on 0 + body sizes [64, 128, 256, 512] + layers: [2, 2, 2, 2] + ## More modification. Main purpose of this module - fast and easy modify model. -And here is the link to more modification to beat Imagenette leaderboard with add MaxBlurPool and modification to ResBlock https://github.com/ayasyrev/imagenette_experiments/blob/master/ResnetTrick_create_model_fit.ipynb +And here is the link to more modification to beat Imagenette leaderboard with add MaxBlurPool and modification to ResBlock [notebook](https://github.com/ayasyrev/imagenette_experiments/blob/master/ResnetTrick_create_model_fit.ipynb) -But now lets create model as mxresnet50 from fastai forums tread https://forums.fast.ai/t/how-we-beat-the-5-epoch-imagewoof-leaderboard-score-some-new-techniques-to-consider +But now lets create model as mxresnet50 from [fastai forums tread](https://forums.fast.ai/t/how-we-beat-the-5-epoch-imagewoof-leaderboard-score-some-new-techniques-to-consider) Lets create mxresnet constructor. + ```python mc = ModelConstructor(name='MxResNet') ``` Then lets modify stem. + ```python mc.stem_sizes = [3,32,64,64] ``` Now lets change activation function to Mish. -Here is link to forum discussion https://forums.fast.ai/t/meet-mish-new-activation-function-possible-successor-to-relu +Here is link to [forum discussion](https://forums.fast.ai/t/meet-mish-new-activation-function-possible-successor-to-relu) We'v got Mish is in model_constructor.activations, but from pytorch 1.9 take it from torch: + ```python # from model_constructor.activations import Mish from torch.nn import Mish ``` + ```python mc.act_fn = Mish() ``` + ```python mc ``` + ModelConstructor(name='MxResNet', in_chans=3, num_classes=1000, block=, conv_layer=, block_sizes=[64, 128, 256, 512], layers=[2, 2, 2, 2], norm=, act_fn=Mish(), pool=AvgPool2d(kernel_size=2, stride=2, padding=0), expansion=1, groups=1, dw=False, div_groups=None, sa=False, se=False, bn_1st=True, zero_bn=True, stem_stride_on=0, stem_sizes=[3, 32, 64, 64], stem_pool=MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False), stem_bn_end=False) - - MxResNet constructor - in_chans: 3, num_classes: 1000 - expansion: 1, groups: 1, dw: False, div_groups: None - sa: False, se: False - stem sizes: [3, 32, 64, 64], stride on 0 - body sizes [64, 128, 256, 512] - layers: [2, 2, 2, 2] - +Here is model: ```python -#collapse_output + mc() ``` -
- Output details ... - - - - Sequential( MxResNet (stem): Sequential( @@ -826,12 +903,11 @@ mc() -
- -### MXResNet50 +## MXResNet50 Now lets make MxResNet50 + ```python mc.expansion = 4 mc.layers = [3,4,6,3] @@ -842,33 +918,45 @@ Now we have mxresnet50 constructor. We can inspect every parts of it. And after call it we got model. + ```python -mc +print(mc) ``` +
ModelConstructor(
+    name='mxresnet50',
+    in_chans=3,
+    num_classes=1000,
+    block=<class 'model_constructor.model_constructor.ResBlock'>,
+    conv_layer=<class 'model_constructor.layers.ConvBnAct'>,
+    block_sizes=[64, 128, 256, 512],
+    layers=[3, 4, 6, 3],
+    norm=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>,
+    act_fn=Mish(),
+    pool=AvgPool2d(kernel_size=2, stride=2, padding=0),
+    expansion=4,
+    groups=1,
+    dw=False,
+    div_groups=None,
+    sa=False,
+    se=False,
+    bn_1st=True,
+    zero_bn=True,
+    stem_stride_on=0,
+    stem_sizes=[3, 32, 64, 64],
+    stem_pool=MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
+    stem_bn_end=False
+)
+
- mxresnet50 constructor - in_chans: 3, num_classes: 1000 - expansion: 4, groups: 1, dw: False, div_groups: None - sa: False, se: False - stem sizes: [3, 32, 64, 64], stride on 0 - body sizes [64, 128, 256, 512] - layers: [3, 4, 6, 3] - ```python -#collapse_output + mc.stem.conv_1 ``` -
- Output details ... - - - - ConvBnAct( (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) @@ -877,18 +965,11 @@ mc.stem.conv_1 -
```python -#collapse_output + mc.body.l_0.bl_0 ``` -
- Output details ... - - - - ResBlock( (convs): Sequential( (conv_0): ConvBnAct( @@ -917,40 +998,46 @@ mc.body.l_0.bl_0 -
- We can get model direct way: + ```python mc = ModelConstructor(name="MxResNet", act_fn=Mish(), layers=[3,4,6,3], expansion=4, stem_sizes=[32,64,64]) model = mc() ``` +Or create with config: + + +```python +mc = ModelConstructor.from_cfg( + CfgMC(name="MxResNet", act_fn=Mish(), layers=[3,4,6,3], expansion=4, stem_sizes=[32,64,64]) +) +model = mc() +``` + ## YaResNet Now lets change Resblock to YaResBlock (Yet another ResNet, former NewResBlock) is in lib from version 0.1.0 + ```python from model_constructor.yaresnet import YaResBlock ``` + ```python mc.block = YaResBlock ``` That all. Now we have YaResNet constructor + ```python -#collapse_output + mc.name = 'YaResNet' -mc +mc.print_cfg() ``` -
- Output details ... - - - - YaResNet constructor in_chans: 3, num_classes: 1000 expansion: 4, groups: 1, dw: False, div_groups: None @@ -960,21 +1047,13 @@ mc layers: [3, 4, 6, 3] - -
- Let see what we have. + ```python -#collapse_output + mc.body.l_1.bl_0 ``` -
- Output details ... - - - - YaResBlock( (reduce): AvgPool2d(kernel_size=2, stride=2, padding=0) (convs): Sequential( @@ -1001,34 +1080,3 @@ mc.body.l_1.bl_0 ) - -
- -## First version - -First version, it deprecated, but still here for compatibility. - -```python -from model_constructor.net import Net -``` - -```python -mc = Net() -``` - -```python -mc -``` - - - - - Net constructor - c_in: 3, c_out: 1000 - expansion: 1, groups: 1, dw: False, div_groups: None - sa: False, se: False - stem sizes: [3, 32, 32, 64], stride on 0 - body sizes [64, 128, 256, 512] - layers: [2, 2, 2, 2] - - diff --git a/docs/00_ModelConstructor.md b/docs/00_ModelConstructor.md index 6deb7c4..8c013e7 100644 --- a/docs/00_ModelConstructor.md +++ b/docs/00_ModelConstructor.md @@ -85,10 +85,11 @@ block ```python + block = ResBlock(2,64,64,act_fn=nn.LeakyReLU(), bn_1st=False) block ``` -???+ done "output" +??? done "output"
ResBlock(
       (convs): Sequential(
         (conv_0): ConvBnAct(
@@ -113,10 +114,11 @@ block
 
 
 ```python
-block = ResBlock(2, 32, 64, dw=True)
+
+lock = ResBlock(2, 32, 64, dw=True)
 block
 ```
-???+ done "output"  
+??? done "output"  
     
ResBlock(
       (convs): Sequential(
         (conv_0): ConvBnAct(
@@ -152,10 +154,11 @@ pool = nn.AvgPool2d(2, ceil_mode=True)
 
 
 ```python
+
 block = ResBlock(2, 32, 64, stride=2, dw=True, pool=pool)
 block
 ```
-???+ done "output"  
+??? done "output"  
     
ResBlock(
       (convs): Sequential(
         (conv_0): ConvBnAct(
@@ -216,9 +219,9 @@ block
         (se): SEModule(
           (squeeze): AdaptiveAvgPool2d(output_size=1)
           (excitation): Sequential(
-            (fc_reduce): Linear(in_features=128, out_features=8, bias=True)
+            (reduce): Linear(in_features=128, out_features=8, bias=True)
             (se_act): ReLU(inplace=True)
-            (fc_expand): Linear(in_features=8, out_features=128, bias=True)
+            (expand): Linear(in_features=8, out_features=128, bias=True)
             (se_gate): Sigmoid()
           )
         )
@@ -261,9 +264,9 @@ block
         (se): SEModule(
           (squeeze): AdaptiveAvgPool2d(output_size=1)
           (excitation): Sequential(
-            (fc_reduce): Linear(in_features=128, out_features=8, bias=True)
+            (reduce): Linear(in_features=128, out_features=8, bias=True)
             (se_act): ReLU(inplace=True)
-            (fc_expand): Linear(in_features=8, out_features=128, bias=True)
+            (expand): Linear(in_features=8, out_features=128, bias=True)
             (se_gate): Sigmoid()
           )
         )
@@ -283,12 +286,302 @@ block
 
 
 
+## Stem, Body, Layer, Head
+
+Helper functions to create stem, body and head of model from config.  
+Returns `nn.Sequential`.  
+
+
+```python
+from model_constructor.model_constructor import CfgMC, make_stem, make_body, make_layer, make_head
+from rich import print
+```
+
+
+```python
+cfg = CfgMC()
+print(cfg)
+```
+
+
+
CfgMC(
+    name='MC',
+    in_chans=3,
+    num_classes=1000,
+    block=<class 'model_constructor.model_constructor.ResBlock'>,
+    conv_layer=<class 'model_constructor.layers.ConvBnAct'>,
+    block_sizes=[64, 128, 256, 512],
+    layers=[2, 2, 2, 2],
+    norm=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>,
+    act_fn=ReLU(inplace=True),
+    pool=AvgPool2d(kernel_size=2, stride=2, padding=0),
+    expansion=1,
+    groups=1,
+    dw=False,
+    div_groups=None,
+    sa=False,
+    se=False,
+    bn_1st=True,
+    zero_bn=True,
+    stem_stride_on=0,
+    stem_sizes=[32, 32, 64],
+    stem_pool=MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
+    stem_bn_end=False
+)
+
+ + + +### Stem + + +```python + +stem = make_stem(cfg) +stem +``` +??? done "output" +
Sequential(
+      (conv_0): ConvBnAct(
+        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
+        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+        (act_fn): ReLU(inplace=True)
+      )
+      (conv_1): ConvBnAct(
+        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+        (act_fn): ReLU(inplace=True)
+      )
+      (stem_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
+    )
+
+
+
+### Layer
+
+`make_layer` need `layer_num` argument - number of layer.
+
+
+```python
+
+layer = make_layer(cfg, layer_num=0)
+layer
+```
+??? done "output"  
+    
Sequential(
+      (bl_0): ResBlock(
+        (convs): Sequential(
+          (conv_0): ConvBnAct(
+            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+            (act_fn): ReLU(inplace=True)
+          )
+          (conv_1): ConvBnAct(
+            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+          )
+        )
+        (act_fn): ReLU(inplace=True)
+      )
+      (bl_1): ResBlock(
+        (convs): Sequential(
+          (conv_0): ConvBnAct(
+            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+            (act_fn): ReLU(inplace=True)
+          )
+          (conv_1): ConvBnAct(
+            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+          )
+        )
+        (act_fn): ReLU(inplace=True)
+      )
+    )
+
+
+
+### Body
+
+`make_body` needs `cfg._make_layer` initialized. As default - `make_layer`.
+
+
+```python
+
+cfg._make_layer = make_layer
+body = make_body(cfg)
+body
+```
+??? done "output"  
+    
Sequential(
+      (l_0): Sequential(
+        (bl_0): ResBlock(
+          (convs): Sequential(
+            (conv_0): ConvBnAct(
+              (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+              (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+              (act_fn): ReLU(inplace=True)
+            )
+            (conv_1): ConvBnAct(
+              (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+              (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+            )
+          )
+          (act_fn): ReLU(inplace=True)
+        )
+        (bl_1): ResBlock(
+          (convs): Sequential(
+            (conv_0): ConvBnAct(
+              (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+              (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+              (act_fn): ReLU(inplace=True)
+            )
+            (conv_1): ConvBnAct(
+              (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+              (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+            )
+          )
+          (act_fn): ReLU(inplace=True)
+        )
+      )
+      (l_1): Sequential(
+        (bl_0): ResBlock(
+          (convs): Sequential(
+            (conv_0): ConvBnAct(
+              (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
+              (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+              (act_fn): ReLU(inplace=True)
+            )
+            (conv_1): ConvBnAct(
+              (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+              (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+            )
+          )
+          (id_conv): Sequential(
+            (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
+            (id_conv): ConvBnAct(
+              (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
+              (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+            )
+          )
+          (act_fn): ReLU(inplace=True)
+        )
+        (bl_1): ResBlock(
+          (convs): Sequential(
+            (conv_0): ConvBnAct(
+              (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+              (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+              (act_fn): ReLU(inplace=True)
+            )
+            (conv_1): ConvBnAct(
+              (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+              (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+            )
+          )
+          (act_fn): ReLU(inplace=True)
+        )
+      )
+      (l_2): Sequential(
+        (bl_0): ResBlock(
+          (convs): Sequential(
+            (conv_0): ConvBnAct(
+              (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
+              (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+              (act_fn): ReLU(inplace=True)
+            )
+            (conv_1): ConvBnAct(
+              (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+              (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+            )
+          )
+          (id_conv): Sequential(
+            (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
+            (id_conv): ConvBnAct(
+              (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
+              (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+            )
+          )
+          (act_fn): ReLU(inplace=True)
+        )
+        (bl_1): ResBlock(
+          (convs): Sequential(
+            (conv_0): ConvBnAct(
+              (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+              (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+              (act_fn): ReLU(inplace=True)
+            )
+            (conv_1): ConvBnAct(
+              (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+              (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+            )
+          )
+          (act_fn): ReLU(inplace=True)
+        )
+      )
+      (l_3): Sequential(
+        (bl_0): ResBlock(
+          (convs): Sequential(
+            (conv_0): ConvBnAct(
+              (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
+              (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+              (act_fn): ReLU(inplace=True)
+            )
+            (conv_1): ConvBnAct(
+              (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+              (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+            )
+          )
+          (id_conv): Sequential(
+            (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
+            (id_conv): ConvBnAct(
+              (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
+              (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+            )
+          )
+          (act_fn): ReLU(inplace=True)
+        )
+        (bl_1): ResBlock(
+          (convs): Sequential(
+            (conv_0): ConvBnAct(
+              (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+              (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+              (act_fn): ReLU(inplace=True)
+            )
+            (conv_1): ConvBnAct(
+              (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+              (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+            )
+          )
+          (act_fn): ReLU(inplace=True)
+        )
+      )
+    )
+
+
+
+## Head
+
+
+```python
+
+head = make_head(cfg)
+head
+```
+??? done "output"  
+    
Sequential(
+      (pool): AdaptiveAvgPool2d(output_size=1)
+      (flat): Flatten(start_dim=1, end_dim=-1)
+      (fc): Linear(in_features=512, out_features=1000, bias=True)
+    )
+
+
+
 ## Model Constructor.
 
 
 ```python
 mc  = ModelConstructor()
-mc
+mc.print_cfg()
 ```
 ???+ done "output"  
     
MC constructor
@@ -301,7 +594,6 @@ mc
 
 
 
-
 ```python
 
 mc.stem
diff --git a/docs/index.md b/docs/index.md
index 8906e80..e0ec490 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -26,9 +26,11 @@ from model_constructor import ModelConstructor
 mc = ModelConstructor()
 ```
 
+Check base parameters with `print_cfg` method:
+
 
 ```python
-mc
+mc.print_cfg()
 ```
 ???+ done "output"  
     
MC constructor
@@ -40,6 +42,43 @@ mc
       layers: [2, 2, 2, 2]
 
 
+ModelConstructor based on dataclass. Repr will show all parameters.  
+Better look at it with `rich.print`  
+
+
+```python
+from rich import print
+print(mc)
+```
+
+
+
ModelConstructor(
+    name='MC',
+    in_chans=3,
+    num_classes=1000,
+    block=<class 'model_constructor.model_constructor.ResBlock'>,
+    conv_layer=<class 'model_constructor.layers.ConvBnAct'>,
+    block_sizes=[64, 128, 256, 512],
+    layers=[2, 2, 2, 2],
+    norm=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>,
+    act_fn=ReLU(inplace=True),
+    pool=AvgPool2d(kernel_size=2, stride=2, padding=0),
+    expansion=1,
+    groups=1,
+    dw=False,
+    div_groups=None,
+    sa=False,
+    se=False,
+    bn_1st=True,
+    zero_bn=True,
+    stem_stride_on=0,
+    stem_sizes=[3, 32, 32, 64],
+    stem_pool=MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
+    stem_bn_end=False
+)
+
+ + Now we have model constructor, default setting as xresnet18. And we can get model after call it. @@ -231,7 +270,7 @@ mc.expansion = 4 mc.layers = [3,4,6,3] ``` -Now we can look at model body and if we call constructor - we have pytorch model! +Now we can look at model parts - stem, body, head. ```python @@ -583,6 +622,67 @@ mc.body +## Create constructor from config. + +Alternative we can create config first and than create constructor from it. + + +```python +from model_constructor import CfgMC +``` + + +```python +cfg = CfgMC() +print(cfg) +``` + + +
CfgMC(
+    name='MC',
+    in_chans=3,
+    num_classes=1000,
+    block=<class 'model_constructor.model_constructor.ResBlock'>,
+    conv_layer=<class 'model_constructor.layers.ConvBnAct'>,
+    block_sizes=[64, 128, 256, 512],
+    layers=[2, 2, 2, 2],
+    norm=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>,
+    act_fn=ReLU(inplace=True),
+    pool=AvgPool2d(kernel_size=2, stride=2, padding=0),
+    expansion=1,
+    groups=1,
+    dw=False,
+    div_groups=None,
+    sa=False,
+    se=False,
+    bn_1st=True,
+    zero_bn=True,
+    stem_stride_on=0,
+    stem_sizes=[32, 32, 64],
+    stem_pool=MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
+    stem_bn_end=False
+)
+
+ + + +Now we can create constructor from config: + + +```python +mc = ModelConstructor.from_cfg(cfg) +mc.print_cfg() +``` +???+ done "output" +
MC constructor
+      in_chans: 3, num_classes: 1000
+      expansion: 1, groups: 1, dw: False, div_groups: None
+      sa: False, se: False
+      stem sizes: [3, 32, 32, 64], stride on 0
+      body sizes [64, 128, 256, 512]
+      layers: [2, 2, 2, 2]
+
+
 ## More modification.
 
 Main purpose of this module - fast and easy modify model.
@@ -625,15 +725,11 @@ mc.act_fn = Mish()
 mc
 ```
 ???+ done "output"  
-    
MxResNet constructor
-      in_chans: 3, num_classes: 1000
-      expansion: 1, groups: 1, dw: False, div_groups: None
-      sa: False, se: False
-      stem sizes: [3, 32, 64, 64], stride on 0
-      body sizes [64, 128, 256, 512]
-      layers: [2, 2, 2, 2]
+    
ModelConstructor(name='MxResNet', in_chans=3, num_classes=1000, block=, conv_layer=, block_sizes=[64, 128, 256, 512], layers=[2, 2, 2, 2], norm=, act_fn=Mish(), pool=AvgPool2d(kernel_size=2, stride=2, padding=0), expansion=1, groups=1, dw=False, div_groups=None, sa=False, se=False, bn_1st=True, zero_bn=True, stem_stride_on=0, stem_sizes=[3, 32, 64, 64], stem_pool=MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False), stem_bn_end=False)
+
 
 
+Here is model:  
 
 
 ```python
@@ -813,7 +909,7 @@ mc()
 
 
 
-### MXResNet50
+## MXResNet50
 
 Now lets make MxResNet50
 
@@ -830,16 +926,35 @@ And after call it we got model.
 
 
 ```python
-mc
+print(mc)
 ```
-???+ done "output"  
-    
mxresnet50 constructor
-      in_chans: 3, num_classes: 1000
-      expansion: 4, groups: 1, dw: False, div_groups: None
-      sa: False, se: False
-      stem sizes: [3, 32, 64, 64], stride on 0
-      body sizes [64, 128, 256, 512]
-      layers: [3, 4, 6, 3]
+
+
+
ModelConstructor(
+    name='mxresnet50',
+    in_chans=3,
+    num_classes=1000,
+    block=<class 'model_constructor.model_constructor.ResBlock'>,
+    conv_layer=<class 'model_constructor.layers.ConvBnAct'>,
+    block_sizes=[64, 128, 256, 512],
+    layers=[3, 4, 6, 3],
+    norm=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>,
+    act_fn=Mish(),
+    pool=AvgPool2d(kernel_size=2, stride=2, padding=0),
+    expansion=4,
+    groups=1,
+    dw=False,
+    div_groups=None,
+    sa=False,
+    se=False,
+    bn_1st=True,
+    zero_bn=True,
+    stem_stride_on=0,
+    stem_sizes=[3, 32, 64, 64],
+    stem_pool=MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
+    stem_bn_end=False
+)
+
@@ -899,6 +1014,16 @@ mc = ModelConstructor(name="MxResNet", act_fn=Mish(), layers=[3,4,6,3], expansio model = mc() ``` +Or create with config: + + +```python +mc = ModelConstructor.from_cfg( + CfgMC(name="MxResNet", act_fn=Mish(), layers=[3,4,6,3], expansion=4, stem_sizes=[32,64,64]) +) +model = mc() +``` + ## YaResNet Now lets change Resblock to YaResBlock (Yet another ResNet, former NewResBlock) is in lib from version 0.1.0 @@ -919,7 +1044,7 @@ That all. Now we have YaResNet constructor ```python mc.name = 'YaResNet' -mc +mc.print_cfg() ``` ??? done "output"
YaResNet constructor
@@ -931,7 +1056,6 @@ mc
       layers: [3, 4, 6, 3]
 
 
-
 Let see what we have.
 
 
@@ -966,32 +1090,3 @@ mc.body.l_1.bl_0
     )
 
 
-
-## First version
-
-First version, it deprecated, but still here for compatibility.
-
-
-```python
-from model_constructor.net import Net
-```
-
-
-```python
-mc = Net()
-```
-
-
-```python
-mc
-```
-???+ done "output"  
-    
Net constructor
-      c_in: 3, c_out: 1000
-      expansion: 1, groups: 1, dw: False, div_groups: None
-      sa: False, se: False
-      stem sizes: [3, 32, 32, 64], stride on 0
-      body sizes [64, 128, 256, 512]
-      layers: [2, 2, 2, 2]
-
-
diff --git a/src/model_constructor/__init__.py b/src/model_constructor/__init__.py
index c93a30b..d720178 100644
--- a/src/model_constructor/__init__.py
+++ b/src/model_constructor/__init__.py
@@ -1,4 +1,4 @@
 from model_constructor.convmixer import ConvMixer  # noqa F401
-from model_constructor.model_constructor import ModelConstructor, ResBlock # noqa F401
+from model_constructor.model_constructor import ModelConstructor, ResBlock, CfgMC # noqa F401
                                                  
 from model_constructor.version import __version__  # noqa F401
diff --git a/src/model_constructor/model_constructor.py b/src/model_constructor/model_constructor.py
index 09922e5..52064af 100644
--- a/src/model_constructor/model_constructor.py
+++ b/src/model_constructor/model_constructor.py
@@ -1,6 +1,8 @@
+from dataclasses import dataclass, field, asdict
+
 from collections import OrderedDict
-from functools import partial
-from typing import Callable, List, Type, Union
+# from functools import partial
+from typing import Callable, List, Optional, Type, Union
 
 import torch.nn as nn
 
@@ -12,24 +14,14 @@
     "act_fn",
     "ResBlock",
     "ModelConstructor",
-    "xresnet34",
-    "xresnet50",
+    # "xresnet34",
+    # "xresnet50",
 ]
 
 
 act_fn = nn.ReLU(inplace=True)
 
 
-def init_cnn(module: nn.Module):
-    "Init module - kaiming_normal for Conv2d and 0 for biases."
-    if getattr(module, "bias", None) is not None:
-        nn.init.constant_(module.bias, 0)  # type: ignore
-    if isinstance(module, (nn.Conv2d, nn.Linear)):
-        nn.init.kaiming_normal_(module.weight)
-    for layer in module.children():
-        init_cnn(layer)
-
-
 class ResBlock(nn.Module):
     """Universal Resnet block. Basic block if expansion is 1, otherwise is Bottleneck."""
 
@@ -130,10 +122,55 @@ def forward(self, x):
         return self.act_fn(self.convs(x) + identity)
 
 
-def _make_stem(self):
-    stem = [
+@dataclass
+class CfgMC:
+    """Model constructor Config. As default - xresnet18"""
+
+    name: str = "MC"
+    in_chans: int = 3
+    num_classes: int = 1000
+    block: Type[nn.Module] = ResBlock
+    conv_layer: Type[nn.Module] = ConvBnAct
+    block_sizes: List[int] = field(default_factory=lambda: [64, 128, 256, 512])
+    layers: List[int] = field(default_factory=lambda: [2, 2, 2, 2])
+    norm: Type[nn.Module] = nn.BatchNorm2d
+    act_fn: nn.Module = nn.ReLU(inplace=True)
+    pool: nn.Module = nn.AvgPool2d(2, ceil_mode=True)
+    expansion: int = 1
+    groups: int = 1
+    dw: bool = False
+    div_groups: Union[int, None] = None
+    sa: Union[bool, int, Type[nn.Module]] = False
+    se: Union[bool, int, Type[nn.Module]] = False
+    se_module = None
+    se_reduction = None
+    bn_1st: bool = True
+    zero_bn: bool = True
+    stem_stride_on: int = 0
+    stem_sizes: List[int] = field(default_factory=lambda: [32, 32, 64])
+    stem_pool: Union[nn.Module, None] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  # type: ignore
+    stem_bn_end: bool = False
+    _init_cnn: Optional[Callable[[nn.Module], None]] = field(repr=False, default=None)
+    _make_stem: Optional[Callable] = field(repr=False, default=None)
+    _make_layer: Optional[Callable] = field(repr=False, default=None)
+    _make_body: Optional[Callable] = field(repr=False, default=None)
+    _make_head: Optional[Callable] = field(repr=False, default=None)
+
+
+def init_cnn(module: nn.Module):
+    "Init module - kaiming_normal for Conv2d and 0 for biases."
+    if getattr(module, "bias", None) is not None:
+        nn.init.constant_(module.bias, 0)  # type: ignore
+    if isinstance(module, (nn.Conv2d, nn.Linear)):
+        nn.init.kaiming_normal_(module.weight)
+    for layer in module.children():
+        init_cnn(layer)
+
+
+def make_stem(self: CfgMC) -> nn.Sequential:
+    stem: List[tuple[str, nn.Module]] = [
         (f"conv_{i}", self.conv_layer(
-            self.stem_sizes[i],
+            self.stem_sizes[i],  # type: ignore
             self.stem_sizes[i + 1],
             stride=2 if i == self.stem_stride_on else 1,
             bn_layer=(not self.stem_bn_end)
@@ -147,39 +184,38 @@ def _make_stem(self):
     if self.stem_pool:
         stem.append(("stem_pool", self.stem_pool))
     if self.stem_bn_end:
-        stem.append(("norm", self.norm(self.stem_sizes[-1])))
+        stem.append(("norm", self.norm(self.stem_sizes[-1])))  # type: ignore
     return nn.Sequential(OrderedDict(stem))
 
 
-def _make_layer(self, layer_num: int) -> nn.Module:
+def make_layer(cfg: CfgMC, layer_num: int) -> nn.Sequential:
     #  expansion, in_channels, out_channels, blocks, stride, sa):
     # if no pool on stem - stride = 2 for first layer block in body
-    stride = 1 if self.stem_pool and layer_num == 0 else 2
-    num_blocks = self.layers[layer_num]
+    stride = 1 if cfg.stem_pool and layer_num == 0 else 2
+    num_blocks = cfg.layers[layer_num]
+    block_chs = [cfg.stem_sizes[-1] // cfg.expansion] + cfg.block_sizes
     return nn.Sequential(
         OrderedDict(
             [
                 (
                     f"bl_{block_num}",
-                    self.block(
-                        self.expansion,
-                        self.block_sizes[layer_num]
-                        if block_num == 0
-                        else self.block_sizes[layer_num + 1],
-                        self.block_sizes[layer_num + 1],
+                    cfg.block(
+                        cfg.expansion,  # type: ignore
+                        block_chs[layer_num] if block_num == 0 else block_chs[layer_num + 1],
+                        block_chs[layer_num + 1],
                         stride if block_num == 0 else 1,
-                        sa=self.sa
+                        sa=cfg.sa
                         if (block_num == num_blocks - 1) and layer_num == 0
                         else None,
-                        conv_layer=self.conv_layer,
-                        act_fn=self.act_fn,
-                        pool=self.pool,
-                        zero_bn=self.zero_bn,
-                        bn_1st=self.bn_1st,
-                        groups=self.groups,
-                        div_groups=self.div_groups,
-                        dw=self.dw,
-                        se=self.se,
+                        conv_layer=cfg.conv_layer,
+                        act_fn=cfg.act_fn,
+                        pool=cfg.pool,
+                        zero_bn=cfg.zero_bn,
+                        bn_1st=cfg.bn_1st,
+                        groups=cfg.groups,
+                        div_groups=cfg.div_groups,
+                        dw=cfg.dw,
+                        se=cfg.se,
                     ),
                 )
                 for block_num in range(num_blocks)
@@ -188,160 +224,96 @@ def _make_layer(self, layer_num: int) -> nn.Module:
     )
 
 
-def _make_body(self):
+def make_body(cfg: CfgMC) -> nn.Sequential:
     return nn.Sequential(
         OrderedDict(
             [
                 (
                     f"l_{layer_num}",
-                    self._make_layer(self, layer_num)
+                    cfg._make_layer(cfg, layer_num)  # type: ignore
                 )
-                for layer_num in range(len(self.layers))
+                for layer_num in range(len(cfg.layers))
             ]
         )
     )
 
 
-def _make_head(self):
+def make_head(cfg: CfgMC) -> nn.Sequential:
     head = [
         ("pool", nn.AdaptiveAvgPool2d(1)),
         ("flat", nn.Flatten()),
-        ("fc", nn.Linear(self.block_sizes[-1] * self.expansion, self.num_classes)),
+        ("fc", nn.Linear(cfg.block_sizes[-1] * cfg.expansion, cfg.num_classes)),
     ]
     return nn.Sequential(OrderedDict(head))
 
 
-class ModelConstructor:
+@dataclass
+class ModelConstructor(CfgMC):
     """Model constructor. As default - xresnet18"""
 
-    def __init__(
-        self,
-        name: str = "MC",
-        in_chans: int = 3,
-        num_classes: int = 1000,
-        block=ResBlock,
-        conv_layer=ConvBnAct,
-        block_sizes: List[int] = [64, 128, 256, 512],
-        layers: List[int] = [2, 2, 2, 2],
-        norm: Type[nn.Module] = nn.BatchNorm2d,
-        act_fn: nn.Module = nn.ReLU(inplace=True),
-        pool: nn.Module = nn.AvgPool2d(2, ceil_mode=True),
-        expansion: int = 1,
-        groups: int = 1,
-        dw: bool = False,
-        div_groups: Union[int, None] = None,
-        sa: Union[bool, int, Type[nn.Module]] = False,
-        se: Union[bool, int, Type[nn.Module]] = False,
-        se_module=None,
-        se_reduction=None,
-        bn_1st: bool = True,
-        zero_bn: bool = True,
-        stem_stride_on: int = 0,
-        stem_sizes: List[int] = [32, 32, 64],
-        stem_pool: Union[Type[nn.Module], None] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1),  # type: ignore
-        stem_bn_end: bool = False,
-        _init_cnn: Callable = init_cnn,
-        _make_stem: Callable = _make_stem,
-        _make_layer: Callable = _make_layer,
-        _make_body: Callable = _make_body,
-        _make_head: Callable = _make_head,
-    ):
-        super().__init__()
-        # se can be bool, int (0, 1) or nn.Module
-        # se_module - deprecated. Leaved for warning and checks.
-        # if stem_pool is False - no pool at stem
-
-        self.name = name
-        self.in_chans = in_chans
-        self.num_classes = num_classes
-        self.block = block
-        self.conv_layer = conv_layer
-        self._block_sizes = block_sizes
-        self.layers = layers
-        self.norm = norm
-        self.act_fn = act_fn
-        self.pool = pool
-        self.expansion = expansion
-        self.groups = groups
-        self.dw = dw
-        self.div_groups = div_groups
-        # se_module
-        # se_reduction
-        self.bn_1st = bn_1st
-        self.zero_bn = zero_bn
-        self.stem_stride_on = stem_stride_on
-        self.stem_pool = stem_pool
-        self.stem_bn_end = stem_bn_end
-        self._init_cnn = _init_cnn
-        self._make_stem = _make_stem
-        self._make_layer = _make_layer
-        self._make_body = _make_body
-        self._make_head = _make_head
-
-        # params = locals()
-        # del params['self']
-        # self.__dict__ = params
-
-        # self._block_sizes = params['block_sizes']
-        self.stem_sizes = stem_sizes
+    def __post_init__(self):
+        if self._init_cnn is None:
+            self._init_cnn = init_cnn
+        if self._make_stem is None:
+            self._make_stem = make_stem
+        if self._make_layer is None:
+            self._make_layer = make_layer
+        if self._make_body is None:
+            self._make_body = make_body
+        if self._make_head is None:
+            self._make_head = make_head
+
         if self.stem_sizes[0] != self.in_chans:
             self.stem_sizes = [self.in_chans] + self.stem_sizes
-        self.se = se
-        if self.se:
-            if type(self.se) in (bool, int):  # if se=1 or se=True
-                self.se = SEModule
-            else:
-                self.se = se  # TODO add check issubclass or isinstance of nn.Module
-        self.sa = sa
-        if self.sa:  # if sa=1 or sa=True
-            if type(self.sa) in (bool, int):
-                self.sa = SimpleSelfAttention  # default: ks=1, sym=sym
-            else:
-                self.sa = sa
-        if se_module or se_reduction:  # pragma: no cover
+        if self.se and isinstance(self.se, (bool, int)):  # if se=1 or se=True
+            self.se = SEModule
+        if self.sa and isinstance(self.sa, (bool, int)):  # if sa=1 or sa=True
+            self.sa = SimpleSelfAttention  # default: ks=1, sym=sym
+        if self.se_module or self.se_reduction:  # pragma: no cover
             print(
                 "Deprecated. Pass se_module as se argument, se_reduction as arg to se."
             )  # add deprecation warning.
 
-    @property
-    def block_sizes(self):
-        return [self.stem_sizes[-1] // self.expansion] + self._block_sizes
-
     @property
     def stem(self):
-        return self._make_stem(self)
+        return self._make_stem(self)  # type: ignore
 
     @property
     def head(self):
-        return self._make_head(self)
+        return self._make_head(self)  # type: ignore
 
     @property
     def body(self):
-        return self._make_body(self)
+        return self._make_body(self)  # type: ignore
+
+    @classmethod
+    def from_cfg(cls, cfg: CfgMC):
+        return cls(**asdict(cfg))
 
     def __call__(self):
         model = nn.Sequential(
             OrderedDict([("stem", self.stem), ("body", self.body), ("head", self.head)])
         )
-        self._init_cnn(model)
+        self._init_cnn(model)  # type: ignore
         model.extra_repr = lambda: f"{self.name}"
         return model
 
-    def __repr__(self):
-        return (
+    def print_cfg(self):
+        print(
             f"{self.name} constructor\n"
             f"  in_chans: {self.in_chans}, num_classes: {self.num_classes}\n"
             f"  expansion: {self.expansion}, groups: {self.groups}, dw: {self.dw}, div_groups: {self.div_groups}\n"
             f"  sa: {self.sa}, se: {self.se}\n"
             f"  stem sizes: {self.stem_sizes}, stride on {self.stem_stride_on}\n"
-            f"  body sizes {self._block_sizes}\n"
+            f"  body sizes {self.block_sizes}\n"
             f"  layers: {self.layers}"
         )
 
 
-xresnet34 = partial(
-    ModelConstructor, name="xresnet34", expansion=1, layers=[3, 4, 6, 3]
+xresnet34 = ModelConstructor.from_cfg(
+    CfgMC(name="xresnet34", expansion=1, layers=[3, 4, 6, 3])
 )
-xresnet50 = partial(
-    ModelConstructor, name="xresnet34", expansion=4, layers=[3, 4, 6, 3]
+
+xresnet50 = ModelConstructor.from_cfg(
+    CfgMC(name="xresnet34", expansion=4, layers=[3, 4, 6, 3])
 )
diff --git a/tests/test_mc.py b/tests/test_mc.py
index e92ccdc..60d739a 100644
--- a/tests/test_mc.py
+++ b/tests/test_mc.py
@@ -11,11 +11,15 @@ def test_MC():
     """test ModelConstructor"""
     img_size = 16
     mc = ModelConstructor()
-    assert "MC constructor" in str(mc)
+    assert "name='MC'" in str(mc)
     model = mc()
     xb = torch.randn(bs_test, 3, img_size, img_size)
     pred = model(xb)
     assert pred.shape == torch.Size([bs_test, 1000])
+    mc.expansion = 2
+    model = mc()
+    pred = model(xb)
+    assert pred.shape == torch.Size([bs_test, 1000])
     num_classes = 10
     mc.num_classes = num_classes
     mc.se = SEModule