In [8]:
from typing import Dict, Text

import numpy as np
import tensorflow as tf

import tensorflow_datasets as tfds
import tensorflow_recommenders as tfrs

In [17]:
# Ratings data.
ratings = tfds.load('movielens/100k-ratings', split="train")
# Features of all the available movies.
movies = tfds.load('movielens/100k-movies', split="train")

# Select the basic features.
ratings = ratings.map(lambda x: {
    "movie_title": x["movie_title"],
    "user_id": x["user_id"]
})
movies = movies.map(lambda x: x["movie_title"])

In [25]:
np.array(list(movies.as_numpy_iterator())).shape

(1682,)

In [6]:
list(ratings.map(lambda x: x["user_id"]).take(1))

[<tf.Tensor: shape=(), dtype=string, numpy=b'138'>]

In [3]:
user_ids_vocabulary = tf.keras.layers.StringLookup(mask_token=None)
user_ids_vocabulary.adapt(ratings.map(lambda x: x["user_id"]))

movie_titles_vocabulary = tf.keras.layers.StringLookup(mask_token=None)
movie_titles_vocabulary.adapt(movies)

In [4]:
class MovieLensModel(tfrs.Model):
  # We derive from a custom base class to help reduce boilerplate. Under the hood,
  # these are still plain Keras Models.

  def __init__(
      self,
      user_model: tf.keras.Model,
      movie_model: tf.keras.Model,
      task: tfrs.tasks.Retrieval):
    super().__init__()

    # Set up user and movie representations.
    self.user_model = user_model
    self.movie_model = movie_model

    # Set up a retrieval task.
    self.task = task

  def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -> tf.Tensor:
    # Define how the loss is computed.

    user_embeddings = self.user_model(features["user_id"])
    movie_embeddings = self.movie_model(features["movie_title"])

    return self.task(user_embeddings, movie_embeddings)

In [5]:
# Define user and movie models.
user_model = tf.keras.Sequential([
    user_ids_vocabulary,
    tf.keras.layers.Embedding(user_ids_vocabulary.vocabulary_size(), 64)
])
movie_model = tf.keras.Sequential([
    movie_titles_vocabulary,
    tf.keras.layers.Embedding(movie_titles_vocabulary.vocabulary_size(), 64)
])

# Define your objectives.
task = tfrs.tasks.Retrieval(metrics=tfrs.metrics.FactorizedTopK(
    movies.batch(128).map(movie_model)
  )
)

In [6]:
# Create a retrieval model.
model = MovieLensModel(user_model, movie_model, task)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.5))

# Train for 3 epochs.
model.fit(ratings.batch(4096), epochs=3)

# Use brute-force search to set up retrieval using the trained representations.
index = tfrs.layers.factorized_top_k.BruteForce(model.user_model)
index.index_from_dataset(
    movies.batch(100).map(lambda title: (title, model.movie_model(title))))

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow_recommenders.layers.factorized_top_k.BruteForce at 0x2d724557790>

In [14]:
# Get some recommendations.
_, titles = index(np.array(["2"]), k=200)
print(f"Top 3 recommendations for user 2: {titles[0, :3]}")

Top 3 recommendations for user 2: [b'3 Ninjas: High Noon At Mega Mountain (1998)' b'Promesse, La (1996)'
 b'Once Upon a Time... When We Were Colored (1995)']


In [15]:
titles

<tf.Tensor: shape=(1, 200), dtype=string, numpy=
array([[b'3 Ninjas: High Noon At Mega Mountain (1998)',
        b'Promesse, La (1996)',
        b'Once Upon a Time... When We Were Colored (1995)',
        b'For the Moment (1994)', b"Antonia's Line (1995)",
        b"Marvin's Room (1996)", b'Shall We Dance? (1996)',
        b'FairyTale: A True Story (1997)',
        b'Cry, the Beloved Country (1995)',
        b'Sense and Sensibility (1995)', b"Boy's Life 2 (1997)",
        b'Postman, The (1997)', b'Deceiver (1997)', b'Deceiver (1997)',
        b'Family Thing, A (1996)', b'Ponette (1996)',
        b'Guantanamera (1994)', b'Mighty Aphrodite (1995)',
        b'Paradise Road (1997)', b'Secrets & Lies (1996)',
        b'Before and After (1996)', b'Bed of Roses (1996)',
        b'Kolya (1996)', b'To Gillian on Her 37th Birthday (1996)',
        b"Widows' Peak (1994)", b'Breakdown (1997)',
        b'Postino, Il (1994)', b'River Wild, The (1994)',
        b'Midnight in the Garden of Good and Ev