In [None]:
#|default_exp models.XceptionTime

# XceptionTime

This is an unofficial PyTorch implementation by Ignacio Oguiza - oguiza@timeseriesAI.co
modified on:

Fawaz, H. I., Lucas, B., Forestier, G., Pelletier, C., Schmidt, D. F., Weber, J. & Petitjean, F. (2019). 
<span style="color:dodgerblue">**InceptionTime: Finding AlexNet for Time Series Classification**</span>.
arXiv preprint arXiv:1909.04939. 

Official InceptionTime tensorflow implementation:
https://github.com/hfawaz/InceptionTime

In [None]:
#|export
from tsai.imports import *
from tsai.models.layers import *
from tsai.models.utils import *

In [None]:
#|export
# This is an unofficial PyTorch implementation developed by Ignacio Oguiza - oguiza@timeseriesAI.co based on:
# Rahimian, E., Zabihi, S., Atashzar, S. F., Asif, A., & Mohammadi, A. (2019). 
# XceptionTime: A Novel Deep Architecture based on Depthwise Separable Convolutions for Hand Gesture Classification. arXiv preprint arXiv:1911.03803.
# and 
# Fawaz, H. I., Lucas, B., Forestier, G., Pelletier, C., Schmidt, D. F., Weber, J., ... & Petitjean, F. (2019). 
# InceptionTime: Finding AlexNet for Time Series Classification. arXiv preprint arXiv:1909.04939.
# Official InceptionTime tensorflow implementation: https://github.com/hfawaz/InceptionTime
    
class XceptionModule(Module):
    def __init__(self, ni, nf, ks=40, bottleneck=True):
        ks = [ks // (2**i) for i in range(3)]
        ks = [k if k % 2 != 0 else k - 1 for k in ks]  # ensure odd ks
        self.bottleneck = Conv1d(ni, nf, 1, bias=False) if bottleneck else noop
        self.convs = nn.ModuleList([SeparableConv1d(nf if bottleneck else ni, nf, k, bias=False) for k in ks])
        self.maxconvpool = nn.Sequential(*[nn.MaxPool1d(3, stride=1, padding=1), Conv1d(ni, nf, 1, bias=False)])
        self.concat = Concat()

    def forward(self, x):
        input_tensor = x
        x = self.bottleneck(input_tensor)
        x = self.concat([l(x) for l in self.convs] + [self.maxconvpool(input_tensor)])
        return x

    
@delegates(XceptionModule.__init__)
class XceptionBlock(Module):
    def __init__(self, ni, nf, residual=True, **kwargs):
        self.residual = residual
        self.xception, self.shortcut = nn.ModuleList(), nn.ModuleList()
        for i in range(4):
            if self.residual and (i-1) % 2 == 0: self.shortcut.append(BN1d(n_in) if n_in == n_out else ConvBlock(n_in, n_out * 4 * 2, 1, act=None))
            n_out = nf * 2 ** i
            n_in = ni if i == 0 else n_out * 2
            self.xception.append(XceptionModule(n_in, n_out, **kwargs))
        self.add = Add()
        self.act = nn.ReLU()
        
    def forward(self, x):
        res = x
        for i in range(4):
            x = self.xception[i](x)
            if self.residual and (i + 1) % 2 == 0: res = x = self.act(self.add(x, self.shortcut[i//2](res)))
        return x
    
    
@delegates(XceptionBlock.__init__)
class XceptionTime(Module):
    def __init__(self, c_in, c_out, nf=16, nb_filters=None, adaptive_size=50, **kwargs):
        nf = ifnone(nf, nb_filters)
        self.block = XceptionBlock(c_in, nf, **kwargs)
        self.head_nf = nf * 32
        self.head = nn.Sequential(nn.AdaptiveAvgPool1d(adaptive_size), 
                                  ConvBlock(self.head_nf, self.head_nf//2, 1), 
                                  ConvBlock(self.head_nf//2, self.head_nf//4, 1), 
                                  ConvBlock(self.head_nf//4, c_out, 1), 
                                  GAP1d(1))

    def forward(self, x):
        x = self.block(x)
        x = self.head(x)
        return x

In [None]:
bs = 16
vars = 3
seq_len = 12
c_out = 6
xb = torch.rand(bs, vars, seq_len)
test_eq(XceptionTime(vars,c_out)(xb).shape, [bs, c_out])
test_eq(XceptionTime(vars,c_out, bottleneck=False)(xb).shape, [bs, c_out])
test_eq(XceptionTime(vars,c_out, residual=False)(xb).shape, [bs, c_out])
test_eq(count_parameters(XceptionTime(3, 2)), 399540)

In [None]:
m = XceptionTime(2,3)
test_eq(check_weight(m, is_bn)[0].sum(), 5) # 2 shortcut + 3 bn
test_eq(len(check_bias(m, is_conv)[0]), 0)
test_eq(len(check_bias(m)[0]), 5) # 2 shortcut + 3 bn

In [None]:
XceptionTime(3, 2)

XceptionTime(
  (block): XceptionBlock(
    (xception): ModuleList(
      (0): XceptionModule(
        (bottleneck): Conv1d(3, 16, kernel_size=(1,), stride=(1,), bias=False)
        (convs): ModuleList(
          (0): SeparableConv1d(
            (depthwise_conv): Conv1d(16, 16, kernel_size=(39,), stride=(1,), padding=(19,), groups=16, bias=False)
            (pointwise_conv): Conv1d(16, 16, kernel_size=(1,), stride=(1,), bias=False)
          )
          (1): SeparableConv1d(
            (depthwise_conv): Conv1d(16, 16, kernel_size=(19,), stride=(1,), padding=(9,), groups=16, bias=False)
            (pointwise_conv): Conv1d(16, 16, kernel_size=(1,), stride=(1,), bias=False)
          )
          (2): SeparableConv1d(
            (depthwise_conv): Conv1d(16, 16, kernel_size=(9,), stride=(1,), padding=(4,), groups=16, bias=False)
            (pointwise_conv): Conv1d(16, 16, kernel_size=(1,), stride=(1,), bias=False)
          )
        )
        (maxconvpool): Sequential(
          (0):

In [None]:
#|eval: false
#|hide
from tsai.export import get_nb_name; nb_name = get_nb_name(locals())
from tsai.imports import create_scripts; create_scripts(nb_name)

<IPython.core.display.Javascript object>

/Users/nacho/notebooks/tsai/nbs/106_models.XceptionTime.ipynb saved at 2022-11-09 13:06:02
Correct notebook to script conversion! 😃
Wednesday 09/11/22 13:06:04 CET
