In [1]:
import tensorflow as tf

2023-04-06 22:57:33.391845: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [12]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, tf.Tensor): # if value ist tensor
        value = value.numpy() # get value of tensor
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def buildSample(fea, matrital, income):
    data = {
        'fea': _bytes_feature(tf.io.serialize_tensor(fea)),
        'marital': _int64_feature(matrital),
        'income': _int64_feature(income)
    }
    example = tf.train.Example(features=tf.train.Features(feature=data))

    return example

def parseLine(line):
    fields = line.strip().split(",")
    return int(fields[0]), int(fields[1]), [float(x) for x in fields[2:]]

# save as tfrecord to disk, or use `tf.data.Dataset.from_tensor_slices` in memory
train_data_path = "data/census/tfrecords/train.tfrecords"
test_data_path = "data/census/tfrecords/test.tfrecords"
with tf.io.TFRecordWriter(train_data_path) as writer:
    with open("data/census/train_data.csv", "r") as f:
        for line in f.readlines():
            marital, income, fea = parseLine(line)
            example = buildSample(fea, marital, income)
            writer.write(example.SerializeToString())

with tf.io.TFRecordWriter(train_data_path) as writer:
    with open("data/census/test_data.csv", "r") as f:
        for line in f.readlines():
            marital, income, fea = parseLine(line)
            example = buildSample(fea, marital, income)
            writer.write(example.SerializeToString())

In [13]:
class MMOEDense(object):
    def __init__(self, nTasks, nExperts, inputDim, expertDim, hiddenDim):
        self.nTasks = nTasks
        self.nExperts = nExperts
        self.inputDim = inputDim
        self.expertDim = expertDim
        # experts (nExperts, inputDim, expertDim)
        expertInit = tf.initializers.truncated_normal(mean=0.0, stddev=1.0)
        self.experts = tf.Variable(expertInit(shape=(nExperts, inputDim, expertDim), dtype=tf.float32), name="experts")
        # gates, (nTasks, inputDim, nExpert)
        gateInit = tf.initializers.truncated_normal(mean=0.0, stddev=1.0)
        self.gates = tf.Variable(gateInit(shape=(nTasks, inputDim, nExperts), dtype=tf.float32), name="gates")
        # towers' mlp for each task, (nTasks, expertDim, hiddenDim)
        towersInit = tf.initializers.truncated_normal(mean=0.0, stddev=1.0)
        self.towers = tf.Variable(towersInit(shape=(nTasks, expertDim, hiddenDim), dtype=tf.float32), name="towers")
        # towers out
        towersOut = tf.initializers.truncated_normal(mean=0.0, stddev=1.0)
        self.outs = tf.Variable(towersOut(shape=(nTasks, hiddenDim, 1), dtype=tf.float32), name="outs")
        # target loss weights, (1, nTasks)
        self.tasksWeights = tf.constant([[1.0, 1.0]])
        assert self.tasksWeights.shape[1] == nTasks

    def __call__(self, input, labels=None):
        """
            @input: (batch, inputDim)
            @labels: (batch, nTasks)
        """
        input = tf.random.normal((bs, inputDim))
        # (batch, 1, 1, inputDim), second dim is for expert broadcast, third dim is for matmul
        input = tf.expand_dims(tf.expand_dims(input, axis=1), axis=1)

        # (batch, 1, 1, inputDim) X (nExperts, inputDim, expertDim) -> (batch, nExperts, 1, expertDim) -> (batch, 1, nExperts, expertDim)
        expertsOut = tf.transpose(tf.matmul(input, model.experts), [0, 2, 1, 3])

        # (batch, 1, 1, inputDim) X (nTasks, inputDim, nExperts) -> (batch, nTasks, 1, nExperts)
        # TODO: mask for seq padding
        weights = tf.nn.softmax(tf.matmul(input, model.gates), 3)
        # (batch, nTasks, 1, nExperts) X (batch, 1, nExperts, expertDim) -> (batch, nTasks, 1, expertDim)
        towersIn = tf.matmul(weights, expertsOut)
        # (batch, nTasks, 1, expertDim) X (nTasks, expertDim, hiddenDim) -> (batch, nTasks, 1, hiddenDim)
        towersHidden = tf.nn.relu(tf.matmul(towersIn, model.towers))
        # (batch, nTasks, 1, hiddenDim) X (nTasks, hiddenDim, 1) -> (batch, nTasks, 1, 1) -> (batch, nTasks)
        outs = tf.squeeze(tf.matmul(towersHidden, model.outs), axis=[2, 3])

        if labels is not None:
            # train
            # (batch, nTasks)
            losses = tf.nn.sigmoid_cross_entropy_with_logits(labels, outs)
            # loss fusion, (1, nTasks) X (batch, nTasks, 1) -> (batch, 1, 1) -> (batch, )
            losses = tf.squeeze(tf.matmul(model.tasksWeights, tf.expand_dims(losses, axis=2)), axis=[1, 2])
            losses = tf.reduce_mean(losses)
            return losses
        else:
            # infer
            return tf.nn.sigmoid(outs)

bs = 4
inputDim = 2
expertDim = 2
nExperts = 3
nTasks = 2
hiddenDim = 8

model = MMOEDense(nTasks, nExperts, inputDim, expertDim, hiddenDim)

In [20]:
def parseSample(sample):
    #use the same structure as above; it's kinda an outline of the structure we now want to create
    data = {
      'fea':tf.io.FixedLenFeature([], tf.string),
      'marital' : tf.io.FixedLenFeature([], tf.int64),
      'income': tf.io.FixedLenFeature([], tf.int64),
    }

    sample = tf.io.parse_single_example(sample, data)

    fea = tf.io.parse_tensor(sample["fea"], out_type=tf.float32)
    marital = sample['marital']
    income = sample['income']

    return fea, marital, income

train_batch_size = 4

trainData = tf.data.TFRecordDataset(train_data_path).map(parseSample).batch(train_batch_size)

for step, (fea, marital, income) in enumerate(trainData):
    labels = tf.cast(tf.transpose(tf.stack([marital, income]), [1, 0]), tf.float32)
    losses = model(fea, labels)


tf.Tensor(1.0289341, shape=(), dtype=float32)


<tf.Tensor: shape=(4, 2), dtype=int64, numpy=
array([[0, 0],
       [1, 0],
       [1, 0],
       [0, 0]])>