<a href="https://colab.research.google.com/github/AbrahamKong/CMPE297-Multi_Task_Learning_and_Transfer_Learning/blob/main/Assignment_3_Multi_Task_Learning_and_Transfer_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MMoE Multi Task Learning Model

In [None]:
import random

import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import VarianceScaling
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import Callback
from tensorflow.keras import activations, initializers, regularizers, constraints
from tensorflow.keras.layers import Layer, InputSpec

from sklearn.metrics import roc_auc_score

In [None]:
from google.colab import drive
drive.mount("/content/drive/")

Mounted at /content/drive/


In [None]:
FILE_PATH = "/content/drive/MyDrive/SJSU/CMPE 297: Advanced Deep Learning/Assignment #3: Multi Task Learning and Transfer Learning/"

In [None]:
class MMoE(Layer):
    """
    Multi-gate Mixture-of-Experts model.
    """

    def __init__(self,
                 units,
                 num_experts,
                 num_tasks,
                 use_expert_bias=True,
                 use_gate_bias=True,
                 expert_activation='relu',
                 gate_activation='softmax',
                 expert_bias_initializer='zeros',
                 gate_bias_initializer='zeros',
                 expert_bias_regularizer=None,
                 gate_bias_regularizer=None,
                 expert_bias_constraint=None,
                 gate_bias_constraint=None,
                 expert_kernel_initializer='VarianceScaling',
                 gate_kernel_initializer='VarianceScaling',
                 expert_kernel_regularizer=None,
                 gate_kernel_regularizer=None,
                 expert_kernel_constraint=None,
                 gate_kernel_constraint=None,
                 activity_regularizer=None,
                 **kwargs):
        """
         Method for instantiating MMoE layer.
        :param units: Number of hidden units
        :param num_experts: Number of experts
        :param num_tasks: Number of tasks
        :param use_expert_bias: Boolean to indicate the usage of bias in the expert weights
        :param use_gate_bias: Boolean to indicate the usage of bias in the gate weights
        :param expert_activation: Activation function of the expert weights
        :param gate_activation: Activation function of the gate weights
        :param expert_bias_initializer: Initializer for the expert bias
        :param gate_bias_initializer: Initializer for the gate bias
        :param expert_bias_regularizer: Regularizer for the expert bias
        :param gate_bias_regularizer: Regularizer for the gate bias
        :param expert_bias_constraint: Constraint for the expert bias
        :param gate_bias_constraint: Constraint for the gate bias
        :param expert_kernel_initializer: Initializer for the expert weights
        :param gate_kernel_initializer: Initializer for the gate weights
        :param expert_kernel_regularizer: Regularizer for the expert weights
        :param gate_kernel_regularizer: Regularizer for the gate weights
        :param expert_kernel_constraint: Constraint for the expert weights
        :param gate_kernel_constraint: Constraint for the gate weights
        :param activity_regularizer: Regularizer for the activity
        :param kwargs: Additional keyword arguments for the Layer class
        """
        # Hidden nodes parameter
        self.units = units
        self.num_experts = num_experts
        self.num_tasks = num_tasks

        # Weight parameter
        self.expert_kernels = None
        self.gate_kernels = None
        self.expert_kernel_initializer = initializers.get(expert_kernel_initializer)
        self.gate_kernel_initializer = initializers.get(gate_kernel_initializer)
        self.expert_kernel_regularizer = regularizers.get(expert_kernel_regularizer)
        self.gate_kernel_regularizer = regularizers.get(gate_kernel_regularizer)
        self.expert_kernel_constraint = constraints.get(expert_kernel_constraint)
        self.gate_kernel_constraint = constraints.get(gate_kernel_constraint)

        # Activation parameter
        self.expert_activation = activations.get(expert_activation)
        self.gate_activation = activations.get(gate_activation)

        # Bias parameter
        self.expert_bias = None
        self.gate_bias = None
        self.use_expert_bias = use_expert_bias
        self.use_gate_bias = use_gate_bias
        self.expert_bias_initializer = initializers.get(expert_bias_initializer)
        self.gate_bias_initializer = initializers.get(gate_bias_initializer)
        self.expert_bias_regularizer = regularizers.get(expert_bias_regularizer)
        self.gate_bias_regularizer = regularizers.get(gate_bias_regularizer)
        self.expert_bias_constraint = constraints.get(expert_bias_constraint)
        self.gate_bias_constraint = constraints.get(gate_bias_constraint)

        # Activity parameter
        self.activity_regularizer = regularizers.get(activity_regularizer)

        # Keras parameter
        self.input_spec = InputSpec(min_ndim=2)
        self.supports_masking = True

        super(MMoE, self).__init__(**kwargs)

    def build(self, input_shape):
        """
        Method for creating the layer weights.
        :param input_shape: Keras tensor (future input to layer)
                            or list/tuple of Keras tensors to reference
                            for weight shape computations
        """
        assert input_shape is not None and len(input_shape) >= 2

        input_dimension = input_shape[-1]

        # Initialize expert weights (number of input features * number of units per expert * number of experts)
        self.expert_kernels = self.add_weight(
            name='expert_kernel',
            shape=(input_dimension, self.units, self.num_experts),
            initializer=self.expert_kernel_initializer,
            regularizer=self.expert_kernel_regularizer,
            constraint=self.expert_kernel_constraint,
        )

        # Initialize expert bias (number of units per expert * number of experts)
        if self.use_expert_bias:
            self.expert_bias = self.add_weight(
                name='expert_bias',
                shape=(self.units, self.num_experts),
                initializer=self.expert_bias_initializer,
                regularizer=self.expert_bias_regularizer,
                constraint=self.expert_bias_constraint,
            )

        # Initialize gate weights (number of input features * number of experts * number of tasks)
        self.gate_kernels = [self.add_weight(
            name='gate_kernel_task_{}'.format(i),
            shape=(input_dimension, self.num_experts),
            initializer=self.gate_kernel_initializer,
            regularizer=self.gate_kernel_regularizer,
            constraint=self.gate_kernel_constraint
        ) for i in range(self.num_tasks)]

        # Initialize gate bias (number of experts * number of tasks)
        if self.use_gate_bias:
            self.gate_bias = [self.add_weight(
                name='gate_bias_task_{}'.format(i),
                shape=(self.num_experts,),
                initializer=self.gate_bias_initializer,
                regularizer=self.gate_bias_regularizer,
                constraint=self.gate_bias_constraint
            ) for i in range(self.num_tasks)]

        self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dimension})

        super(MMoE, self).build(input_shape)

    def call(self, inputs, **kwargs):
        """
        Method for the forward function of the layer.
        :param inputs: Input tensor
        :param kwargs: Additional keyword arguments for the base method
        :return: A tensor
        """
        gate_outputs = []
        final_outputs = []

        # f_{i}(x) = activation(W_{i} * x + b), where activation is ReLU according to the paper
        expert_outputs = tf.tensordot(a=inputs, b=self.expert_kernels, axes=1)
        # Add the bias term to the expert weights if necessary
        if self.use_expert_bias:
            expert_outputs = K.bias_add(x=expert_outputs, bias=self.expert_bias)
        expert_outputs = self.expert_activation(expert_outputs)

        # g^{k}(x) = activation(W_{gk} * x + b), where activation is softmax according to the paper
        for index, gate_kernel in enumerate(self.gate_kernels):
            gate_output = K.dot(x=inputs, y=gate_kernel)
            # Add the bias term to the gate weights if necessary
            if self.use_gate_bias:
                gate_output = K.bias_add(x=gate_output, bias=self.gate_bias[index])
            gate_output = self.gate_activation(gate_output)
            gate_outputs.append(gate_output)

        # f^{k}(x) = sum_{i=1}^{n}(g^{k}(x)_{i} * f_{i}(x))
        for gate_output in gate_outputs:
            expanded_gate_output = K.expand_dims(gate_output, axis=1)
            weighted_expert_output = expert_outputs * K.repeat_elements(expanded_gate_output, self.units, axis=1)
            final_outputs.append(K.sum(weighted_expert_output, axis=2))

        return final_outputs

    def compute_output_shape(self, input_shape):
        """
        Method for computing the output shape of the MMoE layer.
        :param input_shape: Shape tuple (tuple of integers)
        :return: List of input shape tuple where the size of the list is equal to the number of tasks
        """
        assert input_shape is not None and len(input_shape) >= 2

        output_shape = list(input_shape)
        output_shape[-1] = self.units
        output_shape = tuple(output_shape)

        return [output_shape for _ in range(self.num_tasks)]

    def get_config(self):
        """
        Method for returning the configuration of the MMoE layer.
        :return: Config dictionary
        """
        config = {
            'units': self.units,
            'num_experts': self.num_experts,
            'num_tasks': self.num_tasks,
            'use_expert_bias': self.use_expert_bias,
            'use_gate_bias': self.use_gate_bias,
            'expert_activation': activations.serialize(self.expert_activation),
            'gate_activation': activations.serialize(self.gate_activation),
            'expert_bias_initializer': initializers.serialize(self.expert_bias_initializer),
            'gate_bias_initializer': initializers.serialize(self.gate_bias_initializer),
            'expert_bias_regularizer': regularizers.serialize(self.expert_bias_regularizer),
            'gate_bias_regularizer': regularizers.serialize(self.gate_bias_regularizer),
            'expert_bias_constraint': constraints.serialize(self.expert_bias_constraint),
            'gate_bias_constraint': constraints.serialize(self.gate_bias_constraint),
            'expert_kernel_initializer': initializers.serialize(self.expert_kernel_initializer),
            'gate_kernel_initializer': initializers.serialize(self.gate_kernel_initializer),
            'expert_kernel_regularizer': regularizers.serialize(self.expert_kernel_regularizer),
            'gate_kernel_regularizer': regularizers.serialize(self.gate_kernel_regularizer),
            'expert_kernel_constraint': constraints.serialize(self.expert_kernel_constraint),
            'gate_kernel_constraint': constraints.serialize(self.gate_kernel_constraint),
            'activity_regularizer': regularizers.serialize(self.activity_regularizer)
        }
        base_config = super(MMoE, self).get_config()

        return dict(list(base_config.items()) + list(config.items()))

