In [15]:
from typing import Optional, Tuple
import os
import wfdb
import numpy as np
from sklearn.preprocessing import LabelEncoder


def get_segments(data: wfdb.Record,
                 annotations: wfdb.Annotation,
                 labels: np.ndarray,
                 left_offset: int = 99,
                 right_offset: int = 160,
                 fixed_length: Optional[int] = None) -> Tuple[np.ndarray, np.ndarray]:
    """ It generates the segments of uninterrupted sequences of arrythmia beats into the corresponding arrythmia groups
    in labels.

    :param data:            The arrythmia signal as a wfdb Record class
    :param annotations:     The set of annotations as a wfdb Annotation class
    :param labels:          The set of valid labels for the different segments. Segments with different labels are discarded
    :param left_offset:     The number of instance at the left of the first R peak of the segment. Default to 99
    :param right_offset:    The number of instances at the right of the last R peak of the segment. Default to 160
    :param fixed_length:    Should the segments have a fixed length? If fixed_length is a number, then the segments will
                            have the specified length. If the segment length is greater than fixed_length, it is truncated
                            or padded with zeros otherwise. Default to None.

    :return:                A tuple that contains the data and the associated labels. Data has a shape of (N, T, V)
                            where N is the number of segments (or instances), V is the number of variables (1 in this case)
                            and T is the number of timesteps of each segment.  Labels are numerically encoded according to the
                            value passed in the :parameter labels param.
    """
    i = 0
    annot_segments = []

    # Get the tuples for consecutive symbols. The tuple is (first, last, symbol) where first is the index of the first occurrence of symbol,
    # and last is the index of the last consecutive ocurrence.
    while(i < len(annotations.symbol)):
        first = i
        current_symbol = annotations.symbol[i]
        while(i < len(annotations.symbol) and annotations.symbol[i] == current_symbol):
            i += 1
        last = i-1
        tup = (first, last, current_symbol)
        annot_segments.append(tup)

    # Now, for each extracted tuple, get the X segments:
    result = []
    classes = []
    for s in annot_segments:  # s is a tuple (first, last, symbol)
        if s[2] in labels:
            classes.append(s[2])
            init = annotations.sample[s[0]] - left_offset
            if init < 0:
                init = 0

            end = annotations.sample[s[1]] + right_offset
            if end >= len(data.p_signal):
                end = len(data.p_signal) - 1

            r = range(init, end)

            # Get the samples of the segments (p_signal is a 2D array, we only want the first axis)
            new_segment = np.array(data.p_signal[r, 1], dtype='float32')

            # truncate or pad with zeros the segment if necessary
            if (fixed_length != None):
                if (len(new_segment) > fixed_length):  # truncate
                    new_segment = new_segment[:fixed_length]
                elif (len(new_segment < fixed_length)):  # pad with zeros to the right
                    number_of_zeros = fixed_length - len(new_segment)
                    new_segment = np.pad(new_segment, (0, number_of_zeros), mode='constant', constant_values=0)

            result.append(new_segment)

    result = np.stack(result, axis=0)
    result = np.reshape(result, (result.shape[0], result.shape[1], 1))  # shape[0] segments with 1 variable, with shape[1] timestamps each
    classes = np.array(classes, dtype=str)

    # Encode labels: from string to numeric.
    label_encoder = LabelEncoder()
    label_encoder.fit(labels)
    classes = label_encoder.transform(classes)

    return (result, classes)



