In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import tensorflow as tf

In [3]:
# Preprocessing recipe

class SimplePreprocessor(tf.Module):
    
    def __init__(self):
        self.norm = tf.keras.layers.Normalization()

    def fit(self, data):
        self.norm.adapt(data)

    @tf.function(input_signature=[tf.TensorSpec([None, 1], tf.float32)])
    def train_fn(self, examples):
        return {
            "normalized_features": self.norm(examples)
        }

    @tf.function(input_signature=[tf.TensorSpec([None, 1], tf.float32)])
    def serving_fn(self, examples):
        return {
            "normalized_features": self.norm(examples)
        }


p = SimplePreprocessor()
ds = tf.data.Dataset.range(100).batch(5, drop_remainder=True).map(lambda x: tf.cast(tf.reshape(x, [-1, 1]), tf.float32))
for x in ds.take(1):
    print(x.shape)
p.fit(ds)

tf.saved_model.save(p, "./models/simple_preprocessor", signatures={"serving_default": p.serving_fn, "train_default": p.train_fn})
loaded = tf.saved_model.load("./models/simple_preprocessor")
loaded.signatures["serving_default"]
norm_features = loaded.signatures["serving_default"](tf.constant([[49.5]]))["normalized_features"].numpy()[0][0]
print(f"norm_features: {norm_features:0.4f}")

(5, 1)
INFO:tensorflow:Assets written to: ./models/simple_preprocessor/assets
norm_features: 0.0000


# Preprocessing 

In [4]:
import input_pipeline as ip

### low level implementation with preprocessing

In [5]:

p = ip.Preprocessor()
ds = tf.data.TFRecordDataset(["data/example_gen/test.tfrecord"])

parsed_ds = ds.map(lambda x: tf.io.parse_single_example(x, ip.TRAIN_SCHEMA)).batch(16)
for x in parsed_ds.take(1):
    print({k: v.shape for k, v in x.items()})
p.fit(parsed_ds)

{'Amount': TensorShape([16]), 'Class': TensorShape([16]), 'V1': TensorShape([16]), 'V10': TensorShape([16]), 'V11': TensorShape([16]), 'V12': TensorShape([16]), 'V13': TensorShape([16]), 'V14': TensorShape([16]), 'V15': TensorShape([16]), 'V16': TensorShape([16]), 'V17': TensorShape([16]), 'V18': TensorShape([16]), 'V19': TensorShape([16]), 'V2': TensorShape([16]), 'V20': TensorShape([16]), 'V21': TensorShape([16]), 'V22': TensorShape([16]), 'V23': TensorShape([16]), 'V24': TensorShape([16]), 'V25': TensorShape([16]), 'V26': TensorShape([16]), 'V27': TensorShape([16]), 'V28': TensorShape([16]), 'V3': TensorShape([16]), 'V4': TensorShape([16]), 'V5': TensorShape([16]), 'V6': TensorShape([16]), 'V7': TensorShape([16]), 'V8': TensorShape([16]), 'V9': TensorShape([16])}


In [6]:
for x, y in parsed_ds.map(p.preprocessing_fn).take(1):
    print(x.shape)
    print(y.shape)

(16, 29)
(16,)


### Leverage build-in dataset capabilities

In [7]:
src = "data/example_gen/train.tfrecord"
ds = tf.data.experimental.make_batched_features_dataset(
        file_pattern=src,
        batch_size=64,
        features=ip.TRAIN_SCHEMA,
        label_key=ip.LABEL_KEY,
        reader=tf.data.TFRecordDataset,
        shuffle_buffer_size=10000,
        shuffle_seed=42,
        num_epochs=1,
        prefetch_buffer_size=1000,
        reader_num_threads=8,
        parser_num_threads=8,
        drop_final_batch=True,
    )

# fit the preprocessor
p.fit(ds.map(lambda x, _: x)) 

for x, y in ds.map(lambda x, y: (p.serving_fn(x), tf.reshape(y, (-1, 1)))).take(1).as_numpy_iterator():
    print(x.shape)
    print(y.shape)
    # print({k: v.shape for k, v in x.items()})
    # print(y.shape)


Instructions for updating:
Use `tf.data.Dataset.map(tf.io.parse_example(...))` instead.


