In [1]:
# default_exp models.ResNet

# ResNet

> This is an unofficial PyTorch implementation by Ignacio Oguiza - oguiza@gmail.com based on:
* Wang, Z., Yan, W., & Oates, T. (2017, May). Time series classification from scratch with deep neural networks: A strong baseline. In 2017 international joint conference on neural networks (IJCNN) (pp. 1578-1585). IEEE.
* Fawaz, H. I., Forestier, G., Weber, J., Idoumghar, L., & Muller, P. A. (2019). Deep learning for time series classification: a review. Data Mining and Knowledge Discovery, 33(4), 917-963.
* Official ResNet TensorFlow implementation: https://github.com/hfawaz/dl-4-tsc
* 👀 kernel filter size 8 has been replaced by 7 (I believe it's a bug since even kernels are not commonly used in practice)

In [2]:
#export
from tsai.imports import *
from tsai.models.layers import *

In [3]:
#export
class ResBlock(Module):
    def __init__(self, ni, nf, ks=[7, 5, 3]):
        self.conv1 = Conv1d(ni, nf, ks[0], padding='same', act_fn='relu')
        self.conv2 = Conv1d(nf, nf, ks[1], padding='same', act_fn='relu')
        self.conv3 = Conv1d(nf, nf, ks[2], padding='same', act_fn='relu')

        # expand channels for the sum if necessary
        self.shortcut = noop if ni == nf else Conv1d(ni, nf, ks=1, act_fn=False)
        self.act_fn = nn.ReLU()

    def forward(self, x):
        res = x
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        sc = self.shortcut(res)
        x += sc
        x = self.act_fn(x)
        return x
    
class ResNet(Module):
    def __init__(self,c_in, c_out):
        nf = 64
        self.block1 = ResBlock(c_in, nf, ks=[7, 5, 3])
        self.block2 = ResBlock(nf, nf * 2, ks=[7, 5, 3])
        self.block3 = ResBlock(nf * 2, nf * 2, ks=[7, 5, 3])
        self.gap = nn.AdaptiveAvgPool1d(1)
        self.squeeze = Squeeze(-1)
        self.fc = nn.Linear(nf * 2, c_out)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.squeeze(self.gap(x))
        return self.fc(x)

In [7]:
xb = torch.rand(16, 3, 128)
test_eq(ResNet(3,2)(xb).shape, [xb.shape[0], 2])
#ResNet(3,2)

In [15]:
xb[0]

tensor([[0.9865, 0.5548, 0.7389, 0.3476, 0.9911, 0.0014, 0.4159, 0.8363, 0.0011,
         0.6296, 0.6042, 0.4246, 0.5059, 0.8719, 0.2354, 0.3774, 0.5241, 0.8139,
         0.1761, 0.9921, 0.2871, 0.7461, 0.5299, 0.2163, 0.2760, 0.5021, 0.0811,
         0.2028, 0.1220, 0.9873, 0.4096, 0.8561, 0.6109, 0.4858, 0.1473, 0.5019,
         0.7035, 0.8380, 0.1531, 0.4870, 0.1339, 0.5849, 0.1731, 0.5477, 0.1626,
         0.0249, 0.4211, 0.6531, 0.8525, 0.6950, 0.3249, 0.3112, 0.6661, 0.1858,
         0.9469, 0.3913, 0.0716, 0.4491, 0.8630, 0.3791, 0.6624, 0.6937, 0.6499,
         0.3963, 0.1686, 0.8500, 0.8669, 0.5998, 0.0078, 0.6559, 0.8382, 0.7067,
         0.7775, 0.0347, 0.6581, 0.2599, 0.3039, 0.1732, 0.5058, 0.9368, 0.6980,
         0.8964, 0.7503, 0.6020, 0.0961, 0.8276, 0.8767, 0.0688, 0.9516, 0.2784,
         0.5986, 0.7842, 0.0032, 0.8245, 0.5297, 0.9683, 0.3811, 0.6749, 0.0578,
         0.5384, 0.7529, 0.7207, 0.8251, 0.6882, 0.6980, 0.4773, 0.1411, 0.9799,
         0.7333, 0.2786, 0.2

In [16]:
ResNet(3,5)(xb)

tensor([[ 0.6744,  0.4375,  0.6877, -0.4842,  0.9892],
        [ 0.7888,  0.4568,  0.6765, -0.4159,  1.0348],
        [ 0.7784,  0.3607,  0.4974, -0.3997,  0.9241],
        [ 0.8264,  0.4616,  0.6380, -0.4084,  0.9800],
        [ 0.7506,  0.4645,  0.5883, -0.5441,  1.0015],
        [ 0.7124,  0.3300,  0.5315, -0.3056,  0.9942],
        [ 0.6774,  0.3914,  0.6092, -0.3552,  0.9568],
        [ 0.7407,  0.4239,  0.5703, -0.4153,  0.9262],
        [ 0.8377,  0.4500,  0.7103, -0.3953,  1.0411],
        [ 0.8016,  0.4299,  0.7018, -0.4330,  1.0033],
        [ 0.7553,  0.4850,  0.6103, -0.4496,  0.9860],
        [ 0.7719,  0.3752,  0.6187, -0.3252,  0.9986],
        [ 0.6993,  0.5414,  0.6271, -0.3626,  1.0545],
        [ 0.6847,  0.4140,  0.7103, -0.4193,  0.9925],
        [ 0.7883,  0.4750,  0.5835, -0.4669,  0.9433],
        [ 0.7944,  0.5207,  0.6460, -0.4127,  1.0492]],
       grad_fn=<AddmmBackward>)