# Multi-task recommenders

**Learning Objectives**
  1. Training a model which focuses on ratings.
  2. Training a model which focuses on retrieval.
  3. Training a joint model that assigns positive weights to both ratings & retrieval models.


## Introduction
In the basic retrieval notebook we built a retrieval system using movie watches as positive interaction signals.

In many applications, however, there are multiple rich sources of feedback to draw upon. For example, an e-commerce site may record user visits to product pages (abundant, but relatively low signal), image clicks, adding to cart, and, finally, purchases. It may even record post-purchase signals such as reviews and returns.

Integrating all these different forms of feedback is critical to building systems that users love to use, and that do not optimize for any one metric at the expense of overall performance.

In addition, building a joint model for multiple tasks may produce better results than building a number of task-specific models. This is especially true where some data is abundant (for example, clicks), and some data is sparse (purchases, returns, manual reviews). In those scenarios, a joint model may be able to use representations learned from the abundant task to improve its predictions on the sparse task via a phenomenon known as [transfer learning](https://en.wikipedia.org/wiki/Transfer_learning). For example, [this paper](https://openreview.net/pdf?id=SJxPVcSonN) shows that a model predicting explicit user ratings from sparse user surveys can be substantially improved by adding an auxiliary task that uses abundant click log data.

In this jupyter notebook, we are going to build a multi-objective recommender for Movielens, using both implicit (movie watches) and explicit signals (ratings).

Each learning objective will correspond to a __#TODO__ in the [student lab notebook](https://github.com/GoogleCloudPlatform/training-data-analyst/blob/master/courses/machine_learning/deepdive2/recommendation_systems/labs/multitask.ipynb) -- try to complete that notebook first before reviewing this solution notebook.


## Imports


Let's first get our imports out of the way.


In [2]:
# Installing the necessary libraries.
!pip install -q tensorflow-recommenders
!pip install -q --upgrade tensorflow-datasets

**NOTE: Please ignore any incompatibility warnings and errors and re-run the above cell before proceeding.**


In [3]:
# Importing the necessary modules
import os
import pprint
import tempfile

from typing import Dict, Text

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

In [4]:
import tensorflow_recommenders as tfrs

## Preparing the dataset

We're going to use the Movielens 100K dataset.

In [5]:
ratings = tfds.load('movielens/100k-ratings', split="train")
movies = tfds.load('movielens/100k-movies', split="train")

# Select the basic features.
ratings = ratings.map(lambda x: {
    "movie_title": x["movie_title"],
    "user_id": x["user_id"],
    "user_rating": x["user_rating"],
})
movies = movies.map(lambda x: x["movie_title"])

And repeat our preparations for building vocabularies and splitting the data into a train and a test set:

In [6]:
# Randomly shuffle data and split between train and test.
tf.random.set_seed(42)
shuffled = ratings.shuffle(100_000, seed=42, reshuffle_each_iteration=False)

train = shuffled.take(80_000)
test = shuffled.skip(80_000).take(20_000)

movie_titles = movies.batch(1_000)
user_ids = ratings.batch(1_000_000).map(lambda x: x["user_id"])

unique_movie_titles = np.unique(np.concatenate(list(movie_titles)))
unique_user_ids = np.unique(np.concatenate(list(user_ids)))

## A multi-task model

There are two critical parts to multi-task recommenders:

1. They optimize for two or more objectives, and so have two or more losses.
2. They share variables between the tasks, allowing for transfer learning.

In this jupyter notebook, we will define our models as before, but instead of having  a single task, we will have two tasks: one that predicts ratings, and one that predicts movie watches.

The user and movie models are as before:

```python
user_model = tf.keras.Sequential([
  tf.keras.layers.experimental.preprocessing.StringLookup(
      vocabulary=unique_user_ids, mask_token=None),
  # We add 1 to account for the unknown token.
  tf.keras.layers.Embedding(len(unique_user_ids) + 1, embedding_dimension)
])

movie_model = tf.keras.Sequential([
  tf.keras.layers.experimental.preprocessing.StringLookup(
      vocabulary=unique_movie_titles, mask_token=None),
  tf.keras.layers.Embedding(len(unique_movie_titles) + 1, embedding_dimension)
])
```

However, now we will have two tasks. The first is the rating task:

```python
tfrs.tasks.Ranking(
    loss=tf.keras.losses.MeanSquaredError(),
    metrics=[tf.keras.metrics.RootMeanSquaredError()],
)
```

Its goal is to predict the ratings as accurately as possible.

The second is the retrieval task:

```python
tfrs.tasks.Retrieval(
    metrics=tfrs.metrics.FactorizedTopK(
        candidates=movies.batch(128)
    )
)
```

As before, this task's goal is to predict which movies the user will or will not watch.

### Putting it together

We put it all together in a model class.

The new component here is that - since we have two tasks and two losses - we need to decide on how important each loss is. We can do this by giving each of the losses a weight, and treating these weights as hyperparameters. If we assign a large loss weight to the rating task, our model is going to focus on predicting ratings (but still use some information from the retrieval task); if we assign a large loss weight to the retrieval task, it will focus on retrieval instead.

In [7]:
class MovielensModel(tfrs.models.Model):

  def __init__(self, rating_weight: float, retrieval_weight: float) -> None:
    # We take the loss weights in the constructor: this allows us to instantiate
    # several model objects with different loss weights.

    super().__init__()

    embedding_dimension = 32

    # User and movie models.
    self.movie_model: tf.keras.layers.Layer = tf.keras.Sequential([
      tf.keras.layers.experimental.preprocessing.StringLookup(
        vocabulary=unique_movie_titles, mask_token=None),
      tf.keras.layers.Embedding(len(unique_movie_titles) + 1, embedding_dimension)
    ])
    self.user_model: tf.keras.layers.Layer = tf.keras.Sequential([
      tf.keras.layers.experimental.preprocessing.StringLookup(
        vocabulary=unique_user_ids, mask_token=None),
      tf.keras.layers.Embedding(len(unique_user_ids) + 1, embedding_dimension)
    ])

    # A small model to take in user and movie embeddings and predict ratings.
    # We can make this as complicated as we want as long as we output a scalar
    # as our prediction.
    self.rating_model = tf.keras.Sequential([
        tf.keras.layers.Dense(256, activation="relu"),
        tf.keras.layers.Dense(128, activation="relu"),
        tf.keras.layers.Dense(1),
    ])

    # The tasks.
    self.rating_task: tf.keras.layers.Layer = tfrs.tasks.Ranking(
        loss=tf.keras.losses.MeanSquaredError(),
        metrics=[tf.keras.metrics.RootMeanSquaredError()],
    )
    self.retrieval_task: tf.keras.layers.Layer = tfrs.tasks.Retrieval(
        metrics=tfrs.metrics.FactorizedTopK(
            candidates=movies.batch(128).map(self.movie_model)
        )
    )

    # The loss weights.
    self.rating_weight = rating_weight
    self.retrieval_weight = retrieval_weight

  def call(self, features: Dict[Text, tf.Tensor]) -> tf.Tensor:
    # We pick out the user features and pass them into the user model.
    user_embeddings = self.user_model(features["user_id"])
    # And pick out the movie features and pass them into the movie model.
    movie_embeddings = self.movie_model(features["movie_title"])
    
    return (
        user_embeddings,
        movie_embeddings,
        # We apply the multi-layered rating model to a concatentation of
        # user and movie embeddings.
        self.rating_model(
            tf.concat([user_embeddings, movie_embeddings], axis=1)
        ),
    )

  def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -> tf.Tensor:

    ratings = features.pop("user_rating")

    user_embeddings, movie_embeddings, rating_predictions = self(features)

    # We compute the loss for each task.
    rating_loss = self.rating_task(
        labels=ratings,
        predictions=rating_predictions,
    )
    retrieval_loss = self.retrieval_task(user_embeddings, movie_embeddings)

    # And combine them using the loss weights.
    return (self.rating_weight * rating_loss
            + self.retrieval_weight * retrieval_loss)

### Rating-specialized model

Depending on the weights we assign, the model will encode a different balance of the tasks. Let's start with a model that only considers ratings.

In [8]:
# Here, configuring the model with losses and metrics.
# TODO 1: Here is your code.
model = MovielensModel(rating_weight=1.0, retrieval_weight=0.0)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))

