In [None]:
# default_exp mtl_model.mmoe
%load_ext autoreload
%autoreload 2
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

# MMoE


In [None]:
# export
from typing import Dict, Tuple

import tensorflow as tf
from m3tl.base_params import BaseParams
from m3tl.mtl_model.base import MTLBase
from m3tl.utils import get_phase


class MMoE(MTLBase):
    def __init__(self, params: BaseParams, name:str):
        super(MMoE, self).__init__(params, name)
        self.num_experts = self.params.get('num_experts', 8)
        self.num_experts_units = self.params.get('num_experts_units', 128)
        self.problem_list = self.params.problem_list
        self.gate_dict = {
            problem: tf.keras.layers.Dense(self.num_experts, activation='softmax') for problem in self.problem_list
        }

    def build(self, input_shape):
        features_input_shape, hidden_feature_input_shape = input_shape
        pooled_shape = hidden_feature_input_shape['all']['pooled']
        self.experts_kernel = self.add_weight(
            name='experts_kernel',
            shape=(pooled_shape[-1], self.num_experts_units, self.num_experts)
        )
        # add leading dims to support braodcasting
        self.experts_bias = self.add_weight(
            name='experts_bias',
            shape=(1, 1, self.num_experts_units, self.num_experts)
        )

    def call(self, inputs: Tuple[Dict[str, tf.Tensor]]):
        mode = get_phase()
        features, hidden_features = inputs
        all_features, all_hidden_features = self.extract_feature('all', feature_dict=features, hidden_feature_dict=hidden_features)

        # get seq outputs
        # [batch_size, seq_len, hidden_size]
        seq_hidden = all_hidden_features['seq']
        # [batch_size, seq_len, num_expert_units, num_experts]
        experts_outputs = tf.tensordot(seq_hidden, self.experts_kernel, axes=[2, 0]) + self.experts_bias

        experts_output_dict = {
            'pooled': experts_outputs[:, 0, :, :],
            'seq': experts_outputs
        }

        # per problem gating
        # we can save a little bit of computation by extract per problem features first
        features_per_problem, hidden_features_per_problem = {}, {}
        for problem, gate_net in self.gate_dict.items():
            features_per_problem[problem], problem_experts_output = self.extract_feature(
                extract_problem=problem, feature_dict=all_features, hidden_feature_dict=experts_output_dict
            )
            _, problem_hidden_features = self.extract_feature(
                extract_problem=problem, feature_dict=all_features, hidden_feature_dict=all_hidden_features
            )

            # apply gating
            # [problem_batch_size, seq_len, 1, num_experts]
            experts_weight = gate_net(problem_hidden_features['seq'])
            experts_weight = tf.expand_dims(experts_weight, axis=2)
            # [problem_batch_size, seq_len, num_expert_units, num_experts]
            expert_output_per_problem = problem_experts_output['seq']

            # [problem_batch_size, seq_len, num_expert_units]
            gated_experts_output = tf.reduce_mean(experts_weight*expert_output_per_problem, axis=-1)
            hidden_features_per_problem[problem] = {
                'pooled': gated_experts_output[:, 0, :],
                'seq': gated_experts_output
            }

        return features_per_problem, hidden_features_per_problem
        

In [None]:
# hide
from m3tl.test_base import TestBase
from m3tl.special_tokens import TRAIN, EVAL, PREDICT

tb = TestBase()
mmoe = MMoE(params=tb.params, name='test_mmoe')


2021-06-06 00:08:22.007 | INFO     | m3tl.base_params:register_multiple_problems:476 - Adding new problem weibo_fake_ner, problem type: seq_tag
2021-06-06 00:08:22.008 | INFO     | m3tl.base_params:register_multiple_problems:476 - Adding new problem weibo_fake_multi_cls, problem type: multi_cls
2021-06-06 00:08:22.009 | INFO     | m3tl.base_params:register_multiple_problems:476 - Adding new problem weibo_fake_cls, problem type: cls
2021-06-06 00:08:22.009 | INFO     | m3tl.base_params:register_multiple_problems:476 - Adding new problem weibo_masklm, problem type: masklm
2021-06-06 00:08:22.009 | INFO     | m3tl.base_params:register_multiple_problems:476 - Adding new problem weibo_fake_regression, problem type: regression
2021-06-06 00:08:22.010 | INFO     | m3tl.base_params:register_multiple_problems:476 - Adding new problem weibo_fake_vector_fit, problem type: vector_fit
2021-06-06 00:08:22.011 | INFO     | m3tl.base_params:register_multiple_problems:476 - Adding new problem weibo_pre

In [None]:
# hide
# features, hidden_features = tb.get_one_batch_body_model_output()
# mmoe = MMoE(params=tb.params, name='test_mmoe')
# features_per_problem, hidden_features_per_problem = mmoe((features, hidden_features), TRAIN)

In [None]:
# hide
tb.test_mtl_model(mtl_model=mmoe, include_top=False, mode=TRAIN)
tb.test_mtl_model(mtl_model=mmoe, include_top=False, mode=EVAL)
tb.test_mtl_model(mtl_model=mmoe, include_top=False, mode=PREDICT)

2021-06-06 00:08:28.222 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmpqvmqxfcx/weibo_fake_cls_weibo_fake_ner_weibo_fake_regression_weibo_fake_vector_fit/train_00000.tfrecord
2021-06-06 00:08:28.384 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmpqvmqxfcx/weibo_fake_cls_weibo_fake_ner_weibo_fake_regression_weibo_fake_vector_fit/eval_00000.tfrecord
2021-06-06 00:08:28.439 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmpqvmqxfcx/weibo_fake_multi_cls/train_00000.tfrecord
2021-06-06 00:08:28.486 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmpqvmqxfcx/weibo_fake_multi_cls/eval_00000.tfrecord
2021-06-06 00:08:28.578 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmpqvmqxfcx/weibo_masklm/train_00000.tfrecord
2021-06-06 00:08:28.627 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmpqvmqxfcx/weibo_masklm/eval_00000.tfrecord
2021-06-06 00:08:28.692 | DEBUG    | m3tl.r

938, 0.19856867, 0.20290618, 0.1969807 ],
         [0.19851914, 0.20053253, 0.20000899, 0.20155519, 0.19938414],
         [0.19841829, 0.20379424, 0.19912775, 0.20094569, 0.19771394],
         [0.19755088, 0.20352331, 0.19960462, 0.20161344, 0.19770773],
         [0.19965973, 0.20355953, 0.19870031, 0.20076627, 0.19731414],
         [0.19917299, 0.20364495, 0.19787578, 0.20028768, 0.19901864],
         [0.19916093, 0.2036096 , 0.19698974, 0.20146157, 0.19877811],
         [0.20121635, 0.2021869 , 0.1968712 , 0.20115578, 0.19856976],
         [0.19932503, 0.20335594, 0.19703469, 0.20196989, 0.19831455],
         [0.20464312, 0.19796327, 0.1953807 , 0.20584385, 0.19616906],
         [0.20757926, 0.1974735 , 0.19958356, 0.2029948 , 0.19236887],
         [0.20556863, 0.19863138, 0.19653077, 0.20520416, 0.19406508],
         [0.20250417, 0.2012576 , 0.19577971, 0.20254788, 0.19791062],
         [0.20522092, 0.19764833, 0.19517438, 0.20537657, 0.19657977]],
 
        [[0.19836774, 0.20244278