# 纯工程 - 多输入多输出

该例子只是纯工程例子，样本数据随机生成，仅仅是为了描述`可能`。

## 1. 多输入多输出神经网络结构图

<img src="./images/muti-input-output.png" />

<img src="./images/muti-input-output-2.png" />

## 2. 导入包

In [None]:
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd

models = tf.keras.models
layers = tf.keras.layers
losses = tf.keras.losses
optimizers = tf.keras.optimizers
activations = tf.keras.activations
Input = tf.keras.Input
utils = tf.keras.utils
TensorBoard = tf.keras.callbacks.TensorBoard

# 文本片段词典大小
text_vocabulary_size = 1000
# 问题词典大小
question_vocabulary_size = 1000
# 输入文本长度
maxlen = 50
# 样本数量
num_samples = 1000
# 收入分类数
num_income_group = 10

## 3. 准备多输入层

In [None]:
# 处理文本输入
text_input = Input(shape=(None, ), dtype='int32', name='texts_input')
embedded_text = layers.Embedding(text_vocabulary_size, 64, name="Embedding_for_texts")(text_input)
encoded_text = layers.LSTM(32, name="LSTM_for_text")(embedded_text)

# 处理问题输入
question_input = Input(shape=(None, ), dtype='int32', name='questions_input')
embedded_question = layers.Embedding(question_vocabulary_size, 64, name="Embedding_for_questions")(question_input)
encoded_question = layers.LSTM(64, name="LSTM_for_questions")(embedded_question)

# 连接编码后的问题和文本
concatenated = layers.concatenate([encoded_text, encoded_question], axis=-1)

## 4. 准备多输出层

In [None]:
# 二分类问题
gender_prediction = layers.Dense(
    1,
    activation=activations.sigmoid,
    name='gender')(concatenated)


# 多分类问题
income_hidden_prediction = layers.Dense(
    128,
    activation=activations.relu,
    name='income_hidden')(concatenated)

income_prediction = layers.Dense(
    num_income_group,
    activation=activations.softmax,
    name='income')(income_hidden_prediction)

# 回归问题
age_prediction = layers.Dense(1, name='age')(concatenated)

## 5. 构建模型

In [None]:
model = models.Model(
    [text_input, question_input],
    [age_prediction, income_prediction, gender_prediction]
)

model.summary()

## 6. 编译模型

In [None]:
model.compile(
    optimizer=optimizers.Adam(),  # 这个注释只是为了自动格式化
    loss={
        'age': losses.mse,
        'income': losses.categorical_crossentropy,
        'gender': losses.binary_crossentropy
    },
    # 对总loss值贡献度，详情可参考：https://tensorflow.google.cn/api_docs/python/tf/keras/Model#compile
    loss_weights={
        'age': 0.25,
        'income': 1.,
        'gender': 10.
    },
    metrics = ['accuracy', 'mae', 'mse']
)

## 7. 准备模拟数据

In [None]:
# 准备模拟数据
# x_train(s)
texts = np.random.randint(1, text_vocabulary_size, size=(num_samples, maxlen))
questions = np.random.randint(1, question_vocabulary_size, size=(num_samples, maxlen))

# y_train(s)
# 1. 预测age是回归问题， lables.shape = (num_simples,)
age_targets = np.random.randint(16, 40, size=num_samples)

# 2. 预测income是多分类问题，有2种标签向量化方式，这边使用one-hot编码,lables.shape = (num_samples, num_income_group)
income_targets = np.random.randint(num_income_group, size=num_samples)
income_targets = utils.to_categorical(income_targets, num_income_group)

# 3. 预测gender是两分类问题, labels.shape = (num_samples, )
gender_targets = np.random.randint(2, size=num_samples)

## 8. 训练模型

In [None]:
log_dir = '../results/muti-input-output-layers/'

history = model.fit(
    {
        'texts_input': texts,  # 这个注释只是为了自动格式化
        'questions_input': questions
    },
    {
        'age': age_targets,
        'income': income_targets,
        'gender': gender_targets
    },
    epochs=10,
    batch_size=64,
    validation_split=0.2,    
    callbacks=[tf.keras.callbacks.TensorBoard(log_dir=log_dir)]
)

## 9. 使用matplotlib显示结果

### 9.1 整体损失

In [None]:
    hist = pd.DataFrame(history.history)
    hist['epoch'] = history.epoch

    plt.figure(figsize=(15, 2))
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.plot(hist['epoch'], hist['loss'], label='loss')
    plt.plot(hist['epoch'], hist['val_loss'], label = 'val_loss')
    plt.legend()
    plt.show()

### 9.2 年龄推测（回归问题）

回归问题不存在精度的概念，因此第一张图精度总是0，MAE或者MSE是回归问题主要参考指标

