In [1]:
import tensorflow as tf

In [2]:
ssim_loss_weight = 0.085
l1_loss_weight = 0.01
edge_loss_weight = 0.09
loss_metric = tf.keras.metrics.Mean(name="loss")


def calculate_loss(target, pred):
    # Edges
    dy_true, dx_true = tf.image.image_gradients(target)
    dy_pred, dx_pred = tf.image.image_gradients(pred)
    weights_x = tf.exp(tf.reduce_mean(tf.abs(dx_true)))
    weights_y = tf.exp(tf.reduce_mean(tf.abs(dy_true)))

    # Depth smoothness
    smoothness_x = dx_pred * weights_x
    smoothness_y = dy_pred * weights_y

    depth_smoothness_loss = tf.reduce_mean(abs(smoothness_x)) + tf.reduce_mean(
        abs(smoothness_y)
    )

    # Structural similarity (SSIM) index
    ssim_loss = tf.reduce_mean(
        1
        - tf.image.ssim(
            target, pred, max_val=WIDTH, filter_size=7, k1=0.01 ** 2, k2=0.03 ** 2
        )
    )
    # Point-wise depth
    l1_loss = tf.reduce_mean(tf.abs(target - pred))

    loss = (
        (ssim_loss_weight * ssim_loss)
        + (l1_loss_weight * l1_loss)
        + (edge_loss_weight * depth_smoothness_loss)
    )

    return loss


@property
def metrics():
    return [loss_metric]


def train_step(self, batch_data):
    input, target = batch_data
    with tf.GradientTape() as tape:
        pred = self(input, training=True)
        loss = self.calculate_loss(target, pred)

    gradients = tape.gradient(loss, self.trainable_variables)
    self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
    self.loss_metric.update_state(loss)
    return {
        "loss": self.loss_metric.result(),
    }


def test_step(self, batch_data):
    input, target = batch_data

    pred = self(input, training=False)
    loss = self.calculate_loss(target, pred)

    self.loss_metric.update_state(loss)
    return {
        "loss": self.loss_metric.result(),
    }

In [3]:
model = tf.keras.models.load_model('modelo.h5')

model.summary()

ValueError: Unknown loss function: calculate_loss. Please ensure this object is passed to the `custom_objects` argument. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.