Instructions for updating:
Use `tf.data.Dataset.map(tf.io.parse_example(...))` instead.


(64, 29)
(64, 1)


### get datasets and preprocessing

In [8]:
p = ip.Preprocessor()
assert not p.norm.is_adapted
train_ds, val_ds = ip.get_datasets(
    preprocessor=p,
    train_src="data/example_gen/train.tfrecord",
    val_src="data/example_gen/test.tfrecord",
)

# validate the preprocessor is adapted to the training data
assert p.norm.is_adapted

In [9]:
for x, y  in train_ds:
    print(x.shape)
    print(y.shape)
    break

(64, 29)
(64, 1)


In [10]:
# model building with flax, clu

In [11]:
import flax
from flax.training import train_state
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
import orbax
import clu

In [12]:
class CreditCardFraudModel(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(64)(x)
        x = nn.relu(x)
        x = nn.Dense(64)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        return x

def init_model(rng, input_shape):
    model = CreditCardFraudModel()
    params = model.init(rng, jnp.ones(input_shape, jnp.float32))
    return model, params

def create_train_state(rng, input_shape, learning_rate=1e-3):
    """Creates initial `TrainState`."""
    model, params = init_model(rng, input_shape)
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=tx
    )

batch_size = 64
n_features = 29
input_shape = (batch_size, n_features)
rng = jax.random.PRNGKey(42)
state = create_train_state(rng, input_shape)
    

In [13]:
train_ds, val_ds = ip.get_datasets(
    preprocessor=p,
    train_src="data/example_gen/train.tfrecord",
    val_src="data/example_gen/test.tfrecord",
)
for x, y in train_ds:
    print(x.shape)
    print(y.shape)
    # check if y is all 0s
    if jnp.all(y == 0):
        continue

    print(x)
    break


(64, 29)
(64, 1)
[[ 1.5086735  -1.2670166  -1.2634001  ... -0.0185016  -0.06934007
   1.3445339 ]
 [ 2.7657025  -0.8945842   0.03694475 ...  0.16129413  0.11783955
  -0.9416372 ]
 [-3.7643168  -3.4835062  -0.24297953 ... -0.5262222   0.3545389
   1.5926738 ]
 ...
 [ 3.9149642   0.85186887  0.39583114 ... -0.5226201   0.03370051
  -0.59020585]
 [ 1.7115073   0.9279888  -0.43129742 ... -0.14537239 -0.29199368
   1.0793302 ]
 [ 0.65398043  0.8988843  -0.40111452 ... -0.45063296 -0.16948257
   0.9866318 ]]


In [14]:
from clu import metrics

@flax.struct.dataclass
class Precision(metrics.Metric):
  """Computes the precision from model outputs `logits` and `labels`."""

  true_positives: jnp.array
  pred_positives: jnp.array

  @classmethod
  def from_model_output(cls, *, logits: jnp.array, labels: jnp.array,
                        **_) -> metrics.Metric:
    assert logits.shape[-1] == 2, "Expected binary logits."
    preds = logits.argmax(axis=-1)
    return cls(
        true_positives=((preds == 1) & (labels == 1)).sum(),
        pred_positives=(preds == 1).sum(),
    )

  def merge(self, other: metrics.Metric) -> metrics.Metric:
    # Note that for precision we cannot average metric values because the
    # denominator of the metric value is pred_positives and not every batch of
    # examples has the same number of pred_positives (as opposed to e.g.
    # accuracy where every batch has the same number of)
    return type(self)(
        true_positives=self.true_positives + other.true_positives,
        pred_positives=self.pred_positives + other.pred_positives,
    )

  def compute(self):
    return self.true_positives / self.pred_positives

  def empty(self):
    return type(self)(true_positives=0, pred_positives=0)

  @classmethod
  def empty(cls) -> "Precision":
    return cls(true_positives=0, pred_positives=0)

@flax.struct.dataclass  # <-- required for JAX transformations
class MetricCollection(metrics.Collection):
  loss : metrics.Average.from_output('loss')
  accuracy : metrics.Accuracy
  precision: Precision

metric_collection = MetricCollection.empty()

