diff --git a/src/main/java/com/thealgorithms/maths/ChebyshevIteration.java b/src/main/java/com/thealgorithms/maths/ChebyshevIteration.java new file mode 100644 index 000000000000..bc30f1ba6e7e --- /dev/null +++ b/src/main/java/com/thealgorithms/maths/ChebyshevIteration.java @@ -0,0 +1,181 @@ +package com.thealgorithms.maths; + +/** + * In numerical analysis, Chebyshev iteration is an iterative method for solving + * systems of linear equations Ax = b. It is designed for systems where the + * matrix A is symmetric positive-definite (SPD). + * + *

+ * This method is a "polynomial acceleration" method, meaning it finds the + * optimal polynomial to apply to the residual to accelerate convergence. + * + *

+ * It requires knowledge of the bounds of the eigenvalues of the matrix A: + * m(A) (smallest eigenvalue) and M(A) (largest eigenvalue). + * + *

+ * Wikipedia: https://en.wikipedia.org/wiki/Chebyshev_iteration + * + * @author Mitrajit Ghorui(KeyKyrios) + */ +public final class ChebyshevIteration { + + private ChebyshevIteration() { + } + + /** + * Solves the linear system Ax = b using the Chebyshev iteration method. + * + *

+ * NOTE: The matrix A *must* be symmetric positive-definite (SPD) for this + * algorithm to converge. + * + * @param a The matrix A (must be square, SPD). + * @param b The vector b. + * @param x0 The initial guess vector. + * @param minEigenvalue The smallest eigenvalue of A (m(A)). + * @param maxEigenvalue The largest eigenvalue of A (M(A)). + * @param maxIterations The maximum number of iterations to perform. + * @param tolerance The desired tolerance for the residual norm. + * @return The solution vector x. + * @throws IllegalArgumentException if matrix/vector dimensions are + * incompatible, + * if maxIterations <= 0, or if eigenvalues are invalid (e.g., minEigenvalue + * <= 0, maxEigenvalue <= minEigenvalue). + */ + public static double[] solve(double[][] a, double[] b, double[] x0, double minEigenvalue, double maxEigenvalue, int maxIterations, double tolerance) { + validateInputs(a, b, x0, minEigenvalue, maxEigenvalue, maxIterations, tolerance); + + int n = b.length; + double[] x = x0.clone(); + double[] r = vectorSubtract(b, matrixVectorMultiply(a, x)); + double[] p = new double[n]; + + double d = (maxEigenvalue + minEigenvalue) / 2.0; + double c = (maxEigenvalue - minEigenvalue) / 2.0; + + double alpha = 0.0; + double alphaPrev = 0.0; + + for (int k = 0; k < maxIterations; k++) { + double residualNorm = vectorNorm(r); + if (residualNorm < tolerance) { + return x; // Solution converged + } + + if (k == 0) { + alpha = 1.0 / d; + System.arraycopy(r, 0, p, 0, n); // p = r + } else { + double beta = c * alphaPrev / 2.0 * (c * alphaPrev / 2.0); + alpha = 1.0 / (d - beta / alphaPrev); + double[] pUpdate = scalarMultiply(beta / alphaPrev, p); + p = vectorAdd(r, pUpdate); // p = r + (beta / alphaPrev) * p + } + + double[] xUpdate = scalarMultiply(alpha, p); + x = vectorAdd(x, xUpdate); // x = x + alpha * p + + // Recompute residual for accuracy + r = vectorSubtract(b, matrixVectorMultiply(a, x)); + alphaPrev = alpha; + } + + return x; // Return best guess after maxIterations + } + + /** + * Validates the inputs for the Chebyshev solver. + */ + private static void validateInputs(double[][] a, double[] b, double[] x0, double minEigenvalue, double maxEigenvalue, int maxIterations, double tolerance) { + int n = a.length; + if (n == 0) { + throw new IllegalArgumentException("Matrix A cannot be empty."); + } + if (n != a[0].length) { + throw new IllegalArgumentException("Matrix A must be square."); + } + if (n != b.length) { + throw new IllegalArgumentException("Matrix A and vector b dimensions do not match."); + } + if (n != x0.length) { + throw new IllegalArgumentException("Matrix A and vector x0 dimensions do not match."); + } + if (minEigenvalue <= 0) { + throw new IllegalArgumentException("Smallest eigenvalue must be positive (matrix must be positive-definite)."); + } + if (maxEigenvalue <= minEigenvalue) { + throw new IllegalArgumentException("Max eigenvalue must be strictly greater than min eigenvalue."); + } + if (maxIterations <= 0) { + throw new IllegalArgumentException("Max iterations must be positive."); + } + if (tolerance <= 0) { + throw new IllegalArgumentException("Tolerance must be positive."); + } + } + + // --- Vector/Matrix Helper Methods --- + /** + * Computes the product of a matrix A and a vector v (Av). + */ + private static double[] matrixVectorMultiply(double[][] a, double[] v) { + int n = a.length; + double[] result = new double[n]; + for (int i = 0; i < n; i++) { + double sum = 0; + for (int j = 0; j < n; j++) { + sum += a[i][j] * v[j]; + } + result[i] = sum; + } + return result; + } + + /** + * Computes the subtraction of two vectors (v1 - v2). + */ + private static double[] vectorSubtract(double[] v1, double[] v2) { + int n = v1.length; + double[] result = new double[n]; + for (int i = 0; i < n; i++) { + result[i] = v1[i] - v2[i]; + } + return result; + } + + /** + * Computes the addition of two vectors (v1 + v2). + */ + private static double[] vectorAdd(double[] v1, double[] v2) { + int n = v1.length; + double[] result = new double[n]; + for (int i = 0; i < n; i++) { + result[i] = v1[i] + v2[i]; + } + return result; + } + + /** + * Computes the product of a scalar and a vector (s * v). + */ + private static double[] scalarMultiply(double scalar, double[] v) { + int n = v.length; + double[] result = new double[n]; + for (int i = 0; i < n; i++) { + result[i] = scalar * v[i]; + } + return result; + } + + /** + * Computes the L2 norm (Euclidean norm) of a vector. + */ + private static double vectorNorm(double[] v) { + double sumOfSquares = 0; + for (double val : v) { + sumOfSquares += val * val; + } + return Math.sqrt(sumOfSquares); + } +} diff --git a/src/test/java/com/thealgorithms/maths/ChebyshevIterationTest.java b/src/test/java/com/thealgorithms/maths/ChebyshevIterationTest.java new file mode 100644 index 000000000000..d5cf83818fa4 --- /dev/null +++ b/src/test/java/com/thealgorithms/maths/ChebyshevIterationTest.java @@ -0,0 +1,105 @@ +package com.thealgorithms.maths; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import org.junit.jupiter.api.Test; + +public class ChebyshevIterationTest { + + @Test + public void testSolveSimple2x2Diagonal() { + double[][] a = {{2, 0}, {0, 1}}; + double[] b = {2, 2}; + double[] x0 = {0, 0}; + double minEig = 1.0; + double maxEig = 2.0; + int maxIter = 50; + double tol = 1e-9; + double[] expected = {1.0, 2.0}; + + double[] result = ChebyshevIteration.solve(a, b, x0, minEig, maxEig, maxIter, tol); + assertArrayEquals(expected, result, 1e-9); + } + + @Test + public void testSolve2x2Symmetric() { + double[][] a = {{4, 1}, {1, 3}}; + double[] b = {1, 2}; + double[] x0 = {0, 0}; + double minEig = (7.0 - Math.sqrt(5.0)) / 2.0; + double maxEig = (7.0 + Math.sqrt(5.0)) / 2.0; + int maxIter = 100; + double tol = 1e-10; + double[] expected = {1.0 / 11.0, 7.0 / 11.0}; + + double[] result = ChebyshevIteration.solve(a, b, x0, minEig, maxEig, maxIter, tol); + assertArrayEquals(expected, result, 1e-9); + } + + @Test + public void testAlreadyAtSolution() { + double[][] a = {{2, 0}, {0, 1}}; + double[] b = {2, 2}; + double[] x0 = {1, 2}; + double minEig = 1.0; + double maxEig = 2.0; + int maxIter = 10; + double tol = 1e-5; + double[] expected = {1.0, 2.0}; + + double[] result = ChebyshevIteration.solve(a, b, x0, minEig, maxEig, maxIter, tol); + assertArrayEquals(expected, result, 0.0); + } + + @Test + public void testMismatchedDimensionsAB() { + double[][] a = {{1, 0}, {0, 1}}; + double[] b = {1}; + double[] x0 = {0, 0}; + assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, 1, 2, 10, 1e-5)); + } + + @Test + public void testMismatchedDimensionsAX() { + double[][] a = {{1, 0}, {0, 1}}; + double[] b = {1, 1}; + double[] x0 = {0}; + assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, 1, 2, 10, 1e-5)); + } + + @Test + public void testNonSquareMatrix() { + double[][] a = {{1, 0, 0}, {0, 1, 0}}; + double[] b = {1, 1}; + double[] x0 = {0, 0}; + assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, 1, 2, 10, 1e-5)); + } + + @Test + public void testInvalidEigenvalues() { + double[][] a = {{1, 0}, {0, 1}}; + double[] b = {1, 1}; + double[] x0 = {0, 0}; + assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, 2, 1, 10, 1e-5)); + assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, 1, 1, 10, 1e-5)); + } + + @Test + public void testNonPositiveDefinite() { + double[][] a = {{1, 0}, {0, 1}}; + double[] b = {1, 1}; + double[] x0 = {0, 0}; + assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, 0, 1, 10, 1e-5)); + assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, -1, 1, 10, 1e-5)); + } + + @Test + public void testInvalidIterationCount() { + double[][] a = {{1, 0}, {0, 1}}; + double[] b = {1, 1}; + double[] x0 = {0, 0}; + assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, 1, 2, 0, 1e-5)); + assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, 1, 2, -1, 1e-5)); + } +}