<a href="https://colab.research.google.com/github/asangphukieo/delfi/blob/main/MTD_MMoE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import random

import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import VarianceScaling
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import Callback
from sklearn.metrics import roc_auc_score

import sys,os
os.system("wget https://raw.githubusercontent.com/drawbridge/keras-mmoe/master/mmoe.py")
os.system("wget https://raw.githubusercontent.com/drawbridge/keras-mmoe/master/data/census-income.data.gz")
os.system("wget https://raw.githubusercontent.com/drawbridge/keras-mmoe/master/data/census-income.test.gz")
os.system("wget https://raw.githubusercontent.com/asangphukieo/delfi/main/feature_ratioxdeepfrag.csv")


from mmoe import MMoE

SEED = 1

In [None]:

# Fix numpy seed for reproducibility
np.random.seed(SEED)

# Fix random seed for reproducibility
random.seed(SEED)

# Fix TensorFlow graph-level seed for reproducibility
tf.random.set_seed(SEED)

# Simple callback to print out ROC-AUC
class ROCCallback(Callback):
    def __init__(self, training_data, validation_data, test_data):
        self.train_X = training_data[0]
        self.train_Y = training_data[1]
        self.validation_X = validation_data[0]
        self.validation_Y = validation_data[1]
        self.test_X = test_data[0]
        self.test_Y = test_data[1]

    def on_train_begin(self, logs={}):
        return

    def on_train_end(self, logs={}):
        return

    def on_epoch_begin(self, epoch, logs={}):
        return

    def on_epoch_end(self, epoch, logs={}):
        train_prediction = self.model.predict(self.train_X)
        validation_prediction = self.model.predict(self.validation_X)
        test_prediction = self.model.predict(self.test_X)

        # Iterate through each task and output their ROC-AUC across different datasets
        for index, output_name in enumerate(self.model.output_names):
            train_roc_auc = roc_auc_score(self.train_Y[index], train_prediction[index])
            validation_roc_auc = roc_auc_score(self.validation_Y[index], validation_prediction[index])
            test_roc_auc = roc_auc_score(self.test_Y[index], test_prediction[index])
            print(
                'ROC-AUC-{}-Train: {} ROC-AUC-{}-Validation: {} ROC-AUC-{}-Test: {}'.format(
                    output_name, round(train_roc_auc, 4),
                    output_name, round(validation_roc_auc, 4),
                    output_name, round(test_roc_auc, 4)
                )
            )

        return

    def on_batch_begin(self, batch, logs={}):
        return

    def on_batch_end(self, batch, logs={}):
        return



In [None]:
    # The column names are from
    # https://www2.1010data.com/documentationcenter/prod/Tutorials/MachineLearningExamples/CensusIncomeDataSet.html
    column_names = ['age', 'class_worker', 'det_ind_code', 'det_occ_code', 'education', 'wage_per_hour', 'hs_college',
                    'marital_stat', 'major_ind_code', 'major_occ_code', 'race', 'hisp_origin', 'sex', 'union_member',
                    'unemp_reason', 'full_or_part_emp', 'capital_gains', 'capital_losses', 'stock_dividends',
                    'tax_filer_stat', 'region_prev_res', 'state_prev_res', 'det_hh_fam_stat', 'det_hh_summ',
                    'instance_weight', 'mig_chg_msa', 'mig_chg_reg', 'mig_move_reg', 'mig_same', 'mig_prev_sunbelt',
                    'num_emp', 'fam_under_18', 'country_father', 'country_mother', 'country_self', 'citizenship',
                    'own_or_self', 'vet_question', 'vet_benefits', 'weeks_worked', 'year', 'income_50k']

    # Load the dataset in Pandas
    train_df = pd.read_csv(
        'census-income.data.gz',
        delimiter=',',
        header=None,
        index_col=None,
        names=column_names
    )
    other_df = pd.read_csv(
        'census-income.test.gz',
        delimiter=',',
        header=None,
        index_col=None,
        names=column_names
    )

In [None]:
train_df