In [None]:
SEED = 1

# Fix numpy seed for reproducibility
np.random.seed(SEED)

# Fix random seed for reproducibility
random.seed(SEED)

# Fix TensorFlow graph-level seed for reproducibility
tf.random.set_seed(SEED)

In [None]:
# Simple callback to print out ROC-AUC
class ROCCallback(Callback):
    def __init__(self, training_data, validation_data, test_data):
        self.train_X = training_data[0]
        self.train_Y = training_data[1]
        self.validation_X = validation_data[0]
        self.validation_Y = validation_data[1]
        self.test_X = test_data[0]
        self.test_Y = test_data[1]

    def on_train_begin(self, logs={}):
        return

    def on_train_end(self, logs={}):
        return

    def on_epoch_begin(self, epoch, logs={}):
        return

    def on_epoch_end(self, epoch, logs={}):
        train_prediction = self.model.predict(self.train_X)
        validation_prediction = self.model.predict(self.validation_X)
        test_prediction = self.model.predict(self.test_X)

        # Iterate through each task and output their ROC-AUC across different datasets
        for index, output_name in enumerate(self.model.output_names):
            train_roc_auc = roc_auc_score(self.train_Y[index], train_prediction[index])
            validation_roc_auc = roc_auc_score(self.validation_Y[index], validation_prediction[index])
            test_roc_auc = roc_auc_score(self.test_Y[index], test_prediction[index])
            print(
                'ROC-AUC-{}-Train: {} ROC-AUC-{}-Validation: {} ROC-AUC-{}-Test: {}'.format(
                    output_name, round(train_roc_auc, 4),
                    output_name, round(validation_roc_auc, 4),
                    output_name, round(test_roc_auc, 4)
                )
            )

        return

    def on_batch_begin(self, batch, logs={}):
        return

    def on_batch_end(self, batch, logs={}):
        return

