# MXResNet.

> MXResNet model.

MXResNet model, [forum discussion](https://forums.fast.ai/t/how-we-beat-the-5-epoch-imagewoof-leaderboard-score-some-new-techniques-to-consider)

In [None]:
#hide
from fastcore.test import *
import torch
from typing import List

from functools import partial

# from model_constructor.activations import Mish
from model_constructor import ModelConstructor

## MXResNet constructor.

In [None]:
mxresnet  = ModelConstructor(
    name='MXResNet',
    stem_sizes=[3, 32, 64, 64],
    act_fn=torch.nn.Mish(),
)

In [None]:
mxresnet

MXResNet constructor
  in_chans: 3, num_classes: 1000
  expansion: 1, groups: 1, dw: False, div_groups: None
  sa: False, se: False
  stem sizes: [3, 32, 64, 64], stride on 0
  body sizes [64, 128, 256, 512]
  layers: [2, 2, 2, 2]

In [None]:
mxresnet.block_sizes, mxresnet.layers

([64, 128, 256, 512], [2, 2, 2, 2])

In [None]:
#collapse_output
mxresnet.stem

Sequential(
  (conv_0): ConvBnAct(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act_fn): Mish()
  )
  (conv_1): ConvBnAct(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act_fn): Mish()
  )
  (conv_2): ConvBnAct(
    (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act_fn): Mish()
  )
  (stem_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)

In [None]:
#collapse_output
mxresnet.body

Sequential(
  (l_0): Sequential(
    (bl_0): ResBlock(
      (convs): Sequential(
        (conv_0): ConvBnAct(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act_fn): Mish()
        )
        (conv_1): ConvBnAct(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (act_fn): Mish()
    )
    (bl_1): ResBlock(
      (convs): Sequential(
        (conv_0): ConvBnAct(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act_fn): Mish()
        )
        (conv_1): ConvBnAct(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), str

In [None]:
#collapse_output
mxresnet.body

Sequential(
  (l_0): Sequential(
    (bl_0): ResBlock(
      (convs): Sequential(
        (conv_0): ConvBnAct(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act_fn): Mish()
        )
        (conv_1): ConvBnAct(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (act_fn): Mish()
    )
    (bl_1): ResBlock(
      (convs): Sequential(
        (conv_0): ConvBnAct(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act_fn): Mish()
        )
        (conv_1): ConvBnAct(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), str

In [None]:
#collapse_output
mxresnet.head

Sequential(
  (pool): AdaptiveAvgPool2d(output_size=1)
  (flat): Flatten(start_dim=1, end_dim=-1)
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)

In [None]:
#hide
bs_test = 16
xb = torch.randn(bs_test, 3, 128, 128)
y = mxresnet()(xb)
print(y.shape)
assert y.shape == torch.Size([bs_test, 1000]), f"size"

torch.Size([16, 1000])


## MxResNet constructors

Lets create constructor class for MxResnet.

In [None]:
class MXResnet34(ModelConstructor):
    name: str = "MXResnet34"
    expansion: int = 1
    layers: List[int] = [3, 4, 6, 3]
    stem_sizes: List[int] = [3, 32, 64, 64]
    act_fn: torch.nn.Module = torch.nn.Mish()

MXResnet50 inherit from MXResnet34.

In [None]:
class MXResnet50(MXResnet34):
    name: str = "MXResnet50"
    expansion: int = 4

Now we can create constructor from class adn change model parameters during initialization or after.

In [None]:
mc = MXResnet34(num_classes=10)
mc

MXResnet34 constructor
  in_chans: 3, num_classes: 10
  expansion: 1, groups: 1, dw: False, div_groups: None
  sa: False, se: False
  stem sizes: [3, 32, 64, 64], stride on 0
  body sizes [64, 128, 256, 512]
  layers: [3, 4, 6, 3]

In [None]:
mc = MXResnet50()
mc

MXResnet50 constructor
  in_chans: 3, num_classes: 1000
  expansion: 4, groups: 1, dw: False, div_groups: None
  sa: False, se: False
  stem sizes: [3, 32, 64, 64], stride on 0
  body sizes [64, 128, 256, 512]
  layers: [3, 4, 6, 3]

To create model - call model constructor object.

In [None]:
model = mc()

model_constructor
by ayasyrev