In [9]:
cached_train = train.shuffle(100_000).batch(8192).cache()
cached_test = test.batch(4096).cache()

In [10]:
# Training the ratings model.
model.fit(cached_train, epochs=3)
metrics = model.evaluate(cached_test, return_dict=True)

print(f"Retrieval top-100 accuracy: {metrics['factorized_top_k/top_100_categorical_accuracy']:.3f}.")
print(f"Ranking RMSE: {metrics['root_mean_squared_error']:.3f}.")

Epoch 1/3


















 1/10 [==>...........................] - ETA: 0s - root_mean_squared_error: 3.6922 - factorized_top_k/top_1_categorical_accuracy: 4.8828e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0031 - factorized_top_k/top_10_categorical_accuracy: 0.0060 - factorized_top_k/top_50_categorical_accuracy: 0.0294 - factorized_top_k/top_100_categorical_accuracy: 0.0575 - loss: 13.6320 - regularization_loss: 0.0000e+00 - total_loss: 13.6320

 2/10 [=====>........................] - ETA: 1s - root_mean_squared_error: 3.2709 - factorized_top_k/top_1_categorical_accuracy: 5.4932e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0031 - factorized_top_k/top_10_categorical_accuracy: 0.0058 - factorized_top_k/top_50_categorical_accuracy: 0.0298 - factorized_top_k/top_100_categorical_accuracy: 0.0587 - loss: 10.6989 - regularization_loss: 0.0000e+00 - total_loss: 10.6989



















