In [1]:
import tensorflow as tf

In [6]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, tf.Tensor): # if value ist tensor
        value = value.numpy() # get value of tensor
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def parse_single_sample(fea, matrital, income):
  data = {
        'fea' : _bytes_feature(tf.io.serialize_tensor(fea)),
        'matrital' : _int64_feature(matrital),
        'income': _int64_feature(income)
  }
  example = tf.train.Example(features=tf.train.Features(feature=data))

  return example

"""
# save as tfrecord to disk, or use `tf.data.Dataset.from_tensor_slices` in memory
with tf.io.TFRecordWriter("data/census/tfrecords/train.tfrecords") as writer:
    with open("data/census/train_data.csv", "r") as f:
        for line in f.readlines():
            fields = line.strip().split(",")
            matrital, income, fea = int(fields[0]), int(fields[1]), [float(x) for x in fields[2:]]
            example = parse_single_sample(fea, matrital, income)
            writer.write(example.SerializeToString())
"""

'\n# save as tfrecord to disk, or use `tf.data.Dataset.from_tensor_slices` in memory\nwith tf.io.TFRecordWriter("data/census/tfrecords/train.tfrecords") as writer:\n    with open("data/census/train_data.csv", "r") as f:\n        for line in f.readlines():\n            fields = line.strip().split(",")\n            matrital, income, fea = int(fields[0]), int(fields[1]), [float(x) for x in fields[2:]]\n            example = parse_single_sample(fea, matrital, income)\n            writer.write(example.SerializeToString())\n'

In [32]:

class MMOEDense(object):
    def __init__(self, nTasks, nExperts, inputDim, expertDim, hiddenDim):
        self.nTasks = nTasks
        self.nExperts = nExperts
        self.inputDim = inputDim
        self.expertDim = expertDim
        # experts (nExperts, inputDim, expertDim)
        expertInit = tf.initializers.truncated_normal(mean=0.0, stddev=1.0)
        self.experts = tf.Variable(expertInit(shape=(nExperts, inputDim, expertDim), dtype=tf.float32), name="experts")
        # gates, (nTasks, inputDim, nExpert)
        gateInit = tf.initializers.truncated_normal(mean=0.0, stddev=1.0)
        self.gates = tf.Variable(gateInit(shape=(nTasks, inputDim, nExperts), dtype=tf.float32), name="gates")
        # towers' mlp for each task, (nTasks, expertDim, hiddenDim)
        towersInit = tf.initializers.truncated_normal(mean=0.0, stddev=1.0)
        self.towers = tf.Variable(towersInit(shape=(nTasks, expertDim, hiddenDim), dtype=tf.float32), name="towers")
        # towers out
        towersOut = tf.initializers.truncated_normal(mean=0.0, stddev=1.0)
        self.outs = tf.Variable(towersOut(shape=(nTasks, hiddenDim, 1), dtype=tf.float32), name="outs")


    def forward(self, input):
        """
            @input: (batch, inputDim)
        """
        pass

bs = 4
inputDim = 2
expertDim = 2
nExperts = 3
nTasks = 2

model = MMOEDense(nTasks, nExperts, inputDim, expertDim, 8)
#print(model.experts)

input = tf.random.normal((bs, inputDim))
# (batch, 1, 1, inputDim), second dim is for expert broadcast, third dim is for matmul
input = tf.expand_dims(tf.expand_dims(input, axis=1), axis=1)

# (batch, 1, 1, inputDim) X (nExperts, inputDim, expertDim) -> (batch, nExperts, 1, expertDim) -> (batch, 1, nExperts, expertDim)
expertsOut = tf.transpose(tf.matmul(input, model.experts), [0, 2, 1, 3])

# (batch, 1, 1, inputDim) X (nTasks, inputDim, nExperts) -> (batch, nTasks, 1, nExperts)
# TODO: mask for seq padding
weights = tf.nn.softmax(tf.matmul(input, model.gates), 3)

# (batch, nTasks, 1, nExperts) X (batch, 1, nExperts, expertDim) -> (batch, nTasks, 1, expertDim)
towersIn = tf.matmul(weights, expertsOut)
# (batch, nTasks, 1, expertDim) X (nTasks, expertDim, hiddenDim) -> (batch, nTasks, 1, hiddenDim)
towersHidden = tf.nn.relu(tf.matmul(towersIn, model.towers))
# (batch, nTasks, 1, hiddenDim) X (nTasks, hiddenDim, 1) -> (batch, nTasks, 1, 1) -> (batch, nTasks, 1)
outs = tf.squeeze(tf.nn.sigmoid(tf.matmul(towersHidden, model.outs)), axis=[2, 3])




tf.Tensor(
[[0.19165388 0.49798033]
 [0.47662297 0.49534288]
 [0.2825856  0.2111551 ]
 [0.21685195 0.25583968]], shape=(4, 2), dtype=float32)


In [14]:
tf.matmul(input, model.gates)

<tf.Tensor: shape=(4, 2, 1, 3), dtype=float32, numpy=
array([[[[-0.26902193, -1.2600311 ,  1.825197  ]],

        [[ 0.44608107, -0.28010795, -2.5216727 ]]],


       [[[ 0.77840555, -0.40248802,  0.9354823 ]],

        [[ 1.9515262 ,  0.6383407 , -2.6946619 ]]],


       [[[ 0.39436162, -0.47272447,  0.88672954]],

        [[ 1.2039841 ,  0.3119706 , -2.0286016 ]]],


       [[[-0.30976683, -0.28421488,  0.31012106]],

        [[-0.42071053, -0.27292392, -0.02437009]]]], dtype=float32)>