In [1]:
import os
import pprint
import tempfile
import datetime
from typing import Dict, Text

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_recommenders as tfrs
import requests
import itertools

In [2]:
validation_filename = "../data/samples/validation_transactions.tfrecord"
validation = tf.data.TFRecordDataset(validation_filename)

feature_description = {
    'context_item_id': tf.io.FixedLenFeature([10], tf.int64, default_value=np.repeat(0, 10)),    
    'context_item_quantity': tf.io.FixedLenFeature([10], tf.float32, default_value=np.repeat(0, 10)),
    'context_item_price': tf.io.FixedLenFeature([10], tf.float32, default_value=np.repeat(0, 10)),    
    'context_item_department_id': tf.io.FixedLenFeature([10], tf.int64, default_value=np.repeat(0, 10)),
    'label_item_id': tf.io.FixedLenFeature([1], tf.int64, default_value=0),
}

def _parse_function(example_proto):
  return tf.io.parse_single_example(example_proto, feature_description)
def _map_function(x):
  return {
    "context_item_id": tf.strings.as_string(x["context_item_id"]),
    "context_item_price": float(x["context_item_price"]),
    "context_item_department_id": int(x["context_item_department_id"]),
    "label_item_id": tf.strings.as_string(x["label_item_id"])}

test_ds = validation.map(_parse_function).map(_map_function)

In [3]:
path = '../data/seq_model_2features'
model = tf.saved_model.load(path)

In [4]:
feature0 = ["81747", "22107", "133321", "74199", "13247", "23321", "15511", "23321", "11855", "99657"]
feature1 = [753, 744, 744, 749, 692, 739, 739, 739, 673, 739]

features = { 
  'context_item_id' : tf.constant(feature0,shape=(1,10,1)), 
  'context_item_department_id' : tf.constant(feature1, shape=(1,10,1))
}
scores, predicted = model(features, training=False)
print(scores)
print(predicted)

tf.Tensor(
[[13.113109 12.863562 12.779797 12.628791 12.378885 12.347994 12.319192
  12.247762 12.201349 12.161086 12.156644 12.153313]], shape=(1, 12), dtype=float32)
tf.Tensor(
[[[b'82196']
  [b'119292']
  [b'82034']
  [b'143007']
  [b'108984']
  [b'135453']
  [b'23234']
  [b'157353']
  [b'33041']
  [b'15595']
  [b'84455']
  [b'30467']]], shape=(1, 12, 1), dtype=string)


In [5]:
shuffled = test_ds.shuffle(42).take(5)

for x in shuffled:
    foo = { 
        'context_item_id': tf.reshape(x['context_item_id'], shape=(1,10,1), name='context_item_id'),
        'context_item_department_id' : tf.reshape(x['context_item_department_id'], shape=(1,10,1), name='context_item_department_id')
    }
    tf.print(model(foo, training=False), summarize=-1)
    print('Ground Truth:')
    print(x['label_item_id'].numpy())
    print('------')    

([[9.94413185 9.70933723 9.62080193 9.59712219 9.26239 9.09691334 9.09018421 9.06620884 9.00914764 8.98163414 8.89393806 8.86738873]], [[["132953"]
  ["59696"]
  ["128816"]
  ["153471"]
  ["80624"]
  ["132941"]
  ["119716"]
  ["154973"]
  ["154686"]
  ["78958"]
  ["107924"]
  ["132955"]]])
Ground Truth:
[b'36022']
------
([[10.5630436 9.43732548 8.9095211 8.88101673 8.76085281 8.74254608 8.59654808 8.42867374 8.38398 8.16695 8.16144943 8.11431694]], [[["24731"]
  ["23752"]
  ["23754"]
  ["62264"]
  ["112473"]
  ["79003"]
  ["133970"]
  ["133968"]
  ["122984"]
  ["150315"]
  ["73294"]
  ["9366"]]])
Ground Truth:
[b'24731']
------
([[11.6351089 11.5016623 10.6225872 10.3384609 10.2645159 9.92010498 9.88556385 9.7572 9.72596741 9.70106506 9.48393631 9.46371651]], [[["25569"]
  ["25567"]
  ["143333"]
  ["158934"]
  ["157621"]
  ["43768"]
  ["135326"]
  ["122571"]
  ["123683"]
  ["15671"]
  ["153987"]
  ["113983"]]])
Ground Truth:
[b'50130']
------
([[8.8512125 8.36244202 8.32926 8.29277134