In [None]:
import tensorflow as tf
import numpy as np
from typing import Union

class GlucoseSpecificLoss(tf.keras.losses.Loss):
    """
    TensorFlow/Keras implementation of SE and gMSE loss functions from:
    "A Glucose-Specific Metric to Assess Predictors and Identify Models"
    
    Parameters follow the paper's specifications:
    - T_L = 85 mg/dL, T_H = 155 mg/dL
    - alpha_L = 1.5, alpha_H = 1.0
    - beta_L = 30, beta_H = 25
    - gamma_L = 30, gamma_H = 25
    """
    
    def __init__(self, name="glucose_specific_loss", **kwargs):
        super().__init__(name=name, **kwargs)
        
        # Threshold parameters
        self.T_L = tf.constant(85.0, dtype=tf.float32)  # mg/dL
        self.T_H = tf.constant(155.0, dtype=tf.float32)  # mg/dL
        
        # Penalty amplitude parameters
        self.alpha_L = tf.constant(1.5, dtype=tf.float32)
        self.alpha_H = tf.constant(1.0, dtype=tf.float32)
        
        # Transition parameters for glucose thresholds
        self.beta_L = tf.constant(30.0, dtype=tf.float32)
        self.beta_H = tf.constant(25.0, dtype=tf.float32)
        
        # Transition parameters for estimation errors
        self.gamma_L = tf.constant(30.0, dtype=tf.float32)
        self.gamma_H = tf.constant(25.0, dtype=tf.float32)

    def _sigmoid_like(self, x: tf.Tensor, a: tf.Tensor, epsilon: tf.Tensor, 
                     direction: str = "right") -> tf.Tensor:
        """
        Sigmoid-like transition function as defined in Appendix A
        
        Args:
            x: Input values
            a: Threshold value
            epsilon: Transition duration
            direction: "right" for  (increasing) or "left" for σ̄ (decreasing)
        """
        if direction == "right":
            # σ_{x≥a,ε}(x)
            condition1 = x <= a
            condition2 = (x > a) & (x <= a + epsilon/2)
            condition3 = (x > a + epsilon/2) & (x <= a + epsilon)
            condition4 = x > a + epsilon
            
            xi = (2/epsilon) * (x - a - epsilon/2)
            
            # Initialize result tensor
            result = tf.zeros_like(x)
            
            # Apply conditions using tf.where
            result = tf.where(condition1, 0.0, result)
            result = tf.where(condition2, 
                             -0.5 * tf.pow(xi, 4) - tf.pow(xi, 3) + xi + 0.5, 
                             result)
            result = tf.where(condition3, 
                             0.5 * tf.pow(xi, 4) - tf.pow(xi, 3) + xi + 0.5, 
                             result)
            result = tf.where(condition4, 1.0, result)
            
        else:
            # σ̄_{x≤a,ε}(x)
            condition1 = x <= a - epsilon
            condition2 = (x > a - epsilon) & (x <= a - epsilon/2)
            condition3 = (x > a - epsilon/2) & (x <= a)
            condition4 = x > a
            
            xi_bar = -(2/epsilon) * (x - a + epsilon/2)
            
            # Initialize result tensor
            result = tf.zeros_like(x)
            
            # Apply conditions using tf.where
            result = tf.where(condition1, 1.0, result)
            result = tf.where(condition2, 
                             0.5 * tf.pow(xi_bar, 4) - tf.pow(xi_bar, 3) + xi_bar + 0.5, 
                             result)
            result = tf.where(condition3, 
                             -0.5 * tf.pow(xi_bar, 4) - tf.pow(xi_bar, 3) + xi_bar + 0.5, 
                             result)
            result = tf.where(condition4, 0.0, result)
        
        return result

    def penalty_function(self, g: tf.Tensor, g_hat: tf.Tensor) -> tf.Tensor:
        """
        Penalty function Pen(g, g_hat) as defined in equation (2)
        
        Args:
            g: True glucose values
            g_hat: Predicted glucose values
            
        Returns:
            Penalty values for each (g, g_hat) pair
        """
        # Term for hypoglycemia overestimation (Zone D1)
        term_L = self.alpha_L * self._sigmoid_like(g, self.T_L, self.beta_L, "left") * \
                 self._sigmoid_like(g_hat, g, self.gamma_L, "right")
        
        # Term for hyperglycemia underestimation (Zone D2)
        term_H = self.alpha_H * self._sigmoid_like(g, self.T_H, self.beta_H, "right") * \
                 self._sigmoid_like(g_hat, g, self.gamma_H, "left")
        
        # Total penalty
        penalty = 1.0 + term_L + term_H
        
        return penalty

    def se(self, g: tf.Tensor, g_hat: tf.Tensor) -> tf.Tensor:
        """
        Squared Error (SE) loss
        
        Args:
            g: True glucose values
            g_hat: Predicted glucose values
            
        Returns:
            Squared error values
        """
        return tf.square(g - g_hat)

    def gse(self, g: tf.Tensor, g_hat: tf.Tensor) -> tf.Tensor:
        """
        Glucose-Specific Squared Error (gSE) loss
        
        Args:
            g: True glucose values
            g_hat: Predicted glucose values
            
        Returns:
            Glucose-specific squared error values
        """
        se_loss = self.se(g, g_hat)
        penalty = self.penalty_function(g, g_hat)
        
        return se_loss * penalty

    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
        """
        Compute gMSE loss for Keras
        
        Args:
            y_true: True glucose values
            y_pred: Predicted glucose values
            
        Returns:
            gMSE loss value
        """
        gse_loss = self.gse(y_true, y_pred)
        return tf.reduce_mean(gse_loss)

    def mse(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
        """
        Mean Squared Error (MSE)
        
        Args:
            y_true: True glucose values
            y_pred: Predicted glucose values
            
        Returns:
            Mean squared error
        """
        return tf.reduce_mean(tf.square(y_true - y_pred))

    def gmse(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
        """
        Glucose-Specific Mean Squared Error (gMSE)
        
        Args:
            y_true: True glucose values
            y_pred: Predicted glucose values
            
        Returns:
            Glucose-specific mean squared error
        """
        return self.call(y_true, y_pred)

# Functional API version for direct use
def clinical_penalty_loss(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
    """
    Functional version of penalty function for Keras
    
    Args:
        y_true: True glucose values
        y_pred: Predicted glucose values
        
    Returns:
        Penalty values
    """
    loss_fn = GlucoseSpecificLoss()
    return loss_fn.penalty_function(y_true, y_pred)

def gmse_loss(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
    """
    Functional version of gMSE loss for Keras
    
    Args:
        y_true: True glucose values
        y_pred: Predicted glucose values
        
    Returns:
        gMSE loss value
    """
    loss_fn = GlucoseSpecificLoss()
    return loss_fn.gmse(y_true, y_pred)


def se_loss(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
    """
    Functional version of SE loss for Keras
    
    Args:
        y_true: True glucose values
        y_pred: Predicted glucose values
        
    Returns:
        SE loss value
    """
    loss_fn = GlucoseSpecificLoss()
    return loss_fn.se(y_true, y_pred)


# Example usage with Keras model
def create_sample_model(input_shape):
    """Create a simple model for glucose prediction"""
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(64, activation='relu', input_shape=input_shape),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(32, activation='relu'),
        tf.keras.layers.Dense(1, activation='linear')  # Glucose prediction
    ])
    return model


# Example usage
if __name__ == "__main__":
    # Create sample data
    np.random.seed(42)
    n_samples = 1000
    n_features = 5
    
    # Simulated features (could be past glucose values, insulin doses, etc.)
    X = np.random.normal(0, 1, (n_samples, n_features))
    # Simulated glucose values (50-300 mg/dL range)
    y = 100 + 50 * np.sin(np.arange(n_samples) * 0.1) + np.random.normal(0, 15, n_samples)
    y = np.clip(y, 50, 300)  # Clip to physiological range
    
    # Convert to TensorFlow tensors
    X_tf = tf.constant(X, dtype=tf.float32)
    y_tf = tf.constant(y, dtype=tf.float32)
    
    # Create model
    model = create_sample_model((n_features,))
    
    # Compile with gMSE loss
    model.compile(
        optimizer='adam',
        loss=gmse_loss,  # Using functional version
        metrics=['mse']
    )
    
    print("Model compiled with gMSE loss")
    print(model.summary())
    
    # Test the loss functions directly
    print("\nTesting loss functions:")
    
    # Create some test scenarios
    g_true_test = tf.constant([50.0, 120.0, 250.0], dtype=tf.float32)  # hypo, eu, hyper
    g_pred_test = tf.constant([80.0, 150.0, 180.0], dtype=tf.float32)  # over, error, under
    
    loss_fn = GlucoseSpecificLoss()
    
    # Calculate different losses
    mse_vals = loss_fn.mse(g_true_test, g_pred_test)
    gmse_vals = loss_fn.gmse(g_true_test, g_pred_test)
    penalty_vals = loss_fn.penalty_function(g_true_test, g_pred_test)
    
    print(f"True values: {g_true_test.numpy()}")
    print(f"Pred values: {g_pred_test.numpy()}")
    print(f"Penalties: {penalty_vals.numpy()}")
    print(f"MSE: {mse_vals.numpy():.2f}")
    print(f"gMSE: {gmse_vals.numpy():.2f}")
    print(f"Ratio (gMSE/MSE): {gmse_vals.numpy()/mse_vals.numpy():.2f}")
    
    # Demonstrate training (optional)
    print("\nDemonstrating training...")
    
    # Split data
    split_idx = int(0.8 * n_samples)
    X_train, X_test = X[:split_idx], X[split_idx:]
    y_train, y_test = y[:split_idx], y[split_idx:]
    
    # Train with gMSE
    history = model.fit(
        X_train, y_train,
        epochs=10,
        batch_size=32,
        validation_data=(X_test, y_test),
        verbose=1
    )
    
    # Compare with MSE model
    print("\nComparing with MSE model:")
    model_mse = create_sample_model((n_features,))
    model_mse.compile(
        optimizer='adam',
        loss='mse',
        metrics=['mse']
    )
    
    history_mse = model_mse.fit(
        X_train, y_train,
        epochs=10,
        batch_size=32,
        validation_data=(X_test, y_test),
        verbose=0
    )
    
    print("Training completed!")
    print(f"Final gMSE validation loss: {history.history['val_loss'][-1]:.2f}")
    print(f"Final MSE validation loss: {history_mse.history['val_loss'][-1]:.2f}")

Model compiled with gMSE loss
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (None, 64)                384       
                                                                 
 dropout (Dropout)           (None, 64)                0         
                                                                 
 dense_1 (Dense)             (None, 32)                2080      
                                                                 
 dense_2 (Dense)             (None, 1)                 33        
                                                                 
Total params: 2,497
Trainable params: 2,497
Non-trainable params: 0
_________________________________________________________________
None

Testing loss functions:
True values: [ 50. 120. 250.]
Pred values: [ 80. 150. 180.]
Penalties: [2.5 1.  2. ]
MSE: 2233.33
gMSE: 4316.67
Ratio (gMSE/MS