@jax.jit
def train_step(state: train_state.TrainState, x, y, train_metrics):
  """Train for a single step."""
  def loss_fn(params):
    logits = state.apply_fn(
        params,
      x=x)
    loss = optax.sigmoid_binary_cross_entropy(logits, y)
    loss = jnp.mean(loss)
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  logits = jnp.concatenate([1 - logits, logits], axis=-1)

  return state, train_metrics.merge(MetricCollection.single_from_model_output(
    loss=loss,
    labels=y.squeeze(),
    logits=logits,
))
  

@jax.jit
def eval_step(state: train_state.TrainState, x, y, eval_metrics):
  """Evaluates `state` on `x` and `y`."""
  logits = state.apply_fn(
      state.params,
      x=x)
  loss = optax.sigmoid_binary_cross_entropy(logits, y)
  loss = jnp.mean(loss)

  logits = jnp.concatenate([1 - logits, logits], axis=-1)

  

  return eval_metrics.merge(MetricCollection.single_from_model_output(
    loss=loss,
    labels=y.squeeze(),
    logits=logits,
))
  
  

  

In [15]:
from clu import periodic_actions

class TensorboardCallback(periodic_actions.PeriodicCallback):
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @staticmethod
    def write_metrics(step: int, t:float, *, writer, train_metrics, eval_metrics):
        writer.write_scalars(step, {f"train/{k}": v for k, v in train_metrics.compute().items()})
        writer.write_scalars(step, {f"eval/{k}": v for k, v in eval_metrics.compute().items()})
        
class ReportProgress(periodic_actions.ReportProgress):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def __call__(self, step, t = None, **kwargs):
        return super().__call__(step, t)
    
    def _apply(self, step, t, **kwargs):
        super()._apply(step, t)

class Profile(periodic_actions.Profile):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def __call__(self, step, t = None, **kwargs):
        return super().__call__(step, t)
    
    def _apply(self, step, t, **kwargs):
        super()._apply(step, t)
    

In [16]:
from clu import metric_writers
from clu import periodic_actions
from absl import logging

logging.set_verbosity(logging.WARNING)
# create random 1 and 0 labels with shape (batch_size, 1)
batch_x = jax.random.normal(rng, shape=(batch_size, n_features))
batch_y = jax.random.randint(rng, shape=(batch_size, 1), minval=0, maxval=2)


logdir = "./logs"
n_epochs = 10
n_batches_per_epoch = 1000
total_steps = n_epochs * n_batches_per_epoch
writer = metric_writers.create_default_writer(logdir)


hooks = [
    # Outputs progress via metric writer (in this case logs & TensorBoard).
    ReportProgress(
        num_train_steps=total_steps,
        every_steps=n_batches_per_epoch, writer=writer),
    Profile(logdir=logdir),
    TensorboardCallback(callback_fn=TensorboardCallback.write_metrics, every_steps=n_batches_per_epoch),
]

state = create_train_state(rng, input_shape)
train_metrics = metric_collection.empty()
eval_metrics = metric_collection.empty()


for step in range(total_steps):
    state, train_metrics = train_step(state, batch_x, batch_y, train_metrics)
    eval_metrics = eval_step(state, batch_x, batch_y, eval_metrics)
    for hook in hooks:
        hook(step, writer=writer, train_metrics=train_metrics, eval_metrics=eval_metrics)
    
    if not step % n_batches_per_epoch:
        train_metrics = metric_collection.empty()
        eval_metrics = metric_collection.empty()
        
# for epoch in range(10):
#     train_metrics = metric_collection.empty()
#     eval_metrics = metric_collection.empty()
#     for b_update in range(1000):
#         state, train_metrics = train_step(state, batch_x, batch_y, train_metrics)
#         eval_metrics = eval_step(state, batch_x, batch_y, eval_metrics)
        # show wrong predictions
print(f"step {step}")
print(train_metrics.compute(), eval_metrics.compute())




step 9999
{'loss': Array(1.07910274e-07, dtype=float32), 'accuracy': Array(1., dtype=float32), 'precision': Array(1., dtype=float32)} {'loss': Array(1.0784902e-07, dtype=float32), 'accuracy': Array(1., dtype=float32), 'precision': Array(1., dtype=float32)}


In [17]:
%load_ext tensorboard
%tensorboard --logdir=./logs

