diff --git a/Nbs/01_layers.ipynb b/Nbs/01_layers.ipynb index 3b95728..d52275f 100644 --- a/Nbs/01_layers.ipynb +++ b/Nbs/01_layers.ipynb @@ -2,261 +2,261 @@ "cells": [ { "cell_type": "markdown", + "metadata": {}, "source": [ "# Layers\n", "\n", "> Basic layers for constructor." - ], - "metadata": {} + ] }, { "cell_type": "code", "execution_count": 1, - "source": [ - "#hide\n", - "from nbdev.showdoc import show_doc\n", - "from IPython.display import Markdown, display" - ], - "outputs": [], "metadata": { "ExecuteTime": { "end_time": "2021-08-11T16:34:43.007417Z", "start_time": "2021-08-11T16:34:42.829843Z" } - } + }, + "outputs": [], + "source": [ + "#hide\n", + "from nbdev.showdoc import show_doc\n", + "from IPython.display import Markdown, display" + ] }, { "cell_type": "code", "execution_count": 2, - "source": [ - "#hide\n", - "import torch.nn as nn\n", - "import torch\n", - "from collections import OrderedDict" - ], - "outputs": [], "metadata": { "ExecuteTime": { "end_time": "2021-08-11T16:34:45.654722Z", "start_time": "2021-08-11T16:34:45.085125Z" } - } + }, + "outputs": [], + "source": [ + "#hide\n", + "import torch.nn as nn\n", + "import torch\n", + "from collections import OrderedDict" + ] }, { "cell_type": "code", "execution_count": 3, + "metadata": {}, + "outputs": [], "source": [ "# hide\n", "def print_doc(func_name):\n", " doc = show_doc(func_name, title_level=4, disp=False)\n", " display(Markdown(doc))" - ], - "outputs": [], - "metadata": {} + ] }, { "cell_type": "markdown", + "metadata": {}, "source": [ "## Flatten layer" - ], - "metadata": {} + ] }, { "cell_type": "code", "execution_count": 4, - "source": [ - "#hide\n", - "from model_constructor.layers import Flatten, noop, Noop" - ], - "outputs": [], "metadata": { "ExecuteTime": { "end_time": "2021-08-11T16:34:45.996818Z", "start_time": "2021-08-11T16:34:45.982654Z" } - } + }, + "outputs": [], + "source": [ + "#hide\n", + "from model_constructor.layers import Flatten, noop, Noop" + ] }, { "cell_type": "code", "execution_count": 5, - "source": [ - "#hide_input\n", - "# show_doc(Flatten, title_level=4)\n", - "# flatten_doc = show_doc(Flatten, title_level=4, disp=False)\n", - "# Markdown(flatten_doc)\n", - "\n", - "print_doc(Flatten)" - ], + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T16:34:54.998866Z", + "start_time": "2021-08-11T16:34:54.984142Z" + } + }, "outputs": [ { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/markdown": [ "

class Flatten[source]

\n", "\n", "> Flatten() :: `Module`\n", "\n", "flat x to vector" + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" } ], - "metadata": { - "ExecuteTime": { - "end_time": "2021-08-11T16:34:54.998866Z", - "start_time": "2021-08-11T16:34:54.984142Z" - } - } + "source": [ + "#hide_input\n", + "# show_doc(Flatten, title_level=4)\n", + "# flatten_doc = show_doc(Flatten, title_level=4, disp=False)\n", + "# Markdown(flatten_doc)\n", + "\n", + "print_doc(Flatten)" + ] }, { "cell_type": "markdown", + "metadata": {}, "source": [ "## Noop - dummy func and module." - ], - "metadata": {} + ] }, { "cell_type": "code", "execution_count": 6, - "source": [ - "#hide_input\n", - "# show_doc(noop)\n", - "# doc = show_doc(noop, disp=False)\n", - "# display(Markdown(doc))\n", - "print_doc(noop)" - ], + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:55.119562Z", + "start_time": "2021-08-11T15:44:55.103484Z" + } + }, "outputs": [ { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/markdown": [ "

noop[source]

\n", "\n", "> noop(**`x`**)\n", "\n", "Dummy func. Return input" + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" } ], - "metadata": { - "ExecuteTime": { - "end_time": "2021-08-11T15:44:55.119562Z", - "start_time": "2021-08-11T15:44:55.103484Z" - } - } + "source": [ + "#hide_input\n", + "# show_doc(noop)\n", + "# doc = show_doc(noop, disp=False)\n", + "# display(Markdown(doc))\n", + "print_doc(noop)" + ] }, { "cell_type": "code", "execution_count": 7, - "source": [ - "#hide_input\n", - "# show_doc(Noop, title_level=4)\n", - "# doc = show_doc(Noop, title_level=4, disp=False)\n", - "# Markdown(doc)\n", - "print_doc(Noop)" - ], + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T16:35:14.225068Z", + "start_time": "2021-08-11T16:35:14.208899Z" + } + }, "outputs": [ { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/markdown": [ "

class Noop[source]

\n", "\n", "> Noop() :: `Module`\n", "\n", "Dummy module" + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" } ], - "metadata": { - "ExecuteTime": { - "end_time": "2021-08-11T16:35:14.225068Z", - "start_time": "2021-08-11T16:35:14.208899Z" - } - } + "source": [ + "#hide_input\n", + "# show_doc(Noop, title_level=4)\n", + "# doc = show_doc(Noop, title_level=4, disp=False)\n", + "# Markdown(doc)\n", + "print_doc(Noop)" + ] }, { "cell_type": "markdown", + "metadata": {}, "source": [ "## ConvLayer - nn.module" - ], - "metadata": {} + ] }, { "cell_type": "code", "execution_count": 8, - "source": [ - "#hide\n", - "from model_constructor.layers import ConvLayer" - ], - "outputs": [], "metadata": { "ExecuteTime": { "end_time": "2021-08-11T16:35:21.435560Z", "start_time": "2021-08-11T16:35:21.430089Z" } - } + }, + "outputs": [], + "source": [ + "#hide\n", + "from model_constructor.layers import ConvLayer" + ] }, { "cell_type": "code", "execution_count": 9, - "source": [ - "# hide_input\n", - "# show_doc(ConvLayer, title_level=4)\n", - "# doc = show_doc(ConvLayer, title_level=4, disp=False)\n", - "# Markdown(doc)\n", - "print_doc(ConvLayer)" - ], + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T16:35:29.238261Z", + "start_time": "2021-08-11T16:35:29.229345Z" + } + }, "outputs": [ { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/markdown": [ "

class ConvLayer[source]

\n", "\n", "> ConvLayer(**`ni`**, **`nf`**, **`ks`**=*`3`*, **`stride`**=*`1`*, **`act`**=*`True`*, **`act_fn`**=*`ReLU(inplace=True)`*, **`bn_layer`**=*`True`*, **`bn_1st`**=*`True`*, **`zero_bn`**=*`False`*, **`padding`**=*`None`*, **`bias`**=*`False`*, **`groups`**=*`1`*, **\\*\\*`kwargs`**) :: `Sequential`\n", "\n", "Basic conv layers block" + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" } ], - "metadata": { - "ExecuteTime": { - "end_time": "2021-08-11T16:35:29.238261Z", - "start_time": "2021-08-11T16:35:29.229345Z" - } - } + "source": [ + "# hide_input\n", + "# show_doc(ConvLayer, title_level=4)\n", + "# doc = show_doc(ConvLayer, title_level=4, disp=False)\n", + "# Markdown(doc)\n", + "print_doc(ConvLayer)" + ] }, { "cell_type": "code", "execution_count": 10, - "source": [ - "#collapse_output\n", - "conv_layer = ConvLayer(32, 64, act=False)\n", - "conv_layer" - ], + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:56.242450Z", + "start_time": "2021-08-11T15:44:56.231392Z" + } + }, "outputs": [ { - "output_type": "execute_result", "data": { "text/plain": [ "ConvLayer(\n", @@ -265,28 +265,28 @@ ")" ] }, + "execution_count": 10, "metadata": {}, - "execution_count": 10 + "output_type": "execute_result" } ], - "metadata": { - "ExecuteTime": { - "end_time": "2021-08-11T15:44:56.242450Z", - "start_time": "2021-08-11T15:44:56.231392Z" - } - } + "source": [ + "#collapse_output\n", + "conv_layer = ConvLayer(32, 64, act=False)\n", + "conv_layer" + ] }, { "cell_type": "code", "execution_count": 11, - "source": [ - "#collapse_output\n", - "conv_layer = ConvLayer(32, 64, bn_layer=False)\n", - "conv_layer" - ], + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:56.462424Z", + "start_time": "2021-08-11T15:44:56.450423Z" + } + }, "outputs": [ { - "output_type": "execute_result", "data": { "text/plain": [ "ConvLayer(\n", @@ -295,28 +295,28 @@ ")" ] }, + "execution_count": 11, "metadata": {}, - "execution_count": 11 + "output_type": "execute_result" } ], - "metadata": { - "ExecuteTime": { - "end_time": "2021-08-11T15:44:56.462424Z", - "start_time": "2021-08-11T15:44:56.450423Z" - } - } + "source": [ + "#collapse_output\n", + "conv_layer = ConvLayer(32, 64, bn_layer=False)\n", + "conv_layer" + ] }, { "cell_type": "code", "execution_count": 12, - "source": [ - "#collapse_output\n", - "conv_layer = ConvLayer(32, 64, bn_1st=True)\n", - "conv_layer" - ], + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:56.667542Z", + "start_time": "2021-08-11T15:44:56.660962Z" + } + }, "outputs": [ { - "output_type": "execute_result", "data": { "text/plain": [ "ConvLayer(\n", @@ -326,28 +326,28 @@ ")" ] }, + "execution_count": 12, "metadata": {}, - "execution_count": 12 + "output_type": "execute_result" } ], - "metadata": { - "ExecuteTime": { - "end_time": "2021-08-11T15:44:56.667542Z", - "start_time": "2021-08-11T15:44:56.660962Z" - } - } + "source": [ + "#collapse_output\n", + "conv_layer = ConvLayer(32, 64, bn_1st=True)\n", + "conv_layer" + ] }, { "cell_type": "code", "execution_count": 13, - "source": [ - "#collapse_output\n", - "conv_layer = ConvLayer(32, 64, bn_1st=True, act_fn=nn.LeakyReLU())\n", - "conv_layer" - ], + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:56.905543Z", + "start_time": "2021-08-11T15:44:56.900817Z" + } + }, "outputs": [ { - "output_type": "execute_result", "data": { "text/plain": [ "ConvLayer(\n", @@ -357,20 +357,27 @@ ")" ] }, + "execution_count": 13, "metadata": {}, - "execution_count": 13 + "output_type": "execute_result" } ], - "metadata": { - "ExecuteTime": { - "end_time": "2021-08-11T15:44:56.905543Z", - "start_time": "2021-08-11T15:44:56.900817Z" - } - } - }, + "source": [ + "#collapse_output\n", + "conv_layer = ConvLayer(32, 64, bn_1st=True, act_fn=nn.LeakyReLU())\n", + "conv_layer" + ] + }, { "cell_type": "code", "execution_count": 14, + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:57.154044Z", + "start_time": "2021-08-11T15:44:57.115266Z" + } + }, + "outputs": [], "source": [ "#hide\n", "bs = 8\n", @@ -378,317 +385,903 @@ "y = conv_layer(xb)\n", "y.shape\n", "assert y.shape == torch.Size([bs, 64, 32, 32])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## SimpleSelfAttention" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "SA module from mxresnet at fastai." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "# hide\n", + "from model_constructor.layers import conv1d, SimpleSelfAttention" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "

conv1d[source]

\n", + "\n", + "> conv1d(**`ni`**:`int`, **`no`**:`int`, **`ks`**:`int`=*`1`*, **`stride`**:`int`=*`1`*, **`padding`**:`int`=*`0`*, **`bias`**:`bool`=*`False`*)\n", + "\n", + "Create and initialize a `nn.Conv1d` layer with spectral normalization." + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } ], + "source": [ + "#hide_input\n", + "# doc = show_doc(conv1d, title_level=4, disp=False)\n", + "# Markdown(doc)\n", + "print_doc(conv1d)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "

class SimpleSelfAttention[source]

\n", + "\n", + "> SimpleSelfAttention(**`n_in`**:`int`, **`ks`**=*`1`*, **`sym`**=*`False`*) :: `Module`\n", + "\n", + "SimpleSelfAttention module. # noqa W291\n", + "Adapted from SelfAttention layer at \n", + "https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py \n", + "Inspired by https://arxiv.org/pdf/1805.08318.pdf " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# doc = show_doc(SimpleSelfAttention, title_level=4, disp=False)\n", + "# Markdown(doc)\n", + "print_doc(SimpleSelfAttention)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## SEModule" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, "outputs": [], + "source": [ + "#hide\n", + "from model_constructor.layers import SEModule" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "

class SEModule[source]

\n", + "\n", + "> SEModule(**`channels`**, **`reduction`**=*`16`*, **`rd_channels`**=*`None`*, **`rd_max`**=*`False`*, **`se_layer`**=*`Linear`*, **`act_fn`**=*`ReLU(inplace=True)`*, **`use_bias`**=*`True`*, **`gate`**=*`Sigmoid`*) :: `Module`\n", + "\n", + "se block" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# hide_input\n", + "# Markdown(show_doc(SEBlock, title_level=4))\n", + "print_doc(SEModule)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, "metadata": { "ExecuteTime": { - "end_time": "2021-08-11T15:44:57.154044Z", - "start_time": "2021-08-11T15:44:57.115266Z" + "end_time": "2021-08-11T15:44:58.343217Z", + "start_time": "2021-08-11T15:44:58.332669Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "SEModule(\n", + " (squeeze): AdaptiveAvgPool2d(output_size=1)\n", + " (excitation): Sequential(\n", + " (fc_reduce): Linear(in_features=128, out_features=8, bias=True)\n", + " (se_act): ReLU(inplace=True)\n", + " (fc_expand): Linear(in_features=8, out_features=128, bias=True)\n", + " (se_gate): Sigmoid()\n", + " )\n", + ")" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#collapse_output\n", + "se_block = SEModule(128)\n", + "se_block" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:58.570883Z", + "start_time": "2021-08-11T15:44:58.530689Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 128, 32, 32])\n" + ] + } + ], + "source": [ + "#hide\n", + "bs = 8\n", + "xb = torch.randn(bs, 128, 32, 32)\n", + "y = se_block(xb)\n", + "print(y.shape)\n", + "assert y.shape == torch.Size([bs, 128, 32, 32]), f\"size\"" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:58.343217Z", + "start_time": "2021-08-11T15:44:58.332669Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "SEModule(\n", + " (squeeze): AdaptiveAvgPool2d(output_size=1)\n", + " (excitation): Sequential(\n", + " (fc_reduce): Linear(in_features=128, out_features=4, bias=True)\n", + " (se_act): ReLU(inplace=True)\n", + " (fc_expand): Linear(in_features=4, out_features=128, bias=True)\n", + " (se_gate): Sigmoid()\n", + " )\n", + ")" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#collapse_output\n", + "se_block = SEModule(128, reduction=32)\n", + "se_block" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:58.570883Z", + "start_time": "2021-08-11T15:44:58.530689Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 128, 32, 32])\n" + ] } - } + ], + "source": [ + "#hide\n", + "bs = 8\n", + "xb = torch.randn(bs, 128, 32, 32)\n", + "y = se_block(xb)\n", + "print(y.shape)\n", + "assert y.shape == torch.Size([bs, 128, 32, 32]), f\"size\"" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:58.343217Z", + "start_time": "2021-08-11T15:44:58.332669Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "SEModule(\n", + " (squeeze): AdaptiveAvgPool2d(output_size=1)\n", + " (excitation): Sequential(\n", + " (fc_reduce): Linear(in_features=128, out_features=32, bias=True)\n", + " (se_act): ReLU(inplace=True)\n", + " (fc_expand): Linear(in_features=32, out_features=128, bias=True)\n", + " (se_gate): Sigmoid()\n", + " )\n", + ")" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#collapse_output\n", + "se_block = SEModule(128, rd_channels=32)\n", + "se_block" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:58.570883Z", + "start_time": "2021-08-11T15:44:58.530689Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 128, 32, 32])\n" + ] + } + ], + "source": [ + "#hide\n", + "bs = 8\n", + "xb = torch.randn(bs, 128, 32, 32)\n", + "y = se_block(xb)\n", + "print(y.shape)\n", + "assert y.shape == torch.Size([bs, 128, 32, 32]), f\"size\"" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:58.343217Z", + "start_time": "2021-08-11T15:44:58.332669Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "SEModule(\n", + " (squeeze): AdaptiveAvgPool2d(output_size=1)\n", + " (excitation): Sequential(\n", + " (fc_reduce): Linear(in_features=128, out_features=32, bias=True)\n", + " (se_act): ReLU(inplace=True)\n", + " (fc_expand): Linear(in_features=32, out_features=128, bias=True)\n", + " (se_gate): Sigmoid()\n", + " )\n", + ")" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#collapse_output\n", + "se_block = SEModule(128, reduction=4, rd_channels=16, rd_max=True)\n", + "se_block" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:58.570883Z", + "start_time": "2021-08-11T15:44:58.530689Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 128, 32, 32])\n" + ] + } + ], + "source": [ + "#hide\n", + "bs = 8\n", + "xb = torch.randn(bs, 128, 32, 32)\n", + "y = se_block(xb)\n", + "print(y.shape)\n", + "assert y.shape == torch.Size([bs, 128, 32, 32]), f\"size\"" + ] }, { "cell_type": "markdown", + "metadata": {}, "source": [ - "## SimpleSelfAttention" + "## SEModuleConv" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "

class SEModuleConv[source]

\n", + "\n", + "> SEModuleConv(**`channels`**, **`reduction`**=*`16`*, **`rd_channels`**=*`None`*, **`rd_max`**=*`False`*, **`se_layer`**=*`Conv2d`*, **`act_fn`**=*`ReLU(inplace=True)`*, **`use_bias`**=*`True`*, **`gate`**=*`Sigmoid`*) :: `Module`\n", + "\n", + "se block with conv on excitation" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# hide_input\n", + "from model_constructor.layers import SEModuleConv\n", + "print_doc(SEModuleConv)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:58.343217Z", + "start_time": "2021-08-11T15:44:58.332669Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "SEModuleConv(\n", + " (squeeze): AdaptiveAvgPool2d(output_size=1)\n", + " (excitation): Sequential(\n", + " (conv_reduce): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1))\n", + " (se_act): ReLU(inplace=True)\n", + " (conv_expand): Conv2d(8, 128, kernel_size=(1, 1), stride=(1, 1))\n", + " (gate): Sigmoid()\n", + " )\n", + ")" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#collapse_output\n", + "se_block = SEModuleConv(128)\n", + "se_block" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:58.570883Z", + "start_time": "2021-08-11T15:44:58.530689Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 128, 32, 32])\n" + ] + } + ], + "source": [ + "#hide\n", + "bs = 8\n", + "xb = torch.randn(bs, 128, 32, 32)\n", + "y = se_block(xb)\n", + "print(y.shape)\n", + "assert y.shape == torch.Size([bs, 128, 32, 32]), f\"size\"" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:58.343217Z", + "start_time": "2021-08-11T15:44:58.332669Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "SEModuleConv(\n", + " (squeeze): AdaptiveAvgPool2d(output_size=1)\n", + " (excitation): Sequential(\n", + " (conv_reduce): Conv2d(128, 4, kernel_size=(1, 1), stride=(1, 1))\n", + " (se_act): ReLU(inplace=True)\n", + " (conv_expand): Conv2d(4, 128, kernel_size=(1, 1), stride=(1, 1))\n", + " (gate): Sigmoid()\n", + " )\n", + ")" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#collapse_output\n", + "se_block = SEModuleConv(128, reduction=32)\n", + "se_block" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:58.570883Z", + "start_time": "2021-08-11T15:44:58.530689Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 128, 32, 32])\n" + ] + } + ], + "source": [ + "#hide\n", + "bs = 8\n", + "xb = torch.randn(bs, 128, 32, 32)\n", + "y = se_block(xb)\n", + "print(y.shape)\n", + "assert y.shape == torch.Size([bs, 128, 32, 32]), f\"size\"" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:58.343217Z", + "start_time": "2021-08-11T15:44:58.332669Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "SEModuleConv(\n", + " (squeeze): AdaptiveAvgPool2d(output_size=1)\n", + " (excitation): Sequential(\n", + " (conv_reduce): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1))\n", + " (se_act): ReLU(inplace=True)\n", + " (conv_expand): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1))\n", + " (gate): Sigmoid()\n", + " )\n", + ")" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } ], - "metadata": {} - }, - { - "cell_type": "markdown", "source": [ - "SA module from mxresnet at fastai." - ], - "metadata": {} + "#collapse_output\n", + "se_block = SEModuleConv(128, rd_channels=32)\n", + "se_block" + ] }, { "cell_type": "code", - "execution_count": 15, - "source": [ - "# hide\n", - "from model_constructor.layers import conv1d, SimpleSelfAttention" + "execution_count": 40, + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:58.570883Z", + "start_time": "2021-08-11T15:44:58.530689Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 128, 32, 32])\n" + ] + } ], - "outputs": [], - "metadata": {} + "source": [ + "#hide\n", + "bs = 8\n", + "xb = torch.randn(bs, 128, 32, 32)\n", + "y = se_block(xb)\n", + "print(y.shape)\n", + "assert y.shape == torch.Size([bs, 128, 32, 32]), f\"size\"" + ] }, { "cell_type": "code", - "execution_count": 16, - "source": [ - "#hide_input\n", - "# doc = show_doc(conv1d, title_level=4, disp=False)\n", - "# Markdown(doc)\n", - "print_doc(conv1d)" - ], + "execution_count": 41, + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:58.343217Z", + "start_time": "2021-08-11T15:44:58.332669Z" + } + }, "outputs": [ { - "output_type": "display_data", "data": { "text/plain": [ - "" - ], - "text/markdown": [ - "

