In [1]:
from all_imports import *

In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [10]:
from ptbds import *

In [8]:
from nets import *

In [4]:
tfms

([<fastai.transforms.Normalize object at 0x7fb2890f2588>, <ptbds.ChannelOrder1d object at 0x7fb2890f2710>],
 [<fastai.transforms.Normalize object at 0x7fb2890f2588>, <ptbds.ChannelOrder1d object at 0x7fb2890f26a0>])

In [5]:
class PreActBlock(nn.Module):
    def __init__(self, in_planes, planes, stride=1):
        super(PreActBlock, self).__init__()
        self.bn1 = nn.BatchNorm1d(in_planes)
        self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(planes)
        self.conv2 = nn.Conv1d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)

        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_planes, planes, kernel_size=1, stride=stride, bias=False)
            )

        # SE layers
        self.fc1 = nn.Conv1d(planes, planes//16, kernel_size=1)
        self.fc2 = nn.Conv1d(planes//16, planes, kernel_size=1)

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))

        # Squeeze
        w = F.avg_pool1d(out, out.size(2))
        w = F.relu(self.fc1(w))
        w = F.sigmoid(self.fc2(w))
        # Excitation
        out = out * w

        out += shortcut
        return out

In [6]:
class  senet_small(nn.Module):
    def __init__(self, block, inc_list, inc_scale, num_blocks_list, stride_list, num_classes):
        super().__init__()
        self.num_blocks = len(num_blocks_list)
        inc_list1 = [o//inc_scale for o in inc_list]
        self.in_planes = inc_list1[0]
        self.conv1 = nn.Conv1d(15, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(self.in_planes)
        
        lyrs = []
        for inc, nb, strl in zip(inc_list1[1:], num_blocks_list, stride_list):
            lyrs.append(self._make_layer(block, inc, nb, strl))
            
        self.lyrs = nn.Sequential(*lyrs)
        self.linear = nn.Linear(inc_list1[-1], num_classes)
        
    
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes
        return nn.Sequential(*layers)
    
    def forward(self, inp):
#         import pdb; pdb.set_trace()
        out = F.relu(self.bn1(self.conv1(inp)))
        out = self.lyrs(out)
        out = F.adaptive_avg_pool1d(out, 1)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return F.log_softmax(out, dim=-1)

In [11]:
snt_mdl = senet_small(PreActBlock, 
                inc_list=[64, 64, 128, 256], 
                inc_scale = 4,
                num_blocks_list=[2, 3, 2], 
                stride_list=[1, 2, 2], 
                num_classes=2)

learn = ConvLearner.from_model_data(snt_mdl, data)

In [7]:
learn.fit(1e-1, 1, cycle_len=100, best_save_name='mlcnn_ecg1')

HBox(children=(IntProgress(value=0, description='Epoch'), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                  
    0      0.72717    0.698167   0.22619   
    1      0.68561    0.634217   0.833333                  
    2      0.645285   0.570218   0.833333                  
    3      0.610145   0.525978   0.833333                  
    4      0.581513   0.498811   0.833333                 
    5      0.552682   0.361611   0.869048                  
    6      0.513631   1.108742   0.833333                  
    7      0.479467   0.301442   0.857143                  
    8      0.452487   0.340058   0.833333                  
    9      0.426542   2.819387   0.297619                  
    10     0.467622   0.459335   0.833333                  
    11     0.478169   0.440313   0.833333                  
    12     0.482361   0.402384   0.845238                  
    13     0.45671    0.451397   0.857143                  
    14     0.419509   0.351873   0.821429                  
    15     0.383924   0.482709   0.880952                

[array([0.03576]), 0.9880952380952381]

In [9]:
get_scores(learn)

Ac 0.9880952380952381                        
f1 [0.96552 0.99281]
prec [0.93333 1.     ]
rec [1.      0.98571]


(array([[[ -9.2088 ,  -0.0001 ],
         [ -9.99127,  -0.00005],
         [-22.44883,   0.     ],
         [-12.88729,  -0.     ],
         [-17.23374,   0.     ],
         [ -9.91783,  -0.00005],
         [ -0.05526,  -2.92326],
         [ -9.26699,  -0.00009],
         [-11.78953,  -0.00001],
         [-22.85287,   0.     ],
         [ -0.00016,  -8.7154 ],
         [-11.3149 ,  -0.00001],
         [-20.18568,   0.     ],
         [-24.08165,   0.     ],
         [ -0.00207,  -6.18084],
         [ -5.30582,  -0.00498],
         [-14.75025,  -0.     ],
         [-13.39438,  -0.     ],
         [-12.92101,  -0.     ],
         [-14.02743,  -0.     ],
         [ -1.02202,  -0.44608],
         [-21.32597,   0.     ],
         [-23.44081,   0.     ],
         [ -0.3175 ,  -1.30184],
         [ -9.27776,  -0.00009],
         [ -9.46274,  -0.00008],
         [ -9.12377,  -0.00011],
         [ -6.29657,  -0.00184],
         [ -3.39858,  -0.03399],
         [-16.21363,   0.     ],
         [

In [12]:
learn.load('mlcnn_ecg1')

In [13]:
get_scores(learn)

Ac 0.9880952380952381                        
f1 [0.96552 0.99281]
prec [0.93333 1.     ]
rec [1.      0.98571]


(array([[[ -9.32023,  -0.00009],
         [-11.74212,  -0.00001],
         [-23.09288,   0.     ],
         [-10.61829,  -0.00002],
         [-13.78455,  -0.     ],
         [ -9.86735,  -0.00005],
         [ -0.16391,  -1.88929],
         [-15.75846,   0.     ],
         [-11.49766,  -0.00001],
         [-21.6105 ,   0.     ],
         [ -0.00017,  -8.66755],
         [-11.62249,  -0.00001],
         [-19.17711,   0.     ],
         [-21.3015 ,   0.     ],
         [ -0.00277,  -5.89099],
         [ -3.77016,  -0.02332],
         [-13.95063,  -0.     ],
         [-14.93376,  -0.     ],
         [-18.23184,   0.     ],
         [ -8.91646,  -0.00013],
         [ -3.97635,  -0.01893],
         [-20.59655,   0.     ],
         [-20.99894,   0.     ],
         [ -0.17756,  -1.81589],
         [ -9.11848,  -0.00011],
         [ -9.46274,  -0.00008],
         [ -8.23643,  -0.00026],
         [ -6.61817,  -0.00134],
         [ -3.06521,  -0.04777],
         [-17.60469,   0.     ],
         [