In [2]:
import os
import pprint
import tempfile
import datetime
from typing import Dict, Text

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_recommenders as tfrs
import requests
import itertools

In [12]:
validation_filename = "../data/samples/validation_transactions.tfrecord"
validation = tf.data.TFRecordDataset(validation_filename)

feature_description = {
    'context_item_id': tf.io.FixedLenFeature([10], tf.int64, default_value=np.repeat(0, 10)),    
    'context_item_quantity': tf.io.FixedLenFeature([10], tf.float32, default_value=np.repeat(0, 10)),
    'context_item_price': tf.io.FixedLenFeature([10], tf.float32, default_value=np.repeat(0, 10)),    
    'context_item_department_id': tf.io.FixedLenFeature([10], tf.int64, default_value=np.repeat(0, 10)),
    'context_brand_code': tf.io.FixedLenFeature([10], tf.string, default_value=['none'] * 10),
    'label_item_id': tf.io.FixedLenFeature([1], tf.int64, default_value=0),
}

def _parse_function(example_proto):
  return tf.io.parse_single_example(example_proto, feature_description)
def _map_function(x):
  return {
    "context_item_id": tf.strings.as_string(x["context_item_id"]),
    "context_item_price": float(x["context_item_price"]),
    "context_item_department_id": x["context_item_department_id"],
    "context_item_brand_code" : x["context_brand_code"],
    "label_item_id": tf.strings.as_string(x["label_item_id"])}

test_ds = validation.map(_parse_function).map(_map_function).shuffle(42)

In [4]:
path = '../data/rich_features_2'
model = tf.saved_model.load(path)

In [16]:
features = test_ds.shuffle(42).take(1).get_single_element()
tf.print(features, summarize = -1)
features_dataset = { 
  'context_item_id' : tf.constant(features['context_item_id'],shape=(1,10,1), name='context_item_id', dtype=tf.string), 
  'context_item_department_id' : tf.constant(features['context_item_department_id'], shape=(1,10,1), name='context_item_department_id', dtype=tf.int64),
  'context_item_brand_code' : tf.constant(features['context_item_brand_code'], shape=(1,10,1), name='context_item_brand_code', dtype=tf.string)
}
scores, predicted = model(features_dataset, training=False)
print(scores)
print(predicted)

{'context_item_brand_code': ["2288V" "1178V" "1178V" "1178V" "1178V" "2288V" "2288V" "2288V" "2755V" "1178V"],
 'context_item_department_id': [736 735 735 735 735 736 736 736 753 735],
 'context_item_id': ["67829" "9415" "80951" "80951" "80951" "67829" "108318" "108318" "56667" "80951"],
 'context_item_price': [3.49 4.59 2.19 2.19 2.19 3.49 9.99 9.99 11.99 2.19],
 'label_item_id': ["80951"]}
tf.Tensor(
[[13.095102 12.707839 12.293754 11.799453 11.689468 11.430653 11.405487
  11.249676 11.185574 10.955563 10.873921 10.853102]], shape=(1, 12), dtype=float32)
tf.Tensor(
[[[b'94219']
  [b'80951']
  [b'108318']
  [b'150318']
  [b'25107']
  [b'80950']
  [b'81852']
  [b'80958']
  [b'153138']
  [b'9415']
  [b'83828']
  [b'81005']]], shape=(1, 12, 1), dtype=string)


In [None]:
shuffled = test_ds.shuffle(42).take(5)

for x in shuffled:
    foo = { 
        'context_item_id': tf.reshape(x['context_item_id'], shape=feature_shape, name='context_item_id'),
        'context_item_department_id' : tf.reshape(x['context_item_department_id'], shape=feature_shape, name='context_item_department_id'),
        'context_item_price' : tf.reshape(x['context_item_price'], shape=feature_shape, name='context_item_price')
    }
    scores, predictions = model(foo, training=False)
    foo = list(itertools.chain.from_iterable(predictions[0].numpy()))
    print(foo)
    print('Ground Truth:', end='')
    print(x['label_item_id'].numpy())
    print('------')    

[b'117206', b'131481', b'122714', b'130596', b'54697', b'54700', b'65110', b'130616', b'32453', b'54660']
Ground Truth:[b'74158']
------
[b'60588', b'66135', b'138145', b'26434', b'62969', b'138143', b'141882', b'74293', b'125712', b'21765']
Ground Truth:[b'138143']
------
[b'51089', b'105637', b'113558', b'36329', b'51090', b'151992', b'132293', b'36326', b'29737', b'51092']
Ground Truth:[b'128701']
------
[b'24731', b'23752', b'73294', b'146130', b'59532', b'23754', b'84367', b'133970', b'63388', b'137629']
Ground Truth:[b'24731']
------
[b'60463', b'100954', b'142288', b'112477', b'112488', b'105297', b'133189', b'132624', b'138976', b'13152']
Ground Truth:[b'100952']
------