In [18]:

# path absolute path to "./models"
from pathlib import Path

path = Path("./models/checkpoints/")
model_dir = path.absolute()

In [19]:
from typing import Tuple
import orbax.checkpoint as ocp

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow as tf
from flax.training.train_state import TrainState
from orbax.export import ExportManager, JaxModule, ServingConfig
import datetime
import os


def create_manager(model_dir):
    options = ocp.CheckpointManagerOptions(
        max_to_keep=3,
        save_interval_steps=2,
        create=True
    )

    mngr = ocp.CheckpointManager(
        model_dir,
        ocp.PyTreeCheckpointer(),
        options=options
    )
    return mngr

def restore_or_create_state(mngr, rng, input_shape, reinit=False):
    if mngr.latest_step() is None or reinit:
        return create_train_state(rng, input_shape)
    target = {
        "model": create_train_state(rng, input_shape)
    }
    restored_state = mngr.restore(mngr.latest_step(), items=target)["model"]
    return restored_state
   

def to_saved_model(state, preprocessing_fn, output_dir, etr=None, model_name="creditcard"):
    # Construct a JaxModule where JAX->TF conversion happens.

    timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
    jax_module = JaxModule(
        state.params, 
        state.apply_fn, 
        trainable=False,
        jit_compile=False,
        jax2tf_kwargs={
            "enable_xla": False
        },
        input_polymorphic_shape='(b, ...)')
    # Export the JaxModule along with one or more serving configs.
    export_mgr = ExportManager(
        jax_module,
        [
            ServingConfig(
                "serving_default",
                tf_preprocessor=preprocessing_fn,
                # tf_postprocessor=exampe1_postprocess
                extra_trackable_resources=etr,
            ),
        ],
    )
    export_mgr.save(os.path.join(output_dir, model_name, timestamp))
 

In [20]:

p = ip.Preprocessor()
train_ds, val_ds = ip.get_datasets(
    preprocessor=p,
    train_src="data/example_gen/train.tfrecord",
    val_src="data/example_gen/test.tfrecord",
)

writer = metric_writers.create_default_writer(logdir)
hooks = [
    # Outputs progress via metric writer (in this case logs & TensorBoard).
    ReportProgress(
        num_train_steps=total_steps,
        every_steps=n_batches_per_epoch, writer=writer),
    Profile(logdir=logdir),
    TensorboardCallback(callback_fn=TensorboardCallback.write_metrics, every_steps=n_batches_per_epoch),
]

n_train_steps = 1000
n_eval_staps = 100

mngr = create_manager(model_dir)
state = restore_or_create_state(mngr, rng, input_shape)
saved_model_dir = "./models/saved_model"

n_steps_taken = 0
for epoch in range(10):
    train_metrics = metric_collection.empty()
    eval_metrics = metric_collection.empty()
    for step in range(n_train_steps):
        x, y = next(train_ds)
        state, train_metrics = train_step(state, x, y, train_metrics)

    for step in range(n_eval_staps):
        x, y = next(val_ds)
        eval_metrics = eval_step(state, x, y, eval_metrics)
        for hook in hooks:
            hook(n_steps_taken, writer=writer, train_metrics=train_metrics, eval_metrics=eval_metrics)

        mngr.save(step, {"model": state})
        n_steps_taken += 1


print(
    train_metrics.compute(),
    eval_metrics.compute(), 
)
restored_state = restore_or_create_state(mngr, rng, input_shape)

eval_m = metric_collection.empty()
for _ in range(100):
    x, y = next(val_ds)
    eval_m = eval_step(restored_state, x, y, eval_m)

print(eval_m.compute())

to_saved_model(restored_state, p.serving_fn, saved_model_dir, etr={"preprocessor": p.norm})




ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/stefruinard/Documents/personal/projects/202312_probabilistic_deep_learning/augmented_deep_learning/venv/lib/python3.9/site-packages/clu/periodic_actions.py", line 359, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/stefruinard/Documents/personal/projects/202312_probabilistic_deep_learning/augmented_deep_learning/venv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/stefruinard/Documents/personal/projects/202312_probabilistic_deep_learning/augmented_deep_learning/venv/lib/python3.9/site-packages/jax/_src/profiler.py", line 118, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.


