Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Mar 19, 2022
1 parent 3a75902 commit e057897
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions src/xarray_einstats/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def test_matmul_dims2(self, matrices):

def test_matmul_dims3(self):
rng = np.random.default_rng(3)
da = xr.DataArray(rng.normal(size=(2,3,5,7)), dims=["m", "n", "l", "p"])
da = xr.DataArray(rng.normal(size=(2, 3, 5, 7)), dims=["m", "n", "l", "p"])
db = da.rename(m="m_bis")
out = matmul(da, db, dims=("m", "n", "m_bis"))
assert out.shape == (5, 7, 2, 2)
Expand All @@ -162,15 +162,15 @@ def test_matmul_dims3(self):

def test_matmul_dims3_rename(self):
rng = np.random.default_rng(3)
da = xr.DataArray(rng.normal(size=(2,3,5,7)), dims=["m", "n", "l", "p"])
da = xr.DataArray(rng.normal(size=(2, 3, 5, 7)), dims=["m", "n", "l", "p"])
out = matmul(da, da, dims=("m", "n", "m"))
assert out.shape == (5, 7, 2, 2)
assert_dims_not_in_da(out, ["n"])
assert_dims_in_da(out, ("m", "m2", "l", "p"))

def test_matmul_dims22(self):
rng = np.random.default_rng(3)
da = xr.DataArray(rng.normal(size=(2,3,5,7)), dims=["m", "n", "l", "p"])
da = xr.DataArray(rng.normal(size=(2, 3, 5, 7)), dims=["m", "n", "l", "p"])
db = da.rename(n="n_bis", m="n")
out = matmul(da, db, dims=(("m", "n"), ("n_bis", "n")))
assert out.shape == (5, 7, 2, 2)
Expand All @@ -179,14 +179,13 @@ def test_matmul_dims22(self):

def test_matmul_dims22_rename(self):
rng = np.random.default_rng(3)
da = xr.DataArray(rng.normal(size=(2,3,5,7)), dims=["m", "n", "l", "p"])
da = xr.DataArray(rng.normal(size=(2, 3, 5, 7)), dims=["m", "n", "l", "p"])
db = da.rename(n="n_bis")
out = matmul(da, db, dims=(("m", "n"), ("n_bis", "m")))
assert out.shape == (5, 7, 2, 2)
assert_dims_not_in_da(out, ["n_bis", "n"])
assert_dims_in_da(out, ["m", "m2", "l", "p"])


def test_inv_matmul(self, matrices):
aux = inv(matrices, dims=("dim", "dim2"))
out = matmul(matrices, aux, dims=("dim", "dim2"))
Expand Down

0 comments on commit e057897

Please sign in to comment.