In [None]:
#|default_exp models.XceptionTimePlus

# XceptionTimePlus

This is an unofficial PyTorch implementation by Ignacio Oguiza - oguiza@timeseriesAI.co
modified on:

Fawaz, H. I., Lucas, B., Forestier, G., Pelletier, C., Schmidt, D. F., Weber, J. & Petitjean, F. (2019). 
<span style="color:dodgerblue">**InceptionTime: Finding AlexNet for Time Series Classification**</span>.
arXiv preprint arXiv:1909.04939. 

Official InceptionTime tensorflow implementation:
https://github.com/hfawaz/InceptionTime

In [None]:
#|export
from tsai.imports import *
from tsai.models.layers import *
from tsai.models.utils import *

In [None]:
#|export
# This is an unofficial PyTorch implementation developed by Ignacio Oguiza - oguiza@timeseriesAI.co modified on:
# Rahimian, E., Zabihi, S., Atashzar, S. F., Asif, A., & Mohammadi, A. (2019). 
# XceptionTime: A Novel Deep Architecture based on Depthwise Separable Convolutions for Hand Gesture Classification. arXiv preprint arXiv:1911.03803.
# and 
# Fawaz, H. I., Lucas, B., Forestier, G., Pelletier, C., Schmidt, D. F., Weber, J., ... & Petitjean, F. (2019). 
# InceptionTime: Finding AlexNet for Time Series Classification. arXiv preprint arXiv:1909.04939.
# Official InceptionTime tensorflow implementation: https://github.com/hfawaz/InceptionTime

