Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
adrn committed Dec 12, 2018
1 parent ae2051f commit 5fc4dd4
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 16 deletions.
2 changes: 1 addition & 1 deletion twobody/bary_trends.py
Expand Up @@ -59,7 +59,7 @@ def __call__(self, t):
return np.zeros_like(t).view(np.ndarray)

if self.t0 is not None:
t = (t - self.t0).tcb.jd
t = (t - self.t0).jd
else:
t = t.tcb.mjd

Expand Down
3 changes: 2 additions & 1 deletion twobody/orbit.py
Expand Up @@ -102,7 +102,8 @@ def __init__(self, elements=None, elements_type='kepler',
raise TypeError("'elements' must be an instance of an "
"OrbitalElements subclass.")

if barycenter is not None and not isinstance(barycenter, Barycenter):
if barycenter is not None and not isinstance(barycenter, (RVTrend,
Barycenter)):
raise TypeError("barycenter must be a twobody.Barycenter instance")

self.elements = elements
Expand Down
5 changes: 4 additions & 1 deletion twobody/tests/test_bary_trends.py
Expand Up @@ -35,4 +35,7 @@ def test_polynomial():
trend = PolynomialRVTrend([10.*u.km/u.s, 1.*u.km/u.s/u.day],
t0=Time(55555., format='mjd'))
res = trend(t.tcb.mjd)
print(res)

t = (Time(55555., format='mjd') +
np.sort(np.random.uniform(15., 23., 128))*u.day)
res = trend(t)
55 changes: 42 additions & 13 deletions twobody/tests/test_orbit.py
Expand Up @@ -15,9 +15,9 @@
# Package
from ..orbit import KeplerOrbit
from ..barycenter import Barycenter
from ..bary_trends import PolynomialRVTrend


def get_random_orbit(rnd=None, barycenter=None):
def get_random_orbit_pars(rnd=None, barycenter=None):
if rnd is None:
rnd = np.random.RandomState()

Expand All @@ -27,15 +27,11 @@ def get_random_orbit(rnd=None, barycenter=None):
i = rnd.uniform(0, 180)*u.deg
n = 2*np.pi/P
a = K * np.sqrt(1-e**2) / (n * np.sin(i))
return KeplerOrbit(P=P, a=a, e=e, i=i,
omega=rnd.uniform(0, 360)*u.deg,
Omega=rnd.uniform(0, 360)*u.deg,
M0=rnd.uniform(0, 360)*u.deg,
barycenter=barycenter)


def test_init():
pass
return dict(P=P, a=a, e=e, i=i,
omega=rnd.uniform(0, 360)*u.deg,
Omega=rnd.uniform(0, 360)*u.deg,
M0=rnd.uniform(0, 360)*u.deg,
barycenter=barycenter)


def test_radial_velocity():
Expand Down Expand Up @@ -65,7 +61,7 @@ def test_radial_velocity():

# Even when barycenter is moving, rv1 and rv2 should be the same
for n in range(n_iter):
orb = get_random_orbit(rnd, bc2)
orb = KeplerOrbit(**get_random_orbit_pars(rnd, bc2))

rv1 = orb.radial_velocity(times)
_rp = orb.reference_plane(times)
Expand All @@ -75,7 +71,7 @@ def test_radial_velocity():

# All RV's should be equivalent
for n in range(n_iter):
orb = get_random_orbit(rnd, bc1)
orb = KeplerOrbit(**get_random_orbit_pars(rnd, bc1))

rv1 = orb.radial_velocity(times)
_rp = orb.reference_plane(times)
Expand All @@ -87,6 +83,39 @@ def test_radial_velocity():
assert quantity_allclose(rv1, rv3, rtol=0, atol=1*u.km/u.s)


def test_velocity_trend():
rnd = np.random.RandomState(seed=42)

coeffs = [0*u.km/u.s,
0*u.km/u.s/u.day,
1e-2*u.km/u.s/u.day**2]
trend = PolynomialRVTrend(coeffs,
t0=Time('J2018.0'))
pars = get_random_orbit_pars(rnd, trend)
pars['a'] = 0*u.au
orbit = KeplerOrbit(**pars)

t = Time('J2018.0') + np.linspace(-100, 100, 256)*u.day
rv = orbit.radial_velocity(t)
assert quantity_allclose(rv[0], rv[-1])
assert quantity_allclose(rv[0], 100.*u.km/u.s)

# OK THIS SUCKS (copy pasta)
coeffs = [0*u.km/u.s,
1*u.km/u.s/u.day,
0*u.km/u.s/u.day**2]
trend = PolynomialRVTrend(coeffs,
t0=Time('J2018.0'))
pars = get_random_orbit_pars(rnd, trend)
pars['a'] = 0*u.au
orbit = KeplerOrbit(**pars)

t = Time('J2018.0') + np.linspace(-100, 100, 256)*u.day
rv = orbit.radial_velocity(t)
assert quantity_allclose(rv[0], -rv[-1])
assert quantity_allclose(rv[0], -100.*u.km/u.s)


@pytest.mark.skipif(not HAS_MPL, reason="matplotlib not installed")
def test_plotting():
orbit = KeplerOrbit(P=100*u.day, a=1.*u.au, e=0.1,
Expand Down

0 comments on commit 5fc4dd4

Please sign in to comment.