In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import time

import jax
import tensorflow as tf
from clu import metric_writers

import input_pipeline as ip
import metrics
import serving
import train
from configs import default as cfgs

In [None]:
# Preprocessing recipe

class SimplePreprocessor(tf.Module):
    
    def __init__(self):
        self.norm = tf.keras.layers.Normalization()

    def fit(self, data):
        self.norm.adapt(data)

    @tf.function(input_signature=[tf.TensorSpec([None, 1], tf.float32)])
    def train_fn(self, examples):
        return {
            "normalized_features": self.norm(examples)
        }

    @tf.function(input_signature=[tf.TensorSpec([None, 1], tf.float32)])
    def serving_fn(self, examples):
        return {
            "normalized_features": self.norm(examples)
        }


p = SimplePreprocessor()
ds = tf.data.Dataset.range(100).batch(5, drop_remainder=True).map(lambda x: tf.cast(tf.reshape(x, [-1, 1]), tf.float32))
for x in ds.take(1):
    print(x.shape)
p.fit(ds)

tf.saved_model.save(p, "./artifacts/models/simple_preprocessor", signatures={"serving_default": p.serving_fn, "train_default": p.train_fn})
loaded = tf.saved_model.load("./artifacts/models/simple_preprocessor")
loaded.signatures["serving_default"]
norm_features = loaded.signatures["serving_default"](tf.constant([[49.5]]))["normalized_features"].numpy()[0][0]
print(f"norm_features: {norm_features:0.4f}")

# full pipeline

In [95]:
cfg = cfgs.get_config()
p = ip.Preprocessor()
train_ds, val_ds = ip.get_datasets(
    preprocessor=p,
    train_src=cfg.train_src,
    val_src=cfg.val_src
)

x, y = next(iter(train_ds))
print(
    f"x: {x.shape}\n"
    f"y: {y.shape}\n"
)

In [None]:

rng = jax.random.PRNGKey(cfg.seed)
metric_collection = metrics.MetricCollection.empty()

writer = metric_writers.create_default_writer(cfg.logdir)
hooks = [
    # Outputs progress via metric writer (in this case logs & TensorBoard).
    metrics.ReportProgress(
        num_train_steps=cfg.n_steps_per_epoch * cfg.n_epochs, every_steps=cfg.n_steps_per_epoch, writer=writer
    ),
    metrics.Profile(logdir=cfg.logdir),
    metrics.TensorboardCallback(
        callback_fn=metrics.TensorboardCallback.write_metrics,
        every_steps=cfg.n_steps_per_epoch,
    ),
]

mngr = train.create_manager(cfg.checkpoint_dir)
state = train.restore_or_create_state(mngr, rng, x.shape)

global_step = 0
state, train_metrics, eval_metrics, global_step = train.run_loop(
    state=state,
    train_ds=train_ds,
    val_ds=val_ds,
    n_train_steps=cfg.n_steps_per_epoch,
    n_eval_steps=cfg.n_steps_per_epoch // 10,
    global_step=global_step,
    hooks=hooks,
    writer=writer,
    mngr=mngr,
    metric_collection=metric_collection
)

In [None]:
for _ in range(cfg.n_epochs):
    state, train_metrics, eval_metrics, global_step = train.run_loop(
        state=state,
        train_ds=train_ds,
        val_ds=val_ds,
        n_train_steps=cfg.n_steps_per_epoch,
        n_eval_steps=cfg.n_steps_per_epoch // 10,
        global_step=global_step,
        hooks=hooks,
        writer=writer,
        mngr=mngr,
        metric_collection=metric_collection
    )
    
    
    print(f"train_metrics: {train_metrics.compute()}")
    print(f"eval_metrics: {eval_metrics.compute()}")

    
best_model_state = train.restore_or_create_state(mngr, rng, x.shape)
train.to_saved_model(
    best_model_state, p.serving_fn, cfg.model_serving_dir, etr={"preprocessor": p.norm}
)
    
    

In [None]:
train.train_and_eval(cfgs.get_config())

In [99]:

serving.run_tf_serving(os.path.join(cfg.model_serving_dir, cfg.model_name))
# give some time for the server to start
time.sleep(5)



In [None]:
preds = serving.predict_with_docker(
    model_name=cfg.model_name,
    batch=serving.raw_batch
)
print(preds) 
    