# Twist.

> Create and tune models with Twist layers.

In [1]:
#hide
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import OrderedDict

from model_constructor.net import Net
from model_constructor.layers import ConvLayer, noop, act_fn, SimpleSelfAttention

## ConvTwist

In [2]:
class ConvTwist(nn.Module):
    '''Replacement for Conv2d (kernelsize 3x3)'''
    permute = True
    twist = False
    use_groups = True
    groups_ch = 8

    def __init__(self, ni, nf,
                 ks=3, stride=1, padding=1, bias=False,
                 groups=1, iters=1, init_max=0.7, **kvargs):
        super().__init__()
        self.same = ni == nf and stride == 1
        self.groups = ni // self.groups_ch if self.use_groups else 1
        self.conv = nn.Conv2d(ni, nf, kernel_size=3, stride=stride, padding=1, bias=False, groups=self.groups)
        if self.twist:
            std = self.conv.weight.std().item()
            self.coeff_Ax = nn.Parameter(torch.empty((nf, ni // groups)).normal_(0, std), requires_grad=True)
            self.coeff_Ay = nn.Parameter(torch.empty((nf, ni // groups)).normal_(0, std), requires_grad=True)
        self.iters = iters
        self.stride = stride
        self.DD = self.derivatives()

    def derivatives(self):
        I = torch.Tensor([[0, 0, 0], [0, 1, 0], [0, 0, 0]]).view(1, 1, 3, 3)   # noqa E741
        D_x = torch.Tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).view(1, 1, 3, 3) / 10
        D_y = torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).view(1, 1, 3, 3) / 10

        def convolution(K1, K2):
            return F.conv2d(K1, K2.flip(2).flip(3), padding=2)
        D_xx = convolution(I + D_x, I + D_x).view(5, 5)
        D_yy = convolution(I + D_y, I + D_y).view(5, 5)
        D_xy = convolution(I + D_x, I + D_y).view(5, 5)
        return {'x': D_x, 'y': D_y, 'xx': D_xx, 'yy': D_yy, 'xy': D_xy}

    def kernel(self, coeff_x, coeff_y):
        D_x = torch.Tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).to(coeff_x.device)
        D_y = torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).to(coeff_x.device)
        return coeff_x[:, :, None, None] * D_x + coeff_y[:, :, None, None] * D_y

    def full_kernel(self, kernel):  # permuting the groups
        if self.groups == 1:
            return kernel
        n = self.groups
        a, b, _, _ = kernel.size()
        a = a // n
        KK = torch.zeros((a * n, b * n, 3, 3)).to(kernel.device)
        for i in range(n):
            if i % 4 == 0:
                KK[a * i:a * (i + 1), b * (i + 3):b * (i + 4)] = kernel[a * i:a * (i + 1)]
            else:
                KK[a * i:a * (i + 1), b * (i - 1):b * i] = kernel[a * i:a * (i + 1)]
        return KK

    def _conv(self, inpt, kernel=None):
        if kernel is None:
            kernel = self.conv.weight
        if self.permute is False:
            return F.conv2d(inpt, kernel, padding=1, stride=self.stride, groups=self.groups)
        else:
            return F.conv2d(inpt, self.full_kernel(kernel), padding=1, stride=self.stride, groups=1)

    def symmetrize(self, conv_wt):
        if self.same:
            n = conv_wt.size()[1]
            for i in range(self.groups):
                conv_wt.data[n * i:n * (i + 1)] = (conv_wt[n * i:n * (i + 1)]
                                                   + torch.transpose(conv_wt[n * i:n * (i + 1)], 0, 1)) / 2  # noqa E503

    def forward(self, inpt):
        out = self._conv(inpt)
        if self.twist is False:
            return out
        _, _, h, w = out.size()
        XX = torch.from_numpy(np.indices((1, 1, h, w))[3] * 2 / w - 1).type(out.dtype).to(out.device)
        YY = torch.from_numpy(np.indices((1, 1, h, w))[2] * 2 / h - 1).type(out.dtype).to(out.device)
        kernel_x = self.kernel(self.coeff_Ax, self.coeff_Ay)
        self.symmetrize(kernel_x)
        kernel_y = kernel_x.transpose(2, 3).flip(3)  # make conv_y a 90 degree rotation of conv_x
        out = out + XX * self._conv(inpt, kernel_x) + YY * self._conv(inpt, kernel_y)
        if self.same and self.iters > 1:
            out = inpt + out / self.iters
            for _ in range(self.iters - 1):
                out = out + (self._conv(out) + XX * self._conv(out, kernel_x)
                                             + YY * self._conv(out, kernel_y)) / self.iters  # noqa E727
            out = out - inpt
        return out

    def extra_repr(self):
        return f"twist: {self.twist}, permute: {self.permute}, same: {self.same}, groups: {self.groups}"

In [None]:
ConvTwist(64,64)

ConvTwist(
  twist: False, permute: True, same: True, groups: 8
  (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)
)

In [None]:
ConvTwist.twist, ConvTwist.permute

(False, True)

In [None]:
ConvTwist.use_groups, ConvTwist.groups_ch

(True, 8)

In [None]:
ConvTwist(64,64)

ConvTwist(
  twist: False, permute: True, same: True, groups: 8
  (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)
)

In [None]:
ConvTwist.twist = True
ConvTwist.permute = False
ConvTwist(64,64)

ConvTwist(
  twist: True, permute: False, same: True, groups: 8
  (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)
)

## ConvLayerTwist

In [None]:
class ConvLayerTwist(ConvLayer):  # replace Conv2d by Twist
    Conv2d = ConvTwist

In [None]:
ConvLayerTwist(64,64, stride=1)

ConvLayerTwist(
  (conv): ConvTwist(
    twist: True, permute: False, same: True, groups: 8
    (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)
  )
  (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act_fn): ReLU(inplace=True)
)

In [None]:
ConvLayer.Conv2d

torch.nn.modules.conv.Conv2d

In [None]:
ConvLayerTwist.Conv2d

__main__.ConvTwist

In [None]:
conv_layer = ConvLayerTwist(32, 64)
conv_layer

ConvLayerTwist(
  (conv): ConvTwist(
    twist: True, permute: False, same: False, groups: 4
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4, bias=False)
  )
  (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act_fn): ReLU(inplace=True)
)

In [None]:
ConvTwist.twist = False
conv_layer = ConvLayerTwist(32, 64)
conv_layer

ConvLayerTwist(
  (conv): ConvTwist(
    twist: False, permute: False, same: False, groups: 4
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4, bias=False)
  )
  (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act_fn): ReLU(inplace=True)
)

