diff --git a/docs/.gitignore b/docs/.gitignore deleted file mode 100644 index 57510a2..0000000 --- a/docs/.gitignore +++ /dev/null @@ -1 +0,0 @@ -_site/ diff --git a/docs/Gemfile b/docs/Gemfile index f09053c..dd0dc0e 100644 --- a/docs/Gemfile +++ b/docs/Gemfile @@ -2,6 +2,7 @@ source "https://rubygems.org" gem 'github-pages', group: :jekyll_plugins - # Added at 2019-11-25 10:11:40 -0800 by jhoward: -gem "jekyll", "~> 3.7" +gem "nokogiri", "< 1.10.9" +gem "jekyll", ">= 3.7" +gem "kramdown", ">= 2.3.0" \ No newline at end of file diff --git a/docs/Gemfile.lock b/docs/Gemfile.lock index 182e13e..4bd9ce6 100644 --- a/docs/Gemfile.lock +++ b/docs/Gemfile.lock @@ -1,12 +1,13 @@ GEM remote: https://rubygems.org/ specs: - activesupport (4.2.11.1) - i18n (~> 0.7) + activesupport (6.0.4) + concurrent-ruby (~> 1.0, >= 1.0.2) + i18n (>= 0.7, < 2) minitest (~> 5.1) - thread_safe (~> 0.3, >= 0.3.4) tzinfo (~> 1.1) - addressable (2.7.0) + zeitwerk (~> 2.2, >= 2.2.2) + addressable (2.8.0) public_suffix (>= 2.0.2, < 5.0) coffee-script (2.4.1) coffee-script-source @@ -15,93 +16,109 @@ GEM colorator (1.1.0) commonmarker (0.17.13) ruby-enum (~> 0.5) - concurrent-ruby (1.1.5) - dnsruby (1.61.3) - addressable (~> 2.5) - em-websocket (0.5.1) + concurrent-ruby (1.1.9) + dnsruby (1.61.7) + simpleidn (~> 0.1) + em-websocket (0.5.2) eventmachine (>= 0.12.9) http_parser.rb (~> 0.6.0) - ethon (0.12.0) - ffi (>= 1.3.0) + ethon (0.14.0) + ffi (>= 1.15.0) eventmachine (1.2.7) - execjs (2.7.0) - faraday (0.17.0) + execjs (2.8.1) + faraday (1.6.0) + faraday-em_http (~> 1.0) + faraday-em_synchrony (~> 1.0) + faraday-excon (~> 1.1) + faraday-httpclient (~> 1.0.1) + faraday-net_http (~> 1.0) + faraday-net_http_persistent (~> 1.1) + faraday-patron (~> 1.0) + faraday-rack (~> 1.0) multipart-post (>= 1.2, < 3) - ffi (1.11.3) + ruby2_keywords (>= 0.0.4) + faraday-em_http (1.0.0) + faraday-em_synchrony (1.0.0) + faraday-excon (1.1.0) + faraday-httpclient (1.0.1) + faraday-net_http (1.0.1) + faraday-net_http_persistent (1.2.0) + faraday-patron (1.0.0) + faraday-rack (1.0.0) + ffi (1.15.3) forwardable-extended (2.6.0) gemoji (3.0.1) - github-pages (202) - activesupport (= 4.2.11.1) - github-pages-health-check (= 1.16.1) - jekyll (= 3.8.5) - jekyll-avatar (= 0.6.0) + github-pages (218) + github-pages-health-check (= 1.17.2) + jekyll (= 3.9.0) + jekyll-avatar (= 0.7.0) jekyll-coffeescript (= 1.1.1) jekyll-commonmark-ghpages (= 0.1.6) jekyll-default-layout (= 0.1.4) - jekyll-feed (= 0.11.0) + jekyll-feed (= 0.15.1) jekyll-gist (= 1.5.0) - jekyll-github-metadata (= 2.12.1) - jekyll-mentions (= 1.4.1) - jekyll-optional-front-matter (= 0.3.0) + jekyll-github-metadata (= 2.13.0) + jekyll-mentions (= 1.6.0) + jekyll-optional-front-matter (= 0.3.2) jekyll-paginate (= 1.1.0) - jekyll-readme-index (= 0.2.0) - jekyll-redirect-from (= 0.14.0) - jekyll-relative-links (= 0.6.0) - jekyll-remote-theme (= 0.4.0) + jekyll-readme-index (= 0.3.0) + jekyll-redirect-from (= 0.16.0) + jekyll-relative-links (= 0.6.1) + jekyll-remote-theme (= 0.4.3) jekyll-sass-converter (= 1.5.2) - jekyll-seo-tag (= 2.5.0) - jekyll-sitemap (= 1.2.0) - jekyll-swiss (= 0.4.0) - jekyll-theme-architect (= 0.1.1) - jekyll-theme-cayman (= 0.1.1) - jekyll-theme-dinky (= 0.1.1) - jekyll-theme-hacker (= 0.1.1) - jekyll-theme-leap-day (= 0.1.1) - jekyll-theme-merlot (= 0.1.1) - jekyll-theme-midnight (= 0.1.1) - jekyll-theme-minimal (= 0.1.1) - jekyll-theme-modernist (= 0.1.1) - jekyll-theme-primer (= 0.5.3) - jekyll-theme-slate (= 0.1.1) - jekyll-theme-tactile (= 0.1.1) - jekyll-theme-time-machine (= 0.1.1) - jekyll-titles-from-headings (= 0.5.1) - jemoji (= 0.10.2) - kramdown (>= 2.3.0) - liquid (= 4.0.0) - listen (= 3.1.5) + jekyll-seo-tag (= 2.7.1) + jekyll-sitemap (= 1.4.0) + jekyll-swiss (= 1.0.0) + jekyll-theme-architect (= 0.2.0) + jekyll-theme-cayman (= 0.2.0) + jekyll-theme-dinky (= 0.2.0) + jekyll-theme-hacker (= 0.2.0) + jekyll-theme-leap-day (= 0.2.0) + jekyll-theme-merlot (= 0.2.0) + jekyll-theme-midnight (= 0.2.0) + jekyll-theme-minimal (= 0.2.0) + jekyll-theme-modernist (= 0.2.0) + jekyll-theme-primer (= 0.6.0) + jekyll-theme-slate (= 0.2.0) + jekyll-theme-tactile (= 0.2.0) + jekyll-theme-time-machine (= 0.2.0) + jekyll-titles-from-headings (= 0.5.3) + jemoji (= 0.12.0) + kramdown (= 2.3.1) + kramdown-parser-gfm (= 1.1.0) + liquid (= 4.0.3) mercenary (~> 0.3) - minima (= 2.5.0) - nokogiri (>= 1.11.4, < 2.0) - rouge (= 3.11.0) + minima (= 2.5.1) + nokogiri (>= 1.10.4, < 2.0) + rouge (= 3.26.0) terminal-table (~> 1.4) - github-pages-health-check (1.16.1) + github-pages-health-check (1.17.2) addressable (~> 2.3) dnsruby (~> 1.60) octokit (~> 4.0) - public_suffix (~> 3.0) + public_suffix (>= 2.0.2, < 5.0) typhoeus (~> 1.3) - html-pipeline (2.12.2) + html-pipeline (2.14.0) activesupport (>= 2) - nokogiri (>= 1.11.4) + nokogiri (>= 1.4) http_parser.rb (0.6.0) i18n (0.9.5) concurrent-ruby (~> 1.0) - jekyll (3.8.5) + jekyll (3.9.0) addressable (~> 2.4) colorator (~> 1.0) em-websocket (~> 0.5) i18n (~> 0.7) jekyll-sass-converter (~> 1.0) jekyll-watch (~> 2.0) - kramdown (>= 2.3.0) + kramdown (>= 1.17, < 3) liquid (~> 4.0) mercenary (~> 0.3.3) pathutil (~> 0.9) rouge (>= 1.7, < 4) safe_yaml (~> 1.0) - jekyll-avatar (0.6.0) - jekyll (~> 3.0) + jekyll-avatar (0.7.0) + jekyll (>= 3.0, < 5.0) jekyll-coffeescript (1.1.1) coffee-script (~> 2.2) coffee-script-source (~> 1.11.1) @@ -114,113 +131,118 @@ GEM rouge (>= 2.0, < 4.0) jekyll-default-layout (0.1.4) jekyll (~> 3.0) - jekyll-feed (0.11.0) - jekyll (~> 3.3) + jekyll-feed (0.15.1) + jekyll (>= 3.7, < 5.0) jekyll-gist (1.5.0) octokit (~> 4.2) - jekyll-github-metadata (2.12.1) - jekyll (~> 3.4) + jekyll-github-metadata (2.13.0) + jekyll (>= 3.4, < 5.0) octokit (~> 4.0, != 4.4.0) - jekyll-mentions (1.4.1) + jekyll-mentions (1.6.0) html-pipeline (~> 2.3) - jekyll (~> 3.0) - jekyll-optional-front-matter (0.3.0) - jekyll (~> 3.0) + jekyll (>= 3.7, < 5.0) + jekyll-optional-front-matter (0.3.2) + jekyll (>= 3.0, < 5.0) jekyll-paginate (1.1.0) - jekyll-readme-index (0.2.0) - jekyll (~> 3.0) - jekyll-redirect-from (0.14.0) - jekyll (~> 3.3) - jekyll-relative-links (0.6.0) - jekyll (~> 3.3) - jekyll-remote-theme (0.4.0) + jekyll-readme-index (0.3.0) + jekyll (>= 3.0, < 5.0) + jekyll-redirect-from (0.16.0) + jekyll (>= 3.3, < 5.0) + jekyll-relative-links (0.6.1) + jekyll (>= 3.3, < 5.0) + jekyll-remote-theme (0.4.3) addressable (~> 2.0) - jekyll (~> 3.5) - rubyzip (>= 1.2.1, < 3.0) + jekyll (>= 3.5, < 5.0) + jekyll-sass-converter (>= 1.0, <= 3.0.0, != 2.0.0) + rubyzip (>= 1.3.0, < 3.0) jekyll-sass-converter (1.5.2) sass (~> 3.4) - jekyll-seo-tag (2.5.0) - jekyll (~> 3.3) - jekyll-sitemap (1.2.0) - jekyll (~> 3.3) - jekyll-swiss (0.4.0) - jekyll-theme-architect (0.1.1) - jekyll (~> 3.5) + jekyll-seo-tag (2.7.1) + jekyll (>= 3.8, < 5.0) + jekyll-sitemap (1.4.0) + jekyll (>= 3.7, < 5.0) + jekyll-swiss (1.0.0) + jekyll-theme-architect (0.2.0) + jekyll (> 3.5, < 5.0) jekyll-seo-tag (~> 2.0) - jekyll-theme-cayman (0.1.1) - jekyll (~> 3.5) + jekyll-theme-cayman (0.2.0) + jekyll (> 3.5, < 5.0) jekyll-seo-tag (~> 2.0) - jekyll-theme-dinky (0.1.1) - jekyll (~> 3.5) + jekyll-theme-dinky (0.2.0) + jekyll (> 3.5, < 5.0) jekyll-seo-tag (~> 2.0) - jekyll-theme-hacker (0.1.1) - jekyll (~> 3.5) + jekyll-theme-hacker (0.2.0) + jekyll (> 3.5, < 5.0) jekyll-seo-tag (~> 2.0) - jekyll-theme-leap-day (0.1.1) - jekyll (~> 3.5) + jekyll-theme-leap-day (0.2.0) + jekyll (> 3.5, < 5.0) jekyll-seo-tag (~> 2.0) - jekyll-theme-merlot (0.1.1) - jekyll (~> 3.5) + jekyll-theme-merlot (0.2.0) + jekyll (> 3.5, < 5.0) jekyll-seo-tag (~> 2.0) - jekyll-theme-midnight (0.1.1) - jekyll (~> 3.5) + jekyll-theme-midnight (0.2.0) + jekyll (> 3.5, < 5.0) jekyll-seo-tag (~> 2.0) - jekyll-theme-minimal (0.1.1) - jekyll (~> 3.5) + jekyll-theme-minimal (0.2.0) + jekyll (> 3.5, < 5.0) jekyll-seo-tag (~> 2.0) - jekyll-theme-modernist (0.1.1) - jekyll (~> 3.5) + jekyll-theme-modernist (0.2.0) + jekyll (> 3.5, < 5.0) jekyll-seo-tag (~> 2.0) - jekyll-theme-primer (0.5.3) - jekyll (~> 3.5) + jekyll-theme-primer (0.6.0) + jekyll (> 3.5, < 5.0) jekyll-github-metadata (~> 2.9) jekyll-seo-tag (~> 2.0) - jekyll-theme-slate (0.1.1) - jekyll (~> 3.5) + jekyll-theme-slate (0.2.0) + jekyll (> 3.5, < 5.0) jekyll-seo-tag (~> 2.0) - jekyll-theme-tactile (0.1.1) - jekyll (~> 3.5) + jekyll-theme-tactile (0.2.0) + jekyll (> 3.5, < 5.0) jekyll-seo-tag (~> 2.0) - jekyll-theme-time-machine (0.1.1) - jekyll (~> 3.5) + jekyll-theme-time-machine (0.2.0) + jekyll (> 3.5, < 5.0) jekyll-seo-tag (~> 2.0) - jekyll-titles-from-headings (0.5.1) - jekyll (~> 3.3) + jekyll-titles-from-headings (0.5.3) + jekyll (>= 3.3, < 5.0) jekyll-watch (2.2.1) listen (~> 3.0) - jemoji (0.10.2) + jemoji (0.12.0) gemoji (~> 3.0) html-pipeline (~> 2.2) - jekyll (~> 3.0) - kramdown (>= 2.3.0) - liquid (4.0.0) - listen (3.1.5) - rb-fsevent (~> 0.9, >= 0.9.4) - rb-inotify (~> 0.9, >= 0.9.7) - ruby_dep (~> 1.2) + jekyll (>= 3.0, < 5.0) + kramdown (2.3.1) + rexml + kramdown-parser-gfm (1.1.0) + kramdown (~> 2.0) + liquid (4.0.3) + listen (3.6.0) + rb-fsevent (~> 0.10, >= 0.10.3) + rb-inotify (~> 0.9, >= 0.9.10) mercenary (0.3.6) mini_portile2 (2.4.0) - minima (2.5.0) - jekyll (~> 3.5) + minima (2.5.1) + jekyll (>= 3.5, < 5.0) jekyll-feed (~> 0.9) jekyll-seo-tag (~> 2.1) - minitest (5.13.0) + minitest (5.14.4) multipart-post (2.1.1) nokogiri (1.10.8) mini_portile2 (~> 2.4.0) - octokit (4.14.0) + octokit (4.21.0) + faraday (>= 0.9) sawyer (~> 0.8.0, >= 0.5.3) pathutil (0.16.2) forwardable-extended (~> 2.6) - public_suffix (3.1.1) - rb-fsevent (0.10.3) - rb-inotify (0.10.0) + public_suffix (4.0.6) + rb-fsevent (0.11.0) + rb-inotify (0.10.1) ffi (~> 1.0) - rouge (3.11.0) - ruby-enum (0.7.2) + rexml (3.2.5) + rouge (3.26.0) + ruby-enum (0.9.0) i18n - ruby_dep (1.5.0) - rubyzip (2.0.0) + ruby2_keywords (0.0.5) + rubyzip (2.3.2) safe_yaml (1.0.5) sass (3.7.4) sass-listen (~> 4.0.0) @@ -228,23 +250,31 @@ GEM rb-fsevent (~> 0.9, >= 0.9.4) rb-inotify (~> 0.9, >= 0.9.7) sawyer (0.8.2) - addressable (>= 2.8.0) + addressable (>= 2.3.5) faraday (> 0.8, < 2.0) + simpleidn (0.2.1) + unf (~> 0.1.4) terminal-table (1.8.0) unicode-display_width (~> 1.1, >= 1.1.1) thread_safe (0.3.6) - typhoeus (1.3.1) + typhoeus (1.4.0) ethon (>= 0.9.0) - tzinfo (1.2.5) + tzinfo (1.2.9) thread_safe (~> 0.1) - unicode-display_width (1.6.0) + unf (0.1.4) + unf_ext + unf_ext (0.0.7.7) + unicode-display_width (1.7.0) + zeitwerk (2.4.2) PLATFORMS - ruby + x86_64-linux DEPENDENCIES github-pages - jekyll (~> 3.7) + jekyll (>= 3.7) + kramdown (>= 2.3.0) + nokogiri (< 1.10.9) BUNDLED WITH - 2.0.2 + 2.2.25 diff --git a/docs/Twist.html b/docs/Twist.html index 72d9f3f..2a0e82a 100644 --- a/docs/Twist.html +++ b/docs/Twist.html @@ -858,6 +858,9 @@
class NewResBloc
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call :meth:to, etc.
:ivar training: Boolean represents whether this module is in training or + evaluation mode. +:vartype training: bool
@@ -1090,6 +1093,9 @@class ResBlockTwist
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call :meth:to, etc.
:ivar training: Boolean represents whether this module is in training or + evaluation mode. +:vartype training: bool
diff --git a/docs/_config.yml b/docs/_config.yml deleted file mode 100644 index 74610e8..0000000 --- a/docs/_config.yml +++ /dev/null @@ -1,64 +0,0 @@ -repository: ayasyrev/model_constructor -output: web -topnav_title: model_constructor -site_title: model_constructor -company_name: Andrei Yasyrev -description: Constructor for pytorch models. -# Set to false to disable KaTeX math -use_math: true -# Add Google analytics id if you have one and want to use it here -google_analytics: -# See http://nbdev.fast.ai/search for help with adding Search -google_search: - -host: 127.0.0.1 -# the preview server used. Leave as is. -port: 4000 -# the port where the preview is rendered. - -exclude: - - .idea/ - - .gitignore - - vendor - -exclude: [vendor] - -highlighter: rouge -markdown: kramdown -kramdown: - input: GFM - auto_ids: true - hard_wrap: false - syntax_highlighter: rouge - -collections: - tooltips: - output: false - -defaults: - - - scope: - path: "" - type: "pages" - values: - layout: "page" - comments: true - search: true - sidebar: home_sidebar - topnav: topnav - - - scope: - path: "" - type: "tooltips" - values: - layout: "page" - comments: true - search: true - tooltip: true - -sidebars: -- home_sidebar -permalink: pretty - -theme: jekyll-theme-cayman -baseurl: /model_constructor/ \ No newline at end of file diff --git a/docs/_data/alerts.yml b/docs/_data/alerts.yml deleted file mode 100644 index 157e162..0000000 --- a/docs/_data/alerts.yml +++ /dev/null @@ -1,15 +0,0 @@ -tip: 'MXResNet model, disscussed at https://forums.fast.ai/t/how-we-beat-the-5-epoch-imagewoof-leaderboard-score-some-new-techniques-to-consider
+ +mxresnet = Net(stem_sizes = [3, 32, 64, 64], name='MXResNet', act_fn=Mish())
+mxresnet
+mxresnet.block_sizes, mxresnet.layers
+mxresnet.stem
+mxresnet.body
+mxresnet.body
+mxresnet.head
+model = mxresnet50(c_out=10)
+model
+model.c_out, model.layers
+from nbdev import docs, showdoc, show_doc
+ResBlock(1,64,64,sa=True)
+ResBlock(4,64,64,sa=True, dw=True)
+ResBlock(4,64,64,sa=True, groups=4)
+ResBlock(2,64,64,act_fn=nn.LeakyReLU(), bn_1st=False)
+ResBlock(2, 64, 64, sa=True, se=True)
+NewResBlock now is YaResBlock - Yet Another ResNet Block! It is now at model_constructor.yaresnet. +Here i left old name for compatibility with existing Notebooks.
+ +NewResBlock(4, 64, 128, dw=1)
+model = Net()
+model
+model._block_sizes
+model.block_sizes
+model._block_sizes = [128, 256, 512, 1024]
+model
+model.block_sizes
+model.block_sizes
+model = Net()
+model.stem
+model.stem_stride_on = 1
+model.stem
+model.bn_1st = False
+model.act_fn =nn.LeakyReLU(inplace=True)
+model.sa = True
+model.se = True
+model.body.l_1
+model.block = NewResBlock
+model.expansion = 4
+m = model()
+m.body
+# %nbdev_export
+# # me = sys.modules[__name__]
+# # for n,e,l in [[ 18 , 1, [2,2,2 ,2] ],
+# # [ 34 , 1, [3,4,6 ,3] ],
+# # [ 50 , 4, [3,4,6 ,3] ],
+# # [ 101, 4, [3,4,23,3] ],
+# # [ 152, 4, [3,8,36,3] ],]:
+# # name = f'net{n}'
+# # setattr(me, name, partial(Net, expansion=e, layers=l, name=name))
+# net34 = partial(Net, expansion=1, layers=[3, 4, 6, 3], name='xresnet34')
+# net50 = partial(Net, expansion=4, layers=[3, 4, 6, 3], name='xresnet50')
+
+
+class XResNet34(Net):
+ def __init__(self):
+ super().__init__()
+ self.layers = [3,4,6,3]
+model = XResNet34()
+model
+from dataclasses import field
+@dataclass
+class ConfXresnet34:
+ name: str = 'xresnet34'
+ layers: list = field(default_factory=lambda : [3,4,6,3])
+asdict(ConfXresnet34())
+@dataclass
+class Xres50(ConfXresnet34):
+ name = 'xresnet50'
+ expansion: int = 4
+Xres50()
+asdict(Xres50())
+model = Net(asdict(Xres50()))
+model
+ConvTwist(64,64)
+ConvTwist.twist, ConvTwist.permute
+ConvTwist.use_groups, ConvTwist.groups_ch
+ConvTwist(64,64)
+ConvTwist.twist = True
+ConvTwist.permute = False
+ConvTwist(64,64)
+ConvLayerTwist(64,64, stride=1)
+ConvLayer.Conv2d
+ConvLayerTwist.Conv2d
+conv_layer = ConvLayerTwist(32, 64)
+conv_layer
+ConvTwist.twist = False
+conv_layer = ConvLayerTwist(32, 64)
+conv_layer
+conv_layer = ConvLayerTwist(32, 64, act=False)
+conv_layer
+conv_layer = ConvLayerTwist(32, 64, bn_layer=False)
+conv_layer
+conv_layer = ConvLayerTwist(32, 64, bn_1st=True)
+conv_layer
+conv_layer = ConvLayerTwist(32, 64, bn_1st=True, act_fn=nn.LeakyReLU())
+conv_layer
+conv_layer = ConvLayerTwist(32, 64, ks=1)
+conv_layer
+conv_layer = ConvLayerTwist(32, 64, ks=1, stride=2)
+conv_layer
+conv_layer = ConvLayerTwist(32, 64, stride=2)
+conv_layer
+ConvTwist.groups_ch = 4
+conv_layer = ConvLayerTwist(32, 64, stride=2)
+conv_layer
+bl = NewResBlockTwist(4,64,64,sa=True)
+bl
+bl = NewResBlockTwist(4,64,64,stride=2)
+bl
+bl = NewResBlockTwist(4,64,128,stride=2)
+bl
+bl = ResBlockTwist(4,64,64,sa=True)
+bl
+bl = ResBlockTwist(4,64,64,stride=2)
+bl
+bl = ResBlockTwist(4,64,128,stride=2)
+bl
+model = Net(expansion=4, layers=[3,4,6,3])
+model.block = NewResBlockTwist
+model.body
+model.block = ResBlockTwist
+m = model()
+m.stem
+m.head
+m.body.l_0
+m.body.l_1
+m.body.l_2
+m.body.l_3
+yaresnet = Net(block=YaResBlock, stem_sizes = [3, 32, 64, 64], name='YaResNet')
+yaresnet
+yaresnet.block_sizes, yaresnet.layers
+yaresnet.stem
+yaresnet.body
+yaresnet.head
+# bs_test = 16
+# xb = torch.randn(bs_test, 3, 128, 128)
+# y = yaresnet()(xb)
+# print(y.shape)
+# assert y.shape == torch.Size([bs_test, 1000]), f"size"
+Lots of experiments showed that it worth trying Mish activation function.
+ +yaresnet.act_fn = Mish()
+yaresnet()
+model = yaresnet50(c_out=10)
+model
+model.c_out, model.layers
+Activation functions, forked from https://github.com/rwightman/pytorch-image-models/timm/models/layers/activations.py
+ +Mish: Self Regularized
+Non-Monotonic Activation Function
+https://github.com/digantamisra98/Mish
+fastai forum discussion https://forums.fast.ai/t/meet-mish-new-activation-function-possible-successor-to-relu
class Stem
class Stem
class Stem
class Stem
class Stem
class Stem
class Stem
- {% raw %}
-DownsampleBlock
-class BasicBlock[source]+
BasicBlock(ni,nf,expansion=1,stride=1,zero_bn=False,conv_layer=ConvLayer,act_fn=ReLU(inplace=True),downsample_block=DownsampleBlock, **kwargs) ::Module
classBasicBlock[source]
BasicBlock(ni,nf,expansion=1,stride=1,zero_bn=False,conv_layer=ConvLayer,act_fn=ReLU(inplace=True),downsample_block=DownsampleBlock, **kwargs) ::ModuleBasic block (simplified) as in pytorch resnet
@@ -462,16 +436,16 @@
classBasicBlock
class BasicBlock
-class Bottleneck[source]+
Bottleneck(ni,nh,expansion=4,stride=1,zero_bn=False,conv_layer=ConvLayer,act_fn=ReLU(inplace=True),downsample_block=DownsampleBlock, **kwargs) ::Module
classBottleneck[source]
Bottleneck(ni,nh,expansion=4,stride=1,zero_bn=False,conv_layer=ConvLayer,act_fn=ReLU(inplace=True),downsample_block=DownsampleBlock, **kwargs) ::ModuleBottlneck block for resnet models
@@ -494,16 +468,16 @@
classBottleneck
class Bottleneck
-class BasicLayer[source]+
BasicLayer(block,blocks,ni,nf,expansion,stride,sa=False, **kwargs) ::Sequential
classBasicLayer[source]
BasicLayer(block,blocks,ni,nf,expansion,stride,sa=False, **kwargs) ::SequentialLayer from blocks
@@ -526,7 +500,14 @@
classBasicLayer
class Body[source]+
Body(block,body_in=64,body_out=512,bodylayer=BasicLayer,expansion=1,layer_szs=[64, 128, 256],blocks=[2, 2, 2, 2],sa=False, **kwargs) ::Sequential
classBody[source]
Body(block,body_in=64,body_out=512,bodylayer=BasicLayer,expansion=1,layer_szs=[64, 128, 256],blocks=[2, 2, 2, 2],sa=False, **kwargs) ::SequentialConstructor for body
@@ -564,9 +538,16 @@
classBody+ +
class Body
- {% raw %}
-class Head
- {% raw %}
-init_model
-class Net[source]+
Net(stem=Stem,body=Body,block=BasicBlock,sa=False,layer_szs=[64, 128, 256],blocks=[2, 2, 2, 2],head=Head,c_in=3,num_classes=1000,body_in=64,body_out=512,expansion=1,init_fn=init_model, **kwargs) ::Sequential
classNet[source]
Net(stem=Stem,body=Body,block=BasicBlock,sa=False,layer_szs=[64, 128, 256],blocks=[2, 2, 2, 2],head=Head,c_in=3,num_classes=1000,body_in=64,body_out=512,expansion=1,init_fn=init_model, **kwargs) ::SequentialConstructor for model
@@ -939,9 +913,16 @@
classNet+ +
class Net
class Net
class Netpip install model-constructor
Or install from repo:
+ +pip install git+https://github.com/ayasyrev/model_constructor.git
First import constructor class, then create model constructor oject.
+Now you can change every part of model.
+ +from model_constructor.net import *
+model = Net()
+model
+Now we have model consructor, default setting as xresnet18. And we can get model after call it.
+ +model.c_in
+model.c_out
+model.stem_sizes
+model.layers
+model.expansion
+model()
+If you want to change model, just change constructor parameters.
+Lets create xresnet50.
model.expansion = 4
+model.layers = [3,4,6,3]
+Now we can look at model body and if we call constructor - we have pytorch model!
+ +model.body
+Main purpose of this module - fast and easy modify model. +And here is the link to more modification to beat Imagenette leaderboard with add MaxBlurPool and modification to ResBlock https://github.com/ayasyrev/imagenette_experiments/blob/master/ResnetTrick_create_model_fit.ipynb
+But now lets create model as mxresnet50 from fastai forums tread https://forums.fast.ai/t/how-we-beat-the-5-epoch-imagewoof-leaderboard-score-some-new-techniques-to-consider
+ +Lets create mxresnet constructor.
+ +model = Net(name='MxResNet')
+Then lets modify stem.
+ +model.stem_sizes = [3,32,64,64]
+Now lets change activation function to Mish.
+Here is link to forum disscussion https://forums.fast.ai/t/meet-mish-new-activation-function-possible-successor-to-relu
+Mish is in model_constructor.activations
from model_constructor.activations import Mish
+model.act_fn = Mish()
+model
+model()
+Now lets make MxResNet50
+ +model.expansion = 4
+model.layers = [3,4,6,3]
+model.name = 'mxresnet50'
+Now we have mxresnet50 constructor.
+We can inspect every parts of it.
+And after call it we got model.
model
+model.stem.conv_1
+model.body.l_0.bl_0
+Now lets change Resblock to YaResBlock (Yet another ResNet, former NewResBlock) is in lib from version 0.1.0
+ +from model_constructor.yaresnet import YaResBlock
+model.block = YaResBlock
+That all. Now we have YaResNet constructor
+ +model.name = 'YaResNet'
+model
+Let see what we have.
+ +model.body.l_1.bl_0
+conv_layer = ConvLayer(32, 64)
+conv_layer
+conv_layer = ConvLayer(32, 64, act=False)
+conv_layer
+conv_layer = ConvLayer(32, 64, bn_layer=False)
+conv_layer
+conv_layer = ConvLayer(32, 64, bn_1st=True)
+conv_layer
+conv_layer = ConvLayer(32, 64, bn_1st=True, act_fn=nn.LeakyReLU())
+conv_layer
+se_block = SEBlock(128)
+se_block
+se_block = SEBlockConv(128)
+se_block
+body = Body(Bottleneck, expansion=4)
-body
+body = Body(XResBlock, expansion=4)
+body
resnet18[source]+
resnet18(**kwargs)
xresnet18[source]-
xresnet18(**kwargs)Constructs a ResNet-18 model.
+Constructs a xresnet-18 model.
resnet18
@@ -386,9 +439,9 @@ resnet18
-resnet34[source]+
resnet34(**kwargs)
xresnet34[source]-
xresnet34(**kwargs)Constructs a ResNet-34 model.
+Constructs axresnet-34 model.
resnet34
@@ -411,9 +464,9 @@ resnet34
-resnet50[source]+
resnet50(**kwargs)
xresnet50[source]-
xresnet50(**kwargs)Constructs a ResNet-18 model.
+Constructs axresnet-34 model.
resnet50
+
+resnet18()
+xresnet18()
resnet50
Net(
(stem): Stem(
- sizes: [3, 64]
- (conv0): ConvLayer(
- (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
+ sizes: [3, 32, 32, 64]
+ (conv_0): ConvLayer(
+ (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
+ (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+ (act_fn): ReLU(inplace=True)
+ )
+ (conv_1): ConvLayer(
+ (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+ (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+ (act_fn): ReLU(inplace=True)
+ )
+ (conv_2): ConvLayer(
+ (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_fn): ReLU(inplace=True)
)
@@ -460,8 +530,8 @@ resnet50resnet50resnet50resnet50resnet50resnet50resnet50resnet50resnet50resnet50
-model = resnet34()
-model
+xresnet34()
@@ -652,9 +727,19 @@ resnet50
Net(
(stem): Stem(
- sizes: [3, 64]
- (conv0): ConvLayer(
- (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
+ sizes: [3, 32, 32, 64]
+ (conv_0): ConvLayer(
+ (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
+ (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+ (act_fn): ReLU(inplace=True)
+ )
+ (conv_1): ConvLayer(
+ (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+ (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+ (act_fn): ReLU(inplace=True)
+ )
+ (conv_2): ConvLayer(
+ (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_fn): ReLU(inplace=True)
)
@@ -663,8 +748,8 @@ resnet50resnet50resnet50resnet50resnet50resnet50resnet50resnet50resnet50resnet50resnet50resnet50resnet50resnet50resnet50resnet50resnet50resnet50
-xb = torch.randn(16, 3, 128, 128)
-y = model(xb)
-y.shape
+model = xresnet50()
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -991,26 +1115,9 @@ resnet50
-
-
-
-
-model = resnet50()
-
-
-
-
-
-
-
- {% endraw %}
- {% raw %}
+
@@ -1035,366 +1142,373 @@ resnet50resnet50
-xb = torch.randn(16, 3, 128, 128)
+model.head
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+xb = torch.randn(8, 3, 128, 128)
y = model(xb)
y.shape
@@ -1432,7 +1583,7 @@ resnet50
-torch.Size([16, 1000])
+torch.Size([8, 1000])
@@ -1441,11 +1592,11 @@ resnet50
diff --git a/docs/activations.html b/docs/activations.html
index 74137ea..488bdf8 100644
--- a/docs/activations.html
+++ b/docs/activations.html
@@ -141,7 +141,7 @@ MishJit
@@ -225,7 +228,7 @@ MishJitMe - memory-efficient.
-ScriptFunction object at 0x7f60a962c950>[source]
ScriptFunction object at 0x7f60a962c950>()
+ScriptFunction object at 0x7f691998d310>[source]
ScriptFunction object at 0x7f691998d310>()
@@ -249,7 +252,7 @@ Scrip
@@ -273,7 +276,7 @@ Scrip