In [None]:
import os

from astropy.time import Time
import astropy.coordinates as coord
import astropy.units as u

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from tqdm.notebook import tqdm

from twobody import TwoBodyKeplerElements, KeplerOrbit
from twobody import (eccentric_anomaly_from_mean_anomaly, 
                     true_anomaly_from_eccentric_anomaly)

In [None]:
def true_anomaly(orbit, time):
    # mean anomaly
    with u.set_enabled_equivalencies(u.dimensionless_angles()):
        M = 2*np.pi * (time.tcb - orbit.t0.tcb) / orbit.P - orbit.M0
        M = M.to(u.radian)

    # eccentric anomaly
    E = eccentric_anomaly_from_mean_anomaly(M, orbit.e)

    # true anomaly
    return true_anomaly_from_eccentric_anomaly(E, orbit.e)

In [None]:
def hb_model(S, i, omega, f, R, a):
    num = 1 - 3*np.sin(i)**2 * np.sin(f - omega)**2
    den = (R / a) ** 3
    return S * num / den

In [None]:
P = 20 * u.day
e = 0.5
S = 1.

epoch = Time(Time.now().mjd, format='mjd')
t = epoch + np.linspace(0, P.value, 8192)

In [None]:
for e in [0.3, 0.5, 0.7]:
    fig, axes = plt.subplots(6, 6, figsize=(16, 16), 
                             sharex=True, sharey=True,
                             constrained_layout=True)

    n = 0
    omegas = np.linspace(-90, 90, axes.shape[0]) * u.deg
    incls = np.linspace(6, 90, axes.shape[0]) * u.deg
    for omega in omegas:
        for incl in incls:
            ax = axes.flat[n]

            elem = TwoBodyKeplerElements(P=P, e=e, 
                                         m1=1.*u.Msun, m2=0.25*u.Msun,
                                         omega=omega, i=incl,
                                         t0=epoch)
            orbit1 = KeplerOrbit(elem.primary)
            orbit2 = KeplerOrbit(elem.secondary)

            x1 = orbit1.reference_plane(t)
            x2 = orbit2.reference_plane(t)

            R = (x1.data.without_differentials() - x2.data.without_differentials()).norm()
            a = elem.a
            f = true_anomaly(orbit1, t)

            phase = ((t.mjd - t.mjd.min()) / P.to_value(u.day) + 0.5) % 1 - 0.5
            y = hb_model(S, elem.i, elem.omega, f, R, a)
            y = y[phase.argsort()]
            phase = phase[phase.argsort()]
            ax.plot(phase, y, marker='', ls='-', lw=2, color='k')
            ax.plot(phase - 1, y, marker='', ls='-', lw=2, color='k')
            ax.plot(phase + 1, y, marker='', ls='-', lw=2, color='k')
            ax.axhline(0, marker='', zorder=-100, 
                       color='tab:blue', alpha=0.2)

            # plt.setp(ax.get_xticklabels(), fontsize=8)
            # plt.setp(ax.get_yticklabels(), fontsize=8)
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)

            n += 1
            
    ax.set_xlim(-0.75, 0.75)
    
    n = 0
    for omega in omegas:
        for incl in incls:
            ax = axes.flat[n]
            
            xlim = ax.get_xlim()
            xspan = xlim[1] - xlim[0]
            ylim = ax.get_ylim()
            yspan = ylim[1] - ylim[0]
            ax.text(xlim[0] + 0.05 * xspan, 
                    ylim[0] + 0.05 * yspan, 
                    (rf'$\omega = {omega.value:.1f}^\circ$' + 
                     f'\n$i = {incl.value:.0f}^\circ$'),
                    ha='left', va='bottom', fontsize=12)
            
            n += 1

    fig.suptitle(f'$e={e:.1f}$', fontsize=16)
    fig.set_facecolor('w')