In [1]:
import tensorflow as tf
from tensorflow import keras 
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist

In [2]:
(x_train,y_train), (x_test, y_test) = mnist.load_data()

In [3]:
x_train = x_train.reshape(-1, 28*28).astype("float32") / 255.0
x_test = x_test.reshape(-1, 28*28).astype("float32") / 255.0

In [14]:
def get_uncompiled_model():
    inputs = keras.Input(shape=(784,), name="digits")
    x = layers.Dense(64, activation="relu", name="dense_1")(inputs)
    x = layers.Dense(64, activation="relu", name="dense_2")(x)
    outputs = layers.Dense(10, activation="softmax", name="predictions")(x)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model


def get_compiled_model():
    model = get_uncompiled_model()
    model.compile(
        optimizer="rmsprop",
        loss="sparse_categorical_crossentropy",
        metrics=["sparse_categorical_accuracy"],
    )
    return model

In [15]:

    model = get_compiled_model()
    model.fit(x_train,y_train, batch_size=32, epochs=4,verbose=1)
    model.evaluate(x_test,y_test, batch_size=32,verbose=1)
 

Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


[0.10688374191522598, 0.9696999788284302]

In [18]:
model.compile(
    optimizer=tf.keras.optimizers.legacy.RMSprop(1e-3),
    loss={
        "dense_1": keras.losses.MeanSquaredError(),
        "dense_2": keras.losses.CategoricalCrossentropy(),
    },
    metrics={
        "dense_1": [
            keras.metrics.MeanAbsolutePercentageError(),
            keras.metrics.MeanAbsoluteError(),
        ],
        "dense_2": [keras.metrics.CategoricalAccuracy()],
    },
)

In [16]:
def custom_mean_squared_error(y_true, y_pred):
    return tf.math.reduce_mean(tf.square(y_true - y_pred), axis=-1)


model = get_uncompiled_model()
model.compile(optimizer=keras.optimizers.Adam(), loss=custom_mean_squared_error)

# We need to one-hot encode the labels to use MSE
y_train_one_hot = tf.one_hot(y_train, depth=10)
model.fit(x_train, y_train_one_hot, batch_size=64, epochs=1)





<keras.callbacks.History at 0x1667addd0>