Unnamed: 0,age,class_worker,det_ind_code,det_occ_code,education,wage_per_hour,hs_college,marital_stat,major_ind_code,major_occ_code,...,country_father,country_mother,country_self,citizenship,own_or_self,vet_question,vet_benefits,weeks_worked,year,income_50k
0,73,Not in universe,0,0,High school graduate,0,Not in universe,Widowed,Not in universe or children,Not in universe,...,United-States,United-States,United-States,Native- Born in the United States,0,Not in universe,2,0,95,- 50000.
1,58,Self-employed-not incorporated,4,34,Some college but no degree,0,Not in universe,Divorced,Construction,Precision production craft & repair,...,United-States,United-States,United-States,Native- Born in the United States,0,Not in universe,2,52,94,- 50000.
2,18,Not in universe,0,0,10th grade,0,High school,Never married,Not in universe or children,Not in universe,...,Vietnam,Vietnam,Vietnam,Foreign born- Not a citizen of U S,0,Not in universe,2,0,95,- 50000.
3,9,Not in universe,0,0,Children,0,Not in universe,Never married,Not in universe or children,Not in universe,...,United-States,United-States,United-States,Native- Born in the United States,0,Not in universe,0,0,94,- 50000.
4,10,Not in universe,0,0,Children,0,Not in universe,Never married,Not in universe or children,Not in universe,...,United-States,United-States,United-States,Native- Born in the United States,0,Not in universe,0,0,94,- 50000.
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
199518,87,Not in universe,0,0,7th and 8th grade,0,Not in universe,Married-civilian spouse present,Not in universe or children,Not in universe,...,Canada,United-States,United-States,Native- Born in the United States,0,Not in universe,2,0,95,- 50000.
199519,65,Self-employed-incorporated,37,2,11th grade,0,Not in universe,Married-civilian spouse present,Business and repair services,Executive admin and managerial,...,United-States,United-States,United-States,Native- Born in the United States,0,Not in universe,2,52,94,- 50000.
199520,47,Not in universe,0,0,Some college but no degree,0,Not in universe,Married-civilian spouse present,Not in universe or children,Not in universe,...,Poland,Poland,Germany,Foreign born- U S citizen by naturalization,0,Not in universe,2,52,95,- 50000.
199521,16,Not in universe,0,0,10th grade,0,High school,Never married,Not in universe or children,Not in universe,...,United-States,United-States,United-States,Native- Born in the United States,0,Not in universe,2,0,95,- 50000.


In [None]:
# First group of tasks according to the paper
label_columns = ['income_50k', 'marital_stat','country_self']
label_columns

['income_50k', 'marital_stat', 'country_self']

In [None]:
categorical_columns = ['class_worker', 'det_ind_code', 'det_occ_code', 'education', 'hs_college', 'major_ind_code',
                        'major_occ_code', 'race', 'hisp_origin', 'sex', 'union_member', 'unemp_reason',
                        'full_or_part_emp', 'tax_filer_stat', 'region_prev_res', 'state_prev_res', 'det_hh_fam_stat',
                        'det_hh_summ', 'mig_chg_msa', 'mig_chg_reg', 'mig_move_reg', 'mig_same', 'mig_prev_sunbelt',
                        'fam_under_18', 'country_father', 'country_mother', 'citizenship',
                        'vet_question']
train_raw_labels = train_df[label_columns]
other_raw_labels = other_df[label_columns]
train_raw_labels

Unnamed: 0,income_50k,marital_stat,country_self
0,- 50000.,Widowed,United-States
1,- 50000.,Divorced,United-States
2,- 50000.,Never married,Vietnam
3,- 50000.,Never married,United-States
4,- 50000.,Never married,United-States
...,...,...,...
199518,- 50000.,Married-civilian spouse present,United-States
199519,- 50000.,Married-civilian spouse present,United-States
199520,- 50000.,Married-civilian spouse present,Germany
199521,- 50000.,Never married,United-States


In [None]:
transformed_train = pd.get_dummies(train_df.drop(label_columns, axis=1), columns=categorical_columns)
transformed_other = pd.get_dummies(other_df.drop(label_columns, axis=1), columns=categorical_columns)


In [None]:
transformed_other

Unnamed: 0,age,wage_per_hour,capital_gains,capital_losses,stock_dividends,instance_weight,num_emp,own_or_self,vet_benefits,weeks_worked,...,country_mother_ Vietnam,country_mother_ Yugoslavia,citizenship_ Foreign born- Not a citizen of U S,citizenship_ Foreign born- U S citizen by naturalization,citizenship_ Native- Born abroad of American Parent(s),citizenship_ Native- Born in Puerto Rico or U S Outlying,citizenship_ Native- Born in the United States,vet_question_ No,vet_question_ Not in universe,vet_question_ Yes
0,38,0,0,0,0,1032.38,4,0,2,12,...,0,0,1,0,0,0,0,0,1,0
1,44,0,0,0,2500,1462.33,1,0,2,26,...,0,0,0,0,0,0,1,0,1,0
2,2,0,0,0,0,1601.75,0,0,0,0,...,0,0,0,0,0,0,1,0,1,0
3,35,0,0,0,0,1866.88,5,2,2,52,...,0,0,0,0,0,0,1,0,1,0
4,49,0,0,0,0,1394.54,4,0,2,50,...,0,0,0,0,0,0,1,0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
99757,14,0,0,0,0,1708.85,0,0,0,0,...,0,0,0,0,0,0,1,0,1,0
99758,61,0,0,0,0,2511.11,4,0,2,52,...,0,0,0,0,0,0,1,0,1,0
99759,24,0,0,0,0,2083.76,2,0,2,52,...,0,0,0,1,0,0,0,0,1,0
99760,30,0,0,0,0,1680.06,5,0,2,52,...,0,0,0,0,0,0,1,0,1,0


