In [10]:
# import sys

# import matplotlib.pyplot as plt
import numpy as np
from scipy.integrate import solve_ivp


class Planet:
    def __init__(self, e, a, omega, Omega, i, r_p, first_periastron=0):
        self.e = e  # eccentricity
        self.a = a  # semi major axis
        self.omega = omega  # argument of periapsis
        self.Omega = Omega  # longitude of ascending node
        self.i = i  # inclination
        self.r_p = r_p  # planet radius
        self.first_periaston = first_periastron  # time origin, perhaps keep as a datetime object?
        self.split = 1000
        self.alphas = self.alpha_wrt_time(e=self.e, split=self.split, first_periastron=self.first_periastron)

    def der_alpha(self, t, alpha, e):
        return (2 * np.pi / (1 - e * e)**1.5) * (1 + e * np.cos(alpha))**2

    def alpha_wrt_time(self, e=0.0, split=1000, first_periastron=0.0):
        t_span = (0.0, 1.0)
        t = np.linspace(0.0, 1.0, split + 1)
        y0 = np.array([0])
        sol = solve_ivp(self.der_alpha, t_span, y0, t_eval=t, args=(e,))
        alpha_array = sol.y[0]
        # return lambda time : alpha_array[int((time%1) * split)]

        def alphas(time):
            nonlocal split, alpha_array
            time = (time - first_periastron) % 1.0
            n = time * split
            if int(n) < split:
                return alpha_array[int(n)]

        return alphas

    def getOrbitalElements(self):  # returns dictionary of Kepler Orbital elements (everything above except radius and first_periastron)
        return {'Eccentricity': self.e,
                'Semi Major Axis': self.a,
                'Argument of Periapsis': self.omega,
                'Longitude of Ascending Node': self.Omega,
                'Inclination': self.i}

    def getNu_from_time(self, time):  # returns true anomaly when time is inputted
        # use solutions of differential equation
        self.alphas(time)
        pass

    def getPosition_from_nu(self, nu):  # returns position when true anomaly is inputted
        n_x = -np.cos(self.i) * np.cos(self.Omega) * np.sin(self.omega + self.nu) - np.sin(self.Omega) * np.cos(self.omega + self.nu)
        n_y = np.cos(self.Omega) * np.cos(self.omega + self.nu) - np.cos(self.i) * np.sin(self.Omega) * np.sin(self.omega + self.nu)
        n_z = np.sin(self.i) * np.sin(self.omega + self.nu)
        unit_vector = np.array([n_x, n_y, n_z])

        r = self.a * (1 - self.e**2) / (1 + self.e * np.cos(self.nu))

        position = np.array([r * unit_vector[0], r * unit_vector[1], r * unit_vector[2]])
        return position

    def getPosition(self, time):  # returns position when time is inputted (just combining the above two functions)
        return self.getPosition_from_nu(self.getNu_from_time(time))


In [11]:
class System:
  def __init__(obj, star: dict, planet_list, sort=True):
    obj.star = star
    if sort:
      obj.planet_list = sorted(planet_list, key = lambda p: p.a) # if planet object not dictionary
    else:
      obj.planet_list = planet_list

  def getPosition(index, time):
    return obj.planet_list[index].getPosition(time)

  def plot(model='Quadratic',normalise=False):
    """Plot the transit curve"""

  def visualise(): #not sure what arguments need to be here
    """Plot actual images from Earth's vantage point"""