conv1d[source]

\n", - "\n", - "> conv1d(**`ni`**:`int`, **`no`**:`int`, **`ks`**:`int`=*`1`*, **`stride`**:`int`=*`1`*, **`padding`**:`int`=*`0`*, **`bias`**:`bool`=*`False`*)\n", - "\n", - "Create and initialize a `nn.Conv1d` layer with spectral normalization." + "SEModuleConv(\n", + " (squeeze): AdaptiveAvgPool2d(output_size=1)\n", + " (excitation): Sequential(\n", + " (conv_reduce): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1))\n", + " (se_act): ReLU(inplace=True)\n", + " (conv_expand): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1))\n", + " (gate): Sigmoid()\n", + " )\n", + ")" ] }, - "metadata": {} + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" } ], - "metadata": {} + "source": [ + "#collapse_output\n", + "se_block = SEModuleConv(128, reduction=4, rd_channels=16, rd_max=True)\n", + "se_block" + ] }, { "cell_type": "code", - "execution_count": 17, - "source": [ - "# doc = show_doc(SimpleSelfAttention, title_level=4, disp=False)\n", - "# Markdown(doc)\n", - "print_doc(SimpleSelfAttention)" - ], + "execution_count": 42, + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:58.570883Z", + "start_time": "2021-08-11T15:44:58.530689Z" + } + }, "outputs": [ { - "output_type": "display_data", - "data": { - "text/plain": [ - "" - ], - "text/markdown": [ - "

class SimpleSelfAttention[source]

\n", - "\n", - "> SimpleSelfAttention(**`n_in`**:`int`, **`ks`**=*`1`*, **`sym`**=*`False`*) :: `Module`\n", - "\n", - "SimpleSelfAttention module. \n", - "Adapted from SelfAttention layer at \n", - "https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py \n", - "Inspired by https://arxiv.org/pdf/1805.08318.pdf " - ] - }, - "metadata": {} + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 128, 32, 32])\n" + ] } ], - "metadata": {} + "source": [ + "#hide\n", + "bs = 8\n", + "xb = torch.randn(bs, 128, 32, 32)\n", + "y = se_block(xb)\n", + "print(y.shape)\n", + "assert y.shape == torch.Size([bs, 128, 32, 32]), f\"size\"" + ] }, { "cell_type": "markdown", + "metadata": {}, "source": [ - "## SE Block" - ], - "metadata": {} + "## SEBlock" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First version of SE block, leaved for compatibility." + ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 43, + "metadata": {}, + "outputs": [], "source": [ "#hide\n", "from model_constructor.layers import SEBlock" - ], - "outputs": [], - "metadata": {} + ] }, { "cell_type": "code", - "execution_count": 26, - "source": [ - "# hide_input\n", - "# Markdown(show_doc(SEBlock, title_level=4))\n", - "print_doc(SEBlock)" - ], + "execution_count": 44, + "metadata": {}, "outputs": [ { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/markdown": [ "

class SEBlock[source]

\n", "\n", "> SEBlock(**`c`**, **`r`**=*`16`*) :: `Module`\n", "\n", "se block" + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" } ], - "metadata": {} + "source": [ + "# hide_input\n", + "# Markdown(show_doc(SEBlock, title_level=4))\n", + "print_doc(SEBlock)" + ] }, { "cell_type": "code", - "execution_count": 21, - "source": [ - "#collapse_output\n", - "se_block = SEBlock(128)\n", - "se_block" - ], + "execution_count": 45, + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:58.343217Z", + "start_time": "2021-08-11T15:44:58.332669Z" + } + }, "outputs": [ { - "output_type": "execute_result", "data": { "text/plain": [ "SEBlock(\n", " (squeeze): AdaptiveAvgPool2d(output_size=1)\n", " (excitation): Sequential(\n", - " (fc_reduce): Linear(in_features=128, out_features=8, bias=False)\n", + " (fc_reduce): Linear(in_features=128, out_features=8, bias=True)\n", " (se_act): ReLU(inplace=True)\n", - " (fc_expand): Linear(in_features=8, out_features=128, bias=False)\n", + " (fc_expand): Linear(in_features=8, out_features=128, bias=True)\n", " (sigmoid): Sigmoid()\n", " )\n", ")" ] }, + "execution_count": 45, "metadata": {}, - "execution_count": 21 + "output_type": "execute_result" } ], - "metadata": { - "ExecuteTime": { - "end_time": "2021-08-11T15:44:58.343217Z", - "start_time": "2021-08-11T15:44:58.332669Z" - } - } + "source": [ + "#collapse_output\n", + "se_block = SEBlock(128)\n", + "se_block" + ] }, { "cell_type": "code", - "execution_count": 22, - "source": [ - "#hide\n", - "bs = 8\n", - "xb = torch.randn(bs, 128, 32, 32)\n", - "y = se_block(xb)\n", - "print(y.shape)\n", - "assert y.shape == torch.Size([bs, 128, 32, 32]), f\"size\"" - ], + "execution_count": 46, + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:58.570883Z", + "start_time": "2021-08-11T15:44:58.530689Z" + } + }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "torch.Size([8, 128, 32, 32])\n" ] } ], - "metadata": { - "ExecuteTime": { - "end_time": "2021-08-11T15:44:58.570883Z", - "start_time": "2021-08-11T15:44:58.530689Z" - } - } + "source": [ + "#hide\n", + "bs = 8\n", + "xb = torch.randn(bs, 128, 32, 32)\n", + "y = se_block(xb)\n", + "print(y.shape)\n", + "assert y.shape == torch.Size([bs, 128, 32, 32]), f\"size\"" + ] }, { "cell_type": "markdown", + "metadata": {}, "source": [ "## SEBlockConv" - ], - "metadata": {} + ] }, { - "cell_type": "code", - "execution_count": 23, + "cell_type": "markdown", + "metadata": {}, "source": [ - "# hide_input\n", - "from model_constructor.layers import SEBlockConv\n", - "print_doc(SEBlockConv)" - ], + "First version of SEBlockConv, leaved for compatibility." + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, "outputs": [ { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/markdown": [ - "

class SEBlockConv[source]

\n", + "

class SEBlockConv[source]

\n", "\n", "> SEBlockConv(**`c`**, **`r`**=*`16`*) :: `Module`\n", "\n", "se block with conv on excitation" + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" } ], - "metadata": {} + "source": [ + "# hide_input\n", + "from model_constructor.layers import SEBlockConv\n", + "print_doc(SEBlockConv)" + ] }, { "cell_type": "code", - "execution_count": 24, - "source": [ - "#collapse_output\n", - "se_block = SEBlockConv(128)\n", - "se_block" - ], + "execution_count": 48, + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:59.189301Z", + "start_time": "2021-08-11T15:44:59.177022Z" + } + }, "outputs": [ { - "output_type": "execute_result", "data": { "text/plain": [ "SEBlockConv(\n", " (squeeze): AdaptiveAvgPool2d(output_size=1)\n", " (excitation): Sequential(\n", - " (conv_reduce): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (conv_reduce): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1))\n", " (se_act): ReLU(inplace=True)\n", - " (conv_expand): Conv2d(8, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (conv_expand): Conv2d(8, 128, kernel_size=(1, 1), stride=(1, 1))\n", " (sigmoid): Sigmoid()\n", " )\n", ")" ] }, + "execution_count": 48, "metadata": {}, - "execution_count": 24 + "output_type": "execute_result" } ], - "metadata": { - "ExecuteTime": { - "end_time": "2021-08-11T15:44:59.189301Z", - "start_time": "2021-08-11T15:44:59.177022Z" - } - } + "source": [ + "#collapse_output\n", + "se_block = SEBlockConv(128)\n", + "se_block" + ] }, { "cell_type": "code", - "execution_count": 25, - "source": [ - "#hide\n", - "bs = 8\n", - "xb = torch.randn(bs, 128, 32, 32)\n", - "y = se_block(xb)\n", - "print(y.shape)\n", - "assert y.shape == torch.Size([bs, 128, 32, 32]), f\"size\"" - ], + "execution_count": 49, + "metadata": { + "ExecuteTime": { + "end_time": "2021-08-11T15:44:59.406923Z", + "start_time": "2021-08-11T15:44:59.362600Z" + } + }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "torch.Size([8, 128, 32, 32])\n" ] } ], - "metadata": { - "ExecuteTime": { - "end_time": "2021-08-11T15:44:59.406923Z", - "start_time": "2021-08-11T15:44:59.362600Z" - } - } + "source": [ + "#hide\n", + "bs = 8\n", + "xb = torch.randn(bs, 128, 32, 32)\n", + "y = se_block(xb)\n", + "print(y.shape)\n", + "assert y.shape == torch.Size([bs, 128, 32, 32]), f\"size\"" + ] }, { "cell_type": "markdown", + "metadata": {}, "source": [ "## end\n", "model_constructor\n", "by ayasyrev" - ], - "metadata": {} + ] } ], "metadata": { + "interpreter": { + "hash": "460c8d17e5de1304fcc9388854d8b1e7fdd10d3c58b2d7b68fabbdff2124405d" + }, "kernelspec": { - "name": "python3", - "display_name": "Python 3.9.6 64-bit ('mc_dev': conda)" + "display_name": "Python 3.9.6 64-bit ('mc_dev': conda)", + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -700,7 +1293,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.6" + "version": "3.8.10" }, "toc": { "base_numbering": 1, @@ -719,11 +1312,8 @@ }, "toc_section_display": true, "toc_window_display": true - }, - "interpreter": { - "hash": "460c8d17e5de1304fcc9388854d8b1e7fdd10d3c58b2d7b68fabbdff2124405d" } }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/docs/layers.html b/docs/layers.html index 2e183a7..8062319 100644 --- a/docs/layers.html +++ b/docs/layers.html @@ -375,7 +375,7 @@

