In [1]:
import os
import pprint
import tempfile

from typing import Dict, Text

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_recommenders as tfrs

In [2]:
train_filename = "../data/samples/train_transactions_rich.tfrecord"
train = tf.data.TFRecordDataset(train_filename)

test_filename = "../data/samples/test_transactions_rich.tfrecord"
test = tf.data.TFRecordDataset(test_filename)

feature_description = {
    'context_item_id': tf.io.FixedLenFeature([10], tf.int64, default_value=np.repeat(0, 10)),    
    'context_item_price': tf.io.FixedLenFeature([10], tf.float32, default_value=np.repeat(0, 10)),
    'context_item_discount': tf.io.FixedLenFeature([10], tf.float32, default_value=np.repeat(0, 10)),
    'context_item_description': tf.io.FixedLenFeature([10], tf.string, default_value=np.repeat('None', 10)), 
    'label_item_id': tf.io.FixedLenFeature([1], tf.int64, default_value=0),
}

def _parse_function(example_proto):
  return tf.io.parse_single_example(example_proto, feature_description)
def _map_function(x):
  return {
    "context_item_id": tf.strings.as_string(x["context_item_id"]),
    "context_item_description": x["context_item_description"],
    "context_item_price": x["context_item_price"],
    "context_item_discount": x["context_item_discount"],
    "label_item_id": tf.strings.as_string(x["label_item_id"])}
    
train_ds = train.map(_parse_function).map(_map_function)
test_ds = test.map(_parse_function).map(_map_function)
cached_train = train_ds.shuffle(10_000).batch(12800).cache()
cached_test = test_ds.batch(2560).cache()

In [3]:
items_filename = "../data/samples/items.tfrecord"
items_tf = tf.data.TFRecordDataset(items_filename)
item_feature_description = {
    'item_id': tf.io.FixedLenFeature([1], tf.int64, default_value=0),
    'item_fullprice' : tf.io.FixedLenFeature([1], tf.float32, default_value=0),
    'item_description': tf.io.FixedLenFeature([1], tf.string, default_value='None')}
def item_parse_function(example_proto):
  return tf.io.parse_single_example(example_proto, item_feature_description)

items_ds = items_tf.map(item_parse_function).map(lambda x: {
    "item_id": tf.strings.as_string(x["item_id"]),
    "item_fullprice": x["item_fullprice"],
})
item_ids = items_ds.map(lambda x: x["item_id"]).batch(1_000)
unique_item_ids = np.unique(np.concatenate(list(item_ids)))
item_prices = np.concatenate(list(items_ds.map(lambda x: x["item_fullprice"]).batch(1000)))

In [4]:
class ItemEmbeddingModel(tf.keras.Model):
  embedding_dimension = 32
  def __init__(self):
    super().__init__()

    self.item_embedding = tf.keras.Sequential([
      tf.keras.layers.StringLookup(vocabulary=unique_item_ids, mask_token=None, name='item_id_string_lookup'),
      tf.keras.layers.Embedding(len(unique_item_ids) + 1, self.embedding_dimension, name='item_id_embedding'),
      tf.keras.layers.GRU(self.embedding_dimension, name='item_id_rnn'),
    ], name='item_embedding')

    self.price_normalization = tf.keras.layers.Normalization(axis=-1)

    self.price_embedding = tf.keras.Sequential([
      self.price_normalization,
      tf.keras.layers.Embedding(len(item_prices)+1,output_dim=self.embedding_dimension, mask_zero=True),
      tf.keras.layers.GRU(self.embedding_dimension, name='item_price_rnn')
    ], name='price_embedding')

    self.price_normalization.adapt(item_prices)

  def call(self, features):
    return tf.concat([
        self.item_embedding(features["context_item_id"]),
        self.price_embedding(features["context_item_price"]),
    ], axis=1)

class DeepLayerModel(tf.keras.Model):
  """Model for encoding user queries."""

  def __init__(self, layer_sizes, embedding_model):
    """Model for encoding user queries.

    Args:
      layer_sizes:
        A list of integers where the i-th entry represents the number of units
        the i-th layer contains.
    """
    super().__init__()

    # We first use the user model for generating embeddings.
    self.embedding_model = embedding_model

    # Then construct the layers.
    self.dense_layers = tf.keras.Sequential()

    # Use the ReLU activation for all but the last layer.
    for layer_size in layer_sizes[:-1]:
      self.dense_layers.add(tf.keras.layers.Dense(layer_size, activation="relu"))

    # No activation for the last layer.
    for layer_size in layer_sizes[-1:]:
      self.dense_layers.add(tf.keras.layers.Dense(layer_size))

  def call(self, inputs):
    feature_embedding = self.embedding_model(inputs)
    return self.dense_layers(feature_embedding)

