diff --git a/src/xarray_einstats/tests/test_linalg.py b/src/xarray_einstats/tests/test_linalg.py index 5428c33..2c68597 100644 --- a/src/xarray_einstats/tests/test_linalg.py +++ b/src/xarray_einstats/tests/test_linalg.py @@ -153,7 +153,7 @@ def test_matmul_dims2(self, matrices): def test_matmul_dims3(self): rng = np.random.default_rng(3) - da = xr.DataArray(rng.normal(size=(2,3,5,7)), dims=["m", "n", "l", "p"]) + da = xr.DataArray(rng.normal(size=(2, 3, 5, 7)), dims=["m", "n", "l", "p"]) db = da.rename(m="m_bis") out = matmul(da, db, dims=("m", "n", "m_bis")) assert out.shape == (5, 7, 2, 2) @@ -162,7 +162,7 @@ def test_matmul_dims3(self): def test_matmul_dims3_rename(self): rng = np.random.default_rng(3) - da = xr.DataArray(rng.normal(size=(2,3,5,7)), dims=["m", "n", "l", "p"]) + da = xr.DataArray(rng.normal(size=(2, 3, 5, 7)), dims=["m", "n", "l", "p"]) out = matmul(da, da, dims=("m", "n", "m")) assert out.shape == (5, 7, 2, 2) assert_dims_not_in_da(out, ["n"]) @@ -170,7 +170,7 @@ def test_matmul_dims3_rename(self): def test_matmul_dims22(self): rng = np.random.default_rng(3) - da = xr.DataArray(rng.normal(size=(2,3,5,7)), dims=["m", "n", "l", "p"]) + da = xr.DataArray(rng.normal(size=(2, 3, 5, 7)), dims=["m", "n", "l", "p"]) db = da.rename(n="n_bis", m="n") out = matmul(da, db, dims=(("m", "n"), ("n_bis", "n"))) assert out.shape == (5, 7, 2, 2) @@ -179,14 +179,13 @@ def test_matmul_dims22(self): def test_matmul_dims22_rename(self): rng = np.random.default_rng(3) - da = xr.DataArray(rng.normal(size=(2,3,5,7)), dims=["m", "n", "l", "p"]) + da = xr.DataArray(rng.normal(size=(2, 3, 5, 7)), dims=["m", "n", "l", "p"]) db = da.rename(n="n_bis") out = matmul(da, db, dims=(("m", "n"), ("n_bis", "m"))) assert out.shape == (5, 7, 2, 2) assert_dims_not_in_da(out, ["n_bis", "n"]) assert_dims_in_da(out, ["m", "m2", "l", "p"]) - def test_inv_matmul(self, matrices): aux = inv(matrices, dims=("dim", "dim2")) out = matmul(matrices, aux, dims=("dim", "dim2"))