Skip to content

Commit

Permalink
WIP: Refactor propagators into classes
Browse files Browse the repository at this point in the history
Fix poliastrogh-921.

Still missing proper sampling support and incompatibilities with cowell.
  • Loading branch information
astrojuanlu committed Jul 3, 2022
1 parent b2d6e11 commit b3be7d1
Show file tree
Hide file tree
Showing 20 changed files with 651 additions and 674 deletions.
29 changes: 3 additions & 26 deletions src/poliastro/core/propagation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""Low level propagation algorithms."""

import numpy as np
from numba import njit as jit

from poliastro.core.propagation.base import func_twobody
from poliastro.core.propagation.cowell import cowell
from poliastro.core.propagation.danby import danby, danby_coe
from poliastro.core.propagation.farnocchia import farnocchia, farnocchia_coe
from poliastro.core.propagation.gooding import gooding, gooding_coe
Expand All @@ -13,6 +12,7 @@
from poliastro.core.propagation.vallado import vallado

__all__ = [
"cowell",
"func_twobody",
"farnocchia_coe",
"farnocchia",
Expand All @@ -30,26 +30,3 @@
"recseries_coe",
"recseries",
]


@jit
def func_twobody(t0, u_, k):
"""Differential equation for the initial value two body problem.
This function follows Cowell's formulation.
Parameters
----------
t0 : float
Time.
u_ : numpy.ndarray
Six component state vector [x, y, z, vx, vy, vz] (km, km/s).
k : float
Standard gravitational parameter.
"""
x, y, z, vx, vy, vz = u_
r3 = (x**2 + y**2 + z**2) ** 1.5

du = np.array([vx, vy, vz, -k * x / r3, -k * y / r3, -k * z / r3])
return du
23 changes: 23 additions & 0 deletions src/poliastro/core/propagation/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import numpy as np
from numba import njit as jit


@jit
def func_twobody(t0, u_, k):
"""Differential equation for the initial value two body problem.
Parameters
----------
t0 : float
Time.
u_ : numpy.ndarray
Six component state vector [x, y, z, vx, vy, vz] (km, km/s).
k : float
Standard gravitational parameter.
"""
x, y, z, vx, vy, vz = u_
r3 = (x**2 + y**2 + z**2) ** 1.5

du = np.array([vx, vy, vz, -k * x / r3, -k * y / r3, -k * z / r3])
return du
48 changes: 48 additions & 0 deletions src/poliastro/core/propagation/cowell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import numpy as np

from poliastro._math.ivp import DOP853, solve_ivp
from poliastro.core.propagation.base import func_twobody


def cowell(k, r, v, tofs, rtol=1e-11, *, events=None, f=func_twobody):
x, y, z = r
vx, vy, vz = v

u0 = np.array([x, y, z, vx, vy, vz])

result = solve_ivp(
f,
(0, max(tofs)),
u0,
args=(k,),
rtol=rtol,
atol=1e-12,
method=DOP853,
dense_output=True,
events=events,
)
if not result.success:
raise RuntimeError("Integration failed")

if events is not None:
# Collect only the terminal events
terminal_events = [event for event in events if event.terminal]

# If there are no terminal events, then the last time of integration is the
# greatest one from the original array of propagation times
if not terminal_events:
last_t = max(tofs)
else:
# Filter the event which triggered first
last_t = min(event.last_t for event in terminal_events)
tofs = [tof for tof in tofs if tof < last_t] + [last_t]

rrs = []
vvs = []
for i in range(len(tofs)):
t = tofs[i]
y = result.sol(t)
rrs.append(y[:3])
vvs.append(y[3:])

return rrs, vvs
24 changes: 8 additions & 16 deletions src/poliastro/twobody/orbit/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
t_p,
)
from poliastro.twobody.orbit.creation import OrbitCreationMixin
from poliastro.twobody.propagation import farnocchia, propagate
from poliastro.twobody.propagation import FarnocchiaPropagator, propagate
from poliastro.twobody.sampling import sample_closed, sample_open
from poliastro.twobody.states import BaseState
from poliastro.util import norm, wrap_angle
Expand Down Expand Up @@ -410,7 +410,7 @@ def __str__(self):
def __repr__(self):
return self.__str__()

def propagate(self, value, method=farnocchia, rtol=1e-10, **kwargs):
def propagate(self, value, method=FarnocchiaPropagator()):
"""Propagates an orbit a specified time.
If value is true anomaly, propagate orbit to this anomaly and return the result.
Expand All @@ -420,12 +420,8 @@ def propagate(self, value, method=farnocchia, rtol=1e-10, **kwargs):
----------
value : ~astropy.units.Quantity, ~astropy.time.Time, ~astropy.time.TimeDelta
Scalar time to propagate.
rtol : float, optional
Relative tolerance for the propagation algorithm, default to 1e-10.
method : function, optional
Method used for propagation
**kwargs
parameters used in perturbation models
Method used for propagation, default to farnocchia.
Returns
-------
Expand All @@ -441,18 +437,14 @@ def propagate(self, value, method=farnocchia, rtol=1e-10, **kwargs):
# Works for both Quantity and TimeDelta objects
time_of_flight = time.TimeDelta(value)

cartesian = propagate(
self, time_of_flight, method=method, rtol=rtol, **kwargs
new_state = propagate(
self._state,
time_of_flight,
method=method,
)
new_epoch = self.epoch + time_of_flight

return self.from_vectors(
self.attractor,
cartesian[0].xyz,
cartesian[0].differentials["s"].d_xyz,
new_epoch,
plane=self.plane,
)
return self.__class__(new_state, new_epoch)

@u.quantity_input(value=u.rad)
def time_to_anomaly(self, value):
Expand Down

0 comments on commit b3be7d1

Please sign in to comment.