In [None]:
conv_layer = ConvLayerTwist(32, 64, act=False)
conv_layer

ConvLayerTwist(
  (conv): ConvTwist(
    twist: False, permute: False, same: False, groups: 4
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4, bias=False)
  )
  (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [None]:
conv_layer = ConvLayerTwist(32, 64, bn_layer=False)
conv_layer

ConvLayerTwist(
  (conv): ConvTwist(
    twist: False, permute: False, same: False, groups: 4
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4, bias=False)
  )
  (act_fn): ReLU(inplace=True)
)

In [None]:
conv_layer = ConvLayerTwist(32, 64, bn_1st=True)
conv_layer

ConvLayerTwist(
  (conv): ConvTwist(
    twist: False, permute: False, same: False, groups: 4
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4, bias=False)
  )
  (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act_fn): ReLU(inplace=True)
)

In [None]:
conv_layer = ConvLayerTwist(32, 64, bn_1st=True, act_fn=nn.LeakyReLU())
conv_layer

ConvLayerTwist(
  (conv): ConvTwist(
    twist: False, permute: False, same: False, groups: 4
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4, bias=False)
  )
  (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act_fn): LeakyReLU(negative_slope=0.01)
)

In [None]:
conv_layer = ConvLayerTwist(32, 64, ks=1)
conv_layer

