In [18]:
import tensorflow as tf
import numpy as np
from tensorflow import keras as ks, float32

class DomainNormalization(ks.layers.Layer):
    # Normalizes a feature map according to the procedure presented by
    # Zhang et.al. in "Domain-invariant stereo matching networks".
    # TODO: as a function
    def __init__(self, regularizer_weight=0.0004, *args, **kwargs):
        super(DomainNormalization, self).__init__(*args, **kwargs)
        self.regularizer_weight = regularizer_weight

    def build(self, input_shape):
        channels = input_shape[-1]

        self.scale = self.add_weight(name="scale", shape=[1, 1, 1, channels],
                                     dtype='float32',
                                     initializer=tf.ones_initializer(),
                                     trainable=True)
        self.bias = self.add_weight(name="bias", shape=[1, 1, 1, channels],
                                    dtype='float32',
                                    initializer=tf.zeros_initializer(),
                                    trainable=True)

        # Add regularization loss on the scale factor
        regularizer = tf.keras.regularizers.L2(self.regularizer_weight)
        self.add_loss(regularizer(self.scale))

    def call(self, f_map):
        mean = tf.math.reduce_mean(f_map, axis=[1, 2], keepdims=True,
                                   name=None)
        var = tf.math.reduce_variance(f_map, axis=[1, 2], keepdims=True,
                                      name=None)
        normed = tf.math.l2_normalize((f_map - mean) / (var + 1e-12), axis=-1)
        print(self.scale.shape)
        print(self.bias.shape)
        return self.scale * normed + self.bias

    def get_config(self):
        config = super().get_config()
        config.update({
            "regularizer_weight": self.regularizer_weight,
        })
        return config

regularizer_weight=0.0004
x_in= ks.Input(shape=(384, 384, 16,), dtype=float32)  # data image
x = DomainNormalization(regularizer_weight=regularizer_weight)(x_in)
model = ks.Model(inputs=[x_in], outputs=[x])
model.save(
    "DomainNormalization.h5",
    save_format="h5",
)
# model.summary()

def domain_normalization_as_a_function(f_map, regularizer_weight):
    # Normalizes a feature map according to the procedure presented by
    # Zhang et.al. in "Domain-invariant stereo matching networks".
    # Also see DomainNormalization in former implem of M4Depth.
    # ks.layers.Layer.add_loss(regularizer(scale))

    mean = tf.math.reduce_mean(f_map, axis=[1, 2], keepdims=True, name=None)
    var = tf.math.reduce_variance(f_map, axis=[1, 2], keepdims=True, name=None)
    normed = tf.math.l2_normalize((f_map - mean) / (var + 1e-12), axis=-1)
    print('normed', normed.shape)
    channels = f_map.shape[-1]
    to_ret =  ks.layers.Dense([1, 1, 1, channels],
                              use_bias=True,
                              kernel_initializer=tf.ones_initializer(),
                              bias_initializer=tf.zeros_initializer(),
                              kernel_regularizer=tf.keras.regularizers.L2(regularizer_weight))\
        (normed)
    return to_ret

print(x_in.shape)
x2 = domain_normalization_as_a_function(x_in, regularizer_weight)
print(x2.shape)
model2 = ks.Model(inputs=[x_in], outputs=[x2])
model2.save(
    "domain_normalization_as_a_function.h5",
    save_format="h5",
)

(1, 1, 1, 16)
(1, 1, 1, 16)
(None, 384, 384, 16)
normed (None, 384, 384, 16)


TypeError: int() argument must be a string, a bytes-like object or a number, not 'list'