<a href="https://colab.research.google.com/github/Muzhi1920/awesome-models/blob/main/07-%E5%A4%9A%E7%9B%AE%E6%A0%87%E7%BB%93%E6%9E%84/02_PLE_Layer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
from tensorflow import feature_column as fc
from tensorflow.keras import initializers, regularizers, constraints, activations
from tensorflow.keras.layers import Layer, InputSpec, Dense

## 0.准备工作

In [None]:
nums = fc.numeric_column('nums', dtype=tf.float32)
seq = fc.categorical_column_with_hash_bucket('seq', hash_bucket_size=10, dtype=tf.int64)
target = fc.categorical_column_with_hash_bucket('target', hash_bucket_size=10, dtype=tf.int64)
seq_col = fc.embedding_column(seq, dimension=8)
target_col = fc.embedding_column(target, dimension=8)
columns = [seq_col, target_col, nums]
features={
    "seq": tf.sparse.SparseTensor(
        indices=[[0, 0], [0, 1], [1, 0], [1, 1], [2, 0]],
        values=[1100, 1101, 1102, 1101, 1103],
        dense_shape=[3, 2]),
    "target": tf.sparse.SparseTensor(
        indices=[[0, 0],[1,0],[2,0]],
        values=[1102,1103,1100],
        dense_shape=[3, 1]),
    "nums": tf.convert_to_tensor([0.1,0.2,0.3]) 

}
tf.sparse.to_dense(features['seq'])

<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[1100, 1101],
       [1102, 1101],
       [1103,    0]], dtype=int32)>

In [None]:
input_layer = tf.keras.layers.DenseFeatures(columns, name='features_input_layer')
net = input_layer(features)
#tf.concat(sequence_inputs.values(), axis =-1)
net

<tf.Tensor: shape=(3, 17), dtype=float32, numpy=
array([[ 0.1       ,  0.368285  ,  0.01953911,  0.329111  , -0.28448862,
         0.00381372, -0.1794903 , -0.06729466,  0.10163559, -0.36710843,
        -0.47717634,  0.2626422 ,  0.40373984,  0.48506886,  0.4386921 ,
         0.2912761 , -0.08068093],
       [ 0.2       ,  0.18373841,  0.3408668 ,  0.35096297,  0.10208823,
        -0.0336121 , -0.11717728,  0.24010646,  0.35366982, -0.17428383,
         0.18445659, -0.61224234,  0.32126307,  0.35839644, -0.2370223 ,
         0.05115302,  0.3337443 ],
       [ 0.3       ,  0.14535068, -0.12395273,  0.1549992 , -0.56996447,
         0.22735709, -0.59369177, -0.2510585 , -0.1254943 ,  0.09329653,
         0.0244523 , -0.5612847 ,  0.29976568, -0.03937777, -0.51279557,
         0.08720471, -0.23819971]], dtype=float32)>

## 1.PLE Layer