Epoch 2/3


 1/10 [==>...........................] - ETA: 0s - root_mean_squared_error: 1.2016 - factorized_top_k/top_1_categorical_accuracy: 3.6621e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0027 - factorized_top_k/top_10_categorical_accuracy: 0.0066 - factorized_top_k/top_50_categorical_accuracy: 0.0291 - factorized_top_k/top_100_categorical_accuracy: 0.0580 - loss: 1.4439 - regularization_loss: 0.0000e+00 - total_loss: 1.4439

 2/10 [=====>........................] - ETA: 0s - root_mean_squared_error: 1.1857 - factorized_top_k/top_1_categorical_accuracy: 3.6621e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0029 - factorized_top_k/top_10_categorical_accuracy: 0.0060 - factorized_top_k/top_50_categorical_accuracy: 0.0295 - factorized_top_k/top_100_categorical_accuracy: 0.0590 - loss: 1.4059 - regularization_loss: 0.0000e+00 - total_loss: 1.4059



















Epoch 3/3


 1/10 [==>...........................] - ETA: 0s - root_mean_squared_error: 1.1343 - factorized_top_k/top_1_categorical_accuracy: 4.8828e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0028 - factorized_top_k/top_10_categorical_accuracy: 0.0062 - factorized_top_k/top_50_categorical_accuracy: 0.0289 - factorized_top_k/top_100_categorical_accuracy: 0.0587 - loss: 1.2867 - regularization_loss: 0.0000e+00 - total_loss: 1.2867

 2/10 [=====>........................] - ETA: 1s - root_mean_squared_error: 1.1330 - factorized_top_k/top_1_categorical_accuracy: 4.2725e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0028 - factorized_top_k/top_10_categorical_accuracy: 0.0058 - factorized_top_k/top_50_categorical_accuracy: 0.0295 - factorized_top_k/top_100_categorical_accuracy: 0.0594 - loss: 1.2838 - regularization_loss: 0.0000e+00 - total_loss: 1.2838



















1/5 [=====>........................] - ETA: 0s - root_mean_squared_error: 1.1173 - factorized_top_k/top_1_categorical_accuracy: 7.3242e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0032 - factorized_top_k/top_10_categorical_accuracy: 0.0042 - factorized_top_k/top_50_categorical_accuracy: 0.0271 - factorized_top_k/top_100_categorical_accuracy: 0.0603 - loss: 1.2483 - regularization_loss: 0.0000e+00 - total_loss: 1.2483











Retrieval top-100 accuracy: 0.060.
Ranking RMSE: 1.113.


The model does OK on predicting ratings (with an RMSE of around 1.11), but performs poorly at predicting which movies will be watched or not: its accuracy at 100 is almost 4 times worse than a model trained solely to predict watches.

### Retrieval-specialized model

Let's now try a model that focuses on retrieval only.

In [11]:
# Here, configuring the model with losses and metrics.
# TODO 2: Here is your code.
model = MovielensModel(rating_weight=0.0, retrieval_weight=1.0)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))

In [12]:
# Training the retrieval model.
model.fit(cached_train, epochs=3)
metrics = model.evaluate(cached_test, return_dict=True)

