In [1]:
import tensorflow as tf

In [2]:
keys = ['a', 'b', 'c']
lookup = tf.lookup.StaticHashTable(
            tf.lookup.KeyValueTensorInitializer(
                keys=tf.constant(keys),
                values=tf.cast(tf.range(len(keys)), dtype=tf.int64),
                key_dtype=tf.string,
                value_dtype=tf.int64,
            ),
            default_value=-1,
        )

In [3]:
lookup.lookup(tf.constant([["a", "b"], ["b", "c"]]))

<tf.Tensor: shape=(2, 2), dtype=int64, numpy=
array([[0, 1],
       [1, 2]])>

In [60]:
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import numpy as np
import jax
import jax.numpy as jnp
import tensorflow as tf

In [61]:
import flax

In [70]:
import jax
import jax.numpy as jnp
from flax.training.train_state import TrainState
import flax.linen as nn

class FeedForward(nn.Module):

    @nn.compact
    def __call__(self, inputs):
        x, y, z = inputs
        x = nn.Dense(10)(x)
        y = nn.Dense(10)(y)
        z = nn.Dense(10)(z) 

        c = jnp.concatenate([x, y, z], axis=-1)
        out = nn.Dense(10)(c)
        out = nn.relu(out)
        out = nn.Dense(1)(out)
        return out

# init model
rng = jax.random.PRNGKey(0)
input_shape = (1, 2)
model = FeedForward()
params = model.init(rng, [jnp.ones(input_shape), jnp.ones(input_shape), jnp.ones(input_shape)])

In [71]:
import optax
lr = 0.001
momentum = 0.9
tx = optax.sgd(lr, momentum)
state = TrainState.create(apply_fn=model.apply, params=params["params"], tx=tx)

In [72]:
state.params