In [None]:
class PLELayer(Layer):
    def __init__(self,
                 tower_task_specs,
                 experts_network,
                 num_share_experts,
                 num_tasks,
                 multi_level,
                 use_expert_bias=True,
                 use_gate_bias=True,
                 expert_activation='relu',
                 gate_activation='softmax',
                 expert_bias_initializer='zeros',
                 gate_bias_initializer='zeros',
                 expert_bias_regularizer=None,
                 gate_bias_regularizer=None,
                 expert_bias_constraint=None,
                 gate_bias_constraint=None,
                 expert_kernel_initializer='VarianceScaling',
                 gate_kernel_initializer='VarianceScaling',
                 expert_kernel_regularizer=None,
                 gate_kernel_regularizer=None,
                 expert_kernel_constraint=None,
                 gate_kernel_constraint=None,
                 activity_regularizer=None,
                 **kwargs):
        """
         Method for instantiating MMoE layer.
        :param tower_task_specs: Config of every task
        :param experts_network: hidden_layers in expert within last units
        :param num_share_experts: Number of experts for share task
        :param num_tasks: Number of tasks
        :param multi_level: Level number in PLE structure
        :param use_expert_bias: Boolean to indicate the usage of bias in the expert weights
        :param use_gate_bias: Boolean to indicate the usage of bias in the gate weights
        :param expert_activation: Activation function of the expert weights
        :param gate_activation: Activation function of the gate weights
        :param expert_bias_initializer: Initializer for the expert bias
        :param gate_bias_initializer: Initializer for the gate bias
        :param expert_bias_regularizer: Regularizer for the expert bias
        :param gate_bias_regularizer: Regularizer for the gate bias
        :param expert_bias_constraint: Constraint for the expert bias
        :param gate_bias_constraint: Constraint for the gate bias
        :param expert_kernel_initializer: Initializer for the expert weights
        :param gate_kernel_initializer: Initializer for the gate weights
        :param expert_kernel_regularizer: Regularizer for the expert weights
        :param gate_kernel_regularizer: Regularizer for the gate weights
        :param expert_kernel_constraint: Constraint for the expert weights
        :param gate_kernel_constraint: Constraint for the gate weights
        :param activity_regularizer: Regularizer for the activity
        :param kwargs: Additional keyword arguments for the Layer class
        """
        assert experts_network is not None
        assert num_share_experts is not None and num_share_experts > 0
        assert num_tasks is not None and num_tasks > 0
        assert multi_level is not None and multi_level > 0

        # Hidden nodes parameter
        self.multi_level = multi_level
        self.num_tasks = num_tasks
        self.tower_task_specs = tower_task_specs
        self.num_share_experts = num_share_experts
        self.experts_network = experts_network

        # Weight parameter
        self.expert_kernel_initializer = initializers.get(expert_kernel_initializer)
        self.gate_kernel_initializer = initializers.get(gate_kernel_initializer)
        self.expert_kernel_regularizer = regularizers.get(expert_kernel_regularizer)
        self.gate_kernel_regularizer = regularizers.get(gate_kernel_regularizer)
        self.expert_kernel_constraint = constraints.get(expert_kernel_constraint)
        self.gate_kernel_constraint = constraints.get(gate_kernel_constraint)

        # Activation parameter
        self.expert_activation = activations.get(expert_activation)
        self.gate_activation = activations.get(gate_activation)

        # Bias parameter
        self.use_expert_bias = use_expert_bias
        self.use_gate_bias = use_gate_bias
        self.expert_bias_initializer = initializers.get(expert_bias_initializer)
        self.gate_bias_initializer = initializers.get(gate_bias_initializer)
        self.expert_bias_regularizer = regularizers.get(expert_bias_regularizer)
        self.gate_bias_regularizer = regularizers.get(gate_bias_regularizer)
        self.expert_bias_constraint = constraints.get(expert_bias_constraint)
        self.gate_bias_constraint = constraints.get(gate_bias_constraint)

        # Activity parameter
        self.activity_regularizer = regularizers.get(activity_regularizer)

        # Keras parameter
        self.input_spec = InputSpec(min_ndim=2)
        self.supports_masking = True

        super(PLELayer, self).__init__(**kwargs)

    def build(self, input_shape):
        """
        Method for creating the layer weights.

        :param input_shape: Keras tensor (future input to layer)
                            or list/tuple of Keras tensors to reference
                            for weight shape computations
        """
        assert input_shape is not None and len(input_shape) >= 2

        input_dimension = input_shape[-1]

        # build for all level expert kernels
        # 1 level     kernel:512,[256,128,5] == output
        # >1 level    kernel:512,[256,128,5] * (level - 1) -> output
        for level in range(self.multi_level):
            # build for spec experts
            for k, tower in enumerate(self.tower_task_specs):
                for s in range(tower['spec_expert']):
                    for index, num_node in enumerate(self.experts_network):
                        setattr(self, 'multi_extraction_level_{}_task_{}_spec_{}_index_{}'.format(level, k, s, index),
                                Dense(num_node, activation='relu', kernel_initializer=self.expert_kernel_initializer,
                                      kernel_regularizer=self.expert_kernel_regularizer, kernel_constraint=self.expert_kernel_constraint,
                                      name='multi_extraction_level_{}_task_{}_spec_{}_index_{}'.format(level, k, s, index)))
            # build for share experts
            for sh in range(self.num_share_experts):
                for index, num_node in enumerate(self.experts_network):
                    setattr(self, 'multi_extraction_level_{}_share_{}_index_{}'.format(level, sh, index),
                            Dense(num_node, activation='relu', kernel_initializer=self.expert_kernel_initializer,
                                  kernel_regularizer=self.expert_kernel_regularizer, kernel_constraint=self.expert_kernel_constraint,
                                  name='multi_extraction_level_{}_share_{}_index_{}'.format(level, sh, index)))

        # build for all level gate
        # last level;   gate 3, 3, 3
        # other level;  gate 3, 3, 3, 5
        for level in range(self.multi_level):
            # build for gate kernel for task spec (spec_expert + num_share_experts)
            for k, tower in enumerate(self.tower_task_specs):
                setattr(self, 'gate_level_{}_task_{}'.format(level, k),
                        Dense(tower['spec_expert'] + self.num_share_experts, activation='softmax',
                              kernel_initializer=self.gate_kernel_initializer,
                              kernel_regularizer=self.gate_kernel_regularizer, kernel_constraint=self.gate_kernel_constraint,
                              bias_initializer=self.gate_bias_initializer, bias_regularizer=self.gate_bias_regularizer,
                              bias_constraint=self.gate_bias_constraint, name='gate_level_{}_task_{}'.format(level, k)))
            # build for gate kernel for share experts (spec_expert * num_tasks + num_share_experts)
            if level != self.multi_level - 1:
                # build for gate kernel for share (all)
                setattr(self, 'gate_level_{}_share'.format(level),
                        Dense(self.tower_task_specs[0]['spec_expert'] * self.num_tasks + self.num_share_experts, activation='softmax',
                              kernel_initializer=self.gate_kernel_initializer,
                              kernel_regularizer=self.gate_kernel_regularizer, kernel_constraint=self.gate_kernel_constraint,
                              bias_initializer=self.gate_bias_initializer, bias_regularizer=self.gate_bias_regularizer,
                              bias_constraint=self.gate_bias_constraint, name='gate_level_{}_share'.format(level)))

        self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dimension})

        super(PLELayer, self).build(input_shape)

    def call(self, inputs):
        """
        Method for the forward function of the layer.

        :param inputs: Input tensor
        :return: A tensor
        """

        for level in range(self.multi_level):
            if level == 0:
                cur_inputs = inputs
            else:
                cur_inputs = level_outputs
            experts_outputs = self.multi_level_kernel_layer(level, cur_inputs)
            # print('level_{}_expert_outputs:{}'.format(level, experts_outputs))
            group_multi_level_extraction_outputs = self.group_multi_extraction_expert(level, experts_outputs)
            # print('level_{}_group_experts:{}'.format(level, group_multi_level_extraction_outputs))
            gate_outputs = self.multi_level_extraction_gate(level, cur_inputs)
            # print('level_{}_gate_output:{}'.format(level, gate_outputs))
            level_outputs = self.multi_level_extraction_output_emb(level, gate_outputs, group_multi_level_extraction_outputs)
            # print('level_{}_outputs:{}'.format(level, level_outputs))

        return level_outputs

    def compute_output_shape(self, input_shape):
        """
        Method for computing the output shape of the MMoE layer.

        :param input_shape: Shape tuple (tuple of integers)
        :return: List of input shape tuple where the size of the list is equal to the number of tasks
        """
        assert input_shape is not None and len(input_shape) >= 2

        output_shape = list(input_shape)
        output_shape[-1] = self.experts_network[-1]
        output_shape = tuple(output_shape)

        return tf.TensorShape([output_shape for _ in range(self.num_tasks)])

    def get_config(self):
        """
        Method for returning the configuration of the MMoE layer.

        :return: Config dictionary
        """
        cur_config = {
            'experts_network': self.experts_network,
            'num_experts': self.num_experts,
            'num_tasks': self.num_tasks,
            'use_expert_bias': self.use_expert_bias,
            'use_gate_bias': self.use_gate_bias,
            'expert_activation': activations.serialize(self.expert_activation),
            'gate_activation': activations.serialize(self.gate_activation),
            'expert_bias_initializer': initializers.serialize(self.expert_bias_initializer),
            'gate_bias_initializer': initializers.serialize(self.gate_bias_initializer),
            'expert_bias_regularizer': regularizers.serialize(self.expert_bias_regularizer),
            'gate_bias_regularizer': regularizers.serialize(self.gate_bias_regularizer),
            'expert_bias_constraint': constraints.serialize(self.expert_bias_constraint),
            'gate_bias_constraint': constraints.serialize(self.gate_bias_constraint),
            'expert_kernel_initializer': initializers.serialize(self.expert_kernel_initializer),
            'gate_kernel_initializer': initializers.serialize(self.gate_kernel_initializer),
            'expert_kernel_regularizer': regularizers.serialize(self.expert_kernel_regularizer),
            'gate_kernel_regularizer': regularizers.serialize(self.gate_kernel_regularizer),
            'expert_kernel_constraint': constraints.serialize(self.expert_kernel_constraint),
            'gate_kernel_constraint': constraints.serialize(self.gate_kernel_constraint),
            'activity_regularizer': regularizers.serialize(self.activity_regularizer)
        }
        config = super(PLELayer, self).get_config()
        config.update(cur_config)
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

    # expert kernel layer: input -> multi_experts -> output_emb(None,128,1)
    def multi_level_kernel_layer(self, level, cur_inputs):
        experts_outputs = []
        # forward to task spec experts
        for k, tower in enumerate(self.tower_task_specs):
            for s in range(tower['spec_expert']):
                if level == 0:
                    inputs = cur_inputs
                else:
                    inputs = cur_inputs[k]
                # forward to dnn experts
                for index, _ in enumerate(self.experts_network):
                    net = getattr(self, 'multi_extraction_level_{}_task_{}_spec_{}_index_{}'.format(level, k, s, index))(inputs)
                    inputs = net
                experts_outputs.append(tf.keras.backend.expand_dims(net, axis=-1))
        # forward to share experts
        for sh in range(self.num_share_experts):
            if level == 0:
                inputs = cur_inputs
            else:
                inputs = cur_inputs[-1]
            # forward to dnn experts
            for index, _ in enumerate(self.experts_network):
                net = getattr(self, 'multi_extraction_level_{}_share_{}_index_{}'.format(level, sh, index))(inputs)
                inputs = net
            experts_outputs.append(tf.keras.backend.expand_dims(net, axis=-1))
        # print('multi_level_{}_expert_list:{}'.format(level, experts_outputs))
        return experts_outputs

    # group experts list for spec task and share task
    def group_multi_extraction_expert(self, level, expert_output_list):
        multi_level_extraction_outputs = []
        share_expert_list = expert_output_list[-self.num_share_experts:]
        # 1 group for each spec task expert
        for index, tower in enumerate(self.tower_task_specs):
            task_expert = [expert_output_list[index]]
            task_expert.extend(share_expert_list)
            task_expert = tf.concat(values=task_expert, axis=-1, name='extraction_level_{}_concat_{}'.format(level, index))
            multi_level_extraction_outputs.append(task_expert)
        # for last level, only group (spec expert + share_expert) for every spec task, not for share task
        if level == self.multi_level - 1:
            return multi_level_extraction_outputs

        # 2 all expert for share
        share_experts = tf.concat(values=expert_output_list, axis=-1, name='multi_level_extraction_share_concat_{}'.format(level))
        multi_level_extraction_outputs.append(share_experts)
        return multi_level_extraction_outputs

    # expert gate kenel: input -> gate_kernel -> gate [(None,3),(None,3),(None,3),(None,4),(None,4)]
    def multi_level_extraction_gate(self, level, cur_inputs):
        gate_outputs = []
        # 1. cal for task_experts_group gate
        for k, tower in enumerate(self.tower_task_specs):
            if level == 0:
                inputs = cur_inputs
            else:
                inputs = cur_inputs[k]
            spec_gate_output = getattr(self, 'gate_level_{}_task_{}'.format(level, k))(inputs)
            gate_outputs.append(spec_gate_output)
        # 2. cal for share_experts_group gate except for last level
        if level != self.multi_level - 1:
            if level == 0:
                inputs = cur_inputs
            else:
                inputs = cur_inputs[-1]
            share_gate_output = getattr(self, 'gate_level_{}_share'.format(level))(inputs)
            gate_outputs.append(share_gate_output)
        return gate_outputs

    # gate * experts_list = weighted emb
    def multi_level_extraction_output_emb(self, level, input_extraction_gate_list, input_extraction_expert_list):
        outputs = []
        for index, (gate_output, task_expert) in enumerate(zip(input_extraction_gate_list, input_extraction_expert_list)):
            # print('gate_output_level_{}_index_{}:{}'.format(level, index, gate_output))
            expanded_gate_output = tf.keras.backend.expand_dims(gate_output, axis=1)
            # print('expanded_gate_output_level_{}_index_{}:{}'.format(level, index, expanded_gate_output))
            repeated_gate_output = tf.keras.backend.repeat_elements(expanded_gate_output, self.experts_network[-1], axis=1)
            # print('repeated_expanded_gate_output_level_{}_index_{}:{}'.format(level, index, repeated_gate_output))
            weighted_expert_output = task_expert * repeated_gate_output
            # print('weighted_repeated_expanded_gate_output_level_{}_index_{}:{}'.format(level, index, weighted_expert_output))
            outputs.append(tf.keras.backend.sum(weighted_expert_output, axis=2))
        # print('multi_level_{}_extraction_outputs_emb:{}'.format(level, outputs))
        return outputs


