In [1]:
import tensorflow as tf

In [53]:

def _inception_module(input_tensor, stride=1, activation='linear', use_bottleneck=True, kernel_size=40, bottleneck_size=32, nb_filters=32):

    if use_bottleneck and int(input_tensor.shape[-1]) > 1:
        input_inception = tf.keras.layers.Conv1D(filters=bottleneck_size, kernel_size=1,
                                              padding='same', activation=activation, use_bias=False)(input_tensor)
    else:
        input_inception = input_tensor

    # kernel_size_s = [3, 5, 8, 11, 17]
    kernel_size_s = [kernel_size // (2 ** i) for i in range(3)]

    conv_list = []

    for i in range(len(kernel_size_s)):
        conv_list.append(tf.keras.layers.Conv1D(filters=nb_filters, kernel_size=kernel_size_s[i],
                                              strides=stride, padding='same', activation=activation, use_bias=False)(
            input_inception))

    max_pool_1 = tf.keras.layers.MaxPool1D(pool_size=3, strides=stride, padding='same')(input_tensor)

    conv_6 = tf.keras.layers.Conv1D(filters=nb_filters, kernel_size=1,
                                  padding='same', activation=activation, use_bias=False)(max_pool_1)

    conv_list.append(conv_6)

    x = tf.keras.layers.Concatenate(axis=2)(conv_list)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation(activation='relu')(x)
    return x

def _shortcut_layer(input_tensor, out_tensor):
    shortcut_y = tf.keras.layers.Conv1D(filters=int(out_tensor.shape[-1]), kernel_size=1,
                                      padding='same', use_bias=False)(input_tensor)
    shortcut_y = tf.keras.layers.BatchNormalization()(shortcut_y)

    x = tf.keras.layers.Add()([shortcut_y, out_tensor])
    x = tf.keras.layers.Activation('relu')(x)
    return x

def base_model(sig_len,n_features, depth=10, use_residual=True):
    input_layer = tf.keras.layers.Input(shape=(sig_len,n_features))

    x = input_layer
    input_res = input_layer

    for d in range(depth):

        x = _inception_module(x)

        if use_residual and d % 3 == 2:
            x = _shortcut_layer(input_res, x)
            input_res = x

    gap_layer = tf.keras.layers.GlobalAveragePooling1D()(x)

    output = tf.keras.layers.Dense(1, activation='sigmoid')(gap_layer)

    model = tf.keras.models.Model(inputs=input_layer, outputs=output)
    return model


def build_murmur_model(sig_len,n_features, depth=10, use_residual=True):
    input_layer = tf.keras.layers.Input(shape=(sig_len,n_features))

    x = input_layer
    input_res = input_layer

    for d in range(depth):

        x = _inception_module(x)

        if use_residual and d % 3 == 2:
            x = _shortcut_layer(input_res, x)
            input_res = x

    gap_layer = tf.keras.layers.GlobalAveragePooling1D()(x)

    murmur_output = tf.keras.layers.Dense(3, activation='softmax', name="murmur_output")(gap_layer)
    #clinical_output = tf.keras.layers.Dense(1, activation='sigmoid', name="clinical_output")(gap_layer)

    model = tf.keras.models.Model(inputs=input_layer, outputs=murmur_output)
    model.compile(loss="categorical_crossentropy", optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), metrics = [tf.keras.metrics.CategoricalAccuracy(),
    tf.keras.metrics.AUC(curve='ROC')])
    return model

def build_clinical_model(sig_len,n_features, depth=10, use_residual=True):
    input_layer = tf.keras.layers.Input(shape=(sig_len,n_features))

    x = input_layer
    input_res = input_layer

    for d in range(depth):

        x = _inception_module(x)

        if use_residual and d % 3 == 2:
            x = _shortcut_layer(input_res, x)
            input_res = x

    gap_layer = tf.keras.layers.GlobalAveragePooling1D()(x)

    clinical_output = tf.keras.layers.Dense(1, activation='sigmoid', name="clinical_output")(gap_layer)

    model = tf.keras.models.Model(inputs=input_layer, outputs=clinical_output)
    model.compile(loss="binary_crossentropy", optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), metrics = [tf.keras.metrics.BinaryAccuracy(),
    tf.keras.metrics.AUC(curve='ROC')])
    
    return model


