Skip to content

Commit

Permalink
feat: Added tridiagonal_solve in tensorflow frontend (ivy-llc#23279)
Browse files Browse the repository at this point in the history
Co-authored-by: NripeshN <nripesh14@gmail.com>
  • Loading branch information
2 people authored and Kacper-W-Kozdon committed Dec 4, 2023
1 parent 8a5020b commit 48df5d3
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 0 deletions.
49 changes: 49 additions & 0 deletions ivy/functional/frontends/tensorflow/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,3 +458,52 @@ def tensorsolve(a, b, axes):
@to_ivy_arrays_and_back
def trace(x, name=None):
return ivy.trace(x, axis1=-2, axis2=-1)


@to_ivy_arrays_and_back
@with_supported_dtypes(
{
"2.13.0 and below": (
"float32",
"float64",
"complex64",
"complex128",
)
},
"tensorflow",
)
def tridiagonal_solve(
diagonals,
rhs,
diagonals_format="compact",
transpose_rhs=False,
conjugate_rhs=False,
name=None,
partial_pivoting=True,
perturb_singular=False,
):
if transpose_rhs is True:
rhs_copy = ivy.matrix_transpose(rhs)
if conjugate_rhs is True:
rhs_copy = ivy.conj(rhs)
if not transpose_rhs and not conjugate_rhs:
rhs_copy = ivy.array(rhs)

if diagonals_format == "matrix":
return ivy.solve(diagonals, rhs_copy)
elif diagonals_format in ["sequence", "compact"]:
diagonals = ivy.array(diagonals)
dim = diagonals[0].shape[0]
diagonals[[0, -1], [-1, 0]] = 0
dummy_idx = [0, 0]
indices = ivy.array([
[(i, i + 1) for i in range(dim - 1)] + [dummy_idx],
[(i, i) for i in range(dim)],
[dummy_idx] + [(i + 1, i) for i in range(dim - 1)],
])
constructed_matrix = ivy.scatter_nd(
indices, diagonals, shape=ivy.array([dim, dim])
)
return ivy.solve(constructed_matrix, rhs_copy)
else:
raise "Unexpected diagonals_format"
106 changes: 106 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_tensorflow/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,41 @@ def _get_second_matrix(draw):
)


@st.composite
def _get_tridiagonal_dtype_matrix_format(draw):
input_dtype_strategy = st.shared(
st.sampled_from(draw(helpers.get_dtypes("float_and_complex"))),
key="shared_dtype",
)
input_dtype = draw(input_dtype_strategy)
shared_size = draw(
st.shared(helpers.ints(min_value=2, max_value=4), key="shared_size")
)
diagonals_format = draw(st.sampled_from(["compact", "sequence", "matrix"]))
if diagonals_format == "matrix":
matrix = draw(
helpers.array_values(
dtype=input_dtype,
shape=tuple([shared_size, shared_size]),
min_value=2,
max_value=5,
).filter(tridiagonal_matrix_filter)
)
elif diagonals_format in ["compact", "sequence"]:
matrix = draw(
helpers.array_values(
dtype=input_dtype,
shape=tuple([3, shared_size]),
min_value=2,
max_value=5,
).filter(tridiagonal_compact_filter)
)
if diagonals_format == "sequence":
matrix = list(matrix)

return input_dtype, matrix, diagonals_format


# --- Main --- #
# ------------ #

Expand Down Expand Up @@ -1207,3 +1242,74 @@ def test_tensorflow_trace(
fn_tree=fn_tree,
x=x[0],
)


# tridiagonal_solve
@handle_frontend_test(
fn_tree="tensorflow.linalg.tridiagonal_solve",
x=_get_tridiagonal_dtype_matrix_format(),
y=_get_second_matrix(),
transpose_rhs=st.just(False),
conjugate_rhs=st.booleans(),
)
def test_tensorflow_tridiagonal_solve(
*,
x,
y,
transpose_rhs,
conjugate_rhs,
frontend,
backend_fw,
test_flags,
fn_tree,
on_device,
):
input_dtype1, x1, diagonals_format = x
input_dtype2, x2 = y
helpers.test_frontend_function(
input_dtypes=[input_dtype1, input_dtype2],
backend_to_test=backend_fw,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
rtol=1e-3,
atol=1e-3,
diagonals=x1,
rhs=x2,
diagonals_format=diagonals_format,
transpose_rhs=transpose_rhs,
conjugate_rhs=conjugate_rhs,
)


def tridiagonal_compact_filter(x):
diagonals = ivy.array(x)
dim = diagonals[0].shape[0]
diagonals[[0, -1], [-1, 0]] = 0
dummy_idx = [0, 0]
indices = ivy.array([
[(i, i + 1) for i in range(dim - 1)] + [dummy_idx],
[(i, i) for i in range(dim)],
[dummy_idx] + [(i + 1, i) for i in range(dim - 1)],
])
matrix = ivy.scatter_nd(
indices, diagonals, ivy.array([dim, dim]), reduction="replace"
)
return tridiagonal_matrix_filter(matrix)


def tridiagonal_matrix_filter(x):
dim = x.shape[0]
if ivy.abs(ivy.det(x)) < 1e-3:
return False
for i in range(dim):
for j in range(dim):
cell = x[i][j]
if i == j or i == j - 1 or i == j + 1:
if cell == 0:
return False
else:
if cell != 0:
return False
return True

0 comments on commit 48df5d3

Please sign in to comment.