In [5]:
import tensorflow as tf

class Weighted_Linear(tf.keras.losses.Loss):
    def __init__(self, alpha=1.0, name="Weighted_Linear",
                 reduction=tf.keras.losses.Reduction.AUTO):
        super().__init__(name=name, reduction=reduction)
        self.alpha = alpha

    def call(self, y_true, y_pred):
        y_true_rain = y_true[..., 0:1] 
        y_pred_rain = y_pred[..., 0:1]  

        error = y_true_rain - y_pred_rain  
        sample_weight = self.alpha * y_true_rain + 1  

        weighted_error = sample_weight * tf.square(error)  
        return tf.reduce_mean(weighted_error, axis=0)

    def get_config(self):
        config = super(Weighted_Linear, self).get_config()
        config.update({"alpha": self.alpha})
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