In [5]:
class RetrievalModel(tfrs.Model):
    def __init__(self):
        super().__init__()
        self._query_model = DeepLayerModel([128,64,32], ItemEmbeddingModel())
        self._candidate_model = tf.keras.Sequential([
            tf.keras.layers.StringLookup(vocabulary=unique_item_ids, mask_token=None, name='candidate_itemid_lookup'),
            tf.keras.layers.Embedding(len(unique_item_ids) + 1, 32, name='candidate_embedding_lookup'),
            ], name='candidate_model')
        metrics = tfrs.metrics.FactorizedTopK(candidates=items_tf.batch(128).map(self._candidate_model))
        self._task = tfrs.tasks.Retrieval(metrics=metrics)

    def compute_loss(self, features, training=False):
        item_history = {
            "context_item_id": features["context_item_id"],
            "context_item_price": features["context_item_price"]}   
        next_item_label = features["label_item_id"]

        query_embedding = self._query_model(item_history)       
        candidate_embedding = self._candidate_model(next_item_label)

        return self._task(query_embedding, candidate_embedding, compute_metrics=not training)

In [6]:
model = RetrievalModel()
model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.1))

In [7]:
model.fit(cached_train, epochs=9)

Epoch 1/24
Epoch 2/24
Epoch 3/24
Epoch 4/24
Epoch 5/24
Epoch 6/24
Epoch 7/24
Epoch 8/24
Epoch 9/24
Epoch 10/24
Epoch 11/24
Epoch 12/24
Epoch 13/24
Epoch 14/24
Epoch 15/24
Epoch 16/24
Epoch 17/24
Epoch 18/24
Epoch 19/24
Epoch 20/24
Epoch 21/24
Epoch 22/24
Epoch 23/24
Epoch 24/24


<keras.callbacks.History at 0x180a0982a90>

In [8]:
model.evaluate(cached_test, return_dict=True)

     87/Unknown - 182s 2s/step - factorized_top_k/top_1_categorical_accuracy: 0.9119 - factorized_top_k/top_5_categorical_accuracy: 0.9119 - factorized_top_k/top_10_categorical_accuracy: 0.9119 - factorized_top_k/top_50_categorical_accuracy: 0.9119 - factorized_top_k/top_100_categorical_accuracy: 0.9119 - loss: 16907.4118 - regularization_loss: 0.0000e+00 - total_loss: 16907.4118

In [27]:
# Create a model that takes in raw query features, and
index = tfrs.layers.factorized_top_k.BruteForce(model._query_model)
# recommends movies out of the entire movies dataset.
items = items_ds.map(lambda x: x["item_id"])
index.index_from_dataset(
  tf.data.Dataset.zip((items.batch(100), items.batch(100).map(model._candidate_model)))
)
feature0 = ['139716','35287','142953','132041','','','','','','']
feature1 = [17.99, 14.99, 4.99, 14.50, 0.,0.,0.,0.,0.,0.]
features_dataset = { 'context_item_id' : tf.constant(feature0,shape=(1,10,1)), 'context_item_price' : tf.constant(feature1,shape=(1,10,1))}

_, predicted = index(features_dataset)
# tf.constant(['139716','35287','142953','132041','','','','','','',],shape=(1,10,1))
print(predicted)


array([[[b'139716'],
        [b'35287'],
        [b'142953'],
        [b'132041'],
        [b''],
        [b''],
        [b''],
        [b''],
        [b''],
        [b'']]], dtype=object)>, 'context_item_price': <tf.Tensor: shape=(1, 10, 1), dtype=float32, numpy=
array([[[17.99],
        [14.99],
        [ 4.99],
        [14.5 ],
        [ 0.  ],
        [ 0.  ],
        [ 0.  ],
        [ 0.  ],
        [ 0.  ],
        [ 0.  ]]], dtype=float32)>}. Consider rewriting this model with the Functional API.
tf.Tensor(
[[[b'60518']
  [b'60485']
  [b'56420']
  [b'158934']
  [b'158889']
  [b'9155']
  [b'66130']
  [b'119830']
  [b'29635']
  [b'148082']]], shape=(1, 10, 1), dtype=string)


In [None]:
path = '../data/retrieval_model'
tf.saved_model.save(index, path)

In [None]:
foo, titles = loaded(tf.constant(['139716','35287','142953','132041','','','','','','',],shape=(1,10,1)))

In [None]:
print (titles)