# added bn and relu (not in XceptionTimeModule)


    
class XceptionModulePlus(Module):
    def __init__(self, ni, nf, ks=40, kss=None, bottleneck=True, coord=False, separable=True, norm='Batch', zero_norm=False, bn_1st=True, 
                 act=nn.ReLU, act_kwargs={}, norm_act=False):
        if kss is None: kss = [ks // (2**i) for i in range(3)]
        kss = [ksi if ksi % 2 != 0 else ksi - 1 for ksi in kss]  # ensure odd kss for padding='same'
        self.bottleneck = Conv(ni, nf, 1, coord=coord, bias=False) if bottleneck else noop
        self.convs = nn.ModuleList()
        for i in range(len(kss)): self.convs.append(Conv(nf if bottleneck else ni, nf, kss[i], coord=coord, separable=separable, bias=False))
        self.mp_conv = nn.Sequential(*[nn.MaxPool1d(3, stride=1, padding=1), Conv(ni, nf, 1, coord=coord, bias=False)])
        self.concat = Concat()
        _norm_act = []
        if act is not None: _norm_act.append(act(**act_kwargs))
        _norm_act.append(Norm(nf * 4, norm=norm, zero_norm=zero_norm))
        if bn_1st: _norm_act.reverse()
        self.norm_act = noop if not norm_act else _norm_act[0] if act is None else nn.Sequential(*_norm_act)

    def forward(self, x):
        input_tensor = x
        x = self.bottleneck(x)
        x = self.concat([l(x) for l in self.convs] + [self.mp_conv(input_tensor)])
        return self.norm_act(x) 
    

@delegates(XceptionModulePlus.__init__)
class XceptionBlockPlus(Module):
    def __init__(self, ni, nf, residual=True, coord=False, norm='Batch', zero_norm=False, act=nn.ReLU, act_kwargs={}, **kwargs):
        self.residual = residual
        self.xception, self.shortcut, self.act = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
        for i in range(4):
            if self.residual and (i-1) % 2 == 0: 
                self.shortcut.append(Norm(n_in, norm=norm) if n_in == n_out else 
                                     ConvBlock(n_in, n_out * 4 * 2, 1, coord=coord, bias=False, norm=norm, act=None))
                self.act.append(act(**act_kwargs))
            n_out = nf * 2 ** i
            n_in = ni if i == 0 else n_out * 2
            self.xception.append(XceptionModulePlus(n_in, n_out, coord=coord, 
                                                    norm=norm, zero_norm=zero_norm if self.residual and (i-1) % 2 == 0 else False, 
                                                    act=act if self.residual and (i-1) % 2 == 0 else None, **kwargs))
        self.add = Add()
        
    def forward(self, x):
        res = x
        for i in range(4):
            x = self.xception[i](x)
            if self.residual and (i + 1) % 2 == 0: res = x = self.act[i//2](self.add(x, self.shortcut[i//2](res)))
        return x
    
    
@delegates(XceptionBlockPlus.__init__)
class XceptionTimePlus(nn.Sequential):
    def __init__(self, c_in, c_out, seq_len=None, nf=16, nb_filters=None, coord=False, norm='Batch', concat_pool=False, adaptive_size=50, 
                 custom_head=None, **kwargs):

        nf = ifnone(nf, nb_filters)
        # Backbone
        backbone = XceptionBlockPlus(c_in, nf, coord=coord, norm=norm, **kwargs)
        
        # Head
        gap1 = AdaptiveConcatPool1d(adaptive_size) if adaptive_size and concat_pool else nn.AdaptiveAvgPool1d(adaptive_size) if adaptive_size else noop
        mult = 2 if adaptive_size and concat_pool else 1
        conv1x1_1 = ConvBlock(nf * 32 * mult, nf * 16 * mult, 1, coord=coord, norm=norm)
        conv1x1_2 = ConvBlock(nf * 16 * mult, nf * 8 * mult, 1, coord=coord, norm=norm)
        conv1x1_3 = ConvBlock(nf * 8 * mult, c_out, 1, coord=coord, norm=norm)
        gap2 = GAP1d(1)
        self.head_nf = nf * 32 * mult
        self.seq_len = seq_len
        if custom_head is not None: 
            if isinstance(custom_head, nn.Module): head = custom_head
            else: head = custom_head(self.head_nf, c_out, seq_len)
        else: head = nn.Sequential(gap1, conv1x1_1, conv1x1_2, conv1x1_3, gap2)
        
        super().__init__(OrderedDict([('backbone', backbone), ('head', head)]))

In [None]:
bs = 16
vars = 3
seq_len = 12
c_out = 2
xb = torch.rand(bs, vars, seq_len)

In [None]:
test_eq(XceptionTimePlus(vars,c_out)(xb).shape, [bs, c_out])
test_eq(XceptionTimePlus(vars,c_out, nf=32)(xb).shape, [bs, c_out])
test_eq(XceptionTimePlus(vars,c_out, bottleneck=False)(xb).shape, [bs, c_out])
test_eq(XceptionTimePlus(vars,c_out, residual=False)(xb).shape, [bs, c_out])
test_eq(XceptionTimePlus(vars,c_out, coord=True)(xb).shape, [bs, c_out])
test_eq(XceptionTimePlus(vars,c_out, concat_pool=True)(xb).shape, [bs, c_out])
test_eq(count_parameters(XceptionTimePlus(3, 2)), 399540)

In [None]:
m = XceptionTimePlus(2,3)
test_eq(check_weight(m, is_bn)[0].sum(), 5)
test_eq(len(check_bias(m, is_conv)[0]), 0)
m = XceptionTimePlus(2,3, zero_norm=True)
test_eq(check_weight(m, is_bn)[0].sum(), 5)
m = XceptionTimePlus(2,3, zero_norm=True, norm_act=True)
test_eq(check_weight(m, is_bn)[0].sum(), 7)

In [None]:
m = XceptionTimePlus(2,3, coord=True)
test_eq(len(get_layers(m, cond=is_layer(AddCoords1d))), 25)
test_eq(len(get_layers(m, cond=is_layer(nn.Conv1d))), 37)
m = XceptionTimePlus(2,3, bottleneck=False, coord=True)
test_eq(len(get_layers(m, cond=is_layer(AddCoords1d))), 21)
test_eq(len(get_layers(m, cond=is_layer(nn.Conv1d))), 33)

In [None]:
m = XceptionTimePlus(vars, c_out, seq_len=seq_len, custom_head=mlp_head)
test_eq(m(xb).shape, [bs, c_out])

In [None]:
XceptionTimePlus(vars, c_out, coord=True)

XceptionTimePlus(
  (backbone): XceptionBlockPlus(
    (xception): ModuleList(
      (0): XceptionModulePlus(
        (bottleneck): ConvBlock(
          (0): AddCoords1d()
          (1): Conv1d(4, 16, kernel_size=(1,), stride=(1,), bias=False)
        )
        (convs): ModuleList(
          (0): ConvBlock(
            (0): AddCoords1d()
            (1): SeparableConv1d(
              (depthwise_conv): Conv1d(17, 17, kernel_size=(39,), stride=(1,), padding=(19,), groups=17, bias=False)
              (pointwise_conv): Conv1d(17, 16, kernel_size=(1,), stride=(1,), bias=False)
            )
          )
          (1): ConvBlock(
            (0): AddCoords1d()
            (1): SeparableConv1d(
              (depthwise_conv): Conv1d(17, 17, kernel_size=(19,), stride=(1,), padding=(9,), groups=17, bias=False)
              (pointwise_conv): Conv1d(17, 16, kernel_size=(1,), stride=(1,), bias=False)
            )
          )
          (2): ConvBlock(
            (0): AddCoords1d()
            (

In [None]:
#|eval: false
#|hide
from tsai.export import get_nb_name; nb_name = get_nb_name(locals())
from tsai.imports import create_scripts; create_scripts(nb_name)

<IPython.core.display.Javascript object>

/Users/nacho/notebooks/tsai/nbs/045_models.XceptionTimePlus.ipynb saved at 2023-03-19 14:28:28
Correct notebook to script conversion! 😃
Sunday 19/03/23 14:28:31 CET