In [None]:
def data_preparation():
    # The column names are from
    # https://www2.1010data.com/documentationcenter/prod/Tutorials/MachineLearningExamples/CensusIncomeDataSet.html
    column_names = ['age', 'class_worker', 'det_ind_code', 'det_occ_code', 'education', 'wage_per_hour', 'hs_college',
                    'marital_stat', 'major_ind_code', 'major_occ_code', 'race', 'hisp_origin', 'sex', 'union_member',
                    'unemp_reason', 'full_or_part_emp', 'capital_gains', 'capital_losses', 'stock_dividends',
                    'tax_filer_stat', 'region_prev_res', 'state_prev_res', 'det_hh_fam_stat', 'det_hh_summ',
                    'instance_weight', 'mig_chg_msa', 'mig_chg_reg', 'mig_move_reg', 'mig_same', 'mig_prev_sunbelt',
                    'num_emp', 'fam_under_18', 'country_father', 'country_mother', 'country_self', 'citizenship',
                    'own_or_self', 'vet_question', 'vet_benefits', 'weeks_worked', 'year', 'income_50k']

    # Load the dataset in Pandas
    train_df = pd.read_csv(
        FILE_PATH + 'census-income.data.gz',
        delimiter=',',
        header=None,
        index_col=None,
        names=column_names
    )
    other_df = pd.read_csv(
        FILE_PATH + 'census-income.test.gz',
        delimiter=',',
        header=None,
        index_col=None,
        names=column_names
    )

    # First group of tasks according to the paper
    label_columns = ['income_50k', 'marital_stat']

    # One-hot encoding categorical columns
    categorical_columns = ['class_worker', 'det_ind_code', 'det_occ_code', 'education', 'hs_college', 'major_ind_code',
                           'major_occ_code', 'race', 'hisp_origin', 'sex', 'union_member', 'unemp_reason',
                           'full_or_part_emp', 'tax_filer_stat', 'region_prev_res', 'state_prev_res', 'det_hh_fam_stat',
                           'det_hh_summ', 'mig_chg_msa', 'mig_chg_reg', 'mig_move_reg', 'mig_same', 'mig_prev_sunbelt',
                           'fam_under_18', 'country_father', 'country_mother', 'country_self', 'citizenship',
                           'vet_question']
    train_raw_labels = train_df[label_columns]
    other_raw_labels = other_df[label_columns]
    transformed_train = pd.get_dummies(train_df.drop(label_columns, axis=1), columns=categorical_columns)
    transformed_other = pd.get_dummies(other_df.drop(label_columns, axis=1), columns=categorical_columns)

    # Filling the missing column in the other set
    transformed_other['det_hh_fam_stat_ Grandchild <18 ever marr not in subfamily'] = 0

    # One-hot encoding categorical labels
    train_income = to_categorical((train_raw_labels.income_50k == ' 50000+.').astype(int), num_classes=2)
    train_marital = to_categorical((train_raw_labels.marital_stat == ' Never married').astype(int), num_classes=2)
    other_income = to_categorical((other_raw_labels.income_50k == ' 50000+.').astype(int), num_classes=2)
    other_marital = to_categorical((other_raw_labels.marital_stat == ' Never married').astype(int), num_classes=2)

    dict_outputs = {
        'income': train_income.shape[1],
        'marital': train_marital.shape[1]
    }
    dict_train_labels = {
        'income': train_income,
        'marital': train_marital
    }
    dict_other_labels = {
        'income': other_income,
        'marital': other_marital
    }
    output_info = [(dict_outputs[key], key) for key in sorted(dict_outputs.keys())]

    # Split the other dataset into 1:1 validation to test according to the paper
    validation_indices = transformed_other.sample(frac=0.5, replace=False, random_state=SEED).index
    test_indices = list(set(transformed_other.index) - set(validation_indices))
    validation_data = transformed_other.iloc[validation_indices]
    validation_label = [dict_other_labels[key][validation_indices] for key in sorted(dict_other_labels.keys())]
    test_data = transformed_other.iloc[test_indices]
    test_label = [dict_other_labels[key][test_indices] for key in sorted(dict_other_labels.keys())]
    train_data = transformed_train
    train_label = [dict_train_labels[key] for key in sorted(dict_train_labels.keys())]

    return train_data, train_label, validation_data, validation_label, test_data, test_label, output_info


