Skip to content

Commit

Permalink
Added missing imports
Browse files Browse the repository at this point in the history
  • Loading branch information
astrofrog committed Feb 19, 2016
1 parent 2896039 commit 13453ea
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions astropy/modeling/tests/test_quantities_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@
from __future__ import (absolute_import, unicode_literals, division,
print_function)

import numpy as np

from ..models import Gaussian1D
from ... import units as u
from ...units import UnitsError
from ...tests.helper import pytest, assert_quantity_allclose
from ...utils import NumpyRNGContext
from .. import fitting


# Fitting should be as intuitive as possible to the user. Essentially, models
Expand All @@ -29,7 +36,7 @@ def test_fitting_simple():
y = y * u.Jy

# Fit the data using a Gaussian with units
g_init = models.Gaussian1D(amplitude=1. * u.mJy, mean=3 * u.cm, stddev=2 * u.mm)
g_init = Gaussian1D(amplitude=1. * u.mJy, mean=3 * u.cm, stddev=2 * u.mm)
fit_g = fitting.LevMarLSQFitter()
g = fit_g(g_init, x, y)

Expand All @@ -45,16 +52,16 @@ def test_fitting_missing_data_units():
"""
Raise an error if the model has units but the data doesn't
"""
g_init = models.Gaussian1D(amplitude=1. * u.mJy, mean=3 * u.cm, stddev=2 * u.mm)
g_init = Gaussian1D(amplitude=1. * u.mJy, mean=3 * u.cm, stddev=2 * u.mm)
fit_g = fitting.LevMarLSQFitter()

with pytest.raises(UnitsError) as exc:
g = fit_g(g_init, [1, 2, 3], [4, 5, 6])
fit_g(g_init, [1, 2, 3], [4, 5, 6])
assert exc.value.args[0] == ("Units of input 'x', (dimensionless), does not "
"match required units for model input, cm (length)")

with pytest.raises(UnitsError) as exc:
g = fit_g(g_init, [1, 2, 3] * u.m, [4, 5, 6])
fit_g(g_init, [1, 2, 3] * u.m, [4, 5, 6])
assert exc.value.args[0] == ("Units of input 'y', (dimensionless), does not "
"match required units for model output, Jy")

Expand All @@ -67,20 +74,20 @@ def test_fitting_missing_model_units():

# TODO: determine whether this breaks backward-compatibility.

g_init = models.Gaussian1D(amplitude=1., mean=3, stddev=2)
g_init = Gaussian1D(amplitude=1., mean=3, stddev=2)
fit_g = fitting.LevMarLSQFitter()

with pytest.raises(UnitsError) as exc:
g = fit_g(g_init, [1, 2, 3] * u.m, [4, 5, 6] * u.Jy)
fit_g(g_init, [1, 2, 3] * u.m, [4, 5, 6] * u.Jy)
assert exc.value.args[0] == ("Units of input 'x', m (length), does not "
"match required units for model input, "
"(dimensionless)")

g_init = models.Gaussian1D(amplitude=1., mean=3 * u.m, stddev=2 * u.m)
g_init = Gaussian1D(amplitude=1., mean=3 * u.m, stddev=2 * u.m)
fit_g = fitting.LevMarLSQFitter()

with pytest.raises(UnitsError) as exc:
g = fit_g(g_init, [1, 2, 3] * u.m, [4, 5, 6] * u.Jy)
fit_g(g_init, [1, 2, 3] * u.m, [4, 5, 6] * u.Jy)
assert exc.value.args[0] == ("Units of input 'y', Jy, does not "
"match required units for model output, "
"(dimensionless)")
Expand All @@ -92,11 +99,11 @@ def test_fitting_incompatible_units():
Raise an error if the data and model have incompatible units
"""

g_init = models.Gaussian1D(amplitude=1. * u.Jy, mean=3 * u.m, stddev=2 * u.cm)
g_init = Gaussian1D(amplitude=1. * u.Jy, mean=3 * u.m, stddev=2 * u.cm)
fit_g = fitting.LevMarLSQFitter()

with pytest.raises(UnitsError) as exc:
g = fit_g(g_init, [1, 2, 3] * u.Hz, [4, 5, 6] * u.Jy)
fit_g(g_init, [1, 2, 3] * u.Hz, [4, 5, 6] * u.Jy)
assert exc.value.args[0] == ("Units of input 'x', Hz (frequency), does not "
"match required units for model input, "
"m (length)")
Expand All @@ -110,7 +117,7 @@ def test_fitting_with_equivalencies():

# A simple test with the spectral equivalency

g_init = models.Gaussian1D(amplitude=1. * u.Jy, mean=3 * u.m, stddev=2 * u.cm)
g_init = Gaussian1D(amplitude=1. * u.Jy, mean=3 * u.m, stddev=2 * u.cm)
g_init.input_equivalencies = u.spectral()

fit_g = fitting.LevMarLSQFitter()
Expand Down

0 comments on commit 13453ea

Please sign in to comment.