{'loss': Array(0.0021319, dtype=float32), 'accuracy': Array(0.99951565, dtype=float32), 'precision': Array(0.9056604, dtype=float32)} {'loss': Array(0.00121001, dtype=float32), 'accuracy': Array(0.99953127, dtype=float32), 'precision': Array(1., dtype=float32)}
{'loss': Array(0.00489452, dtype=float32), 'accuracy': Array(0.99921876, dtype=float32), 'precision': Array(1., dtype=float32)}
INFO:tensorflow:Assets written to: ./models/saved_model/creditcard/20231220133420/assets


INFO:tensorflow:Assets written to: ./models/saved_model/creditcard/20231220133420/assets


In [21]:
raw_ds = tf.data.TFRecordDataset(["data/example_gen/test.tfrecord"])
parsed_ds = raw_ds.map(lambda x: tf.io.parse_single_example(x, ip.SERVING_SCHEMA)).batch(16)
batch = next(iter(parsed_ds))

In [22]:
saved_model_dir

'./models/saved_model'

In [23]:
# get last credit card model in saved_model_dir/models
import os
import glob
import tensorflow as tf

creditcard_model_dir = os.path.join(saved_model_dir, "creditcard")
versions = [int(v) for v in os.listdir(creditcard_model_dir)]
latest_version = max(versions)
latest_model = os.path.join(creditcard_model_dir, str(latest_version))


In [24]:

predictions = tf.saved_model.load(latest_model).signatures["serving_default"](**batch)["output_0"]
print(predictions)

tf.Tensor(
[[ -9.401173 ]
 [ -8.420352 ]
 [ -8.5709095]
 [ -9.2276325]
 [ -7.2797523]
 [ -8.968629 ]
 [ -9.036577 ]
 [ -8.611546 ]
 [ -8.864519 ]
 [ -6.95437  ]
 [ -7.7919803]
 [ -9.988268 ]
 [ -9.603817 ]
 [-11.05954  ]
 [-11.539463 ]
 [-10.359598 ]], shape=(16, 1), dtype=float32)


In [25]:
batch["Amount"].numpy().reshape(-1, 1).tolist()

[[50.0],
 [14.949999809265137],
 [7.699999809265137],
 [6.989999771118164],
 [460.7099914550781],
 [68.0],
 [56.310001373291016],
 [30.520000457763672],
 [19.989999771118164],
 [40.22999954223633],
 [10.0],
 [9.989999771118164],
 [505.92999267578125],
 [33.369998931884766],
 [126.0],
 [385.9800109863281]]

In [26]:
serving_batch = []
parsed_ds = raw_ds.map(lambda x: tf.io.parse_single_example(x, ip.SERVING_SCHEMA)).batch(16)
serving_batch = [{"Amount": x["Amount"].numpy().reshape(-1, 1).tolist()} for x in parsed_ds.take(1)]
serving_batch


[{'Amount': [[50.0],
   [14.949999809265137],
   [7.699999809265137],
   [6.989999771118164],
   [460.7099914550781],
   [68.0],
   [56.310001373291016],
   [30.520000457763672],
   [19.989999771118164],
   [40.22999954223633],
   [10.0],
   [9.989999771118164],
   [505.92999267578125],
   [33.369998931884766],
   [126.0],
   [385.9800109863281]]}]

In [32]:
import json
data = json.dumps(
    {
        "signature_name": "serving_default",
        "instances": [
            
            {k: v.numpy().tolist()[0] for k, v in batch.items()},
            {k: v.numpy().tolist()[0] for k, v in batch.items()}
        ]
    }
)
model_name = "creditcard"
url_sig = f"http://localhost:8501/v1/models/{model_name}:predict"

In [33]:
import requests
import json

headers = {"content-type": "application/json"}
json_response = requests.post(url_sig, data=data, headers=headers)
print(json.loads(json_response.text))

{'predictions': [[-9.40117264], [-9.40117264]]}


In [34]:
import json
data = json.dumps(
    {
        "signature_name": "serving_default",
        "inputs": {k: v.numpy().tolist() for k, v in batch.items()},
    }
)
model_name = "creditcard"
url_sig = f"http://localhost:8501/v1/models/{model_name}:predict"

