In [2]:
from src.json_models.src.model_generator import ModelGenerator

In [15]:
gen = ModelGenerator(json_path="/home/andrewheschl/PycharmProjects/ClassificationPipeline/models/mobilenets/polys/xmodules/xmodule_poly_1_3_3_5.json")

In [16]:
model = gen.get_model()
print(gen.get_log_kwargs())

{'backbone': 'Mobilenetv2', 'conv': 'Poly with X[3, 5]', 'order': '1, 3', 'mode': 'sum'}


In [17]:
import torch
data = torch.ones((1, 3, 320, 320))
out = model(data)
out.shape

torch.Size([1, 7])

In [18]:
model

ModelBuilder(
  (self_modules): ModuleList(
    (0): MobileNetV2(
      (stem_conv): Sequential(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU6(inplace=True)
      )
      (layers): Sequential(
        (0): InvertedBlock(
          (layers): Sequential(
            (0): PolyWrapper(
              (branches): ModuleList(
                (0-1): 2 x PolyBlock(
                  (conv): XModule(
                    (branches): ModuleList(
                      (0): Sequential(
                        (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
                        (1): Conv2d(16, 16, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), groups=16)
                        (2): Conv2d(16, 16, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), groups=16)
                      )
                      (1): Sequential(
             

In [57]:
import torch.nn as nn


class DWSeperable(nn.Module):
    def __init__(self, in_channels, out_channels, stride, **kwargs):
        super().__init__()
        self.net = nn.Sequential(
            # depthwise
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=stride, groups=in_channels, bias=True),
            nn.BatchNorm2d(in_channels),
            nn.ReLU6(inplace=True),
            # point
            nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, stride=1, bias=True),
        )

    def forward(self, x):
        return self.net(x)

In [71]:
class XModule(nn.Module):
    """
    """

    def __init__(self, in_channels, out_channels, dilations=None, kernel_sizes=None, mode='concat', stride=1,
                 apply_norm: bool = False, gated=False, norm_op: str = 'batch', **kwargs):
        super(XModule, self).__init__()
        if kernel_sizes is None:
            kernel_sizes = [kwargs['kernel_size']]
        if dilations is None:
            dilations = [1 for i in range(len(kernel_sizes))]
        self.branches = nn.ModuleList()
        self.apply_norm = apply_norm

        # Picl the norm op
        if False:
            self.norm_op = nn.InstanceNorm2d if ModuleStateController.state == TWO_D else nn.InstanceNorm3d
        elif apply_norm:
            self.norm_op = nn.BatchNorm2d

        self.gated = gated

        assert len(dilations) == len(kernel_sizes)
        assert out_channels % len(dilations) == 0, f"Got out channels: {out_channels}"
        self.mode = mode

        for d, k in zip(dilations, kernel_sizes):
            assert (k - 1) % 2 == 0, "kernel sizes must be odd numbers"
            if True:
                branch = self._get_2d_branch(d, k, in_channels, out_channels, stride)
            else:
                if "heavy" in kwargs and kwargs['heavy']:
                    branch = self._get_3d_branch_heavy(d, k, in_channels, out_channels, stride)
                else:
                    branch = self._get_3d_branch(d, k, in_channels, out_channels, stride)

            self.branches.append(branch)

        if True:
            self.pw = nn.Sequential(
                nn.LeakyReLU(),
                nn.Conv2d(in_channels=in_channels * len(dilations), out_channels=out_channels,
                                                kernel_size=1)
            )

    def _get_2d_branch(self, d: int, k: int, in_channels: int, out_channels: int, stride: int) -> nn.Sequential:
        pad = (k - 1) // 2 * d
        first_op = nn.Conv2d
        branch = nn.Sequential(
            nn.Conv2d(in_channels, in_channels,
                      kernel_size=(1, k), dilation=d, padding=(0, pad),
                      groups=in_channels, stride=(stride, stride)),
            nn.Conv2d(in_channels, in_channels,
                      kernel_size=(k, 1), dilation=d, padding=(pad, 0),
                      groups=in_channels),
        )

        if self.apply_norm:
            branch.insert(1, self.norm_op(num_features=out_channels))
        return branch

    def forward(self, x):
        output = []
        for branch in self.branches:
            output.append(
                branch(x)
            )
        if True:
            return self.pw(torch.concat(output, dim=1))
        else:
            return torch.sum(output)

In [72]:
xmod = XModule(2,128 , [1], [3]) # 5
dw = DWSeperable(2, out_channels=128, stride=1) # 3

all_params_x = sum(param.numel() for param in xmod.parameters())
all_params_dw = sum(param.numel() for param in dw.parameters())

In [73]:
print(xmod)

XModule(
  (branches): ModuleList(
    (0): Sequential(
      (0): Conv2d(2, 2, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), groups=2)
      (1): Conv2d(2, 2, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), groups=2)
    )
  )
  (pw): Sequential(
    (0): LeakyReLU(negative_slope=0.01)
    (1): Conv2d(2, 128, kernel_size=(1, 1), stride=(1, 1))
  )
)


In [74]:
print(dw)

DWSeperable(
  (net): Sequential(
    (0): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2, bias=False)
    (1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU6(inplace=True)
    (3): Conv2d(2, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
)


In [83]:
print(all_params_dw)
print(all_params_x)

278
400
