In [None]:
import jax.numpy as jnp
import matplotlib.pyplot as plt
import matplotlib

from ipywidgets import interactive

In [27]:
class q_crit():

    def __init__(
        self,
        q_min = 0.01,
        q_max = 8,
        steps = 10000
    ):
        self.qs = jnp.linspace(q_min, q_max, steps)
        self.betas = jnp.linspace(0,1,steps).reshape(-1,1)
        self.zeta_RLs = self.zeta_RL()
        
    def zeta_RL(self):
        # part 1
        part1 =  -2. * (1. - self.betas*self.qs - (1. - self.betas) * (self.qs + 1./2.) * (self.qs / (self.qs + 1.) )  )

        #part 2
        A = (self.qs**(1./3.)) / 3.
        B = 2./(self.qs**(1./3.))
        C = (  1.2 * self.qs**(1./3.) +  1./(1. + self.qs**(1./3)) ) / (0.6 * self.qs**(2./3.) + jnp.log(1 + self.qs**(1./3.) ) )
        part2 =  A * (B - C)

        #part 3
        part3 =  1. + self.betas*self.qs
        return part1 + part2 * part3

    def __call__(
            self,
            zeta_eff, 
            beta):
        beta_idx = jnp.abs(self.betas.reshape(-1) - beta).argmin()
        idx = jnp.abs(self.zeta_RLs[beta_idx] - zeta_eff).argmin()
        return self.qs[idx]

In [4]:
def _q_bh(q_zams, beta_a = 0.5, fsn_a = 0.2, fsn_b = 0.2, fcore = 0.34):
    return (1 - fsn_a) / (1 - fsn_b) * (q_zams + beta_a * (1 - fcore))
    

In [5]:
def q_bh(q_zams, q_zams_min, q_zams_max):

    q_min = _q_bh(q_zams_min)
    q_max = _q_bh(q_zams_max)
    q = _q_bh(q_zams)
    
    return jnp.where(jnp.less(q, q_min), 0, jnp.where(jnp.greater(q, q_max), 0, q))


In [6]:
q_zams_min = 1 / q_crit_1
q_zams_max = fcore * (1 - fsn_a) * q_crit_2 - beta_a * (1 - fcore)
    

NameError: name 'q_crit_1' is not defined

In [7]:
class q_bh(object):

    def __init__(self,
                 fsn_a = 0.2,
                 fsn_b = 0.2,
                 fcore = 0.34
                 ):
        self.fsn_a = fsn_a
        self.fsn_b = fsn_b
        self.fcore = fcore
        self.q_crit = q_crit()
        
    def _q_bh(self, q_zams, beta_a):
        return ((1 - self.fsn_b) / (1 - self.fsn_a)) * (q_zams + beta_a * (1 - self.fcore))

    def __call__(self, q_zams, zeta_eff, beta_a):

        q_zams_min = 1 / self.q_crit(zeta_eff, beta_a)
        q_zams_max = self.q_crit(zeta_eff, 0.) * self.fcore * (1 - self.fsn_a) - beta_a * (1 - self.fcore)

        q_min = _q_bh(q_zams_min)
        q_max_1 = _q_bh(q_zams_max)
        q_max_2 = _q_bh(1.)
        q_max = jnp.min(jnp.array([q_max_1,q_max_2]))
        q = _q_bh(q_zams)
        print(beta_a * (1 - self.fcore))
    
        sel = (q > q_min) * (q < q_max)
        t = q * (q > q_min) * (q < q_max)
        return t[sel], q_zams[sel]


In [48]:
def p_qobs_from_zams(q, beta, fcore, fsn_a, fsn_b, q_zams_min, q_zams_max):

    def q_zams_uniform_pdf(qzams, q_zams_min, q_zams_max):
        qmax = jnp.where(jnp.greater(q_zams_max, 1), 1, q_zams_max)
        pdf = 1 / (qmax - q_zams_min)
        return jnp.where(jnp.less(qzams, q_zams_min) | jnp.greater(qzams, q_zams_max) | jnp.greater(qzams, 1), 0, pdf)

    def q_to_qzams(q):
        return q * (1-fsn_a)/(1-fsn_b) - beta * (1-fcore)

    qzams_A = q_to_qzams(q)
    qzams_B = q_to_qzams(1/q)

    qA_pdf = q_zams_uniform_pdf(qzams_A, q_zams_min, q_zams_max)
    qB_pdf = q_zams_uniform_pdf(qzams_B, q_zams_min, q_zams_max)

    dqzams_dq_A = (1-fsn_a)/(1-fsn_b)
    dqzams_dq_B = (1/q**2)*(1-fsn_a)/(1-fsn_b)

    pdf = qA_pdf * dqzams_dq_A + qB_pdf * dqzams_dq_B

    qq = jnp.linspace(q_zams_min, q_zams_max,1000)
    Dqzams_dq_B = (1/qq**2)*(1-fsn_a)/(1-fsn_b)
    norm_pdf = q_zams_uniform_pdf(qq, q_zams_min, q_zams_max) * dqzams_dq_A + q_zams_uniform_pdf(1/qq, q_zams_min, q_zams_max) * Dqzams_dq_B
    norm = jnp.trapezoid(norm_pdf, qq)

    return pdf / norm