def read_MIT_BIH(path: str,
                 labels: np.ndarray = np.array(['N','L','R','A','V']),
                 left_offset: int = 99,
                 right_offset: int = 160,
                 fixed_length: Optional[int] = 1000) -> Tuple[np.ndarray, np.ndarray]:
    """ It reads the MIT-BIH Arrythmia X with the specified default configuration of the work presented at:
    Oh, Shu Lih, et al. "Automated diagnosis of arrhythmia using combination of CNN and LSTM techniques with
    variable length heart beats." Computers in biology and medicine 102 (2018): 278-287.

    :param labels:              The labels of the different types of arrythmia to be employed
    :param path:                The path of the directory where the X files are stored. Note: The X and annotations
                                files must have the same name, but different extension (annotations must have .atr extension)
    :param left_offset:         The number of instances at the left of the first R peak of the segment. Defaults to 99
    :param right_offset:        The number of instances at the right of the last R peak of the segment. Defaults to 160
    :param fixed_length:        If different to None, the segment will have the specified number of instances. Note that
                                if the segment length > fixed_length it will be truncate or padded with zeros otherwise.

    :return:                     A tuple that contains the data and the associated labels as an ndarray. Data has a shape of (N, T, V)
                                where N is the number of segments (or instances), V is the number of variables (1 in this case)
                                and T is the number of timesteps of each segment.  Labels are numerically encoded according to the
                                value passed in the :parameter labels param.
    """
    print("reading data...")
    segments = []
    classes = []

    files = [ file[:-4] for file in os.listdir(path) if file.endswith('.dat') ]
    for f in files:
        data = wfdb.rdrecord(path + f)
        annotation = wfdb.rdann(path + f, 'atr')

        s, clazz = get_segments(data=data,
                                 annotations=annotation,
                                 labels=labels,
                                 left_offset=left_offset,
                                 right_offset=right_offset,
                                 fixed_length=fixed_length)

        segments.append(s)
        classes.append(clazz)

    segments = np.vstack(segments)
    classes = np.concatenate(classes)
    print("done.")

    return (segments, classes)

# Leemos los datos
dir = "../physionet.org/files/mitdb/1.0.0/"

X, y = read_MIT_BIH(dir)

# mostramos la forma de los datos de entrada. En total tenemos 16499 series temporales 
# de 1 variable con 1000 instantes de tiempo cada una de ellas.
# Cada serie temporal tiene Ãºnicamente 1 valor asociado o clase que determina el tipo de arritmia
print("X shape: ", X.shape)
print("y shape: ", y.shape)

reading data...
done.
X shape:  (16499, 1000, 1)
y shape:  (16499,)


In [47]:
%load_ext autoreload
%autoreload 2

# temporary solution for relative imports in case sadl is not installed
# if sadl is installed, no need to use the line
import sys
import os
import inspect
import keras
from inspect import signature
from inspect import signature
import numpy as np
sys.path.append(os.path.abspath(os.path.join(os.path.dirname("__file__"), '..')))
from SADL.time_series.algorithms import tsfedl
from TSFEDL.models_keras import OhShuLih


input_layer = keras.Input(shape=(X.shape[1], X.shape[2]))
kwargs = {"algorithm_": "ohshulih","include_top":False, "label_parser": True, "input_tensor": input_layer, "hello": 5}
model1 = tsfedl.TsfedlAnomalyDetection(**kwargs)

print(model1.get_params())



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
{'label_parser': True, 'algorithm_': 'OhShuLih', 'include_top': False, 'weights': [<tf.Variable 'conv1d_111/kernel:0' shape=(20, 1, 3) dtype=float32, numpy=
array([[[-0.03382894, -0.23153992, -0.04951137]],

       [[ 0.23038143,  0.1741544 , -0.23614272]],

       [[ 0.17115462, -0.04498816, -0.02411483]],

       [[ 0.16918471, -0.11182962, -0.01625282]],

       [[-0.16344331,  0.16488951, -0.25568414]],

       [[-0.20773864, -0.2701233 ,  0.05473956]],

       [[-0.2427876 , -0.01873803, -0.06627446]],

       [[ 0.14307576,  0.26460987,  0.22212988]],

       [[-0.14187951,  0.2649634 , -0.25059977]],

       [[-0.07035583, -0.24914289,  0.09550479]],

       [[-0.21800587,  0.06808323, -0.082471  ]],

       [[ 0.07254773, -0.00672832,  0.05980778]],

       [[ 0.22880137,  0.26418096,  0.12697741]],

       [[-0.17699036, -0.13283086, -0.0548362 ]],

       [[ 0.11246127, -0.24275808,  0.223