In [None]:
# Filling the missing column in the other set
transformed_other['det_hh_fam_stat_ Grandchild <18 ever marr not in subfamily'] = 0


In [None]:
train_income = to_categorical((train_raw_labels.income_50k == ' 50000+.').astype(int), num_classes=2)
train_country_self = to_categorical((train_raw_labels.country_self == ' United-States').astype(int), num_classes=2)

other_income = to_categorical((other_raw_labels.income_50k == ' 50000+.').astype(int), num_classes=2)
other_country_self = to_categorical((other_raw_labels.country_self == 'United-States').astype(int), num_classes=2)


In [None]:
set(train_raw_labels.country_self)

{' ?',
 ' Cambodia',
 ' Canada',
 ' China',
 ' Columbia',
 ' Cuba',
 ' Dominican-Republic',
 ' Ecuador',
 ' El-Salvador',
 ' England',
 ' France',
 ' Germany',
 ' Greece',
 ' Guatemala',
 ' Haiti',
 ' Holand-Netherlands',
 ' Honduras',
 ' Hong Kong',
 ' Hungary',
 ' India',
 ' Iran',
 ' Ireland',
 ' Italy',
 ' Jamaica',
 ' Japan',
 ' Laos',
 ' Mexico',
 ' Nicaragua',
 ' Outlying-U S (Guam USVI etc)',
 ' Panama',
 ' Peru',
 ' Philippines',
 ' Poland',
 ' Portugal',
 ' Puerto-Rico',
 ' Scotland',
 ' South Korea',
 ' Taiwan',
 ' Thailand',
 ' Trinadad&Tobago',
 ' United-States',
 ' Vietnam',
 ' Yugoslavia'}

In [None]:
uniq_marital = sorted(set(train_raw_labels.marital_stat))
index =dict(enumerate(uniq_marital))
index[0]

' Divorced'

In [None]:
def text2index(text,dict_target):
  for i in dict_target:
    if dict_target[i] == text:
      #print(dict_target[i])
      return i
text2index(" Never married",index)

4

In [None]:
def index2text(number,dict_target):
    return dict_target[number] 

index2text(0,index)


' Divorced'

In [None]:
# One-hot encoding categorical labels
train_marital = to_categorical(train_raw_labels.marital_stat.apply(text2index, dict_target=index),num_classes=7)
other_marital = to_categorical(other_raw_labels.marital_stat.apply(text2index, dict_target=index),num_classes=7)

In [None]:
train_country_self

array([[0., 1.],
       [0., 1.],
       [1., 0.],
       ...,
       [1., 0.],
       [0., 1.],
       [1., 0.]], dtype=float32)

In [None]:
other_marital

