Skip to content

Commit

Permalink
testing max log likelihood
Browse files Browse the repository at this point in the history
  • Loading branch information
CDonnerer committed Jul 8, 2021
1 parent cb5097c commit bec6322
Showing 1 changed file with 29 additions and 11 deletions.
40 changes: 29 additions & 11 deletions tests/distributions/test_log_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,20 @@ def lognormal():
return LogNormal()


def test_target_validation(lognormal):
valid_target = np.array([0, 1, 4, 5, 10])
lognormal.check_target(valid_target)


@pytest.mark.parametrize(
"invalid_target",
[np.array([-0.1, 1.2]), pd.Series([-1.1, 0.4, 2.3])],
)
def test_target_validation_raises(lognormal, invalid_target):
with pytest.raises(ValueError):
lognormal.check_target(invalid_target)


@pytest.mark.parametrize(
"y, params, natural_gradient, expected_grad",
[
Expand All @@ -35,15 +49,19 @@ def test_gradient_calculation(lognormal, y, params, natural_gradient, expected_g
np.testing.assert_array_equal(grad, expected_grad)


def test_target_validation(lognormal):
valid_target = np.array([0, 1, 4, 5, 10])
lognormal.check_target(valid_target)

def test_loss(lognormal):
loss_name, loss_value = lognormal.loss(
y=np.array(
[
0,
]
),
params=np.array(
[
[1, 0],
]
),
)

@pytest.mark.parametrize(
"invalid_target",
[np.array([-0.1, 1.2]), pd.Series([-1.1, 0.4, 2.3])],
)
def test_target_validation_raises(lognormal, invalid_target):
with pytest.raises(ValueError):
lognormal.check_target(invalid_target)
assert loss_name == "LogNormalError"
assert loss_value == np.inf

0 comments on commit bec6322

Please sign in to comment.