Skip to content

Commit

Permalink
Remove duplicated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanYashchuk committed Nov 30, 2020
1 parent 262b786 commit 6b9b3f5
Showing 1 changed file with 0 additions and 106 deletions.
106 changes: 0 additions & 106 deletions test/test_linalg.py
Expand Up @@ -4855,112 +4855,6 @@ def maybe_squeeze_result(l, r, result):
for indices in itertools.product((True, False), repeat=2):
verify_batched_matmul(*indices)

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.double)
def test_lu_solve_batched_non_contiguous(self, device, dtype):
from numpy.linalg import solve
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value

A = random_fullrank_matrix_distinct_singular_value(2, 2, dtype=dtype, device='cpu')
b = torch.randn(2, 2, 2, dtype=dtype, device='cpu')
x_exp = torch.as_tensor(solve(A.permute(0, 2, 1).numpy(), b.permute(2, 1, 0).numpy())).to(device)
A = A.to(device).permute(0, 2, 1)
b = b.to(device).permute(2, 1, 0)
assert not A.is_contiguous() and not b.is_contiguous(), "contiguous inputs"
LU_data, LU_pivots = torch.lu(A)
x = torch.lu_solve(b, LU_data, LU_pivots)
self.assertEqual(x, x_exp)

def lu_solve_test_helper(self, A_dims, b_dims, pivot, device, dtype):
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value

b = torch.randn(*b_dims, dtype=dtype, device=device)
A = random_fullrank_matrix_distinct_singular_value(*A_dims, dtype=dtype, device=device)
LU_data, LU_pivots, info = torch.lu(A, get_infos=True, pivot=pivot)
self.assertEqual(info, torch.zeros_like(info))
return b, A, LU_data, LU_pivots

@skipCPUIfNoLapack
@skipCUDAIfNoMagma
@dtypes(torch.double)
def test_lu_solve(self, device, dtype):
def sub_test(pivot):
for k, n in zip([2, 3, 5], [3, 5, 7]):
b, A, LU_data, LU_pivots = self.lu_solve_test_helper((n,), (n, k), pivot, device, dtype)
x = torch.lu_solve(b, LU_data, LU_pivots)
self.assertLessEqual(b.dist(A.mm(x)), 1e-12)

sub_test(True)
if self.device_type == 'cuda':
sub_test(False)

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.double)
def test_lu_solve_batched(self, device, dtype):
def sub_test(pivot):
def lu_solve_batch_test_helper(A_dims, b_dims, pivot):
b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, pivot, device, dtype)
x_exp_list = []
for i in range(b_dims[0]):
x_exp_list.append(torch.lu_solve(b[i], LU_data[i], LU_pivots[i]))
x_exp = torch.stack(x_exp_list) # Stacked output
x_act = torch.lu_solve(b, LU_data, LU_pivots) # Actual output
self.assertEqual(x_exp, x_act) # Equality check
self.assertLessEqual(b.dist(torch.matmul(A, x_act)), 1e-12) # Correctness check

for batchsize in [1, 3, 4]:
lu_solve_batch_test_helper((5, batchsize), (batchsize, 5, 10), pivot)

# Tests tensors with 0 elements
b = torch.randn(3, 0, 3, dtype=dtype, device=device)
A = torch.randn(3, 0, 0, dtype=dtype, device=device)
LU_data, LU_pivots = torch.lu(A)
self.assertEqual(torch.empty_like(b), b.lu_solve(LU_data, LU_pivots))

sub_test(True)
if self.device_type == 'cuda':
sub_test(False)

@slowTest
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.double)
def test_lu_solve_batched_many_batches(self, device, dtype):
def run_test(A_dims, b_dims):
b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype)
x = torch.lu_solve(b, LU_data, LU_pivots)
b_ = torch.matmul(A, x)
self.assertEqual(b_, b.expand_as(b_))

run_test((5, 65536), (65536, 5, 10))
run_test((5, 262144), (262144, 5, 10))

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.double)
def test_lu_solve_batched_broadcasting(self, device, dtype):
from numpy.linalg import solve
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value

def run_test(A_dims, b_dims, pivot=True):
A_matrix_size = A_dims[-1]
A_batch_dims = A_dims[:-2]
A = random_fullrank_matrix_distinct_singular_value(A_matrix_size, *A_batch_dims, dtype=dtype)
b = torch.randn(*b_dims, dtype=dtype)
x_exp = torch.as_tensor(solve(A.numpy(), b.numpy())).to(dtype=dtype, device=device)
A, b = A.to(device), b.to(device)
LU_data, LU_pivots = torch.lu(A, pivot=pivot)
x = torch.lu_solve(b, LU_data, LU_pivots)
self.assertEqual(x, x_exp)

# test against numpy.linalg.solve
run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6)) # no broadcasting
run_test((2, 1, 3, 4, 4), (4, 6)) # broadcasting b
run_test((4, 4), (2, 1, 3, 4, 2)) # broadcasting A
run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5)) # broadcasting A & b

@precisionOverride({torch.float32: 1e-5, torch.complex64: 1e-5})
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
Expand Down

0 comments on commit 6b9b3f5

Please sign in to comment.