array([[0., 0., 1., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       [0., 0., 0., ..., 1., 0., 0.],
       ...,
       [0., 0., 1., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.]], dtype=float32)

In [None]:
dict_outputs = {
    'income': train_income.shape[1],
    'marital': train_marital.shape[1],
    'country_self': train_country_self.shape[1]

}
dict_train_labels = {
    'income': train_income,
    'marital': train_marital,
    'country_self': train_country_self
}
dict_other_labels = {
    'income': other_income,
    'marital': other_marital,
    'country_self': other_country_self
}
output_info = [(dict_outputs[key], key) for key in sorted(dict_outputs.keys())]

In [None]:
output_info

[(2, 'country_self'), (2, 'income'), (7, 'marital')]

In [None]:
 transformed_other.iloc[3]

age                                                           35.0
wage_per_hour                                                  0.0
capital_gains                                                  0.0
capital_losses                                                 0.0
stock_dividends                                                0.0
                                                              ... 
citizenship_ Native- Born in the United States                 1.0
vet_question_ No                                               0.0
vet_question_ Not in universe                                  1.0
vet_question_ Yes                                              0.0
det_hh_fam_stat_ Grandchild <18 ever marr not in subfamily     0.0
Name: 3, Length: 456, dtype: float64

In [None]:
validation_indices

Int64Index([68068, 63683, 67714, 26690, 29516, 66480, 81648, 62932, 59304,
            26071,
            ...
            83661,  8309, 11829, 42298, 98726, 49251, 15354, 61545, 84639,
            92186],
           dtype='int64', length=49881)

In [None]:

# Split the other dataset into 1:1 validation to test according to the paper
validation_indices = transformed_other.sample(frac=0.5, replace=False, random_state=SEED).index
test_indices = list(set(transformed_other.index) - set(validation_indices))
validation_data = transformed_other.iloc[validation_indices]
validation_label = [dict_other_labels[key][validation_indices] for key in sorted(dict_other_labels.keys())]
test_data = transformed_other.iloc[test_indices]
test_label = [dict_other_labels[key][test_indices] for key in sorted(dict_other_labels.keys())]
train_data = transformed_train
train_label = [dict_train_labels[key] for key in sorted(dict_train_labels.keys())]


In [None]:
validation_indices

Int64Index([68068, 63683, 67714, 26690, 29516, 66480, 81648, 62932, 59304,
            26071,
            ...
            83661,  8309, 11829, 42298, 98726, 49251, 15354, 61545, 84639,
            92186],
           dtype='int64', length=49881)

In [None]:
train_data

Unnamed: 0,age,wage_per_hour,capital_gains,capital_losses,stock_dividends,instance_weight,num_emp,own_or_self,vet_benefits,weeks_worked,...,country_mother_ Vietnam,country_mother_ Yugoslavia,citizenship_ Foreign born- Not a citizen of U S,citizenship_ Foreign born- U S citizen by naturalization,citizenship_ Native- Born abroad of American Parent(s),citizenship_ Native- Born in Puerto Rico or U S Outlying,citizenship_ Native- Born in the United States,vet_question_ No,vet_question_ Not in universe,vet_question_ Yes
0,73,0,0,0,0,1700.09,0,0,2,0,...,0,0,0,0,0,0,1,0,1,0
1,58,0,0,0,0,1053.55,1,0,2,52,...,0,0,0,0,0,0,1,0,1,0
2,18,0,0,0,0,991.95,0,0,2,0,...,1,0,1,0,0,0,0,0,1,0
3,9,0,0,0,0,1758.14,0,0,0,0,...,0,0,0,0,0,0,1,0,1,0
4,10,0,0,0,0,1069.16,0,0,0,0,...,0,0,0,0,0,0,1,0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
199518,87,0,0,0,0,955.27,0,0,2,0,...,0,0,0,0,0,0,1,0,1,0
199519,65,0,6418,0,9,687.19,1,0,2,52,...,0,0,0,0,0,0,1,0,1,0
199520,47,0,0,0,157,1923.03,6,0,2,52,...,0,0,0,1,0,0,0,0,1,0
199521,16,0,0,0,0,4664.87,0,0,2,0,...,0,0,0,0,0,0,1,0,1,0


In [None]:
    #train_data, train_label, validation_data, validation_label, test_data, test_label, output_info = data_preparation()
    num_features = train_data.shape[1]

    print('Training data shape = {}'.format(train_data.shape))
    print('Validation data shape = {}'.format(validation_data.shape))
    print('Test data shape = {}'.format(test_data.shape))

    # Set up the input layer
    input_layer = Input(shape=(num_features,))

    # Set up MMoE layer
    mmoe_layers = MMoE(
        units=4,
        num_experts=8,
        num_tasks=3
    )(input_layer)

    output_layers = []


Training data shape = (199523, 456)
Validation data shape = (49881, 456)
Test data shape = (49881, 456)


In [None]:
num_features

456

In [None]:
    # Build tower layer from MMoE layer
    for index, task_layer in enumerate(mmoe_layers):
        tower_layer = Dense(
            units=8,
            activation='relu',
            kernel_initializer=VarianceScaling())(task_layer)
        output_layer = Dense(
            units=output_info[index][0],
            name=output_info[index][1],
            activation='softmax',
            kernel_initializer=VarianceScaling())(tower_layer)
        output_layers.append(output_layer)

    # Compile model
    model = Model(inputs=[input_layer], outputs=output_layers)
    adam_optimizer = Adam()
    model.compile(
        loss={'income': 'binary_crossentropy', 'marital': 'binary_crossentropy','country_self': 'binary_crossentropy'},
        optimizer=adam_optimizer,
        metrics=['accuracy']
    )

In [None]:
    # Print out model architecture summary
    model.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 456)]        0           []                               
                                                                                                  
 m_mo_e_1 (MMoE)                [(None, 4),          25592       ['input_2[0][0]']                
                                 (None, 4),                                                       
                                 (None, 4)]                                                       
                                                                                                  
 dense_2 (Dense)                (None, 8)            40          ['m_mo_e_1[0][0]']               
                                                                                            

In [None]:
    # Train the model
    model.fit(
        x=train_data,
        y=train_label,
        validation_data=(validation_data, validation_label),
       epochs=100
    )

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

<keras.callbacks.History at 0x7faa50bb9f50>

In [None]:
model.evaluate(test_data, test_label)



[2.725534200668335,
 2.2160089015960693,
 0.24446551501750946,
 0.26507940888404846,
 0.0,
 0.9421623349189758,
 0.7347687482833862]