ConvLayerTwist(
  (conv): ConvTwist(
    twist: False, permute: False, same: False, groups: 4
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4, bias=False)
  )
  (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act_fn): ReLU(inplace=True)
)

In [None]:
conv_layer = ConvLayerTwist(32, 64, ks=1, stride=2)
conv_layer

ConvLayerTwist(
  (conv): ConvTwist(
    twist: False, permute: False, same: False, groups: 4
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=4, bias=False)
  )
  (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act_fn): ReLU(inplace=True)
)

In [None]:
conv_layer = ConvLayerTwist(32, 64, stride=2)
conv_layer

ConvLayerTwist(
  (conv): ConvTwist(
    twist: False, permute: False, same: False, groups: 4
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=4, bias=False)
  )
  (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act_fn): ReLU(inplace=True)
)

In [None]:
ConvTwist.groups_ch = 4
conv_layer = ConvLayerTwist(32, 64, stride=2)
conv_layer

ConvLayerTwist(
  (conv): ConvTwist(
    twist: False, permute: False, same: False, groups: 8
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=8, bias=False)
  )
  (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act_fn): ReLU(inplace=True)
)

## NewResBlockTwist

In [None]:
class NewResBlockTwist(nn.Module):
    def __init__(self, expansion, ni, nh, stride=1,
                 conv_layer=ConvLayer, act_fn=act_fn, bn_1st=True,
                 pool=nn.AvgPool2d(2, ceil_mode=True), sa=False, sym=False, zero_bn=True, **kvargs):
        super().__init__()
        nf, ni = nh * expansion, ni * expansion
        self.reduce = noop if stride == 1 else pool
        layers = [("conv_0", conv_layer(ni, nh, 3, act_fn=act_fn, bn_1st=bn_1st)),
                  ("conv_1", conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
                  ] if expansion == 1 else [
                      ("conv_0", conv_layer(ni, nh, 1, act_fn=act_fn, bn_1st=bn_1st)),
                      ("conv_1_twist", ConvLayerTwist(nh, nh, 3, act_fn=act_fn, bn_1st=bn_1st)),
                      ("conv_2", conv_layer(nh, nf, 1, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
        ]
        if sa:
            layers.append(('sa', SimpleSelfAttention(nf, ks=1, sym=sym)))
        self.convs = nn.Sequential(OrderedDict(layers))
        self.idconv = noop if ni == nf else conv_layer(ni, nf, 1, act=False, bn_1st=bn_1st)
        self.merge = act_fn

    def forward(self, x):
        o = self.reduce(x)
        return self.merge(self.convs(o) + self.idconv(o))

In [None]:
#collapse_output
bl = NewResBlockTwist(4,64,64,sa=True)
bl

NewResBlockTwist(
  (convs): Sequential(
    (conv_0): ConvLayer(
      (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act_fn): ReLU(inplace=True)
    )
    (conv_1_twist): ConvLayerTwist(
      (conv): ConvTwist(
        twist: False, permute: False, same: True, groups: 16
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
      )
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act_fn): ReLU(inplace=True)
    )
    (conv_2): ConvLayer(
      (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (sa): SimpleSelfAttention(
      (conv): Conv1d(256, 256, kernel_size=(1,), stride=(1,), bias=False)
    )
  )
  (merge): ReLU(inplace=True)
)

In [None]:
#hide
bs_test = 16
xb = torch.randn(bs_test, 256, 32, 32)
y = bl(xb)
print(y.shape)
assert y.shape == torch.Size([bs_test, 256, 32, 32]), f"size"

torch.Size([16, 256, 32, 32])


In [None]:
#collapse_output
bl = NewResBlockTwist(4,64,64,stride=2)
bl

NewResBlockTwist(
  (reduce): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (convs): Sequential(
    (conv_0): ConvLayer(
      (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act_fn): ReLU(inplace=True)
    )
    (conv_1_twist): ConvLayerTwist(
      (conv): ConvTwist(
        twist: False, permute: False, same: True, groups: 16
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
      )
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act_fn): ReLU(inplace=True)
    )
    (conv_2): ConvLayer(
      (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (merge): ReLU(inplace=True)
)

In [None]:
#hide
bs_test = 16
xb = torch.randn(bs_test, 256, 32, 32)
y = bl(xb)
print(y.shape)
assert y.shape == torch.Size([bs_test, 256, 16, 16]), f"size"

torch.Size([16, 256, 16, 16])


In [None]:
#collapse_output
bl = NewResBlockTwist(4,64,128,stride=2)
bl

NewResBlockTwist(
  (reduce): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (convs): Sequential(
    (conv_0): ConvLayer(
      (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act_fn): ReLU(inplace=True)
    )
    (conv_1_twist): ConvLayerTwist(
      (conv): ConvTwist(
        twist: False, permute: False, same: True, groups: 32
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      )
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act_fn): ReLU(inplace=True)
    )
    (conv_2): ConvLayer(
      (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (idconv): ConvLayer(
    (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), 

In [None]:
#hide
bs_test = 16
xb = torch.randn(bs_test, 256, 32, 32)
y = bl(xb)
print(y.shape)
assert y.shape == torch.Size([bs_test, 512, 16, 16]), f"size"

torch.Size([16, 512, 16, 16])


In [None]:
#hide
bl = NewResBlockTwist(1,64,64,sa=True)
bl

NewResBlockTwist(
  (convs): Sequential(
    (conv_0): ConvLayer(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act_fn): ReLU(inplace=True)
    )
    (conv_1): ConvLayer(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (sa): SimpleSelfAttention(
      (conv): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)
    )
  )
  (merge): ReLU(inplace=True)
)

In [None]:
#hide
bs_test = 16
xb = torch.randn(bs_test, 64, 32, 32)
y = bl(xb)
print(y.shape)
assert y.shape == torch.Size([bs_test, 64, 32, 32]), f"size"

torch.Size([16, 64, 32, 32])


In [None]:
#collapse_output
bl = NewResBlockTwist(4,64,128,stride=2,act_fn=nn.LeakyReLU(), bn_1st=False)
bl

NewResBlockTwist(
  (reduce): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (convs): Sequential(
    (conv_0): ConvLayer(
      (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (act_fn): LeakyReLU(negative_slope=0.01)
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv_1_twist): ConvLayerTwist(
      (conv): ConvTwist(
        twist: False, permute: False, same: True, groups: 32
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      )
      (act_fn): LeakyReLU(negative_slope=0.01)
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv_2): ConvLayer(
      (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (idconv): ConvLayer(
    (conv): Conv2d(256, 512, kernel_size

In [None]:
#hide
bs_test = 16
xb = torch.randn(bs_test, 256, 32, 32)
y = bl(xb)
print(y.shape)
assert y.shape == torch.Size([bs_test, 512, 16, 16]), f"size"

torch.Size([16, 512, 16, 16])


## ResBlockTwist

In [None]:
class ResBlockTwist(nn.Module):
    def __init__(self, expansion, ni, nh, stride=1,
                 conv_layer=ConvLayer, act_fn=act_fn, zero_bn=True, bn_1st=True,
                 pool=nn.AvgPool2d(2, ceil_mode=True), sa=False, sym=False, **kvargs):
        super().__init__()
        nf, ni = nh * expansion, ni * expansion
        layers = [("conv_0", conv_layer(ni, nh, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st)),
                  ("conv_1", conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
                  ] if expansion == 1 else [
                      ("conv_0", conv_layer(ni, nh, 1, act_fn=act_fn, bn_1st=bn_1st)),
                      ("conv_1_twist", ConvLayerTwist(nh, nh, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st)),
                      ("conv_2", conv_layer(nh, nf, 1, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
        ]
        if sa:
            layers.append(('sa', SimpleSelfAttention(nf, ks=1, sym=sym)))
        self.convs = nn.Sequential(OrderedDict(layers))
        self.pool = noop if stride == 1 else pool
        self.idconv = noop if ni == nf else conv_layer(ni, nf, 1, act=False)
        self.act_fn = act_fn

    def forward(self, x):
        return self.act_fn(self.convs(x) + self.idconv(self.pool(x)))

In [None]:
#collapse_output
bl = ResBlockTwist(4,64,64,sa=True)
bl

ResBlockTwist(
  (convs): Sequential(
    (conv_0): ConvLayer(
      (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act_fn): ReLU(inplace=True)
    )
    (conv_1_twist): ConvLayerTwist(
      (conv): ConvTwist(
        twist: False, permute: False, same: True, groups: 16
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
      )
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act_fn): ReLU(inplace=True)
    )
    (conv_2): ConvLayer(
      (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (sa): SimpleSelfAttention(
      (conv): Conv1d(256, 256, kernel_size=(1,), stride=(1,), bias=False)
    )
  )
  (act_fn): ReLU(inplace=True)
)

In [None]:
#hide
bs_test = 16
xb = torch.randn(bs_test, 256, 32, 32)
y = bl(xb)
print(y.shape)
assert y.shape == torch.Size([bs_test, 256, 32, 32]), f"size"

torch.Size([16, 256, 32, 32])


In [None]:
#collapse_output
bl = ResBlockTwist(4,64,64,stride=2)
bl

ResBlockTwist(
  (convs): Sequential(
    (conv_0): ConvLayer(
      (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act_fn): ReLU(inplace=True)
    )
    (conv_1_twist): ConvLayerTwist(
      (conv): ConvTwist(
        twist: False, permute: False, same: False, groups: 16
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
      )
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act_fn): ReLU(inplace=True)
    )
    (conv_2): ConvLayer(
      (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (act_fn): ReLU(inplace=True)
)

In [None]:
#hide
bs_test = 16
xb = torch.randn(bs_test, 256, 32, 32)
y = bl(xb)
print(y.shape)
assert y.shape == torch.Size([bs_test, 256, 16, 16]), f"size"

torch.Size([16, 256, 16, 16])


In [None]:
#collapse_output
bl = ResBlockTwist(4,64,128,stride=2)
bl

ResBlockTwist(
  (convs): Sequential(
    (conv_0): ConvLayer(
      (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act_fn): ReLU(inplace=True)
    )
    (conv_1_twist): ConvLayerTwist(
      (conv): ConvTwist(
        twist: False, permute: False, same: False, groups: 32
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
      )
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act_fn): ReLU(inplace=True)
    )
    (conv_2): ConvLayer(
      (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (idconv): ConvLayer(
    (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias

In [None]:
#hide
bs_test = 16
xb = torch.randn(bs_test, 256, 32, 32)
y = bl(xb)
print(y.shape)
assert y.shape == torch.Size([bs_test, 512, 16, 16]), f"size"

torch.Size([16, 512, 16, 16])


## Model

In [None]:
model  = Net(expansion=4, layers=[3,4,6,3])

In [None]:
model.block = NewResBlockTwist

In [None]:
#collapse_output
model.body

Sequential(
  (l_0): Sequential(
    (bl_0): NewResBlockTwist(
      (convs): Sequential(
        (conv_0): ConvLayer(
          (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act_fn): ReLU(inplace=True)
        )
        (conv_1_twist): ConvLayerTwist(
          (conv): ConvTwist(
            twist: False, permute: False, same: True, groups: 16
            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
          )
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act_fn): ReLU(inplace=True)
        )
        (conv_2): ConvLayer(
          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (idconv): ConvLayer(
    

In [None]:
#hide
bs_test = 16
xb = torch.randn(bs_test, 3, 128, 128)
y = model.stem(xb)
print(y.shape)
assert y.shape == torch.Size([bs_test, 64, 32, 32]), f"size"

torch.Size([16, 64, 32, 32])


In [None]:
#hide
bs_test = 16
xb = torch.randn(bs_test, 64, 32, 32)
y = model.body.l_0(xb)
print(y.shape)
assert y.shape == torch.Size([bs_test, 256, 32, 32]), f"size"

torch.Size([16, 256, 32, 32])


In [None]:
model.block = ResBlockTwist

In [None]:
#hide
bs_test = 16
xb = torch.randn(bs_test, 64, 32, 32)
y = model.body.l_0(xb)
print(y.shape)
assert y.shape == torch.Size([bs_test, 256, 32, 32]), f"size"

torch.Size([16, 256, 32, 32])


In [None]:
m = model()

In [None]:
#collapse_output
m

Sequential(
  model Net
  (stem): Sequential(
    (conv_0): ConvLayer(
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act_fn): ReLU(inplace=True)
    )
    (conv_1): ConvLayer(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act_fn): ReLU(inplace=True)
    )
    (conv_2): ConvLayer(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act_fn): ReLU(inplace=True)
    )
    (stem_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (body): Sequential(
    (l_0): Sequential(
      (bl_0): ResBlockTwist(
        (convs): Sequential(
      

In [None]:
#hide
bs_test = 16
xb = torch.randn(bs_test, 3, 128, 128)
y = m(xb)
print(y.shape)
assert y.shape == torch.Size([bs_test, 1000]), f"size expected {bs_test}, 1000"

torch.Size([16, 1000])


In [None]:
#collapse_output
m.stem

Sequential(
  (conv_0): ConvLayer(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act_fn): ReLU(inplace=True)
  )
  (conv_1): ConvLayer(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act_fn): ReLU(inplace=True)
  )
  (conv_2): ConvLayer(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act_fn): ReLU(inplace=True)
  )
  (stem_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)

In [None]:
#collapse_output
m.head

Sequential(
  (pool): AdaptiveAvgPool2d(output_size=1)
  (flat): Flatten()
  (fc): Linear(in_features=2048, out_features=1000, bias=True)
)

In [None]:
#collapse_output
m.body.l_0

Sequential(
  (bl_0): ResBlockTwist(
    (convs): Sequential(
      (conv_0): ConvLayer(
        (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act_fn): ReLU(inplace=True)
      )
      (conv_1_twist): ConvLayerTwist(
        (conv): ConvTwist(
          twist: False, permute: False, same: True, groups: 16
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
        )
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act_fn): ReLU(inplace=True)
      )
      (conv_2): ConvLayer(
        (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (idconv): ConvLayer(
      (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=

In [None]:
#collapse_output
m.body.l_1

Sequential(
  (bl_0): ResBlockTwist(
    (convs): Sequential(
      (conv_0): ConvLayer(
        (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act_fn): ReLU(inplace=True)
      )
      (conv_1_twist): ConvLayerTwist(
        (conv): ConvTwist(
          twist: False, permute: False, same: False, groups: 32
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
        )
        (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act_fn): ReLU(inplace=True)
      )
      (conv_2): ConvLayer(
        (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (idconv): ConvLayer(
    

In [None]:
#collapse_output
m.body.l_2

Sequential(
  (bl_0): ResBlockTwist(
    (convs): Sequential(
      (conv_0): ConvLayer(
        (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act_fn): ReLU(inplace=True)
      )
      (conv_1_twist): ConvLayerTwist(
        (conv): ConvTwist(
          twist: False, permute: False, same: False, groups: 64
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64, bias=False)
        )
        (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act_fn): ReLU(inplace=True)
      )
      (conv_2): ConvLayer(
        (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (idconv): ConvLayer(
  

In [None]:
#collapse_output
m.body.l_3

Sequential(
  (bl_0): ResBlockTwist(
    (convs): Sequential(
      (conv_0): ConvLayer(
        (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act_fn): ReLU(inplace=True)
      )
      (conv_1_twist): ConvLayerTwist(
        (conv): ConvTwist(
          twist: False, permute: False, same: False, groups: 128
          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=128, bias=False)
        )
        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act_fn): ReLU(inplace=True)
      )
      (conv_2): ConvLayer(
        (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (idconv): ConvLayer(

# end
model_constructor
by ayasyrev