In [None]:
!pip install tensorflow_datasets

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np

import mlflow
import mlflow.keras
from mlflow.models.signature import infer_signature
import plotly.express as px

In [None]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

In [None]:
def normalize_img(image, label):
    """Normalizes images: `uint8` -> `float32`."""
    return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

In [None]:
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

In [None]:
mlflow.set_tracking_uri("http://localhost:4040")
mlflow.set_experiment("tf-mnist-sample")

In [None]:
mlflow.start_run()

In [None]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128,activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)

In [None]:
for el in ds_test.take(1):
    pred = np.argmax(model.predict(el[0]), axis=1)
    img = np.array(el[0][:, :,:,0])
    print(el[1])
    print(pred)

In [None]:
px.imshow(img[1])

In [None]:
model_input = {"images": img[0:2]}

In [None]:
class mnistModel(mlflow.pyfunc.PythonModel):
    def __init__(self, config, weights):
        self.config = config
        self.weights = weights
        self.init_model()

    def init_model(self):
        self.model = tf.keras.Sequential.from_config(self.config)
        self.model.set_weights(self.weights)

    def predict(self, data):
        if self.model is None:
            self.init_model()
        
        input_img = np.expand_dims(np.array(data["images"]), axis=[-1])
        values = self.model.predict(input_img)
        
        return {"numbers" : np.argmax(values,axis=1)}
    
pymodel = mnistModel(model.get_config(), model.get_weights())

In [None]:
pymodel.predict(model_input)

In [None]:
signature = infer_signature(model_input, pymodel.predict(model_input))

In [None]:
pymodel.model = None
mlflow.sklearn.log_model(pymodel, artifact_path="mnistModel",signature=signature, input_example=model_input, registered_model_name="tf-mnist-sample")

In [None]:
mlflow.end_run()