def p_qbh_from_zams(q, beta, fcore, fsn_a, fsn_b, q_zams_min, q_zams_max):

    def q_zams_uniform_pdf(qzams, q_zams_min, q_zams_max):
        qmax = jnp.where(jnp.greater(q_zams_max, 1), 1, q_zams_max)
        pdf = 1 / (qmax - q_zams_min)
        return jnp.where(jnp.less(qzams, q_zams_min) | jnp.greater(qzams, q_zams_max) | jnp.greater(qzams, 1), 0, pdf)

    def q_to_qzams(q):
        return q * (1-fsn_a)/(1-fsn_b) - beta * (1-fcore)

    qzams_A = q_to_qzams(q)
    # qzams_B = q_to_qzams(1/q)

    qA_pdf = q_zams_uniform_pdf(qzams_A, q_zams_min, q_zams_max)
    # qB_pdf = q_zams_uniform_pdf(qzams_B, q_zams_min, q_zams_max)

    dqzams_dq_A = (1-fsn_a)/(1-fsn_b)
    # dqzams_dq_B = (1/q**2)*(1-fsn_a)/(1-fsn_b)

    pdf = qA_pdf * dqzams_dq_A # + qB_pdf * dqzams_dq_B

    qq = jnp.linspace(q_zams_min, q_zams_max,1000)
    # Dqzams_dq_B = (1/qq**2)*(1-fsn_a)/(1-fsn_b)
    norm_pdf = q_zams_uniform_pdf(qq, q_zams_min, q_zams_max) * dqzams_dq_A # + q_zams_uniform_pdf(1/qq, q_zams_min, q_zams_max) * Dqzams_dq_B
    norm = jnp.trapezoid(norm_pdf, qq)

    return pdf / norm

In [49]:
from ipywidgets import widgets

In [91]:
qobs_s = jnp.linspace(0.01, 1, 1000)
qcrit = q_crit()

def plot_p_qobs(fsn_a=0.2, fsn_b=0.2, fcore=0.34, β=0.5, ζ=7):
    qmin = 1/qcrit(ζ, β)
    qmax = fcore * (1 - fsn_a) * qcrit(ζ, 0) - β * (1 - fcore)
    ps = p_qobs_from_zams(qobs_s, β, fcore, fsn_a, fsn_b, qmin, qmax)
    plt.plot(qobs_s, ps)
    norm = jnp.trapezoid(ps, qobs_s)
    plt.ylabel(r'$p(q_\text{obs})$')
    plt.xlabel(r'$q_\text{obs}$')
    _, yhigh = plt.ylim()
    plt.text(0, yhigh*.95, f"{norm:.2f}")

qbh_s = jnp.linspace(0.01, 2, 1000)


def plot_p_qbh(fsn_a=0.2, fsn_b=0.2, fcore=0.34, β=0.5, ζ=7):
    qmin = 1/qcrit(ζ, β)
    qmax = fcore * (1 - fsn_a) * qcrit(ζ, 0) - β * (1 - fcore)
    ps = p_qbh_from_zams(qbh_s, β, fcore, fsn_a, fsn_b, qmin, qmax)
    plt.plot(qbh_s, ps)
    norm = jnp.trapezoid(ps, qbh_s)
    plt.ylabel(r'$p(q_\text{BH})$')
    plt.xlabel(r'$q_\text{BH}$')
    plt.axvline(1, ls='--', color='k')
    _, yhigh = plt.ylim()
    plt.text(0, yhigh*.95, f"{norm:.2f}")



def plot_both(fsn_a=0.2, fsn_b=0.2, fcore=0.34, β=0.5, ζ=7):
    plt.figure(figsize=(12, 4))
    qmin = 1/qcrit(ζ, β)
    qmax = fcore * (1 - fsn_a) * qcrit(ζ, 0) - β * (1 - fcore)
    p_qobs = p_qobs_from_zams(qobs_s, β, fcore, fsn_a, fsn_b, qmin, qmax)
    p_qbhs = p_qbh_from_zams(qbh_s, β, fcore, fsn_a, fsn_b, qmin, qmax)
    plt.plot(qobs_s, p_qobs)
    plt.plot(qbh_s, p_qbhs)
    plt.axvline(1, ls='--', color='k')
    qbh_norm = jnp.trapezoid(p_qbhs, qbh_s)
    qobs_norm = jnp.trapezoid(p_qobs, qobs_s)
    plt.ylabel(r'$p(q)$')
    plt.xlabel(r'$q$')
    _, yhigh = plt.ylim()
    plt.text(0, yhigh*.95, f"{qbh_norm:.4f}")
    plt.text(0, yhigh*.9, f"{qobs_norm:.4f}")


In [81]:
i_plot = interactive(
    plot_p_qobs,
    fsn_a=(0, 1, 0.01),
    fsn_b=(0, 1, 0.01),
    fcore=(0, 1, 0.01),
    β=(0, 1, .01),
    ζ=(1, 12, .1)
)
output = i_plot.children[-1]
output.layout.height = '380px'
i_plot

interactive(children=(FloatSlider(value=0.2, description='fsn_a', max=1.0, step=0.01), FloatSlider(value=0.2, …

In [82]:
i_plot = interactive(
    plot_p_qbh,
    fsn_a=(0, 1, 0.01),
    fsn_b=(0, 1, 0.01),
    fcore=(0, 1, 0.01),
    β=(0, 1, .01),
    ζ=(1, 12, .1)
)
output = i_plot.children[-1]
output.layout.height = '350px'
i_plot

interactive(children=(FloatSlider(value=0.2, description='fsn_a', max=1.0, step=0.01), FloatSlider(value=0.2, …

In [None]:
i_plot = interactive(
    plot_both,
    fsn_a=(0, 1, 0.01),
    fsn_b=(0, 1, 0.01),
    fcore=(0, 1, 0.01),
    β=(0, 1, .01),
    ζ=(1, 12, .1)
)
output = i_plot.children[-1]
output.layout.height = '350px'
i_plot

interactive(children=(FloatSlider(value=0.2, description='fsn_a', max=1.0, step=0.01), FloatSlider(value=0.2, …