In [2]:
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist

# To Avoid GPU errors
physical_devices = tf.config.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_devices[0], True)

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, 28 * 28).astype("float32") / 255.0
x_test = x_test.reshape(-1, 28 * 28).astype("float32") / 255.0

In [3]:
model1 = keras.Sequential([layers.Dense(64, activation="relu"), layers.Dense(10)])

inputs = keras.Input(784)
x = layers.Dense(64, activation="relu")(inputs)
outputs = layers.Dense(10)(x)
model2 = keras.Model(inputs=inputs, outputs=outputs)

In [4]:
class MyModel(keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.dense1 = layers.Dense(64, activation="relu")
        self.dense2 = layers.Dense(10)

    def call(self, input_tensor):
        x = tf.nn.relu(self.dense1(input_tensor))
        return self.dense2(x)


# SavedModel format or HDF5 format
model3 = MyModel()
# model = keras.models.load_model('saved_model/')
# model.load_weights('checkpoint_folder/')


In [9]:
model = model1
model.load_weights('saved_model/')

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7f3a6b00fac0>

In [10]:
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(),
    metrics=["accuracy"],
)

In [None]:
model = keras.models.load_model('compl_model/')

In [11]:
model.fit(x_train, y_train, batch_size=32, epochs=2, verbose=2)
model.evaluate(x_test, y_test, batch_size=32, verbose=2)

Epoch 1/2
1875/1875 - 8s - loss: 0.1070 - accuracy: 0.9675 - 8s/epoch - 4ms/step
Epoch 2/2
1875/1875 - 7s - loss: 0.0827 - accuracy: 0.9751 - 7s/epoch - 4ms/step
313/313 - 1s - loss: 0.0991 - accuracy: 0.9696 - 638ms/epoch - 2ms/step


[0.09914275258779526, 0.9696000218391418]

In [12]:
model.save_weights('saved_model/')
model.save('compl_model/')

INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(784, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f3a24b35660>, 139888758548608), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(784, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f3a24b35660>, 139888758548608), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f3a24b35c60>, 139888291853872), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f3a24b35c60>, 139888291853872), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 10), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f3a24b88cd0>, 139888292179488), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 10), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f3a24b88cd0>, 139888292179488), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(10,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f3a24b89c60>, 139888696751280), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(10,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f3a24b89c60>, 139888696751280), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(784, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f3a24b35660>, 139888758548608), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(784, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f3a24b35660>, 139888758548608), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f3a24b35c60>, 139888291853872), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f3a24b35c60>, 139888291853872), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 10), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f3a24b88cd0>, 139888292179488), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 10), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f3a24b88cd0>, 139888292179488), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(10,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f3a24b89c60>, 139888696751280), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(10,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f3a24b89c60>, 139888696751280), {}).


INFO:tensorflow:Assets written to: compl_model/assets


INFO:tensorflow:Assets written to: compl_model/assets