conv1d

class SimpleSelfAttention[source]

SimpleSelfAttention(n_in:int, ks=1, sym=False) :: Module

-

SimpleSelfAttention module.
+

SimpleSelfAttention module. # noqa W291 Adapted from SelfAttention layer at
https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py
Inspired by https://arxiv.org/pdf/1805.08318.pdf

@@ -392,7 +392,444 @@

class SimpleS
-

SE Block

+

SEModule

+
+
+
+ {% raw %} + +
+ +
+
+ +
+ + +
+

class SEModule[source]

SEModule(channels, reduction=16, rd_channels=None, rd_max=False, se_layer=Linear, act_fn=ReLU(inplace=True), use_bias=True, gate=Sigmoid) :: Module

+
+

se block

+ +
+ +
+ +
+
+ +
+ {% endraw %} + + {% raw %} + +
+
+ +
+
+
se_block = SEModule(128)
+se_block
+
+ +
+
+
+
+ + + +
+
+ +
+ + + +
+
SEModule(
+  (squeeze): AdaptiveAvgPool2d(output_size=1)
+  (excitation): Sequential(
+    (fc_reduce): Linear(in_features=128, out_features=8, bias=True)
+    (se_act): ReLU(inplace=True)
+    (fc_expand): Linear(in_features=8, out_features=128, bias=True)
+    (se_gate): Sigmoid()
+  )
+)
+
+ +
+ +
+
+ +
+
+ {% endraw %} + + {% raw %} + +
+
+ +
+
+
se_block = SEModule(128, reduction=32)
+se_block
+
+ +
+
+
+
+ + + +
+
+ +
+ + + +
+
SEModule(
+  (squeeze): AdaptiveAvgPool2d(output_size=1)
+  (excitation): Sequential(
+    (fc_reduce): Linear(in_features=128, out_features=4, bias=True)
+    (se_act): ReLU(inplace=True)
+    (fc_expand): Linear(in_features=4, out_features=128, bias=True)
+    (se_gate): Sigmoid()
+  )
+)
+
+ +
+ +
+
+ +
+
+ {% endraw %} + + {% raw %} + +
+
+ +
+
+
se_block = SEModule(128, rd_channels=32)
+se_block
+
+ +
+
+
+
+ + + +
+
+ +
+ + + +
+
SEModule(
+  (squeeze): AdaptiveAvgPool2d(output_size=1)
+  (excitation): Sequential(
+    (fc_reduce): Linear(in_features=128, out_features=32, bias=True)
+    (se_act): ReLU(inplace=True)
+    (fc_expand): Linear(in_features=32, out_features=128, bias=True)
+    (se_gate): Sigmoid()
+  )
+)
+
+ +
+ +
+
+ +
+
+ {% endraw %} + + {% raw %} + +
+
+ +
+
+
se_block = SEModule(128, reduction=4, rd_channels=16, rd_max=True)
+se_block
+
+ +
+
+
+
+ + + +
+
+ +
+ + + +
+
SEModule(
+  (squeeze): AdaptiveAvgPool2d(output_size=1)
+  (excitation): Sequential(
+    (fc_reduce): Linear(in_features=128, out_features=32, bias=True)
+    (se_act): ReLU(inplace=True)
+    (fc_expand): Linear(in_features=32, out_features=128, bias=True)
+    (se_gate): Sigmoid()
+  )
+)
+
+ +
+ +
+
+ +
+
+ {% endraw %} + +
+
+

