Skip to content

Commit

Permalink
Fixes BSS algorithm for new API of numpy.linalg.solve in numpy 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
fakufaku committed Jun 18, 2024
1 parent 449f725 commit 9c5a82e
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pyroomacoustics/bss/auxiva.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def demix(Y, X, W):
)

WV = np.matmul(W_hat, V)
W[:, s, :] = np.conj(np.linalg.solve(WV, eyes[:, :, s]))
W[:, s, :] = np.conj(np.linalg.solve(WV, eyes[:, :, [s]]))[..., 0]

# normalize
denom = np.matmul(
Expand Down
4 changes: 2 additions & 2 deletions pyroomacoustics/bss/fastmnmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def separate():

try:
tmp_FM = np.linalg.solve(
np.matmul(Q_FMM, V_FMM), np.eye(n_chan)[None, m]
)
np.matmul(Q_FMM, V_FMM), np.eye(n_chan)[None, :, [m]]
)[..., 0]
except np.linalg.LinAlgError:
# If Gaussian elimination fails due to a singlular matrix, we
# switch to the pseudo-inverse solution
Expand Down
4 changes: 2 additions & 2 deletions pyroomacoustics/bss/fastmnmf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ def separate():
np.einsum("ftij, ft -> fij", XX_FTMM, 1 / Y_FTM[..., m]) / n_frames
)
tmp_FM = np.linalg.solve(
np.matmul(Q_FMM, V_FMM), np.eye(n_chan)[None, m]
)
np.matmul(Q_FMM, V_FMM), np.eye(n_chan)[None, :, [m]]
)[..., 0]
Q_FMM[:, m] = (
tmp_FM
/ np.sqrt(
Expand Down
2 changes: 1 addition & 1 deletion pyroomacoustics/bss/ilrma.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def demix(Y, X, W):
C = np.matmul((X * iR[s, :, None, :]), np.conj(X.swapaxes(1, 2))) / n_frames

WV = np.matmul(W, C)
W[:, s, :] = np.conj(np.linalg.solve(WV, eyes[:, :, s]))
W[:, s, :] = np.conj(np.linalg.solve(WV, eyes[:, :, [s]]))[..., 0]

# normalize
denom = np.matmul(
Expand Down
2 changes: 1 addition & 1 deletion pyroomacoustics/bss/sparseauxiva.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def demixsparse(Y, X, S, W):
W_H = np.conj(np.swapaxes(W, 1, 2))
WV = np.matmul(W_H, V[:, s, :, :])
rhs = I[None, :, s][[0] * WV.shape[0], :]
W[:, :, s] = np.linalg.solve(WV, rhs)
W[:, :, s] = np.linalg.solve(WV, rhs[..., None])[..., 0]

# normalize
P1 = np.conj(W[:, :, s])
Expand Down

0 comments on commit 9c5a82e

Please sign in to comment.