Skip to content

Commit

Permalink
Merge pull request #1 from jhonsnow456/patch1-README
Browse files Browse the repository at this point in the history
Patch1 readme
  • Loading branch information
jhonsnow456 committed Nov 4, 2021
2 parents e9414fa + 11be825 commit 6102d8d
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 5 deletions.
6 changes: 3 additions & 3 deletions README.md
Expand Up @@ -138,7 +138,7 @@ To test pull requests, use

## Regenerate Experimental <span class="title-ref">LaTeX</span> Parser/Lexer

The parser and lexer generated with the [ANTLR4](http://antlr4.org)
The parser and lexer were generated with the [ANTLR4](http://antlr4.org)
toolchain in `sympy/parsing/latex/_antlr` and checked into the repo.
Presently, most users should not need to regenerate these files, but
if you plan to work on this feature, you will need the `antlr4`
Expand Down Expand Up @@ -197,7 +197,7 @@ on Stack Overflow using the [sympy](https://stackoverflow.com/questions/tagged/s

SymPy was started by Ondřej Čertík in 2005, he wrote some code during
the summer, then he wrote some more code during summer 2006. In February
2007, Fabian Pedregosa joined the project and helped fixed many things,
2007, Fabian Pedregosa joined the project and helped fix many things,
contributed documentation, and made it alive again. 5 students (Mateusz
Paprocki, Brian Jorgensen, Jason Gedge, Robert Schwarz, and Chris Wu)
improved SymPy incredibly during summer 2007 as part of the Google
Expand Down Expand Up @@ -270,7 +270,7 @@ A BibTeX entry for LaTeX users is
```

SymPy is BSD licensed, so you are free to use it whatever you like, be
it academic, commercial, creating forks or derivatives, as long as you
it academic, commercial, creating forks, or derivatives, as long as you
copy the BSD statement if you redistribute it (see the LICENSE file for
details). That said, although not required by the SymPy license, if it
is convenient for you, please cite SymPy when using it in your work and
Expand Down
44 changes: 43 additions & 1 deletion sympy/tensor/array/expressions/conv_array_to_matrix.py
Expand Up @@ -4,6 +4,7 @@
from functools import singledispatch
from itertools import accumulate

from sympy import MatMul, Basic
from sympy.assumptions.ask import (Q, ask)
from sympy.core.mul import Mul
from sympy.core.singleton import S
Expand Down Expand Up @@ -129,6 +130,43 @@ def _support_function_tp1_recognize(contraction_indices, args):
return editor.to_array_contraction()


def _find_trivial_matrices_rewrite(expr: ArrayTensorProduct):
# If there are matrices of trivial shape in the tensor product (i.e. shape
# (1, 1)), try to check if there is a suitable non-trivial MatMul where the
# expression can be inserted.

# For example, if "a" has shape (1, 1) and "b" has shape (k, 1), the
# expressions "ArrayTensorProduct(a, b*b.T)" can be rewritten as
# "b*a*b.T"

trivial_matrices = []
pos: Optional[int] = None
first: Optional[MatrixExpr] = None
second: Optional[MatrixExpr] = None
removed: List[int] = []
counter: int = 0
args: List[Optional[Basic]] = [i for i in expr.args]
for i, arg in enumerate(expr.args):
if isinstance(arg, MatrixExpr):
if arg.shape == (1, 1):
trivial_matrices.append(arg)
args[i] = None
removed.extend([counter, counter+1])
elif pos is None and isinstance(arg, MatMul):
margs = arg.args
for j, e in enumerate(margs):
if isinstance(e, MatrixExpr) and e.shape[1] == 1:
pos = i
first = MatMul.fromiter(margs[:j+1])
second = MatMul.fromiter(margs[j+1:])
break
counter += get_rank(arg)
if pos is None:
return expr, []
args[pos] = (first*MatMul.fromiter(i for i in trivial_matrices)*second).doit()
return ArrayTensorProduct(*[i for i in args if i is not None]), removed


@singledispatch
def _array2matrix(expr):
return expr
Expand Down Expand Up @@ -369,7 +407,11 @@ def _(expr: ArrayTensorProduct):
else:
newargs.append(arg)
pending = None
return _a2m_tensor_product(*newargs), sorted(removed)
newexpr, newremoved = _a2m_tensor_product(*newargs), sorted(removed)
if isinstance(newexpr, ArrayTensorProduct):
newexpr, newremoved2 = _find_trivial_matrices_rewrite(newexpr)
newremoved = _combine_removed(-1, newremoved, newremoved2)
return newexpr, newremoved


@_remove_trivial_dims.register(ArrayAdd) # type: ignore
Expand Down
Expand Up @@ -14,7 +14,7 @@
from sympy.matrices import Trace, MatMul, Transpose
from sympy.tensor.array.expressions.array_expressions import ZeroArray, OneArray, \
ArrayTensorProduct, ArrayAdd, PermuteDims, ArrayDiagonal, \
ArrayContraction, ArrayElement
ArrayContraction, ArrayElement, ArraySymbol
from sympy.testing.pytest import raises


Expand Down Expand Up @@ -393,6 +393,19 @@ def test_arrayexpr_convert_array_to_matrix_remove_trivial_dims():
assert ret == PermuteDims(ArrayContraction(ArrayTensorProduct(A, B, C, M), (3, 4)), [0, 2, 3, 4, 5, 1])
assert removed == []

# Trivial matrices are sometimes inserted into MatMul expressions:

cg = ArrayTensorProduct(b*b.T, a.T*a)
ret, removed = _remove_trivial_dims(cg)
assert ret == b*a.T*a*b.T
assert removed == [2, 3]

Xs = ArraySymbol("X", (3, 2, k))
cg = ArrayTensorProduct(M, Xs, b.T*c, a*a.T, b*b.T, c.T*d)
ret, removed = _remove_trivial_dims(cg)
assert ret == ArrayTensorProduct(M, Xs, a*b.T*c*c.T*d*a.T, b*b.T)
assert removed == [5, 6, 11, 12]


def test_arrayexpr_convert_array_to_matrix_diag2contraction_diagmatrix():
cg = ArrayDiagonal(ArrayTensorProduct(M, a), (1, 2))
Expand Down

0 comments on commit 6102d8d

Please sign in to comment.