{'Dense_0': {'kernel': Array([[-1.0669333 , -0.8827529 , -0.00570564,  0.603013  ,  1.0019107 ,
           0.28586274, -0.15108383, -0.7755998 , -0.7439601 , -0.6873577 ],
         [-0.9427239 ,  1.2475272 ,  0.50599295,  1.0079471 ,  0.46328905,
          -0.08936425, -0.39633456, -1.4424019 , -0.48588607,  0.4325949 ]],      dtype=float32),
  'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)},
 'Dense_1': {'kernel': Array([[-0.9376454 ,  1.3456881 ,  0.12843367,  0.04694894,  0.38935086,
           0.16207337, -1.2485627 , -0.05399594,  0.6877025 , -0.91342294],
         [ 0.39089563,  0.30140245, -0.18343109, -1.0258522 ,  0.26916406,
           0.2980016 , -0.3358007 , -1.2467374 , -0.6431401 ,  1.0915588 ]],      dtype=float32),
  'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)},
 'Dense_2': {'kernel': Array([[ 0.6666576 , -1.0246173 , -0.05735719,  0.41652977,  1.2951125 ,
          -0.07839004,  0.50164974, -0.95135075, -0.56714505,  0

In [73]:
state.apply_fn

<bound method Module.apply of FeedForward()>

In [75]:
import tensorflow as tf

signature_dict = {
  "x": tf.TensorSpec(shape=(2,), dtype=tf.float32), 
  "y": tf.TensorSpec(shape=(2,), dtype=tf.float32),
  "z": tf.TensorSpec(shape=(2,), dtype=tf.float32)
}
@tf.function(input_signature=[signature_dict])
def f(inputs):
  return [inputs["x"], inputs["y"], inputs["z"]]

# Construct a JaxModule where JAX->TF conversion happens.
jax_module = JaxModule({"params": state.params}, state.apply_fn)
# Export the JaxModule along with one or more serving configs.
export_mgr = ExportManager(
  jax_module, [
    ServingConfig(
      'serving_default',
      # input_signature=[tf.TensorSpec(shape=(10), dtype=tf.float32)],
      tf_preprocessor=f,
      # tf_postprocessor=example1_postprocess
    ),
])
output_dir='/tmp/example1_output_dir'
export_mgr.save(output_dir)

INFO:tensorflow:Assets written to: /tmp/example1_output_dir/assets


INFO:tensorflow:Assets written to: /tmp/example1_output_dir/assets


In [76]:
jnp.ones(input_shape)

Array([[1., 1.]], dtype=float32)

In [11]:

batch_size = 32
inputs = tf.random.normal([10,], dtype=tf.float32)
batch = [inputs] * batch_size

loaded_model = tf.saved_model.load(output_dir)
loaded_model_outputs = loaded_model({"x": inputs, "y": inputs})
print("loaded model output: ", loaded_model_outputs)

loaded model output:  tf.Tensor([-0.12772207], shape=(1,), dtype=float32)


2023-11-29 19:46:38.041351: E ./tensorflow/compiler/xla/stream_executor/stream_executor_internal.h:124] SetPriority unimplemented for this stream.


In [77]:

for lookup_name in ["album", "artist", "track"]:
    lookup = tf.keras.layers.StringLookup()
    lookup.adapt(keys)
    dst = f"/tmp/vocab/lookup/{lookup_name}"
    if not tf.io.gfile.exists(dst):
        tf.io.gfile.makedirs(dst)
    lookup.save_assets(f"/tmp/vocab/lookup/{lookup_name}")

tf.keras.layers.StringLookup().load_assets("/tmp/vocab/lookup/album/")





In [178]:
def create_file_lookup(filename):
    with tf.init_scope():
        initializer = tf.lookup.TextFileInitializer(
            filename,
            key_dtype=tf.string, 
            key_index=tf.lookup.TextFileIndex.WHOLE_LINE, 
            value_dtype=tf.int64, 
            value_index=tf.lookup.TextFileIndex.LINE_NUMBER,
            value_index_offset=1, # starting from 1
        )
        table = tf.lookup.StaticHashTable(initializer, 0)
        
    return table

In [191]:
# inputs = tf.keras.Input(shape=(1,), dtype=tf.string)
t = create_file_lookup("/tmp/vocab/lookup/album/vocabulary.txt")
ll = tf.keras.layers.Lambda(lambda x: t.lookup(x))
ll.build((None, 1))
m = tf.keras.Model()
m.lookup = ll

In [192]:
t = create_file_lookup("/tmp/vocab/lookup/album/vocabulary.txt")

# .lookup(tf.constant(["a", "b", "c"]))

In [None]:
ll = tf.keras.Sequential([
    tf.keras.layers.StringLookup(
        vocabulary=t,
        mask_token=None,
        num_oov_indices=0,
        output_mode="int",
    )])

In [241]:
m = tf.keras.Sequential([tf.keras.layers.StringLookup(
    vocabulary=["a", "b", "c"],
    mask_token=None,
    num_oov_indices=0,
    output_mode="int",
)])

m.compile()

m.predict(["a"])



array([0])

In [250]:
inputs = tf.keras.Input(shape=(1,), dtype=tf.string)

t = create_file_lookup("/tmp/vocab/lookup/album/vocabulary.txt")
t.lookup(tf.constant(["a", "b", "c"]))
l = tf.keras.layers.StringLookup(
    vocabulary=["a", "b", "c"],
    mask_token=None,
    num_oov_indices=0,
    output_mode="int",
)(inputs)




m = tf.keras.Model(
    inputs=inputs,
    outputs=l)

In [258]:
m = tf.keras.Model()
m.lookup = l

In [259]:

tf.saved_model.save(m, "/tmp/vocab/lookup/album/")





INFO:tensorflow:Assets written to: /tmp/vocab/lookup/album/assets


INFO:tensorflow:Assets written to: /tmp/vocab/lookup/album/assets


In [252]:
loaded = tf.saved_model.load("/tmp/vocab/lookup/album/")

In [256]:
loaded(tf.constant([["a"]]))

<tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[0]])>

In [243]:
tf.keras.Sequential([])

<keras.src.engine.sequential.Sequential at 0x2cc4eff10>

In [260]:

def get_vocab(src):
    lookup_layer = tf.keras.layers.StringLookup()
    lookup_layer.load_assets(src)

    return lookup_layer.get_vocabulary()


    
def make_servable(model, src, dst):
    vocab = get_vocab(src)

    lookup_table = tf.lookup.StaticVocabularyTable(
        tf.lookup.KeyValueTensorInitializer(
            keys=tf.constant(vocab),
            values=tf.cast(tf.range(len(vocab)), dtype=tf.int64),
            key_dtype=tf.string,
            value_dtype=tf.int64,
        ),
        num_oov_buckets=1,
    )

    model.lookup = lookup_table


signature_dict = {
"album": tf.TensorSpec(shape=(2), dtype=tf.string), 
"artist": tf.TensorSpec(shape=(2), dtype=tf.string),
"track": tf.TensorSpec(shape=(2), dtype=tf.string)
}
@tf.function(input_signature=[signature_dict])
def preprocessing_fn(inputs):
    album = inputs["album"]
    # artist = inputs["artist"]
    # track = inputs["track"]

    # tables = {}

    with tf.init_scope():
        src = f"/tmp/vocab/lookup/album/"
        lookup_layer = tf.keras.layers.StringLookup()
        lookup_layer.load_assets(src)
    # lookup_table = tf.lookup.StaticVocabularyTable(
    #     tf.lookup.KeyValueTensorInitializer(
    #         keys=tf.constant(["a", "b", "c"]),
    #         values=tf.cast(tf.range(len(["a", "b", "c"])), dtype=tf.int64),
    #         key_dtype=tf.string,
    #         value_dtype=tf.int64,
    #     ),
    #     num_oov_buckets=1,
    # )
        # src = f"/tmp/vocab/lookup/album/"
        # lookup_table = create_file_lookup(src)
        # initializer = tf.lookup.TextFileInitializer(
        #     src,
        #     key_dtype=tf.string, 
        #     key_index=tf.lookup.TextFileIndex.WHOLE_LINE, 
        #     value_dtype=tf.int64, 
        #     value_index=tf.lookup.TextFileIndex.LINE_NUMBER,
        #     value_index_offset=1, # starting from 1
        # )
        # lookup_table = tf.lookup.StaticHashTable(initializer, 0)

        # lookup_table = tf.keras.layers.StringLookup()
        # lookup_table.load_assets("/tmp/vocab/lookup/album/")
        # resolve issue with uninitialized variables
        # for lookup_name in ["album", "artist", "track"]:

            
        #     lookup_layer = tf.keras.layers.StringLookup()
        #     lookup_layer.load_assets(src)
        #     vocab = lookup_layer.get_vocabulary()

        #     # vocab = get_vocab(f"/tmp/vocab/lookup/{lookup_name}/")

        #     lookup_table = tf.lookup.StaticVocabularyTable(
        #         tf.lookup.KeyValueTensorInitializer(
        #             keys=tf.constant(vocab),
        #             values=tf.cast(tf.range(len(vocab)), dtype=tf.int64),
        #             key_dtype=tf.string,
        #             value_dtype=tf.int64,
        #         ),
        #         num_oov_buckets=1,
        #     )

            # tables[lookup_name] = get_lookup_table(f"/tmp/vocab/lookup/{lookup_name}/")
    out = lookup_layer(album)
    out = [out] * 3

    return out
    
    # return [
    #     lookup_table.lookup(album),
    #     # tables["artist"].lookup(artist),
    #     # tables["track"].lookup(track),
    # ] * 3
    

In [279]:
class Preprocessing(tf.Module):

  def __init__(self, src):
    self.lookup = tf.lookup.StaticVocabularyTable(
        tf.lookup.KeyValueTensorInitializer(
            keys=tf.constant(["a", "b", "c"]),
            values=tf.cast(tf.range(len(["a", "b", "c"])), dtype=tf.int64),
            key_dtype=tf.string,
            value_dtype=tf.int64,
        ),
        num_oov_buckets=1,
    )
  def __call__(self, x):
    return self.lookup.lookup(x)

    

model = Preprocessing("/tmp/vocab/lookup/album/")
model(tf.constant(["a", "b", "c"]))


signature_dict = {
"album": tf.TensorSpec(shape=(2), dtype=tf.string), 
"artist": tf.TensorSpec(shape=(2), dtype=tf.string),
"track": tf.TensorSpec(shape=(2), dtype=tf.string)
}
@tf.function(input_signature=[signature_dict])
def example1_preprocess(inputs):  # Optional: preprocessor in TF.
  album = inputs["album"]
  out = model(album)
  # with tf.init_scope():
  #   src = f"/tmp/vocab/lookup/album/"
  #   lookup_layer = tf.keras.layers.StringLookup()
  #   lookup_layer.load_assets(src)
  # norm_inputs = tf.nest.map_structure(lambda x: lookup(x), album)
  return [out] * 3

In [280]:
preprocessing_fn({"album": tf.constant(["a", "b"]), "artist": tf.constant(["a", "b"]), "track": tf.constant(["a", "b"])})

[<tf.Tensor: shape=(2,), dtype=int64, numpy=array([3, 2])>,
 <tf.Tensor: shape=(2,), dtype=int64, numpy=array([3, 2])>,
 <tf.Tensor: shape=(2,), dtype=int64, numpy=array([3, 2])>]

In [291]:
@tf.keras.utils.register_keras_serializable(name='lookup')
class Lookup(tf.keras.layers.Layer):
    def __init__(self, vocab_lookup_layer, **kwargs):
        super(Lookup, self).__init__(**kwargs)
        # save the constructor parameters for get_config() to work properly
        self.vocab_lookup_layer = vocab_lookup_layer
    
    def call(self, x, training=False):
        # split the string on spaces, and make it a rectangular tensor
        inputs = x["album"]
        tokens = self.vocab_lookup_layer.lookup(inputs)
        return [tokens] * 3

    def get_config(self):
        config = super().get_config()
        # save constructor args
        config['vocab_lookup_layer'] = self.vocab_lookup_layer
        return config

In [292]:
tf.lookup.StaticVocabularyTable(
        tf.lookup.KeyValueTensorInitializer(
            keys=tf.constant(["a", "b", "c"]),
            values=tf.cast(tf.range(len(["a", "b", "c"])), dtype=tf.int64),
            key_dtype=tf.string,
            value_dtype=tf.int64,
        ),
        num_oov_buckets=1,
    )
l = Lookup(t)


In [309]:

signature_dict = {
"album": tf.TensorSpec(shape=(2), dtype=tf.string), 
"artist": tf.TensorSpec(shape=(2), dtype=tf.string),
"track": tf.TensorSpec(shape=(2), dtype=tf.string)
}
@tf.function(input_signature=[signature_dict])
def example1_preprocess(inputs):  # Optional: preprocessor in TF.
  album = inputs["album"]
  # out = model(album)
  with tf.init_scope():
    t = tf.lookup.StaticVocabularyTable(
        tf.lookup.KeyValueTensorInitializer(
            keys=tf.constant(["a", "b", "c"]),
            values=tf.cast(tf.range(len(["a", "b", "c"])), dtype=tf.int64),
            key_dtype=tf.string,
            value_dtype=tf.int64,
        ),
        num_oov_buckets=1,
    )
    l = Lookup(t)
  out = tf.cast(l({"album": album}), tf.int32)
  #   src = f"/tmp/vocab/lookup/album/"
  #   lookup_layer = tf.keras.layers.StringLookup()
  #   lookup_layer.load_assets(src)
  # norm_inputs = tf.nest.map_structure(lambda x: lookup(x), album)
  return [out] * 3

In [330]:
lookup = tf.lookup.StaticVocabularyTable(
            tf.lookup.KeyValueTensorInitializer(
                keys=tf.constant(["a", "b", "c"]),
                values=tf.cast(tf.range(len(["a", "b", "c"])), dtype=tf.int64),
                key_dtype=tf.string,
                value_dtype=tf.int64,
            ),
            num_oov_buckets=1,
        )

In [331]:
lookup.lookup(tf.constant(["a", "b", "c"]))

<tf.Tensor: shape=(3,), dtype=int64, numpy=array([0, 1, 2])>

In [327]:
class Preprocessing(tf.Module):
    def __init__(self):
        self.lookup = tf.lookup.StaticVocabularyTable(
            tf.lookup.KeyValueTensorInitializer(
                keys=tf.constant(["a", "b", "c"]),
                values=tf.cast(tf.range(len(["a", "b", "c"])), dtype=tf.int64),
                key_dtype=tf.string,
                value_dtype=tf.int64,
            ),
            num_oov_buckets=1,
        )
    
    @tf.function(input_signature=[signature_dict])
    def __call__(self, inputs):
        album = inputs["album"]
        out = tf.cast(self.lookup.lookup(album), tf.int32)
        return [out] * 3

In [348]:
t = tf.lookup.StaticVocabularyTable(
      tf.lookup.KeyValueTensorInitializer(
          keys=tf.constant(["a", "b", "c"]),
          values=tf.cast(tf.range(len(["a", "b", "c"])), dtype=tf.int64),
          key_dtype=tf.string,
          value_dtype=tf.int64,
      ),
      num_oov_buckets=1,
  )

m = tf.Module()
m.t = t

signature_dict = {
"album": tf.TensorSpec(shape=(2), dtype=tf.string), 
"artist": tf.TensorSpec(shape=(2), dtype=tf.string),
"track": tf.TensorSpec(shape=(2), dtype=tf.string)
}
@tf.function(input_signature=[signature_dict])
def example1_preprocess(inputs):  # Optional: preprocessor in TF.
  album = inputs["album"]
  # out = model(album)
  
  out = tf.cast(m.t.lookup(album), tf.int32)
  #   src = f"/tmp/vocab/lookup/album/"
  #   lookup_layer = tf.keras.layers.StringLookup()
  #   lookup_layer.load_assets(src)
  # norm_inputs = tf.nest.map_structure(lambda x: lookup(x), album)
  return [out] * 3

m.serving = example1_preprocess

In [349]:
jax_module = JaxModule({"params": state.params}, state.apply_fn)
# Export the JaxModule along with one or more serving configs.
export_mgr = ExportManager(
  jax_module, [
    ServingConfig(
      'serving_default',
      # input_signature=[
      #   {
      #     "album": tf.TensorSpec(shape=(2), dtype=tf.string), 
      #     "artist": tf.TensorSpec(shape=(2), dtype=tf.string),
      #     "track": tf.TensorSpec(shape=(2), dtype=tf.string)
      #   }
      # ],
      tf_preprocessor=m.serving,
      # tf_postprocessor=example1_postprocess
    ),
])
output_dir='/tmp/example1_output_dir'
export_mgr.save(output_dir)

AssertionError: Tried to export a function which references an 'untracked' resource. TensorFlow objects (e.g. tf.Variable) captured by functions must be 'tracked' by assigning them to an attribute of a tracked object or assigned to an attribute of the main object directly. See the information below:
	Function name = b'__inference_signature_wrapper_inference_fn_20160'
	Captured Tensor = <ResourceHandle(name="19975", device="/job:localhost/replica:0/task:0/device:CPU:0", container="localhost", type="tensorflow::lookup::LookupInterface", dtype and shapes : "[  ]")>
	Trackable referencing this tensor = <tensorflow.python.ops.lookup_ops.HashTable object at 0x151723820>
	Internal Tensor = Tensor("20134:0", shape=(), dtype=resource)