diff --git a/twobody/bary_trends.py b/twobody/bary_trends.py index 324bd49..7b5867e 100644 --- a/twobody/bary_trends.py +++ b/twobody/bary_trends.py @@ -59,7 +59,7 @@ def __call__(self, t): return np.zeros_like(t).view(np.ndarray) if self.t0 is not None: - t = (t - self.t0).tcb.jd + t = (t - self.t0).jd else: t = t.tcb.mjd diff --git a/twobody/orbit.py b/twobody/orbit.py index 6c712d8..522ea09 100644 --- a/twobody/orbit.py +++ b/twobody/orbit.py @@ -102,7 +102,8 @@ def __init__(self, elements=None, elements_type='kepler', raise TypeError("'elements' must be an instance of an " "OrbitalElements subclass.") - if barycenter is not None and not isinstance(barycenter, Barycenter): + if barycenter is not None and not isinstance(barycenter, (RVTrend, + Barycenter)): raise TypeError("barycenter must be a twobody.Barycenter instance") self.elements = elements diff --git a/twobody/tests/test_bary_trends.py b/twobody/tests/test_bary_trends.py index 41da339..f106d0c 100644 --- a/twobody/tests/test_bary_trends.py +++ b/twobody/tests/test_bary_trends.py @@ -35,4 +35,7 @@ def test_polynomial(): trend = PolynomialRVTrend([10.*u.km/u.s, 1.*u.km/u.s/u.day], t0=Time(55555., format='mjd')) res = trend(t.tcb.mjd) - print(res) + + t = (Time(55555., format='mjd') + + np.sort(np.random.uniform(15., 23., 128))*u.day) + res = trend(t) diff --git a/twobody/tests/test_orbit.py b/twobody/tests/test_orbit.py index 782025c..5685623 100644 --- a/twobody/tests/test_orbit.py +++ b/twobody/tests/test_orbit.py @@ -15,9 +15,9 @@ # Package from ..orbit import KeplerOrbit from ..barycenter import Barycenter +from ..bary_trends import PolynomialRVTrend - -def get_random_orbit(rnd=None, barycenter=None): +def get_random_orbit_pars(rnd=None, barycenter=None): if rnd is None: rnd = np.random.RandomState() @@ -27,15 +27,11 @@ def get_random_orbit(rnd=None, barycenter=None): i = rnd.uniform(0, 180)*u.deg n = 2*np.pi/P a = K * np.sqrt(1-e**2) / (n * np.sin(i)) - return KeplerOrbit(P=P, a=a, e=e, i=i, - omega=rnd.uniform(0, 360)*u.deg, - Omega=rnd.uniform(0, 360)*u.deg, - M0=rnd.uniform(0, 360)*u.deg, - barycenter=barycenter) - - -def test_init(): - pass + return dict(P=P, a=a, e=e, i=i, + omega=rnd.uniform(0, 360)*u.deg, + Omega=rnd.uniform(0, 360)*u.deg, + M0=rnd.uniform(0, 360)*u.deg, + barycenter=barycenter) def test_radial_velocity(): @@ -65,7 +61,7 @@ def test_radial_velocity(): # Even when barycenter is moving, rv1 and rv2 should be the same for n in range(n_iter): - orb = get_random_orbit(rnd, bc2) + orb = KeplerOrbit(**get_random_orbit_pars(rnd, bc2)) rv1 = orb.radial_velocity(times) _rp = orb.reference_plane(times) @@ -75,7 +71,7 @@ def test_radial_velocity(): # All RV's should be equivalent for n in range(n_iter): - orb = get_random_orbit(rnd, bc1) + orb = KeplerOrbit(**get_random_orbit_pars(rnd, bc1)) rv1 = orb.radial_velocity(times) _rp = orb.reference_plane(times) @@ -87,6 +83,39 @@ def test_radial_velocity(): assert quantity_allclose(rv1, rv3, rtol=0, atol=1*u.km/u.s) +def test_velocity_trend(): + rnd = np.random.RandomState(seed=42) + + coeffs = [0*u.km/u.s, + 0*u.km/u.s/u.day, + 1e-2*u.km/u.s/u.day**2] + trend = PolynomialRVTrend(coeffs, + t0=Time('J2018.0')) + pars = get_random_orbit_pars(rnd, trend) + pars['a'] = 0*u.au + orbit = KeplerOrbit(**pars) + + t = Time('J2018.0') + np.linspace(-100, 100, 256)*u.day + rv = orbit.radial_velocity(t) + assert quantity_allclose(rv[0], rv[-1]) + assert quantity_allclose(rv[0], 100.*u.km/u.s) + + # OK THIS SUCKS (copy pasta) + coeffs = [0*u.km/u.s, + 1*u.km/u.s/u.day, + 0*u.km/u.s/u.day**2] + trend = PolynomialRVTrend(coeffs, + t0=Time('J2018.0')) + pars = get_random_orbit_pars(rnd, trend) + pars['a'] = 0*u.au + orbit = KeplerOrbit(**pars) + + t = Time('J2018.0') + np.linspace(-100, 100, 256)*u.day + rv = orbit.radial_velocity(t) + assert quantity_allclose(rv[0], -rv[-1]) + assert quantity_allclose(rv[0], -100.*u.km/u.s) + + @pytest.mark.skipif(not HAS_MPL, reason="matplotlib not installed") def test_plotting(): orbit = KeplerOrbit(P=100*u.day, a=1.*u.au, e=0.1,