Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions Algorithms.Tests/Numeric/ReluTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
using Algorithms.Numeric;
using NUnit.Framework;
using System;

namespace Algorithms.Tests.Numeric;

[TestFixture]
public static class ReluTests
{
// Tolerance for floating-point comparisons
private const double Tolerance = 1e-9;

// --- SCALAR TESTS (Relu.Compute(double)) ---

[TestCase(0.0, 0.0)]
[TestCase(1.0, 1.0)]
[TestCase(-1.0, 0.0)]
[TestCase(5.0, 5.0)]
[TestCase(-5.0, 0.0)]
public static void ReluFunction_Scalar_ReturnsCorrectValue(double input, double expected)
{
var result = Relu.Compute(input);
Assert.That(result, Is.EqualTo(expected).Within(Tolerance));
}

[Test]
public static void ReluFunction_Scalar_HandlesLimitsAndNaN()
{
// Positive infinity stays +Infinity, negative infinity becomes 0, NaN propagates
Assert.That(RelUComputePositiveInfinity(), Is.EqualTo(double.PositiveInfinity));
Assert.That(RelUComputeNegativeInfinity(), Is.EqualTo(0.0).Within(Tolerance));
Assert.That(RelUComputeNaN(), Is.NaN);

static double RelUComputePositiveInfinity() => Relu.Compute(double.PositiveInfinity);
static double RelUComputeNegativeInfinity() => Relu.Compute(double.NegativeInfinity);
static double RelUComputeNaN() => Relu.Compute(double.NaN);
}

[TestCase(100.0)]
[TestCase(0.0001)]
[TestCase(-100.0)]
public static void ReluFunction_Scalar_ResultIsNonNegative(double input)
{
var result = Relu.Compute(input);
Assert.That(result, Is.GreaterThanOrEqualTo(0.0));
}

// --- VECTOR TESTS (Relu.Compute(double[])) ---

[Test]
public static void ReluFunction_Vector_ReturnsCorrectValues()
{
var input = new[] { 0.0, 1.0, -2.0 };
var expected = new[] { 0.0, 1.0, 0.0 };

var result = Relu.Compute(input);

Assert.That(result, Is.EqualTo(expected).Within(Tolerance));
}

[Test]
public static void ReluFunction_Vector_HandlesLimitsAndNaN()
{
var input = new[] { double.PositiveInfinity, 0.0, double.NaN };
var result = Relu.Compute(input);

Assert.That(result.Length, Is.EqualTo(input.Length));
Assert.That(result[0], Is.EqualTo(double.PositiveInfinity));
Assert.That(result[1], Is.EqualTo(0.0).Within(Tolerance));
Assert.That(result[2], Is.NaN);
}

// --- EXCEPTION TESTS ---

[Test]
public static void ReluFunction_Vector_ThrowsOnNullInput()
{
double[]? input = null;
Assert.Throws<ArgumentNullException>(() => Relu.Compute(input!));
}

[Test]
public static void ReluFunction_Vector_ThrowsOnEmptyInput()
{
var input = Array.Empty<double>();
Assert.Throws<ArgumentException>(() => Relu.Compute(input));
}
}
46 changes: 46 additions & 0 deletions Algorithms/Numeric/Relu.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
namespace Algorithms.Numeric;

/// <summary>
/// Implementation of the Rectified Linear Unit (ReLU) function.
/// ReLU is defined as: ReLU(x) = max(0, x).
/// It is commonly used as an activation function in neural networks.
/// </summary>
public static class Relu
{
/// <summary>
/// Compute the Rectified Linear Unit (ReLU) for a single value.
/// </summary>
/// <param name="input">The input real number.</param>
/// <returns>The output real number (>= 0).</returns>
public static double Compute(double input)
{
return Math.Max(0.0, input);
}

/// <summary>
/// Compute the Rectified Linear Unit (ReLU) element-wise for a vector.
/// </summary>
/// <param name="input">The input vector of real numbers.</param>
/// <returns>The output vector where each element is max(0, input[i]).</returns>
public static double[] Compute(double[] input)
{
if (input is null)
{
throw new ArgumentNullException(nameof(input));
}

if (input.Length == 0)
{
throw new ArgumentException("Array is empty.");
}

var output = new double[input.Length];

for (var i = 0; i < input.Length; i++)
{
output[i] = Math.Max(0.0, input[i]);
}

return output;
}
}