In [None]:
# Load the data
train_data, train_label, validation_data, validation_label, test_data, test_label, output_info = data_preparation()
num_features = train_data.shape[1]

In [None]:
print('Training data shape = {}'.format(train_data.shape))
print('Validation data shape = {}'.format(validation_data.shape))
print('Test data shape = {}'.format(test_data.shape))

Training data shape = (199523, 499)
Validation data shape = (49881, 499)
Test data shape = (49881, 499)


In [None]:
# Set up the input layer
input_layer = Input(shape=(num_features,))

In [None]:
# Set up MMoE layer
mmoe_layers = MMoE(
  units=4,
  num_experts=8,
  num_tasks=2
)(input_layer)

output_layers = []

In [None]:
# Build tower layer from MMoE layer
for index, task_layer in enumerate(mmoe_layers):
  tower_layer = Dense(
    units=8,
    activation='relu',
    kernel_initializer=VarianceScaling())(task_layer)
  output_layer = Dense(
    units=output_info[index][0],
    name=output_info[index][1],
    activation='softmax',
    kernel_initializer=VarianceScaling())(tower_layer)
  output_layers.append(output_layer)

In [None]:
# Compile model
model = Model(inputs=[input_layer], outputs=output_layers)
adam_optimizer = Adam()
model.compile(
  loss={'income': 'binary_crossentropy', 'marital': 'binary_crossentropy'},
  optimizer=adam_optimizer,
  metrics=['accuracy']
)

