In [1]:
import tensorflow as tf
from tensorflow import feature_column as fc
from tensorflow.keras.layers import Layer, Dense, LayerNormalization, Dropout, Embedding, Conv1D, InputSpec
from tensorflow.keras import initializers, regularizers, constraints, activations

## 0.准备工作

In [2]:
nums = fc.numeric_column('nums', dtype=tf.float32)
seq = fc.categorical_column_with_hash_bucket('seq', hash_bucket_size=10, dtype=tf.int64)
target = fc.categorical_column_with_hash_bucket('target', hash_bucket_size=10, dtype=tf.int64)
seq_col = fc.embedding_column(seq, dimension=8)
target_col = fc.embedding_column(target, dimension=8)
columns = [seq_col, target_col, nums]
features={
    "seq": tf.sparse.SparseTensor(
        indices=[[0, 0], [0, 1], [1, 0], [1, 1], [2, 0]],
        values=[1100, 1101, 1102, 1101, 1103],
        dense_shape=[3, 2]),
    "target": tf.sparse.SparseTensor(
        indices=[[0, 0],[1,0],[2,0]],
        values=[1102,1103,1100],
        dense_shape=[3, 1]),
    "nums": tf.convert_to_tensor([0.1,0.2,0.3]) 

}
tf.sparse.to_dense(features['seq'])

<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[1100, 1101],
       [1102, 1101],
       [1103,    0]], dtype=int32)>

In [3]:
input_layer = tf.keras.layers.DenseFeatures(columns, name='features_input_layer')
net = input_layer(features)
#tf.concat(sequence_inputs.values(), axis =-1)
net

<tf.Tensor: shape=(3, 17), dtype=float32, numpy=
array([[ 0.1       ,  0.00786236,  0.08885731,  0.42254514, -0.1863629 ,
         0.18143918, -0.3677284 , -0.03922845, -0.035368  , -0.06283351,
        -0.33125448,  0.07279188, -0.48076993,  0.06331951,  0.3147942 ,
        -0.38336986, -0.16305678],
       [ 0.2       ,  0.04921697, -0.04056343,  0.40612042,  0.1093993 ,
        -0.10839443, -0.13684794, -0.06545366,  0.28763458,  0.02592784,
         0.20621896, -0.07896478,  0.49330452,  0.16031346, -0.6737616 ,
         0.3979211 , -0.130782  ],
       [ 0.3       ,  0.2581591 , -0.4934947 , -0.02980364, -0.15312529,
        -0.1492082 , -0.11314184,  0.39354733,  0.07831115,  0.01220254,
        -0.0779298 , -0.6079622 ,  0.2142255 , -0.31685454,  0.0331269 ,
        -0.28162146, -0.3388205 ]], dtype=float32)>

## 1.MMOE Layer

In [7]:
class MMoELayer(Layer):
    """
    Multi-gate Mixture-of-Experts model.
    """

    def __init__(self,
                 experts_network,
                 num_experts,
                 num_tasks,
                 expert_activation='relu',
                 gate_activation='softmax',
                 **kwargs):
        self.num_experts = num_experts
        self.num_tasks = num_tasks
        self.experts_network = experts_network
        
        self.gate_activation = activations.get(gate_activation)
        # Keras parameter
        self.input_spec = InputSpec(min_ndim=2)
        super(MMoELayer, self).__init__(**kwargs)

    def build(self, input_shape):
        assert input_shape is not None and len(input_shape) >= 2
        # build for expert network
        for expert_index in range(self.num_experts):
            for dnn_index, nodes in enumerate(self.experts_network):
                setattr(self, 'expert_{}_dnn_{}'.format(expert_index, dnn_index),
                        Dense(nodes, activation='relu', use_bias=True, kernel_initializer=tf.keras.initializers.VarianceScaling(distribution='uniform'),
                              name='expert_{}_dnn_{}'.format(expert_index, dnn_index)))

        for task_index in range(self.num_tasks):
            setattr(self, 'task_gate_{}'.format(task_index),
                    Dense(self.num_experts, activation='relu', use_bias=True, kernel_initializer=tf.keras.initializers.VarianceScaling(distribution='uniform'),
                          name='task_gate_{}'.format(task_index)))    

        super(MMoELayer, self).build(input_shape)

    def call(self, inputs):

        expert_outputs = []
        for expert_index in range(self.num_experts):
            cur_input = inputs
            for dnn_index, nodes in enumerate(self.experts_network):
                cur_input = getattr(self, 'expert_{}_dnn_{}'.format(expert_index, dnn_index))(cur_input)
            expert_outputs.append(tf.expand_dims(cur_input, axis=1))
        # print('expert list is{}'.format(expert_outputs))
        experts = tf.concat(expert_outputs,axis=1)
        # print('concat experts is {}'.format(experts))

        mmoe_outputs = []
        for task_index in range(self.num_tasks):
            # 计算gate
            gate_output = getattr(self, 'task_gate_{}'.format(task_index))(inputs)
            # print('task is {} gate weight is {}'.format(task_index, gate_output))
            # 加权求和
            gate_output = tf.expand_dims(self.gate_activation(gate_output), axis=-1)
            # print('task is {} softmax gate weight is {}, experts is {}'.format(task_index, gate_output, experts))
            weighted_output = gate_output * experts
            # print('task is {} weighted output is {}'.format(task_index, weighted_output))
            mmoe_outputs.append(tf.reduce_sum(weighted_output, axis=1))

        return mmoe_outputs

In [8]:
mmoe = MMoELayer(experts_network=[32,16],num_experts=5,num_tasks=3)
mmoe

<__main__.MMoELayer at 0x7f300000e290>

In [9]:
mmoe(net)

[<tf.Tensor: shape=(3, 16), dtype=float32, numpy=
 array([[0.03877088, 0.11113396, 0.0298532 , 0.02345212, 0.05685867,
         0.08997826, 0.03222549, 0.07594021, 0.06223984, 0.        ,
         0.03214021, 0.15673123, 0.06252179, 0.01419564, 0.03323525,
         0.12479531],
        [0.06881708, 0.14380828, 0.09435926, 0.06355689, 0.10793914,
         0.07909707, 0.06030644, 0.0463797 , 0.01762604, 0.07618837,
         0.02163212, 0.08080394, 0.10184368, 0.04214523, 0.07344208,
         0.03026609],
        [0.13196073, 0.0122222 , 0.17413385, 0.04110646, 0.09258364,
         0.03169873, 0.04262413, 0.07298442, 0.09293503, 0.11024777,
         0.11972588, 0.14457038, 0.0254414 , 0.02090997, 0.0270231 ,
         0.01111503]], dtype=float32)>,
 <tf.Tensor: shape=(3, 16), dtype=float32, numpy=
 array([[0.04048721, 0.11957923, 0.02920262, 0.02308273, 0.06261122,
         0.08966283, 0.03612928, 0.08911026, 0.07114395, 0.        ,
         0.02929593, 0.17600541, 0.07427946, 0.01417303, 