print(f"Retrieval top-100 accuracy: {metrics['factorized_top_k/top_100_categorical_accuracy']:.3f}.")
print(f"Ranking RMSE: {metrics['root_mean_squared_error']:.3f}.")

Epoch 1/3


















 1/10 [==>...........................] - ETA: 0s - root_mean_squared_error: 3.6975 - factorized_top_k/top_1_categorical_accuracy: 3.6621e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0032 - factorized_top_k/top_10_categorical_accuracy: 0.0061 - factorized_top_k/top_50_categorical_accuracy: 0.0288 - factorized_top_k/top_100_categorical_accuracy: 0.0581 - loss: 73817.5703 - regularization_loss: 0.0000e+00 - total_loss: 73817.5703

 2/10 [=====>........................] - ETA: 1s - root_mean_squared_error: 3.7038 - factorized_top_k/top_1_categorical_accuracy: 5.4932e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0046 - factorized_top_k/top_10_categorical_accuracy: 0.0087 - factorized_top_k/top_50_categorical_accuracy: 0.0372 - factorized_top_k/top_100_categorical_accuracy: 0.0731 - loss: 73818.4492 - regularization_loss: 0.0000e+00 - total_loss: 73818.4492



















Epoch 2/3


 1/10 [==>...........................] - ETA: 0s - root_mean_squared_error: 3.7272 - factorized_top_k/top_1_categorical_accuracy: 0.0024 - factorized_top_k/top_5_categorical_accuracy: 0.0156 - factorized_top_k/top_10_categorical_accuracy: 0.0354 - factorized_top_k/top_50_categorical_accuracy: 0.1642 - factorized_top_k/top_100_categorical_accuracy: 0.2831 - loss: 71781.0859 - regularization_loss: 0.0000e+00 - total_loss: 71781.0859

 2/10 [=====>........................] - ETA: 1s - root_mean_squared_error: 3.7329 - factorized_top_k/top_1_categorical_accuracy: 0.0026 - factorized_top_k/top_5_categorical_accuracy: 0.0167 - factorized_top_k/top_10_categorical_accuracy: 0.0360 - factorized_top_k/top_50_categorical_accuracy: 0.1637 - factorized_top_k/top_100_categorical_accuracy: 0.2853 - loss: 71630.4102 - regularization_loss: 0.0000e+00 - total_loss: 71630.4102



















Epoch 3/3


 1/10 [==>...........................] - ETA: 0s - root_mean_squared_error: 3.7464 - factorized_top_k/top_1_categorical_accuracy: 0.0028 - factorized_top_k/top_5_categorical_accuracy: 0.0216 - factorized_top_k/top_10_categorical_accuracy: 0.0432 - factorized_top_k/top_50_categorical_accuracy: 0.1786 - factorized_top_k/top_100_categorical_accuracy: 0.3026 - loss: 70042.7031 - regularization_loss: 0.0000e+00 - total_loss: 70042.7031

 2/10 [=====>........................] - ETA: 1s - root_mean_squared_error: 3.7518 - factorized_top_k/top_1_categorical_accuracy: 0.0026 - factorized_top_k/top_5_categorical_accuracy: 0.0205 - factorized_top_k/top_10_categorical_accuracy: 0.0430 - factorized_top_k/top_50_categorical_accuracy: 0.1790 - factorized_top_k/top_100_categorical_accuracy: 0.3050 - loss: 69991.5781 - regularization_loss: 0.0000e+00 - total_loss: 69991.5781



















1/5 [=====>........................] - ETA: 0s - root_mean_squared_error: 3.7755 - factorized_top_k/top_1_categorical_accuracy: 7.3242e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0115 - factorized_top_k/top_10_categorical_accuracy: 0.0254 - factorized_top_k/top_50_categorical_accuracy: 0.1272 - factorized_top_k/top_100_categorical_accuracy: 0.2383 - loss: 32458.2539 - regularization_loss: 0.0000e+00 - total_loss: 32458.2539











Retrieval top-100 accuracy: 0.235.
Ranking RMSE: 3.773.


We get the opposite result: a model that does well on retrieval, but poorly on predicting ratings.

### Joint model

Let's now train a model that assigns positive weights to both tasks.

In [13]:
# Here, configuring the model with losses and metrics.
# TODO 3: Here is your code.
model = MovielensModel(rating_weight=1.0, retrieval_weight=1.0)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))

In [14]:
# Training the joint model.
model.fit(cached_train, epochs=3)
metrics = model.evaluate(cached_test, return_dict=True)

