In [1]:
import tensorflow as tf

In [2]:
class LayerCanCreateStates(tf.keras.layers.Layer):
    def __init__(self,
                 layer,
                 create_states,
                 **kwargs):
        super().__init__(**kwargs)
        self.layer = layer
        self.create_states = create_states
        
    def call(self, inputs):
        output = self.layer(inputs)
        
        if self.create_states:
            states = [tf.zeros(tf.shapes(output)), tf.ones(tf.shapes(output))]
        else:
            states = None
                
        return output, states
    
class SingleLayerPerceptron(tf.keras.Model):
    def __init__(self,
                 output_layer_with_states,
                 **kwargs):
        super().__init__(**kwargs)
        self.output_layer_with_states = output_layer_with_states
        
    def call(self, inputs):
        output, states = self.output_layer_with_states(inputs)
        return output

In [3]:
inputs = tf.random.normal([8, 1])
inputs

<tf.Tensor: shape=(8, 1), dtype=float32, numpy=
array([[ 0.05510553],
       [ 0.57184964],
       [-0.9665641 ],
       [-0.27213472],
       [ 1.1208997 ],
       [-2.276863  ],
       [ 0.80812895],
       [ 0.7111126 ]], dtype=float32)>

In [4]:
outputs = inputs*2 + 1
outputs

<tf.Tensor: shape=(8, 1), dtype=float32, numpy=
array([[ 1.110211  ],
       [ 2.1436992 ],
       [-0.93312824],
       [ 0.45573056],
       [ 3.2417994 ],
       [-3.5537262 ],
       [ 2.616258  ],
       [ 2.4222252 ]], dtype=float32)>

In [5]:
model = SingleLayerPerceptron(LayerCanCreateStates(tf.keras.layers.Dense(1), create_states=False))

model.predict(inputs)

array([[-0.06077731],
       [-0.63070774],
       [ 1.0660485 ],
       [ 0.3001444 ],
       [-1.2362691 ],
       [ 2.511211  ],
       [-0.8913063 ],
       [-0.78430444]], dtype=float32)

In [6]:
model.compile(loss='mse',
              optimizer=tf.keras.optimizers.Adam(learning_rate=0.002))

history = model.fit(inputs,
                    outputs,
                    epochs=5000,
                    verbose=0)

model.predict(inputs)

array([[ 1.11021   ],
       [ 2.1436942 ],
       [-0.93312156],
       [ 0.45573193],
       [ 3.24179   ],
       [-3.5537097 ],
       [ 2.616251  ],
       [ 2.422219  ]], dtype=float32)

In [21]:
class LayerCanTakeStates(tf.keras.layers.Layer):
    def __init__(self,
                 layer,
                 **kwargs):
        super().__init__(**kwargs)
        self.layer = layer
        
    def call(self, inputs, states=None):
        output = self.layer(inputs)
        
        return output, states
    
class IdentityLayer(tf.keras.layers.Layer):
    """A layer that pass the input. Used to apply an activity_regularizer (well divided by the batch size)."""
    
    def __init__(self,
                 activity_regularizer,
                 **kwargs):
        super().__init__(activity_regularizer=activity_regularizer, **kwargs)
        self.activity_regularizer = activity_regularizer
        
    def call(self, inputs):
        return inputs

class MultiLayerPerceptron(tf.keras.Model):
    def __init__(self,
                 layers_list,
                 **kwargs):
        super().__init__(**kwargs)
        self.layers_list = layers_list
        self.I_layer = IdentityLayer(tf.keras.regularizers.L1(0.))
        
    def call(self, inputs):
        inputs, states = self.layers_list[0](inputs)
        for a_layer in self.layers_list[1:]:
            inputs, states = a_layer(inputs, states)
        inputs = self.I_layer(inputs)
        return inputs

In [22]:
model = MultiLayerPerceptron([LayerCanCreateStates(tf.keras.layers.Dense(1), create_states=False),
                              LayerCanTakeStates(tf.keras.layers.Dense(1)),
                              LayerCanTakeStates(tf.keras.layers.Dense(1))
                             ])

model.predict(inputs)

array([[ 0.03944154],
       [ 0.40929893],
       [-0.6918141 ],
       [-0.19477926],
       [ 0.8022792 ],
       [-1.629655  ],
       [ 0.57841486],
       [ 0.5089758 ]], dtype=float32)

In [23]:
model.compile(loss='mse',
              optimizer=tf.keras.optimizers.Adam(learning_rate=0.002))

history = model.fit(inputs,
                    outputs,
                    epochs=5000,
                    verbose=0)

model.predict(inputs)

array([[ 1.110211  ],
       [ 2.1436992 ],
       [-0.93312824],
       [ 0.45573053],
       [ 3.2417994 ],
       [-3.5537257 ],
       [ 2.6162577 ],
       [ 2.422225  ]], dtype=float32)