def get_lead_index(patient_metadata):    
    lead_name = []
    lead_num = []
    cnt = 0
    for i in patient_metadata.splitlines(): 
        if i.split(" ")[0] == "AV" or i.split(" ")[0] == "PV" or i.split(" ")[0] == "TV" or i.split(" ")[0] == "MV":
            if not i.split(" ")[0] in lead_name:
                lead_name.append(i.split(" ")[0])
                lead_num.append(cnt)
            cnt += 1
    return np.asarray(lead_num)

def scheduler(epoch, lr):
    if epoch == 10:
        return lr * 0.1
    elif epoch == 15:
        return lr * 0.1
    elif epoch == 20:
        return lr * 0.1
    else:
        return lr

def get_murmur_locations(data):
    murmur_location = None
    for l in data.split('\n'):
        if l.startswith('#Murmur locations:'):
            try:
                murmur_location = l.split(': ')[1]
            except:
                pass
    if murmur_location is None:
        raise ValueError('No outcome available. Is your code trying to load labels from the hidden data?')
    return murmur_location

def pad_array(data):
    max_len = 0
    for i in data:
        if len(i) > max_len:
            max_len = len(i)
    new_arr = np.zeros((len(data),max_len))
    for j in range(len(data)):
        new_arr[j,:len(data[j])] = data[j]
    return new_arr

In [54]:
model = base_model(32256//5,1, depth=10, use_residual=True)

In [55]:
model.load_weights("pretrained_model.h5")

In [56]:
#model = build_clinical_model(32256//5,2, depth=10, use_residual=True)

In [57]:
x = tf.keras.layers.Dense(3, "softmax",  name="murmur_output")(model.layers[-2].output)

In [58]:
model2 = tf.keras.Model(inputs=model.layers[0].output, outputs=[x])
model2.summary()

Model: "model_16"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_12 (InputLayer)          [(None, 6451, 1)]    0           []                               
                                                                                                  
 max_pooling1d_80 (MaxPooling1D  (None, 6451, 1)     0           ['input_12[0][0]']               
 )                                                                                                
                                                                                                  
 conv1d_423 (Conv1D)            (None, 6451, 32)     1280        ['input_12[0][0]']               
                                                                                                  
 conv1d_424 (Conv1D)            (None, 6451, 32)     640         ['input_12[0][0]']        

                                                                                                  
 max_pooling1d_83 (MaxPooling1D  (None, 6451, 128)   0           ['activation_107[0][0]']         
 )                                                                                                
                                                                                                  
 conv1d_439 (Conv1D)            (None, 6451, 32)     40960       ['conv1d_438[0][0]']             
                                                                                                  
 conv1d_440 (Conv1D)            (None, 6451, 32)     20480       ['conv1d_438[0][0]']             
                                                                                                  
 conv1d_441 (Conv1D)            (None, 6451, 32)     10240       ['conv1d_438[0][0]']             
                                                                                                  
 conv1d_44

 conv1d_455 (Conv1D)            (None, 6451, 32)     40960       ['conv1d_454[0][0]']             
                                                                                                  
 conv1d_456 (Conv1D)            (None, 6451, 32)     20480       ['conv1d_454[0][0]']             
                                                                                                  
 conv1d_457 (Conv1D)            (None, 6451, 32)     10240       ['conv1d_454[0][0]']             
                                                                                                  
 conv1d_458 (Conv1D)            (None, 6451, 32)     4096        ['max_pooling1d_86[0][0]']       
                                                                                                  
 concatenate_86 (Concatenate)   (None, 6451, 128)    0           ['conv1d_455[0][0]',             
                                                                  'conv1d_456[0][0]',             
          

 conv1d_473 (Conv1D)            (None, 6451, 32)     10240       ['conv1d_470[0][0]']             
                                                                                                  
 conv1d_474 (Conv1D)            (None, 6451, 32)     4096        ['max_pooling1d_89[0][0]']       
                                                                                                  
 concatenate_89 (Concatenate)   (None, 6451, 128)    0           ['conv1d_471[0][0]',             
                                                                  'conv1d_472[0][0]',             
                                                                  'conv1d_473[0][0]',             
                                                                  'conv1d_474[0][0]']             
                                                                                                  
 batch_normalization_116 (Batch  (None, 6451, 128)   512         ['concatenate_89[0][0]']         
 Normaliza