In [1]:
import tensorflow as tf

2023-04-08 10:30:24.845137: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [47]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, tf.Tensor): # if value ist tensor
        value = value.numpy() # get value of tensor
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def buildSample(fea, matrital, income):
    data = {
        'fea': _bytes_feature(tf.io.serialize_tensor(fea)),
        'marital': _int64_feature(matrital),
        'income': _int64_feature(income)
    }
    example = tf.train.Example(features=tf.train.Features(feature=data))

    return example

def parseLine(line):
    fields = line.strip().split(",")
    return int(fields[0]), int(fields[1]), [float(x) for x in fields[2:]]

# save as tfrecord to disk, or use `tf.data.Dataset.from_tensor_slices` in memory
train_data_path = "data/census/tfrecords/train.tfrecords"
test_data_path = "data/census/tfrecords/test.tfrecords"
#with tf.io.TFRecordWriter(train_data_path) as writer:
#    with open("data/census/train_data.csv", "r") as f:
#        for line in f.readlines():
#            marital, income, fea = parseLine(line)
#            example = buildSample(fea, marital, income)
#            writer.write(example.SerializeToString())
#
#with tf.io.TFRecordWriter(test_data_path) as writer:
#    with open("data/census/test_data.csv", "r") as f:
#        for line in f.readlines():
#            marital, income, fea = parseLine(line)
#            example = buildSample(fea, marital, income)
#            writer.write(example.SerializeToString())

In [44]:
class MMOEDense(object):
    def __init__(self, nTasks, nExperts, inputDim, expertDim, hiddenDim, lr=0.01):
        self.nTasks = nTasks
        self.nExperts = nExperts
        self.inputDim = inputDim
        self.expertDim = expertDim
        # experts ops are independent, so don't use loop
        # experts (nExperts, inputDim, expertDim)
        expertInit = tf.initializers.truncated_normal(mean=0.0, stddev=1.0)
        self.experts = tf.Variable(expertInit(shape=(nExperts, inputDim, expertDim), dtype=tf.float32), name="experts")
        # gates, (nTasks, inputDim, nExpert)
        gateInit = tf.initializers.truncated_normal(mean=0.0, stddev=1.0)
        self.gates = tf.Variable(gateInit(shape=(nTasks, inputDim, nExperts), dtype=tf.float32), name="gates")
        # towers' mlp for each task, (nTasks, expertDim, hiddenDim)
        towersInit = tf.initializers.truncated_normal(mean=0.0, stddev=1.0)
        self.towers = tf.Variable(towersInit(shape=(nTasks, expertDim, hiddenDim), dtype=tf.float32), name="towers")
        # towers out
        towersOut = tf.initializers.truncated_normal(mean=0.0, stddev=1.0)
        self.outs = tf.Variable(towersOut(shape=(nTasks, hiddenDim, 1), dtype=tf.float32), name="outs")
        # target loss weights, (1, nTasks)
        self.tasksWeights = tf.constant([[1.0, 1.0]])
        assert self.tasksWeights.shape[1] == nTasks
        self.trainableWeights = [
            self.experts,
            self.gates,
            self.towers,
            self.outs
        ]
        self.opt = tf.optimizers.Adam(learning_rate=lr)

    def __call__(self, input, labels=None):
        """
            @input: (batch, inputDim)
            @labels: (batch, nTasks)
        """
        # (batch, 1, 1, inputDim), second dim is for expert broadcast, third dim is for matmul
        input = tf.expand_dims(tf.expand_dims(input, axis=1), axis=1)

        # (batch, 1, 1, inputDim) X (nExperts, inputDim, expertDim) -> (batch, nExperts, 1, expertDim) -> (batch, 1, nExperts, expertDim)
        expertsOut = tf.transpose(tf.matmul(input, model.experts), [0, 2, 1, 3])

        # (batch, 1, 1, inputDim) X (nTasks, inputDim, nExperts) -> (batch, nTasks, 1, nExperts)
        # TODO: mask for seq padding
        weights = tf.nn.softmax(tf.matmul(input, model.gates), 3)
        # (batch, nTasks, 1, nExperts) X (batch, 1, nExperts, expertDim) -> (batch, nTasks, 1, expertDim)
        towersIn = tf.matmul(weights, expertsOut)
        # (batch, nTasks, 1, expertDim) X (nTasks, expertDim, hiddenDim) -> (batch, nTasks, 1, hiddenDim)
        towersHidden = tf.nn.relu(tf.matmul(towersIn, model.towers))
        # (batch, nTasks, 1, hiddenDim) X (nTasks, hiddenDim, 1) -> (batch, nTasks, 1, 1) -> (batch, nTasks)
        outs = tf.squeeze(tf.matmul(towersHidden, model.outs), axis=[2, 3])

        if labels is not None:
            # train
            # (batch, nTasks)
            losses = tf.nn.sigmoid_cross_entropy_with_logits(labels, outs)
            # loss fusion, (1, nTasks) X (batch, nTasks, 1) -> (batch, 1, 1) -> (batch, )
            losses = tf.squeeze(tf.matmul(model.tasksWeights, tf.expand_dims(losses, axis=2)), axis=[1, 2])
            losses = tf.reduce_mean(losses)
            return losses
        else:
            # infer
            return tf.nn.sigmoid(outs)

