diff --git a/tests/test_conjugate_gradient.py b/tests/test_conjugate_gradient.py index 239dbe3ecd..64efe3b168 100644 --- a/tests/test_conjugate_gradient.py +++ b/tests/test_conjugate_gradient.py @@ -19,6 +19,7 @@ class TestConjugateGradient(unittest.TestCase): + def test_real_valued_inverse(self): """Test ConjugateGradient with real-valued input: when the input is real value, the output should be the inverse of the matrix.""" diff --git a/tests/test_sure_loss.py b/tests/test_sure_loss.py index 945da657bf..903f9bd2ca 100644 --- a/tests/test_sure_loss.py +++ b/tests/test_sure_loss.py @@ -19,6 +19,7 @@ class TestSURELoss(unittest.TestCase): + def test_real_value(self): """Test SURELoss with real-valued input: when the input is real value, the loss should be 0.0.""" sure_loss_real = SURELoss(perturb_noise=torch.zeros(2, 1, 128, 128), eps=0.1)