In [None]:
# ------------------------------------- Age ------------------------------------- 
    plt.figure(figsize=(15, 2))
    plt.xlabel('Epoch')
    plt.ylabel('Age Accuracy')
    plt.plot(hist['epoch'], hist['age_accuracy'], label='age_accuracy')
    plt.plot(hist['epoch'], hist['val_age_accuracy'], label = 'val_age_accuracy')
    plt.legend()
    
    plt.figure(figsize=(15, 2))
    plt.xlabel('Epoch')
    plt.ylabel('Age Loss')
    plt.plot(hist['epoch'], hist['age_loss'], label='age_loss')
    plt.plot(hist['epoch'], hist['val_age_loss'], label = 'val_age_loss')
    plt.legend()
    
    plt.figure(figsize=(15, 2))
    plt.xlabel('Epoch')
    plt.ylabel('Age Mean Absolute Error')
    plt.plot(hist['epoch'], hist['age_mae'], label='age_mae')
    plt.plot(hist['epoch'], hist['val_age_mae'], label = 'val_age_mae')
    plt.legend()
    
    plt.figure(figsize=(15, 2))
    plt.xlabel('Epoch')
    plt.ylabel('Age Mean Squared Error')
    plt.plot(hist['epoch'], hist['age_mse'], label='age_mse')
    plt.plot(hist['epoch'], hist['val_age_mse'], label = 'val_age_mse')
    plt.legend()
    
    plt.show()

### 9.3 收入区间推测（多分类问题）

因为是随机数，并且总类别为10，因此在这种情况下acc约等于10%， MAE和MSE不具有参考价值

In [None]:
   
    # ------------------------------------- income ------------------------------------- 
    plt.figure(figsize=(15, 2))
    plt.xlabel('Epoch')
    plt.ylabel('Income accuracy')
    plt.plot(hist['epoch'], hist['income_accuracy'], label='income_accuracy')
    plt.plot(hist['epoch'], hist['val_income_accuracy'], label = 'val_income_accuracy')
    plt.legend()
    
    plt.figure(figsize=(15, 2))
    plt.xlabel('Epoch')
    plt.ylabel('Income Loss')
    plt.plot(hist['epoch'], hist['income_loss'], label='income_loss')
    plt.plot(hist['epoch'], hist['val_income_loss'], label = 'val_income_loss')
    plt.legend()
    
    plt.figure(figsize=(15, 2))
    plt.xlabel('Epoch')
    plt.ylabel('Income Mean Absolute Error')
    plt.plot(hist['epoch'], hist['income_mae'], label='income_mae')
    plt.plot(hist['epoch'], hist['val_income_mae'], label = 'val_income_mae')
    plt.legend()
    
    plt.figure(figsize=(15, 2))
    plt.xlabel('Epoch')
    plt.ylabel('Income Mean Squared Error')
    plt.plot(hist['epoch'], hist['income_mse'], label='income_mse')
    plt.plot(hist['epoch'], hist['val_income_mse'], label = 'val_income_mse')
    plt.legend()
    
    plt.show()

### 9.4 性别推测（二分类问题）

随机数下，二分类问题， acc约等于50%，MAE和MSE不具有参考价值

In [None]:
  # ------------------------------------- income ------------------------------------- 
    plt.figure(figsize=(15, 2))
    plt.xlabel('Epoch')
    plt.ylabel('Gender Accuracy')
    plt.plot(hist['epoch'], hist['gender_accuracy'], label='gender_accuracy')
    plt.plot(hist['epoch'], hist['val_gender_accuracy'], label = 'val_gender_accuracy')
    plt.legend()
    
    plt.figure(figsize=(15, 2))
    plt.xlabel('Epoch')
    plt.ylabel('Gender Loss')
    plt.plot(hist['epoch'], hist['gender_loss'], label='gender_loss')
    plt.plot(hist['epoch'], hist['val_gender_loss'], label = 'val_gender_loss')
    plt.legend()
    
    plt.figure(figsize=(15, 2))
    plt.xlabel('Epoch')
    plt.ylabel('Gender Mean Absolute Error')
    plt.plot(hist['epoch'], hist['gender_mae'], label='gender_mae')
    plt.plot(hist['epoch'], hist['val_gender_mae'], label = 'val_gender_mae')
    plt.legend()
    
    plt.figure(figsize=(15, 2))
    plt.xlabel('Epoch')
    plt.ylabel('Gender Mean Squared Error')
    plt.plot(hist['epoch'], hist['gender_mse'], label='gender_mse')
    plt.plot(hist['epoch'], hist['val_gender_mse'], label = 'val_gender_mse')
    plt.legend()
    
    plt.show()