In [4]:
import tensorflow as tf
from tensorflow.keras import layers

In [None]:
class CustomActivation(layers.Layer):
    
    def __init__(self, alpha_initializer='ones', beta_initializer='zeros', **kwargs):
        super(CustomActivation, self).__init__(**kwargs)
        self.alpha_initializer = tf.keras.initializers.get(alpha_initializer)
        self.beta_initializer = tf.keras.initializers.get(beta_initializer)

    def build(self, input_shape):
        self.alpha = self.add_weight(
            name='alpha',
            shape=(1,),
            initializer=self.alpha_initializer,
            trainable=True
        )
        self.beta = self.add_weight(
            name='beta',
            shape=(1,),
            initializer=self.beta_initializer,
            trainable=True
        )
        super(CustomActivation, self).build(input_shape)

    def call(self, inputs):
        return tf.where(tf.greater_equal(inputs, 0), inputs, self.alpha * inputs + self.beta)


In [3]:
inputs = tf.constant([[-1.0, 0.0, 1.0], [2.0, -3.0, 4.0]], dtype=tf.float32)  
custom_activation = CustomActivation(alpha_initializer='ones', beta_initializer='zeros')
outputs = custom_activation(inputs)
print(outputs.numpy())

[[-1.  0.  1.]
 [ 2. -3.  4.]]
