In [1]:
# Copyright 2023 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ================================

# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.

<img2 src="https://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_models_03-exploring-different-models/nvidia_logo.png" style="width: 90px; float: right;">

# Multi-Task Learning for Ranking

This notebook is created using the latest stable [merlin-tensorflow](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/merlin/containers/merlin-tensorflow/tags) container. 
    
In the industry, it is common to find scenarios where you need to score the likelihood of different user events regarding items, e.g., clicking, liking, sharing, commenting, following the author, etc. Instead of spending computational resources to train and deploy different models for each task, Multi-Task Learning (MTL) techniques have been popular to train a single model that are able to predict multiple targets.

In this example, we demonstrate how to build and train ranking models with multiple targets. We introduce the building blocks Merlin Models provide for MTL support and also MTL-specific architectures designed for accuracy improvement of many different tasks: **MMoE**, **CGC** and **PLE**.

In this example notebook, we use synthetic data based on the schema of a dataset released publicly by Tencent in the [TenRec paper](https://arxiv.org/abs/2210.10629), which is suitable for multi-task learning for providing multiple targets (types of user-item events). 

### Learning objectives
- Getting to know the buiilding blocks Merlin provides for MTL
- Training different deep learning-based ranking models with multi-task learning using Merlin Models

In [2]:
import os
import tensorflow as tf

#os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"
os.environ["TF_MEMORY_ALLOCATION"] = "0.9"

import merlin.models.tf as mm
from merlin.schema.tags import Tags

2023-01-10 21:34:55.025153: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-01-10 21:34:57.441575: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-01-10 21:34:59.594193: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 29249 MB memory:  -> device: 0, name: Quadro GV100, pci bus id: 0000:15:00.0, compute capability: 7.0
2023-01-10 21:34:59.595309: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created

## Generating data

Here we generate synthetic dataset based on the schema of the `tenrec-video` dataset. The original dataset was released publicly by Tencent in the [TenRec paper](https://arxiv.org/abs/2210.10629), and is suitable for multi-task learning for providing multiple targets (types of user-item events).  
P.s. To make the synthetic data more realistic, our data generator takes into account the original cardinalities of categorical features and the dependency of user features to user id and item features to item id.

In [3]:
import os
from merlin.datasets.synthetic import generate_data

NUM_ROWS = os.environ.get("NUM_ROWS", 100_000)

train_ds, valid_ds = generate_data("tenrec-video", int(NUM_ROWS), set_sizes=(0.8, 0.2))
schema = train_ds.schema



By inspecting the columns tagging on the dataset schema, we can notice that there are number of user features (`user_id`, `gender`, `age`) and item features (`item_id`, `video_category`). There are also four binary classification targets (`click`, `follow`, `like`, and `share`) and one regression target (`watching_times`).

In [4]:
schema

Unnamed: 0,name,tags,dtype,is_list,is_ragged,properties.max_size,properties.freq_threshold,properties.cat_path,properties.embedding_sizes.dimension,properties.embedding_sizes.cardinality,properties.num_buckets,properties.start_index,properties.domain.min,properties.domain.max,properties.domain.name
0,user_id,"(Tags.CATEGORICAL, Tags.USER_ID, Tags.USER, Ta...",int32,False,False,0.0,0.0,.//categories/unique.user_id.parquet,512.0,2633851.0,,1.0,0.0,2633851.0,user_id
1,item_id,"(Tags.ITEM, Tags.CATEGORICAL, Tags.ID, Tags.IT...",int32,False,False,0.0,0.0,.//categories/unique.item_id.parquet,512.0,179280.0,,1.0,0.0,179280.0,item_id
2,video_category,"(Tags.ITEM, Tags.CATEGORICAL)",int32,False,False,0.0,0.0,.//categories/unique.video_category.parquet,16.0,5.0,,1.0,0.0,5.0,video_category
3,gender,"(Tags.CATEGORICAL, Tags.USER)",int32,False,False,0.0,0.0,.//categories/unique.gender.parquet,16.0,5.0,,1.0,0.0,5.0,gender
4,age,"(Tags.CATEGORICAL, Tags.USER)",int32,False,False,0.0,0.0,.//categories/unique.age.parquet,16.0,10.0,,1.0,0.0,10.0,age
5,click,"(Tags.BINARY_CLASSIFICATION, Tags.BINARY, Tags...",int8,False,False,,,,,,,,,,
6,follow,"(Tags.BINARY_CLASSIFICATION, Tags.BINARY, Tags...",int8,False,False,,,,,,,,,,
7,like,"(Tags.BINARY_CLASSIFICATION, Tags.BINARY, Tags...",int8,False,False,,,,,,,,,,
8,share,"(Tags.BINARY_CLASSIFICATION, Tags.BINARY, Tags...",int8,False,False,,,,,,,,,,
9,watching_times,"(Tags.TARGET, Tags.REGRESSION)",int16,False,False,,,,,,,,0.0,5.0,watching_times


In [5]:
# Printing first rows of the generated dataframe
train_ds.to_ddf().head()

Unnamed: 0,user_id,gender,age,item_id,video_category,click,follow,like,share,watching_times
0,26,1,1,11,1,1,1,1,1,1
1,25,1,1,54,1,0,0,0,1,1
2,68,1,1,11,1,0,1,0,0,3
3,45,1,1,26,1,0,1,0,0,2
4,4,1,1,12,1,1,0,0,0,4


## Building and training MTL models

In [6]:
BATCH_SIZE = 4 * 1024

The simplest way to build a model with Merlin Models is using `InputBlockV2` and `OutputBlock` building blocks, that infer the input features and target columns from the schema.
The `InputBlockV2` creates the embedding layers for categorical features and concatenates all features. The `OutputBlock` creates a head `ModelOutput` for each target depending on the task type (tagged in the column schema), e.g. `RegressionOutput()` for regression, `BinaryOuput()` for binary classification, `CategoricalOutput` for multi-class classification.  

You can inspect below a multi-task learning model created for this dataset with just four lines of code.

In [7]:
model = mm.Model(
    mm.InputBlockV2(schema),
    mm.MLPBlock([32,16]),
    mm.OutputBlock(schema)
)

model

Model(
  (blocks): _TupleWrapper((ParallelBlock(
    (_aggregation): ConcatFeatures()
    (parallel_layers): Dict(
      (categorical): ParallelBlock(
        (parallel_layers): Dict(
          (user_id): EmbeddingTable(
            (features): Dict(
              (user_id): ColumnSchema(name='user_id', tags={<Tags.CATEGORICAL: 'categorical'>, <Tags.USER_ID: 'user_id'>, <Tags.USER: 'user'>, <Tags.ID: 'id'>}, properties={'max_size': 0.0, 'freq_threshold': 0.0, 'cat_path': './/categories/unique.user_id.parquet', 'embedding_sizes': {'dimension': 512.0, 'cardinality': 2633851.0}, 'num_buckets': None, 'start_index': 1.0, 'domain': {'min': 0, 'max': 2633851, 'name': 'user_id'}}, dtype=dtype('int32'), is_list=False, is_ragged=False)
            )
            (table): Embedding()
          )
          (item_id): EmbeddingTable(
            (features): Dict(
              (item_id): ColumnSchema(name='item_id', tags={<Tags.ITEM: 'item'>, <Tags.CATEGORICAL: 'categorical'>, <Tags.ID: 'id'>, <Tags

*Note*: If you want to build a model for just a subset of the target features, you can either remove the unwanted columns from schema: `schema.without(["like", "follow", "share"])`

OR you can replace `mm.OutputBlock(schema)` by a `ParallelBlock` with only the desired targets:
```python
mm.ParallelBlock(
  mm.BinaryOutput("click"), mm.RegressionOutput("watching_times")
)
```

### Train and evaluation of MTL models

In [8]:
model.compile(optimizer="adam", run_eagerly=False)
model.fit(train_ds, batch_size=BATCH_SIZE)





<keras.callbacks.History at 0x7f594b26f520>

By inspecting the metrics output from model evaluation, we can observe that there are specific default metrics to each target; for binary classification (`precision`, `recall`, `binary_accuracy`, `auc`) and for regression (`root_mean_squared_error`) tasks.  
Each task has its own loss (e.g. `click/binary_output_loss`, `watching_times/regression_output_loss`) and the `loss` is the sum of all tasks losses.

In [9]:
model.evaluate(valid_ds, batch_size=BATCH_SIZE, return_dict=True)



{'loss': 8.325882911682129,
 'click/binary_output_loss': 0.6932552456855774,
 'follow/binary_output_loss': 0.6932500600814819,
 'like/binary_output_loss': 0.6930577158927917,
 'share/binary_output_loss': 0.6981387734413147,
 'watching_times/regression_output_loss': 5.548181056976318,
 'click/binary_output/precision': 0.4777131676673889,
 'click/binary_output/recall': 0.09886693954467773,
 'click/binary_output/binary_accuracy': 0.49674999713897705,
 'click/binary_output/auc': 0.4955274164676666,
 'follow/binary_output/precision': 0.4951406717300415,
 'follow/binary_output/recall': 0.1938326060771942,
 'follow/binary_output/binary_accuracy': 0.49869999289512634,
 'follow/binary_output/auc': 0.49518218636512756,
 'like/binary_output/precision': 0.4950670897960663,
 'like/binary_output/recall': 0.2529999017715454,
 'like/binary_output/binary_accuracy': 0.5016499757766724,
 'like/binary_output/auc': 0.5072351694107056,
 'share/binary_output/precision': 0.0,
 'share/binary_output/recall': 0.

### Setting loss weights

You can balance the importance of individual task losses into the final `loss` by setting `loss_weights`.

In [10]:
loss_weights = {
        "click/binary_output": 5.0,
        "like/binary_output": 4.0,
        "share/binary_output": 3.0,
        "follow/binary_output": 2.0,
        "watching_times/regression_output": 1.0,        
    }


model.compile(optimizer="adam", run_eagerly=False, loss_weights=loss_weights)
model.fit(train_ds, batch_size=BATCH_SIZE)



<keras.callbacks.History at 0x7f594b3d3e20>

### Setting task-specific class / sample weights

Keras supports setting `class_weight` and `sample_weight` for **single-task models** in `model.fit()`.  

The `class_weight` allows weighting the classes of categorical/binary target in the loss, so that model training can pay more attention to samples from an under-represented class.  

The `sample_weight` allows weighting data samples which should account more or less for the loss during training. If `weighted_metrics` is provided in `model.compile()`, then those metrics will also be weighted by `sample_weight` during training and testing.

Merlin Models provides building blocks for **tasks-specific class and sample weights** with the `ColumnBasedSampleWeight` block. Here are some examples for different use cases.

#### 1. Setting class weights per task
Here we create an MTL model to predict `click` and `like` targets. We set negative events (0s) to have weight 1.0 and positive events (1s) to have a higher weight. As `like` target is a more rare event (sparser) than `click`  we should use higher sample weight for positive examples for it.

In [11]:
output_block = mm.ParallelBlock(
  mm.BinaryOutput("click",
                  post=mm.ColumnBasedSampleWeight(
                        binary_class_weights=(1.0, 5.0), 
                  )), 
  mm.BinaryOutput("like",
                  post=mm.ColumnBasedSampleWeight(
                        binary_class_weights=(1.0, 20.0), 
                  ))
)

In [12]:
model = mm.Model(
    mm.InputBlockV2(schema),
    mm.MLPBlock([32,16]),
    output_block
)

model.compile(optimizer="adam", run_eagerly=False)
model.fit(train_ds, batch_size=BATCH_SIZE)



<keras.callbacks.History at 0x7f594b57bf10>

#### 2. Using other target / feature as weight per task
Another use case would be using a feature or other target for sample weight. Just as a didactic example, we use the `user_age` feature for weighting `click` loss, and `watching_times` target column to weight `like` loss, so that examples with higher values for that column would be more accountable for the final loss.

In [13]:
output_block = mm.ParallelBlock(
  mm.BinaryOutput("click",
                  post=mm.ColumnBasedSampleWeight("age")
                 ), 
  mm.BinaryOutput("like",
                  post=mm.ColumnBasedSampleWeight("watching_times")
                 )
)

In [14]:
model = mm.Model(
    mm.InputBlockV2(schema),
    mm.MLPBlock([32,16]),
    output_block
)

model.compile(optimizer="adam", run_eagerly=False)
model.fit(train_ds, batch_size=BATCH_SIZE)



<keras.callbacks.History at 0x7f59486d4910>

#### 3. Using another binary target as sample space
In some cases, a target might be conditioned to another binary target. For example, there might some event dependency in the system the user is interacting, so that the user can only `like` or `share` if `click` event happened first. As the more specific events are usually much less frequent than `click`, they are sparser thus suffer more from unbalanced class training. In such cases, as you can only have a positive event (i.e., `like=1`) if `click=1`, we can use that as the sample space for training `like`, i.e., the sample is only considered for `like` loss if `click=1`. Here is how you can set such sample space dependency among targets.  

In [15]:
output_block = mm.ParallelBlock(
  mm.BinaryOutput("click"), 
  mm.BinaryOutput("like", post=mm.ColumnBasedSampleWeight("click"))
)

In such cases you might want to compute metrics for `like` considering only its sample space, rather than the entire space. The **`weighed_metrics`** can be used for that, as regular metrics are not influenced by sample weights.  
We also demonstrate below how to override the default **`metrics`** per task. Metrics can be either Keras-like metrics or string aliases supported by Merlin Models (e.g., "auc", "precision", "recall", "binary_accuracy", "rmse", "mse").

In [16]:
model = mm.Model(
    mm.InputBlockV2(schema),
    mm.MLPBlock([32,16]),
    output_block
)

metrics = {
        "click/binary_output": [tf.keras.metrics.AUC(name="auc", num_thresholds=200)],
        "like/binary_output": ["auc"],
    }

weighted_metrics = {
        "click/binary_output": ["auc"],
        "like/binary_output": ["auc"],
    }

model.compile(optimizer="adam", run_eagerly=False, metrics=metrics, weighted_metrics=weighted_metrics)
model.fit(train_ds, batch_size=BATCH_SIZE)



<keras.callbacks.History at 0x7f59481e9e20>

You can notice that when `weighted_metrics` are set we get the specified metrics prefixed by `weighted_`. The regular metrics for `like` (`auc`) differs from the weighted metrics (`weighted_auc`) because the latter are affected by sample weights, i.e., computed only for samples where `click=1`.

In [17]:
model.evaluate(valid_ds, batch_size=BATCH_SIZE, return_dict=True)



{'loss': 1.038915991783142,
 'click/binary_output_loss': 0.6932337880134583,
 'like/binary_output_loss': 0.34568217396736145,
 'click/binary_output/auc': 0.49904555082321167,
 'click/binary_output/weighted_auc': 0.49904555082321167,
 'like/binary_output/auc': 0.5014902949333191,
 'like/binary_output/weighted_auc': 0.5000181198120117,
 'regularization_loss': 0.0,
 'loss_batch': 1.0394577980041504}

## Multi-task learning architectures

In this section we describe different architectures for multi-task learning, which are summarized in the following illustration. The blue shapes are the ones that are shared for all tasks, and the other colored shapes are task-specific ones. We explain each of those architectures in the next sub-sections.

<img src="../images/mtl_architectures.png"  width="90%">

Image adapted from: [Progressive Layered Extraction (PLE): A Novel Multi-Task
Learning (MTL) Model for Personalized Recommendations](https://dl.acm.org/doi/10.1145/3383313.3412236)

### Hard parameter sharing

The examples above used a **hard parameter sharing**, where all tasks share MLP layers in the bottom, and each task has a specific single-layer MLP tower that projects the shared-bottom output to a single neuron per task (for binary classification / regression tasks).  
We can specify more powerful task towers, so that tasks have more freedom to learn different things, with either of the following examples

In [18]:
output_block = mm.OutputBlock(schema, task_blocks=mm.MLPBlock([32]))

or...

In [19]:
output_block = mm.ParallelBlock(
  mm.BinaryOutput("click", pre=mm.MLPBlock([64])), 
  mm.BinaryOutput("like",  pre=mm.MLPBlock([32]))
)

In [20]:
model = mm.Model(
    mm.InputBlockV2(schema),
    mm.MLPBlock([32,16]),
    output_block
)
model.compile(optimizer="adam", run_eagerly=False)
model.fit(train_ds, batch_size=BATCH_SIZE)



<keras.callbacks.History at 0x7f5941ac81c0>

### MMoE architecture

The [**Multi-gate Mixture-of-Experts (MMoE)**](https://dl.acm.org/doi/pdf/10.1145/3219819.3220007) architecture was introduced in 2018 and is one of the most popular models for multi-task learning on tabular data. It is based on the former one-gate **Mixture of Experts (MoE)**, which proposed having different sub-networks (experts) projecting the inputs independently and then having the experts outputs weighted averaged by a gate to for a shared representation to be used for all tasks. The MMoE architecture took a step further and proposed having an independent gate for each task, so that they could choose how better combine the experts outputs. You can find more details in the [MMoE paper](https://dl.acm.org/doi/pdf/10.1145/3219819.3220007).

The MMoE architecture can be created for your dataset by with just a few lines of code!

In [21]:
inputs = mm.InputBlockV2(schema)
output_block = mm.OutputBlock(schema, task_blocks=mm.MLPBlock([16]))
mmoe = mm.MMOEBlock(
    output_block,
    expert_block=mm.MLPBlock([16]),
    num_experts=4,
    gate_block=mm.MLPBlock([16]),
)
model = mm.Model(inputs, mmoe, output_block)
print(model)

Model(
  (blocks): _TupleWrapper((ParallelBlock(
    (_aggregation): ConcatFeatures()
    (parallel_layers): Dict(
      (categorical): ParallelBlock(
        (parallel_layers): Dict(
          (user_id): EmbeddingTable(
            (features): Dict(
              (user_id): ColumnSchema(name='user_id', tags={<Tags.CATEGORICAL: 'categorical'>, <Tags.USER_ID: 'user_id'>, <Tags.USER: 'user'>, <Tags.ID: 'id'>}, properties={'max_size': 0.0, 'freq_threshold': 0.0, 'cat_path': './/categories/unique.user_id.parquet', 'embedding_sizes': {'dimension': 512.0, 'cardinality': 2633851.0}, 'num_buckets': None, 'start_index': 1.0, 'domain': {'min': 0, 'max': 2633851, 'name': 'user_id'}}, dtype=dtype('int32'), is_list=False, is_ragged=False)
            )
            (table): Embedding()
          )
          (item_id): EmbeddingTable(
            (features): Dict(
              (item_id): ColumnSchema(name='item_id', tags={<Tags.ITEM: 'item'>, <Tags.CATEGORICAL: 'categorical'>, <Tags.ID: 'id'>, <Tags

In [22]:
model.compile(optimizer="adam", run_eagerly=False)
model.fit(train_ds, batch_size=BATCH_SIZE)



<keras.callbacks.History at 0x7f59414de040>

### CGC and PLE architectures

The **PLE** architecture was introduced in 2020 in this [paper](https://dl.acm.org/doi/10.1145/3383313.3412236). The authors observed that architectures like **MMoE** presented a "seesaw" phenomenon, where improving the accuracy of one task hurts the accuracy of other tasks.  
So instead of having all tasks sharing all the experts, the proposed allowing for some task-specific experts and shared experts, which they named **Customized Gate Control (CGC) Model**, for which we provide a building block.   
Notice that `CGCBlock` has separate arguments for `num_task_experts` and `num_shared_experts`.

In [23]:
inputs = mm.InputBlockV2(schema)
output_block = mm.OutputBlock(schema, task_blocks=mm.MLPBlock([16]))

cgc = mm.CGCBlock(
    output_block,
    expert_block=mm.MLPBlock([16]),
    num_task_experts=2,
    num_shared_experts=3,
    schema=schema,
)
model = mm.Model(inputs, cgc, output_block)

In [24]:
model.compile(optimizer="adam", run_eagerly=False)
model.fit(train_ds, batch_size=BATCH_SIZE)



<keras.callbacks.History at 0x7f59409ba310>

Furthermore, the [paper](https://dl.acm.org/doi/10.1145/3383313.3412236) authors proposed stacking multiple **CGC** model on top of each other to form a multi-level MTL model, which they called **Progressive Layered Extraction (PLE)**. The `PLEBlock` introduces the `num_layers`, which controls the number of levels.   

You can see how easy is to use such a state-of-the-art MTL model with just a few lines of code with Merlin Models!

In [25]:
inputs = mm.InputBlockV2(schema)
output_block = mm.OutputBlock(schema, task_blocks=mm.MLPBlock([16]))

ple = mm.PLEBlock(
    num_layers=2,
    outputs=output_block,
    expert_block=mm.MLPBlock([16]),
    num_task_experts=2,
    num_shared_experts=1,
)
model = mm.Model(inputs, ple, output_block)

In [26]:
model.compile(optimizer="adam", run_eagerly=False)
model.fit(train_ds, batch_size=BATCH_SIZE)



<keras.callbacks.History at 0x7f5927bce7f0>

In [27]:
metrics_results = model.evaluate(valid_ds, batch_size=BATCH_SIZE, return_dict=True)
metrics_results



{'loss': 8.085714340209961,
 'click/binary_output_loss': 0.693152666091919,
 'follow/binary_output_loss': 0.693204402923584,
 'like/binary_output_loss': 0.6931298971176147,
 'share/binary_output_loss': 0.6931481957435608,
 'watching_times/regression_output_loss': 5.313078880310059,
 'click/binary_output/precision': 0.5066480040550232,
 'click/binary_output/recall': 0.0725960060954094,
 'click/binary_output/binary_accuracy': 0.5023000240325928,
 'click/binary_output/auc': 0.49784404039382935,
 'follow/binary_output/precision': 0.49932488799095154,
 'follow/binary_output/recall': 0.9996996521949768,
 'follow/binary_output/binary_accuracy': 0.49924999475479126,
 'follow/binary_output/auc': 0.495095431804657,
 'like/binary_output/precision': 0.49295252561569214,
 'like/binary_output/recall': 0.13401229679584503,
 'like/binary_output/binary_accuracy': 0.5022500157356262,
 'like/binary_output/auc': 0.5016313195228577,
 'share/binary_output/precision': 0.500723659992218,
 'share/binary_output