print(f"Retrieval top-100 accuracy: {metrics['factorized_top_k/top_100_categorical_accuracy']:.3f}.")
print(f"Ranking RMSE: {metrics['root_mean_squared_error']:.3f}.")

Epoch 1/3


















 1/10 [==>...........................] - ETA: 0s - root_mean_squared_error: 3.6615 - factorized_top_k/top_1_categorical_accuracy: 2.4414e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0028 - factorized_top_k/top_10_categorical_accuracy: 0.0060 - factorized_top_k/top_50_categorical_accuracy: 0.0320 - factorized_top_k/top_100_categorical_accuracy: 0.0636 - loss: 73830.4453 - regularization_loss: 0.0000e+00 - total_loss: 73830.4453

 2/10 [=====>........................] - ETA: 0s - root_mean_squared_error: 3.1084 - factorized_top_k/top_1_categorical_accuracy: 7.9346e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0049 - factorized_top_k/top_10_categorical_accuracy: 0.0092 - factorized_top_k/top_50_categorical_accuracy: 0.0403 - factorized_top_k/top_100_categorical_accuracy: 0.0735 - loss: 73829.0508 - regularization_loss: 0.0000e+00 - total_loss: 73829.0508



















Epoch 2/3


 1/10 [==>...........................] - ETA: 0s - root_mean_squared_error: 1.2498 - factorized_top_k/top_1_categorical_accuracy: 0.0023 - factorized_top_k/top_5_categorical_accuracy: 0.0160 - factorized_top_k/top_10_categorical_accuracy: 0.0364 - factorized_top_k/top_50_categorical_accuracy: 0.1609 - factorized_top_k/top_100_categorical_accuracy: 0.2773 - loss: 71774.2812 - regularization_loss: 0.0000e+00 - total_loss: 71774.2812

 2/10 [=====>........................] - ETA: 0s - root_mean_squared_error: 1.2114 - factorized_top_k/top_1_categorical_accuracy: 0.0023 - factorized_top_k/top_5_categorical_accuracy: 0.0168 - factorized_top_k/top_10_categorical_accuracy: 0.0350 - factorized_top_k/top_50_categorical_accuracy: 0.1556 - factorized_top_k/top_100_categorical_accuracy: 0.2766 - loss: 71638.0938 - regularization_loss: 0.0000e+00 - total_loss: 71638.0938



















Epoch 3/3


 1/10 [==>...........................] - ETA: 0s - root_mean_squared_error: 1.1662 - factorized_top_k/top_1_categorical_accuracy: 0.0024 - factorized_top_k/top_5_categorical_accuracy: 0.0187 - factorized_top_k/top_10_categorical_accuracy: 0.0428 - factorized_top_k/top_50_categorical_accuracy: 0.1776 - factorized_top_k/top_100_categorical_accuracy: 0.3024 - loss: 70021.1250 - regularization_loss: 0.0000e+00 - total_loss: 70021.1250

 2/10 [=====>........................] - ETA: 1s - root_mean_squared_error: 1.1380 - factorized_top_k/top_1_categorical_accuracy: 0.0030 - factorized_top_k/top_5_categorical_accuracy: 0.0196 - factorized_top_k/top_10_categorical_accuracy: 0.0419 - factorized_top_k/top_50_categorical_accuracy: 0.1753 - factorized_top_k/top_100_categorical_accuracy: 0.3033 - loss: 69972.3867 - regularization_loss: 0.0000e+00 - total_loss: 69972.3867



















1/5 [=====>........................] - ETA: 0s - root_mean_squared_error: 1.1393 - factorized_top_k/top_1_categorical_accuracy: 4.8828e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0083 - factorized_top_k/top_10_categorical_accuracy: 0.0222 - factorized_top_k/top_50_categorical_accuracy: 0.1243 - factorized_top_k/top_100_categorical_accuracy: 0.2341 - loss: 32468.5820 - regularization_loss: 0.0000e+00 - total_loss: 32468.5820











Retrieval top-100 accuracy: 0.234.
Ranking RMSE: 1.131.


The result is a model that performs roughly as well on both tasks as each specialized model. 

While the results here do not show a clear accuracy benefit from a joint model in this case, multi-task learning is in general an extremely useful tool. We can expect better results when we can transfer knowledge from a data-abundant task (such as clicks) to a closely related data-sparse task (such as purchases).