In [None]:
#|default_exp models.XResNet1dPlus

# XResNet1dPlus

> This is a modified version of fastai's XResNet model in github

In [None]:
#|export
from tsai.imports import *
from tsai.models.layers import *
from tsai.models.utils import *

In [None]:
#|export
class XResNet1dPlus(nn.Sequential):
    @delegates(ResBlock1dPlus)
    def __init__(self, block=ResBlock1dPlus, expansion=4, layers=[3,4,6,3], fc_dropout=0.0, c_in=3, c_out=None, n_out=1000, seq_len=None, stem_szs=(32,32,64),
                 widen=1.0, sa=False, act_cls=defaults.activation, ks=3, stride=2, coord=False, custom_head=None, block_szs_base=(64,128,256,512), **kwargs):

        store_attr('block,expansion,act_cls,ks')
        n_out = c_out or n_out # added for compatibility
        if ks % 2 == 0: raise Exception('kernel size has to be odd!')
        stem_szs = [c_in, *stem_szs]
        stem = [ConvBlock(stem_szs[i], stem_szs[i+1], ks=ks, coord=coord, stride=stride if i==0 else 1,
                          act=act_cls)
                for i in range(3)]

        block_szs = [int(o*widen) for o in (list(block_szs_base) + [int(block_szs_base[-1]/2)]*(len(layers)-4))]
        block_szs = [64//expansion] + block_szs
        blocks    = self._make_blocks(layers, block_szs, sa, coord, stride, **kwargs)
        backbone = nn.Sequential(*stem, MaxPool(ks=ks, stride=stride, padding=ks//2, ndim=1), *blocks)
        self.head_nf = block_szs[-1]*expansion
        if custom_head is not None: 
            if isinstance(custom_head, nn.Module): head = custom_head
            else: head = custom_head(self.head_nf, n_out, seq_len)
        else: head = nn.Sequential(AdaptiveAvgPool(sz=1, ndim=1), Flatten(), nn.Dropout(fc_dropout), nn.Linear(block_szs[-1]*expansion, n_out))
        super().__init__(OrderedDict([('backbone', backbone), ('head', head)]))
        self._init_cnn(self)

    def _make_blocks(self, layers, block_szs, sa, coord, stride, **kwargs):
        return [self._make_layer(ni=block_szs[i], nf=block_szs[i+1], blocks=l, coord=coord, 
                                 stride=1 if i==0 else stride, sa=sa and i==len(layers)-4, **kwargs)
                for i,l in enumerate(layers)]

    def _make_layer(self, ni, nf, blocks, coord, stride, sa, **kwargs):
        return nn.Sequential(
            *[self.block(self.expansion, ni if i==0 else nf, nf, coord=coord, stride=stride if i==0 else 1,
                      sa=sa and i==(blocks-1), act_cls=self.act_cls, ks=self.ks, **kwargs)
              for i in range(blocks)])
    
    def _init_cnn(self, m):
        if getattr(self, 'bias', None) is not None: nn.init.constant_(self.bias, 0)
        if isinstance(self, (nn.Conv1d,nn.Conv2d,nn.Conv3d,nn.Linear)): nn.init.kaiming_normal_(self.weight)
        for l in m.children(): self._init_cnn(l)

In [None]:
#|export
def _xresnetplus(expansion, layers, c_in, c_out, seq_len=None, **kwargs):
    return XResNet1dPlus(ResBlock1dPlus, expansion, layers, c_in=c_in, c_out=c_out, seq_len=seq_len, **kwargs)

In [None]:
#|export
@delegates(ResBlock)
def xresnet1d18plus (c_in, c_out, seq_len=None, act=nn.ReLU, **kwargs): 
    return _xresnetplus(1, [2, 2,  2, 2], c_in, c_out, seq_len=seq_len, act_cls=act, **kwargs)
@delegates(ResBlock)
def xresnet1d34plus (c_in, c_out, seq_len=None, act=nn.ReLU, **kwargs): 
    return _xresnetplus(1, [3, 4,  6, 3], c_in, c_out, seq_len=seq_len, act_cls=act, **kwargs)
@delegates(ResBlock)
def xresnet1d50plus (c_in, c_out, seq_len=None, act=nn.ReLU, **kwargs): 
    return _xresnetplus(4, [3, 4,  6, 3], c_in, c_out, seq_len=seq_len, act_cls=act, **kwargs)
@delegates(ResBlock)
def xresnet1d101plus (c_in, c_out,seq_len=None,  act=nn.ReLU, **kwargs): 
    return _xresnetplus(4, [3, 4, 23, 3], c_in, c_out, seq_len=seq_len, act_cls=act, **kwargs)
@delegates(ResBlock)
def xresnet1d152plus (c_in, c_out, seq_len=None, act=nn.ReLU, **kwargs): 
    return _xresnetplus(4, [3, 8, 36, 3], c_in, c_out, seq_len=seq_len, act_cls=act, **kwargs)
@delegates(ResBlock)
def xresnet1d18_deepplus (c_in, c_out, seq_len=None, act=nn.ReLU, **kwargs): 
    return _xresnetplus(1, [2,2,2,2,1,1], c_in, c_out, seq_len=seq_len, act_cls=act, **kwargs)
@delegates(ResBlock)
def xresnet1d34_deepplus (c_in, c_out, seq_len=None, act=nn.ReLU, **kwargs): 
    return _xresnetplus(1, [3,4,6,3,1,1], c_in, c_out, seq_len=seq_len, act_cls=act, **kwargs)
@delegates(ResBlock)
def xresnet1d50_deepplus (c_in, c_out, seq_len=None, act=nn.ReLU, **kwargs): 
    return _xresnetplus(4, [3,4,6,3,1,1], c_in, c_out, seq_len=seq_len, act_cls=act, **kwargs)
@delegates(ResBlock)
def xresnet1d18_deeperplus (c_in, c_out, seq_len=None, act=nn.ReLU, **kwargs): 
    return _xresnetplus(1, [2,2,1,1,1,1,1,1], c_in, c_out, seq_len=seq_len, act_cls=act, **kwargs)
@delegates(ResBlock)
def xresnet1d34_deeperplus (c_in, c_out, seq_len=None, act=nn.ReLU, **kwargs): 
    return _xresnetplus(1, [3,4,6,3,1,1,1,1], c_in, c_out, seq_len=seq_len, act_cls=act, **kwargs)
@delegates(ResBlock)
def xresnet1d50_deeperplus (c_in, c_out, seq_len=None, act=nn.ReLU, **kwargs): 
    return _xresnetplus(4, [3,4,6,3,1,1,1,1], c_in, c_out, seq_len=seq_len, act_cls=act, **kwargs)

In [None]:
net = xresnet1d18plus(3, 2, coord=True)
x = torch.rand(32, 3, 50)
net(x)

block <class 'tsai.models.layers.ResBlock1dPlus'> expansion 1 layers [2, 2, 2, 2]


TensorBase([[ 0.1829,  0.3597],
            [ 0.0274, -0.1443],
            [ 0.0240, -0.2374],
            [-0.1323, -0.6574],
            [ 0.1481, -0.1438],
            [ 0.2410, -0.1225],
            [-0.1186, -0.1978],
            [-0.0640, -0.4547],
            [-0.0229, -0.3214],
            [ 0.2336, -0.4466],
            [-0.1843, -0.0934],
            [-0.0416,  0.1997],
            [-0.0109, -0.0253],
            [ 0.3014, -0.2193],
            [ 0.0966,  0.0602],
            [ 0.2364,  0.2209],
            [-0.1437, -0.1476],
            [ 0.0070, -0.2900],
            [ 0.2807,  0.4797],
            [-0.2386, -0.1563],
            [ 0.1620, -0.2285],
            [ 0.0479, -0.2348],
            [ 0.1573, -0.4420],
            [-0.5469,  0.1512],
            [ 0.0243, -0.1806],
            [ 0.3396,  0.1434],
            [ 0.0666, -0.1644],
            [ 0.3286, -0.5637],
            [ 0.0993, -0.6281],
            [-0.1068, -0.0763],
            [-0.2713,  0.1946],
        

In [None]:
bs, c_in, seq_len = 2, 4, 32
c_out = 2
x = torch.rand(bs, c_in, seq_len)
archs = [
    xresnet1d18plus, xresnet1d34plus, xresnet1d50plus, 
    xresnet1d18_deepplus, xresnet1d34_deepplus, xresnet1d50_deepplus, xresnet1d18_deeperplus,
    xresnet1d34_deeperplus, xresnet1d50_deeperplus
#     # Long test
#     xresnet1d101, xresnet1d152,
]
for i, arch in enumerate(archs):
    print(i, arch.__name__)
    test_eq(arch(c_in, c_out, sa=True, act=Mish, coord=True)(x).shape, (bs, c_out))

0 xresnet1d18plus
block <class 'tsai.models.layers.ResBlock1dPlus'> expansion 1 layers [2, 2, 2, 2]
1 xresnet1d34plus
block <class 'tsai.models.layers.ResBlock1dPlus'> expansion 1 layers [3, 4, 6, 3]
2 xresnet1d50plus
block <class 'tsai.models.layers.ResBlock1dPlus'> expansion 4 layers [3, 4, 6, 3]
3 xresnet1d18_deepplus
block <class 'tsai.models.layers.ResBlock1dPlus'> expansion 1 layers [2, 2, 2, 2, 1, 1]
4 xresnet1d34_deepplus
block <class 'tsai.models.layers.ResBlock1dPlus'> expansion 1 layers [3, 4, 6, 3, 1, 1]
5 xresnet1d50_deepplus
block <class 'tsai.models.layers.ResBlock1dPlus'> expansion 4 layers [3, 4, 6, 3, 1, 1]
6 xresnet1d18_deeperplus
block <class 'tsai.models.layers.ResBlock1dPlus'> expansion 1 layers [2, 2, 1, 1, 1, 1, 1, 1]
7 xresnet1d34_deeperplus
block <class 'tsai.models.layers.ResBlock1dPlus'> expansion 1 layers [3, 4, 6, 3, 1, 1, 1, 1]
8 xresnet1d50_deeperplus
block <class 'tsai.models.layers.ResBlock1dPlus'> expansion 4 layers [3, 4, 6, 3, 1, 1, 1, 1]


In [None]:
m = xresnet1d34plus(4, 2, act=Mish)
test_eq(len(get_layers(m, is_bn)), 38)
test_eq(check_weight(m, is_bn)[0].sum(), 22)

block <class 'tsai.models.layers.ResBlock1dPlus'> expansion 1 layers [3, 4, 6, 3]


In [None]:
#|eval: false
#|hide
from tsai.export import get_nb_name; nb_name = get_nb_name(locals())
from tsai.imports import create_scripts; create_scripts(nb_name)

<IPython.core.display.Javascript object>

/Users/nacho/notebooks/tsai/nbs/059_models.XResNet1dPlus.ipynb saved at 2023-03-26 16:07:44
Correct notebook to script conversion! 😃
Sunday 26/03/23 16:07:46 CEST