SEModuleConv

+
+
+
+ {% raw %} + +
+ +
+
+ +
+ + +
+

class SEModuleConv[source]

SEModuleConv(channels, reduction=16, rd_channels=None, rd_max=False, se_layer=Conv2d, act_fn=ReLU(inplace=True), use_bias=True, gate=Sigmoid) :: Module

+
+

se block with conv on excitation

+ +
+ +
+ +
+
+ +
+ {% endraw %} + + {% raw %} + +
+
+ +
+
+
se_block = SEModuleConv(128)
+se_block
+
+ +
+
+
+
+ + + +
+
+ +
+ + + +
+
SEModuleConv(
+  (squeeze): AdaptiveAvgPool2d(output_size=1)
+  (excitation): Sequential(
+    (conv_reduce): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1))
+    (se_act): ReLU(inplace=True)
+    (conv_expand): Conv2d(8, 128, kernel_size=(1, 1), stride=(1, 1))
+    (gate): Sigmoid()
+  )
+)
+
+ +
+ +
+
+ +
+
+ {% endraw %} + + {% raw %} + +
+
+ +
+
+
se_block = SEModuleConv(128, reduction=32)
+se_block
+
+ +
+
+
+
+ + + +
+
+ +
+ + + +
+
SEModuleConv(
+  (squeeze): AdaptiveAvgPool2d(output_size=1)
+  (excitation): Sequential(
+    (conv_reduce): Conv2d(128, 4, kernel_size=(1, 1), stride=(1, 1))
+    (se_act): ReLU(inplace=True)
+    (conv_expand): Conv2d(4, 128, kernel_size=(1, 1), stride=(1, 1))
+    (gate): Sigmoid()
+  )
+)
+
+ +
+ +
+
+ +
+
+ {% endraw %} + + {% raw %} + +
+
+ +
+
+
se_block = SEModuleConv(128, rd_channels=32)
+se_block
+
+ +
+
+
+
+ + + +
+
+ +
+ + + +
+
SEModuleConv(
+  (squeeze): AdaptiveAvgPool2d(output_size=1)
+  (excitation): Sequential(
+    (conv_reduce): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1))
+    (se_act): ReLU(inplace=True)
+    (conv_expand): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1))
+    (gate): Sigmoid()
+  )
+)
+
+ +
+ +
+
+ +
+
+ {% endraw %} + + {% raw %} + +
+
+ +
+
+
se_block = SEModuleConv(128, reduction=4, rd_channels=16, rd_max=True)
+se_block
+
+ +
+
+
+
+ + + +
+
+ +
+ + + +
+
SEModuleConv(
+  (squeeze): AdaptiveAvgPool2d(output_size=1)
+  (excitation): Sequential(
+    (conv_reduce): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1))
+    (se_act): ReLU(inplace=True)
+    (conv_expand): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1))
+    (gate): Sigmoid()
+  )
+)
+
+ +
+ +
+
+ +
+
+ {% endraw %} + +
+
+