In [35]:
import requests
import json

headers = {"content-type": "application/json"}
json_response = requests.post(url_sig, data=data, headers=headers)
print(json.loads(json_response.text))

{'outputs': [[-9.40117264], [-8.42035294], [-8.5709095], [-9.22763252], [-7.27975225], [-8.96863], [-9.03657722], [-8.61154556], [-8.86451912], [-6.9543705], [-7.79198074], [-9.9882679], [-9.60381699], [-11.0595407], [-11.539463], [-10.3595982]]}


In [36]:
-

SyntaxError: invalid syntax (476313318.py, line 1)

In [None]:
!pwd

In [None]:
docker run  -p 8501:8501 --name creditcard --xla_cpu_compilation_enabled --mount type=bind,source=/Users/stefruinard/Documents/personal/projects/202312_probabilistic_deep_learning/augmented_deep_learning/models/saved_model/creditcard,target=/models/creditcard -e MODEL_NAME=creditcard -t emacski/tensorflow-serving 

In [None]:
NAME = "flax-recommender-system"

from typing import Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow as tf
from flax.training.train_state import TrainState
from orbax.export import ExportManager, JaxModule, ServingConfig



# p = ip.Preprocessor()

def save(state, preprocessing_fn, output_dir, etr=None):
    # Construct a JaxModule where JAX->TF conversion happens.
    jax_module = JaxModule(state.params, state.apply_fn)
    # Export the JaxModule along with one or more serving configs.
    export_mgr = ExportManager(
        jax_module,
        [
            ServingConfig(
                "serving_default",
                tf_preprocessor=preprocessing_fn,
                # tf_postprocessor=exampe1_postprocess
                extra_trackable_resources=etr,
            ),
        ],
    )
    export_mgr.save(output_dir)


def load(output_dir):
    loaded_model = tf.saved_model.load(output_dir)
    return loaded_model


def inference(loaded_model, inputs):
    loaded_model_outputs = loaded_model(inputs)
    return loaded_model_outputs


In [None]:
save(state, p.serving_fn, "./models/flax_recommender_system")

In [None]:
loaded_model = load("./models/flax_recommender_system")

In [None]:
example = {"inputs":{
            "V1": tf.constant([1], dtype=tf.float32),
            "V2": tf.constant([1], dtype=tf.float32),
            "V3": tf.constant([1], dtype=tf.float32),
            "V4": tf.constant([1], dtype=tf.float32),
            "V5": tf.constant([1], dtype=tf.float32),
            "V6": tf.constant([1], dtype=tf.float32),
            "V7": tf.constant([1], dtype=tf.float32),
            "V8": tf.constant([1], dtype=tf.float32),
            "V9": tf.constant([1], dtype=tf.float32),
            "V10": tf.constant([1], dtype=tf.float32),
            "V11": tf.constant([1], dtype=tf.float32),
            "V12": tf.constant([1], dtype=tf.float32),
            "V13": tf.constant([1], dtype=tf.float32),
            "V14": tf.constant([1], dtype=tf.float32),
            "V15": tf.constant([1], dtype=tf.float32),
            "V16": tf.constant([1], dtype=tf.float32),
            "V17": tf.constant([1], dtype=tf.float32),
            "V18": tf.constant([1], dtype=tf.float32),
            "V19": tf.constant([1], dtype=tf.float32),
            "V20": tf.constant([1], dtype=tf.float32),
            "V21": tf.constant([1], dtype=tf.float32),
            "V22": tf.constant([1], dtype=tf.float32),
            "V23": tf.constant([1], dtype=tf.float32),
            "V24": tf.constant([1], dtype=tf.float32),
            "V25": tf.constant([1], dtype=tf.float32),
            "V26": tf.constant([1], dtype=tf.float32),
            "V27": tf.constant([1], dtype=tf.float32),
            "V28": tf.constant([1], dtype=tf.float32),
            "Amount": tf.constant([1], dtype=tf.float32),
        }}


# create batch of examples
batch_size = 64
n_features = 29