inputDim = 499
expertDim = 256
nExperts = 3
nTasks = 2
hiddenDim = 128

model = MMOEDense(nTasks, nExperts, inputDim, expertDim, hiddenDim)

In [46]:
def parseSample(sample):
    #use the same structure as above; it's kinda an outline of the structure we now want to create
    data = {
        'fea':tf.io.FixedLenFeature([], tf.string),
        'marital' : tf.io.FixedLenFeature([], tf.int64),
        'income': tf.io.FixedLenFeature([], tf.int64),
    }

    sample = tf.io.parse_single_example(sample, data)

    fea = tf.io.parse_tensor(sample["fea"], out_type=tf.float32)
    marital = sample['marital']
    income = sample['income']

    return fea, marital, income

train_batch_size = 256

trainData = tf.data.TFRecordDataset(train_data_path).map(parseSample).batch(train_batch_size)

EPOCHS = 30

inputBN = tf.keras.layers.BatchNormalization(axis=1)
for epoch in range(EPOCHS):
    epochTotalLoss = 0
    for step, (fea, marital, income) in enumerate(trainData):
        labels = tf.cast(tf.transpose(tf.stack([marital, income]), [1, 0]), tf.float32)
        # batch norm for census features:
        fea = inputBN(fea, training=True)

        with tf.GradientTape() as tape:
            losses = model(fea, labels)
            grads = tape.gradient(losses, model.trainableWeights)
            model.opt.apply_gradients(zip(grads, model.trainableWeights))
        
        epochTotalLoss += losses
        epochAvgLoss = epochTotalLoss / (step + 1)

        if ((step + 1) % 10 == 0):
            print("| epoch: {:03d} | step: {:06d} | epoch avg loss: {:.4f}".format(epoch, step + 1, epochAvgLoss))

| epoch: 000 | step: 000010 | epoch avg loss: 55.0350
| epoch: 000 | step: 000020 | epoch avg loss: 55.3614
| epoch: 000 | step: 000030 | epoch avg loss: 51.9114
| epoch: 000 | step: 000040 | epoch avg loss: 48.8507
| epoch: 000 | step: 000050 | epoch avg loss: 47.8450
| epoch: 000 | step: 000060 | epoch avg loss: 46.1185
| epoch: 000 | step: 000070 | epoch avg loss: 43.9611
| epoch: 000 | step: 000080 | epoch avg loss: 42.0298
| epoch: 000 | step: 000090 | epoch avg loss: 40.5806
| epoch: 000 | step: 000100 | epoch avg loss: 39.0644
| epoch: 000 | step: 000110 | epoch avg loss: 38.7356
| epoch: 000 | step: 000120 | epoch avg loss: 39.5902
| epoch: 000 | step: 000130 | epoch avg loss: 40.6469
| epoch: 000 | step: 000140 | epoch avg loss: 41.6594
| epoch: 000 | step: 000150 | epoch avg loss: 42.0776
| epoch: 000 | step: 000160 | epoch avg loss: 42.4517
| epoch: 000 | step: 000170 | epoch avg loss: 43.0969
| epoch: 000 | step: 000180 | epoch avg loss: 43.4966
| epoch: 000 | step: 000190 

In [48]:
from sklearn.metrics import roc_auc_score

test_batch_size = 64

testData = tf.data.TFRecordDataset(test_data_path).map(parseSample).batch(test_batch_size)

maritalLogits, maritalLabels = [], []
incomeLogits, incomeLabels = [], []
for step, (fea, marital, income) in enumerate(testData):
    labels = tf.cast(tf.transpose(tf.stack([marital, income]), [1, 0]), tf.float32)
    # batch norm for census features:
    fea = inputBN(fea, training=False)
    logits = model(fea) # (batch, nTasks)
    for logit, label in zip(logits, labels):
        maritalLogits.append(logit[0])
        maritalLabels.append(label[0])
        incomeLogits.append(logit[1])
        incomeLabels.append(label[1])

maritalAuc = roc_auc_score(maritalLabels, maritalLogits)
incomeAuc = roc_auc_score(incomeLabels, incomeLogits)

print(maritalAuc, incomeAuc)

0.9908872195532458 0.9595909108248275