In [None]:
# Print out model architecture summary
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 499)]        0           []                               
                                                                                                  
 m_mo_e (MMoE)                  [(None, 4),          24000       ['input_1[0][0]']                
                                 (None, 4)]                                                       
                                                                                                  
 dense (Dense)                  (None, 8)            40          ['m_mo_e[0][0]']                 
                                                                                                  
 dense_1 (Dense)                (None, 8)            40          ['m_mo_e[0][1]']             

In [None]:
# Train the model
model.fit(
  x=train_data,
  y=train_label,
  validation_data=(validation_data, validation_label),
  callbacks=[
    ROCCallback(
      training_data=(train_data, train_label),
      validation_data=(validation_data, validation_label),
      test_data=(test_data, test_label)
    )
  ],
  epochs=100
)

Epoch 1/100
ROC-AUC-income-Train: 0.9083 ROC-AUC-income-Validation: 0.9089 ROC-AUC-income-Test: 0.9084
ROC-AUC-marital-Train: 0.9671 ROC-AUC-marital-Validation: 0.9537 ROC-AUC-marital-Test: 0.9552
Epoch 2/100
ROC-AUC-income-Train: 0.9172 ROC-AUC-income-Validation: 0.9156 ROC-AUC-income-Test: 0.9156
ROC-AUC-marital-Train: 0.9868 ROC-AUC-marital-Validation: 0.9718 ROC-AUC-marital-Test: 0.9714
Epoch 3/100
ROC-AUC-income-Train: 0.9179 ROC-AUC-income-Validation: 0.9159 ROC-AUC-income-Test: 0.9155
ROC-AUC-marital-Train: 0.9897 ROC-AUC-marital-Validation: 0.9747 ROC-AUC-marital-Test: 0.9738
Epoch 4/100
ROC-AUC-income-Train: 0.914 ROC-AUC-income-Validation: 0.9152 ROC-AUC-income-Test: 0.916
ROC-AUC-marital-Train: 0.9902 ROC-AUC-marital-Validation: 0.9744 ROC-AUC-marital-Test: 0.9739
Epoch 5/100
ROC-AUC-income-Train: 0.9276 ROC-AUC-income-Validation: 0.9254 ROC-AUC-income-Test: 0.9237
ROC-AUC-marital-Train: 0.9895 ROC-AUC-marital-Validation: 0.9711 ROC-AUC-marital-Test: 0.9709
Epoch 6/100
ROC-A

<keras.callbacks.History at 0x7fc72c912d90>

