In [None]:
# default_exp models.ResNet

# ResNet

> This is an unofficial PyTorch implementation by Ignacio Oguiza - oguiza@gmail.com based on:
* Wang, Z., Yan, W., & Oates, T. (2017, May). Time series classification from scratch with deep neural networks: A strong baseline. In 2017 international joint conference on neural networks (IJCNN) (pp. 1578-1585). IEEE.
* Fawaz, H. I., Forestier, G., Weber, J., Idoumghar, L., & Muller, P. A. (2019). Deep learning for time series classification: a review. Data Mining and Knowledge Discovery, 33(4), 917-963.
* Official ResNet TensorFlow implementation: https://github.com/hfawaz/dl-4-tsc
* 👀 kernel filter size 8 has been replaced by 7 (I believe it's a bug since even kernels are not commonly used in practice)

In [None]:
#export
from tsai.imports import *
from tsai.models.layers import *

In [None]:
#export
class ResBlock(Module):
    def __init__(self, ni, nf, ks=[7, 5, 3]):
        self.conv1 = Conv1d(ni, nf, ks[0], padding='same', act_fn='relu')
        self.conv2 = Conv1d(nf, nf, ks[1], padding='same', act_fn='relu')
        self.conv3 = Conv1d(nf, nf, ks[2], padding='same', act_fn='relu')

        # expand channels for the sum if necessary
        self.shortcut = noop if ni == nf else Conv1d(ni, nf, ks=1, act_fn=False)
        self.act_fn = nn.ReLU()

    def forward(self, x):
        res = x
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        sc = self.shortcut(res)
        x += sc
        x = self.act_fn(x)
        return x
    
class ResNet(Module):
    def __init__(self,c_in, c_out):
        nf = 64
        self.block1 = ResBlock(c_in, nf, ks=[7, 5, 3])
        self.block2 = ResBlock(nf, nf * 2, ks=[7, 5, 3])
        self.block3 = ResBlock(nf * 2, nf * 2, ks=[7, 5, 3])
        self.gap = nn.AdaptiveAvgPool1d(1)
        self.squeeze = Squeeze(-1)
        self.fc = nn.Linear(nf * 2, c_out)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.squeeze(self.gap(x))
        return self.fc(x)

In [None]:
xb = torch.rand(16, 3, 128)
test_eq(ResNet(3,2)(xb).shape, [xb.shape[0], 2])
ResNet(3,2)

ResNet(
  (block1): ResBlock(
    (conv1): Sequential(
      (0): ConvSP1d(
        (conv): Conv1d(3, 64, kernel_size=(7,), stride=(1,))
      )
      (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (conv2): Sequential(
      (0): ConvSP1d(
        (conv): Conv1d(64, 64, kernel_size=(5,), stride=(1,))
      )
      (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (conv3): Sequential(
      (0): ConvSP1d(
        (conv): Conv1d(64, 64, kernel_size=(3,), stride=(1,))
      )
      (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (shortcut): Sequential(
      (0): ConvSP1d(
        (conv): Conv1d(3, 64, kernel_size=(1,), stride=(1,))
      )
      (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (act_fn): ReLU()
  )
  (block2): ResBlock(
    (conv1): Seq

In [None]:
#hide
out = create_scripts()
beep(out)