diff --git a/docs_input/api/linalg/other/solve.rst b/docs_input/api/linalg/other/solve.rst index 7d2838324..b70619e09 100644 --- a/docs_input/api/linalg/other/solve.rst +++ b/docs_input/api/linalg/other/solve.rst @@ -20,3 +20,54 @@ For sparse ``A``, ``solve`` preserves the legacy sparse solve layout where right-hand sides are stacked by row and the operation solves ``A X^T = B^T``. Sparse direct solve support depends on the sparse format and may require cuDSS; please see :ref:`sparse_tensor_api`. + +Examples +~~~~~~~~ + +Dense vector right-hand side + +.. literalinclude:: ../../../../test/00_solver/SolveDense.cu + :language: cpp + :start-after: example-begin solve-test-1 + :end-before: example-end solve-test-1 + :dedent: + +Dense matrix right-hand side + +.. literalinclude:: ../../../../test/00_solver/SolveDense.cu + :language: cpp + :start-after: example-begin solve-test-2 + :end-before: example-end solve-test-2 + :dedent: + +In an expression + +.. literalinclude:: ../../../../test/00_solver/SolveDense.cu + :language: cpp + :start-after: example-begin solve-test-3 + :end-before: example-end solve-test-3 + :dedent: + +Batched vector right-hand side + +.. literalinclude:: ../../../../test/00_solver/SolveDense.cu + :language: cpp + :start-after: example-begin solve-test-4 + :end-before: example-end solve-test-4 + :dedent: + +Batched matrix right-hand side + +.. literalinclude:: ../../../../test/00_solver/SolveDense.cu + :language: cpp + :start-after: example-begin solve-test-5 + :end-before: example-end solve-test-5 + :dedent: + +In-place right-hand side + +.. literalinclude:: ../../../../test/00_solver/SolveDense.cu + :language: cpp + :start-after: example-begin solve-test-6 + :end-before: example-end solve-test-6 + :dedent: diff --git a/test/00_solver/SolveDense.cu b/test/00_solver/SolveDense.cu index ec058a685..5123a4ba1 100644 --- a/test/00_solver/SolveDense.cu +++ b/test/00_solver/SolveDense.cu @@ -87,6 +87,7 @@ TYPED_TEST(DenseSolveTestFloatTypes, SolveVectorRHS) this->pb->NumpyToTensorView(B, "B"); // example-begin solve-test-1 + // Solve A(n x n) X = B(n) for vector X(n). (X = solve(A, B)).run(this->exec); // example-end solve-test-1 this->exec.sync(); @@ -111,7 +112,10 @@ TYPED_TEST(DenseSolveTestFloatTypes, SolveMatrixRHS) this->pb->NumpyToTensorView(A, "A"); this->pb->NumpyToTensorView(B, "B"); + // example-begin solve-test-2 + // Solve A(n x n) X = B(n x nrhs) for matrix X(n x nrhs). (X = solve(A, B)).run(this->exec); + // example-end solve-test-2 this->exec.sync(); MATX_TEST_ASSERT_COMPARE(this->pb, X, "X", this->thresh); @@ -135,7 +139,10 @@ TYPED_TEST(DenseSolveTestFloatTypes, SolveInExpression) this->pb->NumpyToTensorView(B, "B"); this->pb->NumpyToTensorView(C, "C"); + // example-begin solve-test-3 + // solve(A, B) may be used in a larger expression. (X = solve(A, B) * C).run(this->exec); + // example-end solve-test-3 this->exec.sync(); MATX_TEST_ASSERT_COMPARE(this->pb, X, "X", this->thresh); @@ -158,7 +165,10 @@ TYPED_TEST(DenseSolveTestFloatTypes, SolveBatchedVectorRHS) this->pb->NumpyToTensorView(A, "A"); this->pb->NumpyToTensorView(B, "B"); + // example-begin solve-test-4 + // Solve a batch of independent systems with vector right-hand sides. (X = solve(A, B)).run(this->exec); + // example-end solve-test-4 this->exec.sync(); MATX_TEST_ASSERT_COMPARE(this->pb, X, "X", this->thresh); @@ -182,7 +192,10 @@ TYPED_TEST(DenseSolveTestFloatTypes, SolveBatchedMatrixRHS) this->pb->NumpyToTensorView(A, "A"); this->pb->NumpyToTensorView(B, "B"); + // example-begin solve-test-5 + // Solve a batch of independent systems with matrix right-hand sides. (X = solve(A, B)).run(this->exec); + // example-end solve-test-5 this->exec.sync(); MATX_TEST_ASSERT_COMPARE(this->pb, X, "X", this->thresh); @@ -204,7 +217,10 @@ TYPED_TEST(DenseSolveTestFloatTypes, SolveInPlaceRHS) this->pb->NumpyToTensorView(A, "A"); this->pb->NumpyToTensorView(B, "B"); + // example-begin solve-test-6 + // The right-hand side can be overwritten with the solution. (B = solve(A, B)).run(this->exec); + // example-end solve-test-6 this->exec.sync(); MATX_TEST_ASSERT_COMPARE(this->pb, B, "X", this->thresh);