SEBlock

+
+
+
+
+
+

First version of SE block, leaved for compatibility.

+
@@ -450,9 +887,9 @@

class SEBlockSEBlock( (squeeze): AdaptiveAvgPool2d(output_size=1) (excitation): Sequential( - (fc_reduce): Linear(in_features=128, out_features=8, bias=False) + (fc_reduce): Linear(in_features=128, out_features=8, bias=True) (se_act): ReLU(inplace=True) - (fc_expand): Linear(in_features=8, out_features=128, bias=False) + (fc_expand): Linear(in_features=8, out_features=128, bias=True) (sigmoid): Sigmoid() ) ) @@ -470,6 +907,13 @@

class SEBlock + +
+
+

First version of SEBlockConv, leaved for compatibility.

+
@@ -484,7 +928,7 @@

SEBlockConv
-

class SEBlockConv[source]

SEBlockConv(c, r=16) :: Module

+

class SEBlockConv[source]

SEBlockConv(c, r=16) :: Module

se block with conv on excitation

@@ -527,9 +971,9 @@

class SEBlockConvSEBlockConv( (squeeze): AdaptiveAvgPool2d(output_size=1) (excitation): Sequential( - (conv_reduce): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False) + (conv_reduce): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1)) (se_act): ReLU(inplace=True) - (conv_expand): Conv2d(8, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) + (conv_expand): Conv2d(8, 128, kernel_size=(1, 1), stride=(1, 1)) (sigmoid): Sigmoid() ) ) diff --git a/model_constructor/layers.py b/model_constructor/layers.py index 1ef3497..f9799dc 100644 --- a/model_constructor/layers.py +++ b/model_constructor/layers.py @@ -157,13 +157,20 @@ class SEModule(nn.Module): def __init__(self, channels, reduction=16, + rd_channels=None, + rd_max=False, se_layer=nn.Linear, act_fn=nn.ReLU(inplace=True), # ? obj or class? use_bias=True, gate=nn.Sigmoid ): super().__init__() - rd_channels = channels // reduction + reducted = channels // reduction + if rd_channels is None: + rd_channels = reducted + else: + if rd_max: + rd_channels = max(rd_channels, reducted) self.squeeze = nn.AdaptiveAvgPool2d(1) self.excitation = nn.Sequential( OrderedDict([('fc_reduce', se_layer(channels, rd_channels, bias=use_bias)), @@ -185,6 +192,8 @@ class SEModuleConv(nn.Module): def __init__(self, channels, reduction=16, + rd_channels=None, + rd_max=False, se_layer=nn.Conv2d, act_fn=nn.ReLU(inplace=True), use_bias=True, @@ -192,7 +201,12 @@ def __init__(self, ): super().__init__() # rd_channels = math.ceil(channels//reduction/8)*8 - rd_channels = channels // reduction + reducted = channels // reduction + if rd_channels is None: + rd_channels = reducted + else: + if rd_max: + rd_channels = max(rd_channels, reducted) self.squeeze = nn.AdaptiveAvgPool2d(1) self.excitation = nn.Sequential( OrderedDict([