In [1]:
import os
import pprint
import tempfile
import datetime
from typing import Dict, Text

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_recommenders as tfrs
import requests
import itertools

In [2]:
validation_filename = "../data/samples/validation_transactions.tfrecord"
validation = tf.data.TFRecordDataset(validation_filename)

feature_description = {
    'context_item_id': tf.io.FixedLenFeature([10], tf.int64, default_value=np.repeat(0, 10)),    
    'context_item_quantity': tf.io.FixedLenFeature([10], tf.float32, default_value=np.repeat(0, 10)),
    'context_item_price': tf.io.FixedLenFeature([10], tf.float32, default_value=np.repeat(0, 10)),    
    'context_item_department_id': tf.io.FixedLenFeature([10], tf.int64, default_value=np.repeat(0, 10)),
    'label_item_id': tf.io.FixedLenFeature([1], tf.int64, default_value=0),
}

def _parse_function(example_proto):
  return tf.io.parse_single_example(example_proto, feature_description)
def _map_function(x):
  return {
    "context_item_id": tf.strings.as_string(x["context_item_id"]),
    "context_item_price": float(x["context_item_price"]),
    "context_item_department_id": int(x["context_item_department_id"]),
    "label_item_id": tf.strings.as_string(x["label_item_id"])}

test_ds = validation.map(_parse_function).map(_map_function).shuffle(42)

In [3]:
from os import sep


for x in test_ds.take(1):
    tf.print(x['context_item_id'], summarize=-1, sep=',')
    tf.print(x['context_item_price'], summarize=-1,sep=',')
    tf.print(x['context_item_department_id'], summarize=-1, sep=',')


["101639" "147393" "48364" "48364" "48364" "114590" "83058" "64752" "64752" "33554"]
[7.99 2.99 4.99 4.99 4.99 10.99 10.69 5.25 5.25 1.25]
[673 742 753 753 753 736 736 735 735 749]


In [4]:
path = '../data/rich_2features_model'
model = tf.saved_model.load(path)

In [6]:
feature_shape = (1,10,1)
feature0 = ["81747", "22107", "133321", "74199", "13247", "23321", "15511", "23321", "11855", "99657"]
feature1 = [753, 744, 744, 749, 692, 739, 739, 739, 673, 739]
feature2 = [16.99, 4.99, 22.99, 8.99, 1.99, 119.99, 45.99, 18.99, 1249.99, 1.29]

features = { 
  'context_item_id' : tf.constant(feature0,shape=feature_shape), 
  'context_item_department_id' : tf.constant(feature1, shape=feature_shape),
  'context_item_price' : tf.constant(feature2, shape=feature_shape)
}
scores, predicted = model(features, training=False)
print(scores)
print(predicted)

tf.Tensor(
[[113.49897  110.71968  110.61253  110.57574  110.34062  109.993996
  109.61926  109.560875 109.533104 109.42425 ]], shape=(1, 10), dtype=float32)
tf.Tensor(
[[[b'36797']
  [b'134743']
  [b'15636']
  [b'15510']
  [b'15532']
  [b'25199']
  [b'123778']
  [b'134805']
  [b'15638']
  [b'55504']]], shape=(1, 10, 1), dtype=string)


In [7]:
shuffled = test_ds.shuffle(42).take(5)

for x in shuffled:
    foo = { 
        'context_item_id': tf.reshape(x['context_item_id'], shape=feature_shape, name='context_item_id'),
        'context_item_department_id' : tf.reshape(x['context_item_department_id'], shape=feature_shape, name='context_item_department_id'),
        'context_item_price' : tf.reshape(x['context_item_price'], shape=feature_shape, name='context_item_price')
    }
    scores, predictions = model(foo, training=False)
    foo = list(itertools.chain.from_iterable(predictions[0].numpy()))
    print(foo)
    print('Ground Truth:', end='')
    print(x['label_item_id'].numpy())
    print('------')    

[b'117206', b'131481', b'122714', b'130596', b'54697', b'54700', b'65110', b'130616', b'32453', b'54660']
Ground Truth:[b'74158']
------
[b'60588', b'66135', b'138145', b'26434', b'62969', b'138143', b'141882', b'74293', b'125712', b'21765']
Ground Truth:[b'138143']
------
[b'51089', b'105637', b'113558', b'36329', b'51090', b'151992', b'132293', b'36326', b'29737', b'51092']
Ground Truth:[b'128701']
------
[b'24731', b'23752', b'73294', b'146130', b'59532', b'23754', b'84367', b'133970', b'63388', b'137629']
Ground Truth:[b'24731']
------
[b'60463', b'100954', b'142288', b'112477', b'112488', b'105297', b'133189', b'132624', b'138976', b'13152']
Ground Truth:[b'100952']
------
