# Multiple Inputs and Outputs
In order to explore multiple inputs and outputs in the Keras Functional API let's try to predict how many retweets and likes a news headline will receive on Twitter. 

The main input to the model will be the headline itself, as a sequence of words. Our model will also have an additional input, receiving data such as the time of day when the headline was posted, etc.

In [1]:
import keras
from keras.layers import Input, Embedding, LSTM, Dense
from keras.models import Model
import numpy as np

np.random.seed(23)

# Note that we can name any layer by passing it a "name" argument.
main_input = Input(shape=(100,), dtype='int32', name='main_input')

# This embedding layer will encode the input sequence into a sequence of dense 512-dimensional vectors.
x = Embedding(output_dim=512, input_dim=10000, input_length=100)(main_input)

lstm_out = LSTM(32)(x)

Using TensorFlow backend.
W1110 19:00:25.727532   860 deprecation_wrapper.py:119] From C:\Users\jdeha\Anaconda3\envs\keras\lib\site-packages\keras\backend\tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W1110 19:00:25.741447   860 deprecation_wrapper.py:119] From C:\Users\jdeha\Anaconda3\envs\keras\lib\site-packages\keras\backend\tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W1110 19:00:25.742774   860 deprecation_wrapper.py:119] From C:\Users\jdeha\Anaconda3\envs\keras\lib\site-packages\keras\backend\tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.



In [2]:
#Let's define the additional inputs and outputs
addtl_output = Dense(1, activation='sigmoid', name='addtl_output')(lstm_out)
addtl_input = Input(shape=(5,), name='addtl_input')
x = keras.layers.concatenate([lstm_out, addtl_input])

# Let's add a couple of layers
x = Dense(64, activation='relu')(x)
x = Dense(64, activation='relu')(x)

# And finally we add the main logistic regression layer
main_output = Dense(1, activation='sigmoid', name='main_output')(x)

In [3]:
model = Model(inputs=[main_input, addtl_input], outputs=[main_output, addtl_output])
model.compile(optimizer='rmsprop', loss='binary_crossentropy',
              loss_weights=[1., 0.2])

W1110 19:01:44.499846   860 deprecation_wrapper.py:119] From C:\Users\jdeha\Anaconda3\envs\keras\lib\site-packages\keras\optimizers.py:790: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.

W1110 19:01:44.522175   860 deprecation_wrapper.py:119] From C:\Users\jdeha\Anaconda3\envs\keras\lib\site-packages\keras\backend\tensorflow_backend.py:3376: The name tf.log is deprecated. Please use tf.math.log instead.

W1110 19:01:44.526175   860 deprecation.py:323] From C:\Users\jdeha\Anaconda3\envs\keras\lib\site-packages\tensorflow\python\ops\nn_impl.py:180: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Let's simulate some data to test this out

In [6]:
headline_data = np.round(np.abs(np.random.rand(12, 100) * 100))
additional_data = np.random.randn(12, 5)
headline_labels = np.random.randn(12, 1)
additional_labels = np.random.randn(12, 1)
model.fit([headline_data, additional_data], [headline_labels, additional_labels], epochs=10, batch_size=32)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x1d6434f1c48>

In [7]:
model.predict({'main_input': headline_data, 'addtl_input': additional_data})

[array([[9.9862969e-01],
        [5.1328790e-01],
        [5.2889287e-03],
        [2.6221490e-01],
        [0.0000000e+00],
        [2.9802322e-08],
        [0.0000000e+00],
        [0.0000000e+00],
        [6.7845100e-01],
        [0.0000000e+00],
        [2.0861626e-07],
        [0.0000000e+00]], dtype=float32), array([[0.32905042],
        [0.11721784],
        [0.11697972],
        [0.04524976],
        [0.00246167],
        [0.00297764],
        [0.00334498],
        [0.00152788],
        [0.22368217],
        [0.00263366],
        [0.0028528 ],
        [0.00243452]], dtype=float32)]