Copyright: [Drawbridge, Inc](https://github.com/drawbridge/keras-mmoe)

# Transfer Learning using head2toe

In [2]:
# !pip install ml-collections

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ml-collections
  Downloading ml_collections-0.1.1.tar.gz (77 kB)
[K     |████████████████████████████████| 77 kB 2.9 MB/s 
Building wheels for collected packages: ml-collections
  Building wheel for ml-collections (setup.py) ... [?25l[?25hdone
  Created wheel for ml-collections: filename=ml_collections-0.1.1-py3-none-any.whl size=94524 sha256=d00a16b741607cc602edce2c64f6992a8cb2c19ad86c9cb0adfeae9deccc9db3
  Stored in directory: /root/.cache/pip/wheels/b7/da/64/33c926a1b10ff19791081b705879561b715a8341a856a3bbd2
Successfully built ml-collections
Installing collected packages: ml-collections
Successfully installed ml-collections-0.1.1


In [3]:
import absl.testing.parameterized as parameterized
import tensorflow as tf
import tensorflow.compat.v2 as tf2
import re
from ml_collections import ConfigDict

In [4]:
def get_config(config_string):
  train_batch_size = 128
  eval_batch_size = 50
  config = ConfigDict({
      'dataset':
          'data.caltech101',
      'eval_mode':
          'valid',
      'is_vtab_5fold_valid':
          True,
      'seed':
          8,
      'max_num_gpus':
          1,
      'learning':
          ConfigDict({
              'optimizer': 'adam',  #  adadelta, adadelta_adaptive, sgd
              'learning_rate': 0.1,
              'grad_clip_value': -1.,  # Applied if positive.
              'l1_regularizer': 0.,
              'l2_regularizer': 0.,
              'group_lrp_regularizer_coef': 0.,
              'group_lrp_regularizer_r': 2.,
              'group_lrp_regularizer_p': 1.,
              'group_lrp_is_embedding': False,
              'training_steps': 500,
              'data_fraction': 1.,
              'cached_eval': True,
              'use_cosine_decay': True,
              'train_batch_size': train_batch_size,
              'eval_batch_size': eval_batch_size,
              'finetune_backbones': False,
              'finetune_lr_multiplier': 1.,
              'finetune_steps_multiplier': 1.,
              # ('', 'unit_vector', 'per_feature')
              'feature_normalization': 'unit_vector',
              # nohidden, random_100, random_1000, trainable_100, trainable_1000
              'output_head_type': 'nohidden',
              'output_head_zeroinit': False,
              'log_freq': 50,
          }),
      'model_name':
          'Finetune'
  })

  config.backbone = get_backbone_config(config_string)
  print(f'Config backbone: {config.backbone}')
  return config

In [5]:
def get_backbone_config(config_string):
  """Gets backbone configuration according to the key given."""
  # Example patterns:
  # imagenetr50, imagenetr50_2x
  pattern = r'^([A-Za-z0-9]+)?_?(\d+)?x?'
  searched = re.search(pattern, config_string)
  if not searched:
    raise ValueError(f'Unrecognized config_string: {config_string}')
  added_backbone, n_repeat = searched.groups()
  print(f'Split config: {added_backbone}, {n_repeat}')
  processed_names = []
  processed_handles = []
  processed_signatures = []
  processed_output_keys = []
  input_sizes = tuple()

  if added_backbone in SINGLE_MODELS:
    n_repeat = int(n_repeat) if n_repeat else 1
    processed_names += [added_backbone] * n_repeat
    handle, size = SINGLE_MODELS[added_backbone]
    if isinstance(handle, list):
      processed_handles = handle * n_repeat
      processed_handles = processed_handles[:n_repeat]
    else:
      processed_handles += [handle] * n_repeat

    if 'vit' in added_backbone:
      processed_signatures += ['serving_default'] * n_repeat
      processed_output_keys += ['pre_logits'] * n_repeat
    else:
      processed_signatures += ['representation'] * n_repeat
      processed_output_keys += ['pre_logits'] * n_repeat
    input_sizes += (size,) * n_repeat
  else:
    raise ValueError(f'added_backbone:{added_backbone} is not recognized')

  return ConfigDict({
      'names': processed_names,
      'handles': processed_handles,
      'signatures': processed_signatures,
      'output_keys': processed_output_keys,
      'input_sizes': input_sizes,
      'include_input': False,
      'additional_features': '',
      'additional_features_pool_size': 0,
      'cls_token_pool': 'normal',
      # If target size is provided, pool size is ignored.
      'additional_features_target_size': 0,
      'additional_features_multi_target_sizes': '',
  })

In [6]:
SINGLE_MODELS = {
    'imagenetr50': ('checkpoints/imagenetr50/', 240),
    'imagenetvitB16': ('checkpoints/imagenetvitB16/', 224)
}

In [7]:
def get_config_fs(config_string):
  config = finetune.get_config(config_string)
  config['model_name'] = 'FinetuneFS'
  new_learning_config = ConfigDict({
      'feature_selection':
          ConfigDict({
              # Following types exist: 'connectivity_mask',
              # 'connectivity_l1', 'random', 'none', 'variance' and
              # 'sklearn_x' where x in
              # [chi2, f_classif, mutual_info_classif, trees]
              'type': 'none',
              'fs_dataset': '',
              'is_overwrite': False,
              'average_over_k': 1,
              'keep_fraction': 0.1,
              'keep_fraction_offset': 0,
              'mean_interpolation_coef': 0.,
              'learning_config_overwrite':
                  ConfigDict({
                      'group_lrp_regularizer_coef': 1e-4,
                      'finetune_backbones': False,
                  })
          }),
  })
  config['learning'].update(new_learning_config)
  print(f'Config backbone: {config.backbone}')
  return config

In [8]:
def _filter_to_k_shot(dataset, num_classes, k):
  """Filters k-shot subset from a dataset."""
  # !!! IMPORTANT: the dataset should *not* be shuffled. !!!
  # Make sure that `shuffle_buffer_size=1` in the call to
  # `dloader.get_tf_data`.

  # Indices of included examples in the k-shot balanced dataset.
  keep_example = []
  # Keep track of the number of examples per class included in
  # `keep_example`.
  class_counts = np.zeros([num_classes], dtype=np.int32)
  for _, label in dataset.as_numpy_iterator():
    # If there are less than `k` examples of class `label` in `example_indices`,
    # keep this example and update the class counts.
    keep = class_counts[label] < k
    keep_example.append(keep)
    if keep:
      class_counts[label] += 1
    # When there are `k` examples for each class included in `keep_example`,
    # stop searching.
    if (class_counts == k).all():
      break

  dataset = tf.data.Dataset.zip((
      tf.data.Dataset.from_tensor_slices(keep_example),
      dataset
  )).filter(lambda keep, _: keep).map(lambda _, example: example).cache()

  return dataset

In [9]:
def create_vtab_dataset_balanced(dataset, image_size, batch_size,
                                 data_fraction):
  """Creates a VTAB input_fn to be used by `tf.Estimator`.

  Deterministic balanced sampling from vtab datasets.

  Args:
    dataset: str, VTAB task to evaluate on.
    image_size: int
    batch_size: int
    data_fraction: float, used to calculate n_shots

  Returns:
    input_fn, input function to be passed to `tf.Estimator`.
  """
  dloader = data_loader.get_dataset_instance(
      {'dataset': dataset, 'data_dir': None})
  num_classes = dloader.get_num_classes()
  n_shots = max(int(1000 * data_fraction / num_classes), 1)
  logging.info('n_shots: %d', n_shots)
  def _dict_to_tuple(batch):
    return batch['image'], batch['label']
  dataset = dloader.get_tf_data(
      split_name='trainval',
      batch_size=batch_size,
      preprocess_fn=functools.partial(
          data_loader.preprocess_fn,
          input_range=(-1.0, 1.0),
          size=image_size),
      epochs=0,
      drop_remainder=False,
      for_eval=False,
      shuffle_buffer_size=1,
      prefetch=1,
      train_examples=None,
  ).unbatch().map(_dict_to_tuple)
  filtered_dataset = _filter_to_k_shot(dataset, num_classes, n_shots)
  return filtered_dataset.shuffle(1000).batch(batch_size)

In [10]:
def create_vtab_dataset(dataset, image_size, batch_size, mode,
                        eval_mode='test', valid_fold_id=4):
  """Creates a VTAB input_fn to be used by `tf.Estimator`.

  Note: There is one episode/VTAB dataset.

  Args:
    dataset: str, VTAB task to evaluate on.
    image_size: int
    batch_size: int
    mode: str in {'train', 'eval'}, whether to build the input function for
      training or evaluation.
    eval_mode: str in {'valid', 'test'}, whether to build the input functions
      for validation or test runs.
    valid_fold_id: int, 0 <= valid_fold_id < 5, valid_fold_id=4 corresponds to
      the default value in VTAB.

  Returns:
    input_fn, input function to be passed to `tf.Estimator`.
  """
  assert 0 <= valid_fold_id < 5
  dloader = data_loader.get_dataset_instance(
      {'dataset': dataset, 'data_dir': None})
  if mode not in ('train', 'eval'):
    raise ValueError("mode should be 'train' or 'eval'")
  is_training = mode == 'train'

  def _dict_to_tuple(batch):
    return batch['image'], batch['label']
  if eval_mode == 'test':
    split_name = 'train800val200' if is_training else 'test'
  elif eval_mode == 'valid':
    val_start, val_end = valid_fold_id * 200, (valid_fold_id + 1) * 200
    if is_training:
      split_name = f'train[:{val_start}]+train[{val_end}:1000]'
    else:
      split_name = f'train[{val_start}:{val_end}]'
    logging.info('Using split_name: %s', split_name)

    if split_name not in dloader._tfds_splits:
      dloader._tfds_splits[split_name] = split_name
      dloader._num_samples_splits[split_name] = 800 if is_training else 200
  else:
    raise ValueError(f'eval_mode: {eval_mode} invalid')

  return dloader.get_tf_data(
      split_name=split_name,
      batch_size=batch_size,
      preprocess_fn=functools.partial(
          data_loader.preprocess_fn,
          input_range=(-1.0, 1.0),
          size=image_size),
      epochs=0,
      drop_remainder=False,
      for_eval=not is_training,
      # Our training data has at most 1000 samples, therefore a shuffle buffer
      # size of 1000 is sufficient.
      shuffle_buffer_size=1000,
      prefetch=1,
      train_examples=None,
  ).map(_dict_to_tuple)


In [11]:
class InputPipelineTest(parameterized.TestCase, tf.test.TestCase):

  @parameterized.parameters(
      (84, 2, 'train', 'test'),
      (84, 2, 'eval', 'test'),
      (84, 1000, 'train', 'test'),
      (84, 1000, 'eval', 'test'),
      (240, 2, 'train', 'test'),
      (240, 2, 'train', 'valid'),
      (240, 2, 'eval', 'valid'),
  )
  def test_vtab_pipeline(self, image_size, batch_size, mode, eval_mode):
    data_source = 'data.caltech101'
    dataset = create_vtab_dataset(
        dataset=data_source, mode=mode, image_size=image_size,
        batch_size=batch_size, eval_mode=eval_mode)
    if batch_size <= 1000:
      x, y = next(iter(dataset))
      self.assertAllEqual(x.shape, [batch_size, image_size, image_size, 3])
      self.assertAllEqual(y.shape, [batch_size])
    if batch_size == 1000 and mode == 'train':
      # Full batch.
      self.assertLen(list(iter(dataset)), 1)

In [12]:
class FinetuneTest(parameterized.TestCase, tf2.test.TestCase):

  # @parameterized.named_parameters(('r50', 'imagenetr50'),
  #                                 ('vitb16', 'imagenetvitB16'))
  def test_evaluate(self, model_name):
    """Tests whether the model runs with no-error using dummy inputs."""
    self.config = get_config(model_name)
    self.sur_model = Finetune(self.config)
    dataset = tf2.data.Dataset.from_tensor_slices(
        (tf2.random.uniform([4, 240, 240, 3]),
         tf2.random.uniform([4,], maxval=2, dtype=tf2.int32))).batch(2)
    results = self.sur_model.evaluate(self.config.learning, dataset, dataset)
    print(results)

In [13]:
class FinetuneFSTest(parameterized.TestCase, tf2.test.TestCase):

  @parameterized.named_parameters(('r50', 'imagenetr50'),
                                  ('vitb16', 'imagenetvitB16'))
  def test_evaluate(self, model_name):
    """Tests whether the model runs with no-error using dummy inputs."""
    self.config = get_config_fs(model_name)
    self.sur_model = FinetuneFS(self.config)
    dataset = tf2.data.Dataset.from_tensor_slices(
        (tf2.random.uniform([4, 240, 240, 3]),
         tf2.random.uniform([4,], maxval=2, dtype=tf2.int32))).batch(2)
    results = self.sur_model.evaluate(self.config.learning, dataset, dataset)
    print(results)

In [None]:
ft = FinetuneTest(('r50', 'imagenetr50'), ('vitb16', 'imagenetvitB16'))

copyright: [2022 Head2Toe Authors](https://github.com/google-research/head2toe)

# Reference:

1. [Keras-MMoE](https://github.com/drawbridge/keras-mmoe)
2. [Text classification with Switch Transformer](https://keras.io/examples/nlp/text_classification_with_switch_transformer/)
3. [mtlearn](https://github.com/AmazaspShumik/mtlearn)
4. [The Sparsely Gated Mixture of Experts Layer for PyTorch](https://github.com/davidmrau/mixture-of-experts)
5. [LibMTL](https://github.com/median-research-group/LibMTL)
6. [Head2Toe: Utilizing Intermediate Representations for Better OOD Generalization](https://github.com/google-research/head2toe)
7. [Multi-task learning with Multi-gate Mixture-of-experts](https://towardsdatascience.com/multi-task-learning-with-multi-gate-mixture-of-experts-b46efac3268)
8. [Towards Out-Of-Distribution Generalization: A Survey](https://arxiv.org/abs/2108.13624)
9. [Head2Toe: Utilizing Intermediate Representations for Better Transfer Learning](https://paperswithcode.com/paper/head2toe-utilizing-intermediate-1/review/)
10. [Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts](https://www.youtube.com/watch?v=Dweg47Tswxw&t=12s&ab_channel=KDD2018video)