Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reshape #22

Merged
merged 4 commits into from
Dec 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "tensortrax"
version = "0.2.1"
version = "0.2.2"
description = "Math on (Hyper-Dual) Tensors with Trailing Axes"
readme = "README.md"
requires-python = ">=3.7"
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = tensortrax
version = 0.2.1
version = 0.2.2
author = Andreas Dutzler
author_email = a.dutzler@gmail.com
description = Math on (Hyper-Dual) Tensors with Trailing Axes
Expand Down
25 changes: 22 additions & 3 deletions tensortrax/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ def T(self):
def ravel(self, order="C"):
return ravel(self, order=order)

def reshape(self, *shape, order="C"):
return reshape(self, newshape=shape, order=order)

def __matmul__(self, B):
return matmul(self, B)

Expand Down Expand Up @@ -222,6 +225,22 @@ def ravel(A, order="C"):
return np.ravel(A, order=order)


def reshape(A, newshape, order="C"):
if isinstance(A, Tensor):
δtrax = δ(A).shape[len(A.shape) :]
Δtrax = Δ(A).shape[len(A.shape) :]
Δδtrax = Δδ(A).shape[len(A.shape) :]
return Tensor(
x=f(A).reshape(*newshape, *A.trax, order=order),
δx=δ(A).reshape(*newshape, *δtrax, order=order),
Δx=Δ(A).reshape(*newshape, *Δtrax, order=order),
Δδx=Δδ(A).reshape(*newshape, *Δδtrax, order=order),
ntrax=A.ntrax,
)
else:
return np.reshape(A, newshape=newshape, order=order)


def einsum3(subscripts, *operands):
"Einsum with three operands."
A, B, C = operands
Expand Down Expand Up @@ -387,7 +406,7 @@ def transpose(A):


def matmul(A, B):
ik = "abcdefghijklm"[13-len(A.shape):]
kj = "mnopqrstuvwxy"[: len(B.shape)]
ij = (ik + kj).replace("m", "")
ik = "ik"[2 - len(A.shape) :]
kj = "kj"[: len(B.shape)]
ij = (ik + kj).replace("k", "")
return einsum(f"{ik}...,{kj}...->{ij}...", A, B)
1 change: 1 addition & 0 deletions tensortrax/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@
matmul,
diagonal,
ravel,
reshape,
)
from . import _math_array as array
2 changes: 1 addition & 1 deletion tensortrax/math/_math_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np

from .._tensor import Tensor, ravel, einsum, matmul, f, δ, Δ, Δδ
from .._tensor import Tensor, ravel, reshape, einsum, matmul, f, δ, Δ, Δδ
from ._linalg import _linalg_array as array


Expand Down
8 changes: 8 additions & 0 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ def test_math():

u[0] = t[0]

t.reshape(9)
t.reshape(3, 3)

tm.reshape(t, (9,))
tm.reshape(t, (3, 3))

tm.reshape(x, (3, 3, 100))


if __name__ == "__main__":
test_math()