In [None]:
tower_task_specs = [{'task': 'A', 'spec_expert': 1},
                    {'task': 'B', 'spec_expert': 1},
                    {'task': 'C', 'spec_expert': 1}]

experts_network = [64, 64]

num_share_experts = 2

multi_level = 1

ple_layer = PLELayer(tower_task_specs=tower_task_specs, experts_network=experts_network, num_share_experts=num_share_experts, multi_level=multi_level, num_tasks=len(tower_task_specs))
ple_layer

<__main__.PLELayer at 0x7fcd2ef15190>

In [None]:
ple_layer(net)

[<tf.Tensor: shape=(3, 64), dtype=float32, numpy=
 array([[0.00633524, 0.25465465, 0.05535328, 0.02153273, 0.12176438,
         0.02236252, 0.07568981, 0.24306658, 0.1214754 , 0.25624943,
         0.05656353, 0.12924038, 0.08294437, 0.08587804, 0.20312986,
         0.07619958, 0.02773911, 0.09143917, 0.16418178, 0.18385105,
         0.1264493 , 0.14429104, 0.03289412, 0.10242222, 0.04664102,
         0.04787858, 0.        , 0.09632555, 0.10831477, 0.078369  ,
         0.04134453, 0.08651172, 0.10960472, 0.0883157 , 0.09873402,
         0.05162107, 0.02600574, 0.        , 0.0045758 , 0.16982089,
         0.09039186, 0.14455765, 0.13831814, 0.03501253, 0.0490548 ,
         0.06739961, 0.10657761, 0.14929184, 0.09703999, 0.07625906,
         0.04855544, 0.        , 0.10651408, 0.08478894, 0.08451825,
         0.20067537, 0.        , 0.05383517, 0.00217588, 0.        ,
         0.00964267, 0.        , 0.        , 0.24365865],
        [0.1083691 , 0.21220192, 0.02622489, 0.08850792, 0.07212