batch_inputs = {
    "V1": tf.random.normal((batch_size, )),
    "V2": tf.random.normal((batch_size, )),
    "V3": tf.random.normal((batch_size, )),
    "V4": tf.random.normal((batch_size, )),
    "V5": tf.random.normal((batch_size, )),
    "V6": tf.random.normal((batch_size, )),
    "V7": tf.random.normal((batch_size, )),
    "V8": tf.random.normal((batch_size, )),
    "V9": tf.random.normal((batch_size, )),
    "V10": tf.random.normal((batch_size,)),
    "V11": tf.random.normal((batch_size,)),
    "V12": tf.random.normal((batch_size,)),
    "V13": tf.random.normal((batch_size,)),
    "V14": tf.random.normal((batch_size,)),
    "V15": tf.random.normal((batch_size,)),
    "V16": tf.random.normal((batch_size,)),
    "V17": tf.random.normal((batch_size,)),
    "V18": tf.random.normal((batch_size,)),
    "V19": tf.random.normal((batch_size,)),
    "V20": tf.random.normal((batch_size,)),
    "V21": tf.random.normal((batch_size,)),
    "V22": tf.random.normal((batch_size,)),
    "V23": tf.random.normal((batch_size,)),
    "V24": tf.random.normal((batch_size,)),
    "V25": tf.random.normal((batch_size,)),
    "V26": tf.random.normal((batch_size,)),
    "V27": tf.random.normal((batch_size,)),
    "V28": tf.random.normal((batch_size,)),
    # amount is random int
    "Amount": tf.random.uniform((batch_size,), minval=0, maxval=1000, dtype=tf.float32),
}

In [None]:
loaded_model.signatures["serving_default"](**example["inputs"])

In [None]:
ds = tf.data.Dataset.range(10).batch(2).take(3).repeat(2).as_numpy_iterator()
for _ in range(2):
    
    for x in ds:
        print(x)

In [None]:
# Copyright 2023 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Export a MNIST JAX model.

python flax_mnist_main.py --output_dir=<OUTPUT_DIR>
"""
from absl import app
from absl import flags
from absl import logging

import flax.linen as nn
import jax
import jax.numpy as jnp
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf


batch_size = None
current_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
model_path = f"./models/saved_model/creditcard/{current_time}"



class JaxMnist(nn.Module):
  """Mnist model."""

  @nn.compact
  def __call__(self, x):
    """See base class."""
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x


def export_mnist() -> None:
  """Exports a Mnist JAX model."""

  # Initialize the model.
  model = JaxMnist()
  params = model.init(jax.random.PRNGKey(123), jnp.ones((1, 28, 28, 1)))

  # Wrap the model params and function into a JaxModule.
  jax_module = JaxModule(
      params,
      model.apply,
      trainable=False,
    jit_compile=False,
       jax2tf_kwargs={
          "enable_xla": False
      },
      input_polymorphic_shape='(b, ...)' if batch_size is None else None)

  # Specify the serving configuration and export the model.
  em = ExportManager(jax_module, [
      ServingConfig(
          'serving_default',
          input_signature=[
              tf.TensorSpec([batch_size, 28, 28, 1], tf.float32, name='inputs')
          ],
          tf_postprocessor=lambda x: dict(outputs=x)),
  ])
  # Save the model.
  logging.info('Exporting the model to %s.', model_path)
  em.save(model_path)

  # Test that the saved model could be loaded and run.
  logging.info('Loading the model from %s.', model_path)
  loaded = tf.saved_model.load(model_path)
  logging.info('Loaded the model from %s.', model_path)

  inputs = jnp.ones((batch_size or 1, 28, 28, 1))
  savedmodel_output = loaded.signatures['serving_default'](inputs=inputs)
  jax_output = model.apply(params, inputs)

  logging.info('Savemodel output: %s, JAX output: %s', savedmodel_output,
               jax_output)



In [None]:
export_mnist()

In [None]:

url = "http://localhost:8501/v1/models/creditcard:predict"
image = jnp.ones((1, 28, 28, 1))
data = json.dumps({"signature_name": "serving_default", "instances": image.tolist()})
data

In [None]:
requests.post(url, data=data).json()

In [None]:
docker run -t --rm -p 8501:8501 --mount type=bind,source=/tmp/model_name/,target=/models/model_name/ -e MODEL_NAME=model_name emacski/tensorflow-serving:latest-linux_arm64