In [5]:
import tensorflow as tf

class HyperXActivationTF(tf.keras.layers.Layer):
    def __init__(self, k=1.0, **kwargs):
        super(HyperXActivationTF, self).__init__(**kwargs)
        self.k = k

    def call(self, inputs):
        # Replace NaN values with 0
        inputs = tf.where(tf.math.is_nan(inputs), tf.zeros_like(inputs), inputs)
        # Replace Inf values with a large number (positive or negative)
        inputs = tf.where(tf.math.is_inf(inputs), tf.sign(inputs) * 1e10, inputs)
        # Apply the activation function
        result = inputs * tf.tanh(self.k * inputs)
        # Ensure no NaN values in the result
        return tf.where(tf.math.is_nan(result), tf.zeros_like(result), result)


In [6]:
import torch
import torch.nn as nn

class HyperXActivationTorch(nn.Module):
    def __init__(self, k=1.0):
        super(HyperXActivationTorch, self).__init__()
        self.k = k

    def forward(self, x):
        # Replace NaN values with 0
        x = torch.where(torch.isnan(x), torch.zeros_like(x), x)
        # Replace Inf values with a large number (positive or negative)
        x = torch.where(torch.isinf(x), torch.sign(x) * 1e10, x)
        # Apply the activation function
        result = x * torch.tanh(self.k * x)
        # Ensure no NaN values in the result
        return torch.where(torch.isnan(result), torch.zeros_like(result), result)


In [7]:
import numpy as np

# TensorFlow Test
def test_tensorflow_hyperx():
    activation = HyperXActivationTF(k=1.0)
    x = tf.constant([[float("inf"), -float("inf")], [0.0, float("nan")]])
    result = activation(x)
    print("TensorFlow Input:\n", x.numpy())
    print("TensorFlow Output:\n", result.numpy())
    assert not tf.reduce_any(tf.math.is_nan(result)), "NaN values in TensorFlow output!"
    assert not tf.reduce_any(tf.math.is_inf(result)), "Inf values in TensorFlow output!"

# PyTorch Test
def test_pytorch_hyperx():
    activation = HyperXActivationTorch(k=1.0)
    x = torch.tensor([[float("inf"), -float("inf")], [0.0, float("nan")]])
    result = activation(x)
    print("PyTorch Input:\n", x)
    print("PyTorch Output:\n", result)
    assert not torch.any(torch.isnan(result)), "NaN values in PyTorch output!"
    assert not torch.any(torch.isinf(result)), "Inf values in PyTorch output!"

# Run Tests
test_tensorflow_hyperx()
test_pytorch_hyperx()


TensorFlow Input:
 [[ inf -inf]
 [  0.  nan]]
TensorFlow Output:
 [[1.e+10 1.e+10]
 [0.e+00 0.e+00]]
PyTorch Input:
 tensor([[inf, -inf],
        [0., nan]])
PyTorch Output:
 tensor([[1.0000e+10, 1.0000e+10],
        [0.0000e+00, 0.0000e+00]])
