In [9]:
import tensorflow as tf

In [6]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, tf.Tensor): # if value ist tensor
        value = value.numpy() # get value of tensor
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def parse_single_sample(fea, matrital, income):
  data = {
        'fea' : _bytes_feature(tf.io.serialize_tensor(fea)),
        'matrital' : _int64_feature(matrital),
        'income': _int64_feature(income)
  }
  example = tf.train.Example(features=tf.train.Features(feature=data))

  return example

"""
# save as tfrecord to disk, or use `tf.data.Dataset.from_tensor_slices` in memory
with tf.io.TFRecordWriter("data/census/tfrecords/train.tfrecords") as writer:
    with open("data/census/train_data.csv", "r") as f:
        for line in f.readlines():
            fields = line.strip().split(",")
            matrital, income, fea = int(fields[0]), int(fields[1]), [float(x) for x in fields[2:]]
            example = parse_single_sample(fea, matrital, income)
            writer.write(example.SerializeToString())
"""

'\n# save as tfrecord to disk, or use `tf.data.Dataset.from_tensor_slices` in memory\nwith tf.io.TFRecordWriter("data/census/tfrecords/train.tfrecords") as writer:\n    with open("data/census/train_data.csv", "r") as f:\n        for line in f.readlines():\n            fields = line.strip().split(",")\n            matrital, income, fea = int(fields[0]), int(fields[1]), [float(x) for x in fields[2:]]\n            example = parse_single_sample(fea, matrital, income)\n            writer.write(example.SerializeToString())\n'

In [12]:

class MMOEDense(object):
    def __init__(self, nTasks, nExperts, inputDim, expertDim):
        self.nTasks = nTasks
        self.nExperts = nExperts
        self.inputDim = inputDim
        self.expertDim = expertDim
        # (nExperts, inputDim, expertDim)
        expertInit = tf.initializers.truncated_normal(mean=0.0, stddev=1.0)
        self.experts = tf.Variable(expertInit(shape=(nExperts, inputDim, expertDim), dtype=tf.float32), name="experts")
        # (nTasks, inputDim, nExpert)
        gateInit = tf.initializers.truncated_normal(mean=0.0, stddev=1.0)
        self.gates = tf.Variable(gateInit(shape=(nTasks, inputDim, nExperts), dtype=tf.float32), name="gates")

    def forward(self, input):
        """
            @input: (batch, inputDim)
        """
        pass

bs = 4
inputDim = 2
expertDim = 2
nExperts = 3
nTasks = 2

model = MMOEDense(nTasks, nExperts, inputDim, expertDim)
#print(model.experts)

input = tf.random.normal((bs, inputDim))
# (batch, 1, 1, inputDim), second dim is for expert broadcast, third dim is for matmul
input = tf.expand_dims(tf.expand_dims(input, axis=1), axis=1)
#print(input)

# (batch, 1, 1, inputDim) X (nExperts, inputDim, expertDim) -> (batch, nExperts, 1, expertDim) -> (batch, nExperts, expertDim)
expertsOut = tf.squeeze(tf.matmul(input, model.experts), 2)
print(expertsOut)

# (batch, 1, 1, inputDim) X (nTasks, inputDim, nExperts) -> (batch, nTasks, 1, nExperts) -> (batch, nTasks, nExperts)
# TODO: mask for seq padding
weights = tf.nn.softmax(tf.squeeze(tf.matmul(input, model.gates), 2), 2)
print(weights)

tf.Tensor(
[[[-3.3729243  -2.3023763 ]
  [ 2.4164271   1.9118183 ]
  [-0.77461445 -1.6674687 ]]

 [[ 0.87476665 -1.2033155 ]
  [ 0.803498   -0.15558821]
  [ 1.332193    0.7990842 ]]

 [[ 0.07360959 -2.5701346 ]
  [ 2.0287952   0.45346907]
  [ 1.6634114   0.56998444]]

 [[ 0.26174334 -0.46593297]
  [ 0.32452798 -0.02654492]
  [ 0.46514326  0.2606592 ]]], shape=(4, 3, 2), dtype=float32)
tf.Tensor(
[[[0.0249604  0.21176472 0.7632749 ]
  [0.05003832 0.6352448  0.31471694]]

 [[0.5585442  0.32667807 0.11477774]
  [0.27513042 0.33741978 0.38744983]]

 [[0.43152305 0.42289156 0.14558537]
  [0.13354826 0.4437411  0.4227106 ]]

 [[0.40807486 0.34794104 0.24398407]
  [0.30479762 0.3408607  0.3543417 ]]], shape=(4, 2, 3), dtype=float32)
