In [1]:
# define stacked model from multiple member input models
def define_stacked_model(members):

    # update all layers in all models to not be trainable

    for i in range(len(members)):
        model = members[i]
        for layer in model.layers:
            
            # make not trainable
            layer.trainable = False
            # rename to avoid 'unique layer name' issue
            layer.name = 'ensemble_' + str(i+1) + '_' + layer.name
    
    # define multi-headed input
    ensemble_visible = [model.input for model in members]
    
    # concatenate merge output from each model
    ensemble_outputs = [model.output for model in members]
    
    merge = concatenate(ensemble_outputs)
    hidden = Dense(10, activation='relu')(merge)
    output = Dense(3, activation='softmax')(hidden)
    model = Model(inputs=ensemble_visible, outputs=output)
    
    # plot graph of ensemble
    plot_model(model, show_shapes=True, to_file='model_graph.png')
    
    # compile
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    
    return model


In [2]:
# fit a stacked model
def fit_stacked_model(model, inputX, inputy):
    
    # prepare input data
    X = [inputX for _ in range(len(model.input))]
    
    # encode output data
    inputy_enc = to_categorical(inputy)
    
    # fit model
    model.fit(X, inputy_enc, epochs=300, verbose=0)

In [3]:
# make a prediction with a stacked model
def predict_stacked_model(model, inputX):
    # prepare input data
    X = [inputX for _ in range(len(model.input))]
    
    # make prediction
    return model.predict(X, verbose=0)

In [None]:
# load all models
l = [-1,2,3,4,5]
st = load_all_models(l)

print('Loaded %d models' % len(st))

# define ensemble model
stacked_model = define_stacked_model(st)

# fit stacked model on test dataset
fit_stacked_model(stacked_model, X_test, Y_test)

# make predictions and evaluate
yhat = predict_stacked_model(stacked_model, X_test)
yhat = argmax(yhat, axis=1)
acc = accuracy_score(Y_test, yhat)
print('Stacked Test Accuracy: %.3f' % acc)



