In [None]:
from astropy import units as u
from astropy.io import ascii
import numpy as np
from scipy import integrate
from matplotlib import pyplot as plt
import csv

In [None]:
# equation for semimajor axis change under tidal decay and stellar mass loss. See Mustill & Villaver (2012)

def eom(t,x,star,Mpl):
    a = x
    Ms = np.interp(t,star.t,star.Ms)
    Rs = np.interp(t,star.t,star.Rs)
    Menv = np.interp(t,star.t,star.Menv)
    Ls = np.interp(t,star.t,star.Ls)
    mdot = np.interp(t,star.t,star.mdot)
    #mean motion
    n = 2*np.pi*np.sqrt(Ms/(a**3))
    #convective timescale
    tconv = (Menv*Rs*Rs/(star.etaf*Ls))**(1./3.)
        
    freq = (np.pi/(n*star.cf*tconv))**star.gammaf
    if freq > 1:
        f2s = star.fprime
    else:
        f2s = star.fprime*freq
    
    merat = Menv/Ms
    mrat = Mpl/Ms
    
    adot_tide = -merat*(1+mrat)*mrat*(Rs/a)**7*Rs*2*f2s/(9*tconv)
    adot_ml = -a*mdot/(Ms+Mpl)
    
    return adot_tide + adot_ml

class Star:
    
    def __init__(self,t,Ms,Rs,Menv,Ls,mdot,etaf=3,gammaf=2,cf=1,fprime=4.5):
        
        self.t = t
        self.Ms = Ms
        self.Rs = Rs
        self.Menv = Menv
        self.Ls = Ls
        self.mdot = mdot
        self.etaf = etaf
        self.cf = cf
        self.gammaf = gammaf
        self.fprime = fprime

def read_star(file):
    
    Lsol = (1*u.Lsun).decompose().to(u.au**2 * u.Msun / u.yr**3).value #Solar luminosity to code units
    Rsol = (1*u.Rsun).to(u.au).value #Solar radius to code units
    
    data = ascii.read(file,format='csv',names=('Time','Teff','logL','Ms','Rs','Me0','Mee','Md'),delimiter=' ')
    
    Ls = 10**data['logL'] * Lsol

    data.add_column(Ls,name='Ls')
    data['Rs'] = data['Rs']*Rsol
    
    return data

def inside(t,x,star,mpl):
    return x - np.interp(t,star.t,star.Rs)

inside.terminal = True
inside.direction = -1

In [None]:
# read the 1.5MSol stellar AGB model from Vassiliadis & Wood (1993)

data = read_star('agb1p5.dat')

star = Star(data['Time'],data['Ms'],data['Rs'],data['Me0'],data['Ls'],-data['Md'])

In [None]:
# do a test integartion and plot the results

a0 = np.array([2.5])
mpl = 3e-6

times = np.linspace(0,1e6,1001)

In [None]:
rtol = 1e-12
atol = 1e-12

sol = integrate.solve_ivp(eom,(times[0],times[-1]),a0,method='DOP853',t_eval=times,args=[star,mpl],
                          rtol=rtol,atol=atol,events=inside)

In [None]:
plt.figure()

plt.plot(sol.t,sol.y[0])
plt.plot(star.t,star.Rs)
plt.xlim([times[0],times[-1]])
plt.ylim([0,10])
plt.show()

In [None]:
sol

In [None]:
# set up a coarse grid in planet mass and semimajor axis

na = 301
nm = 31

agrid = np.linspace(1,10,na)
mgrid = np.logspace(np.log10((1*u.Mearth/u.Msun).decompose()).value,np.log10((13*u.Mjup/u.Msun).decompose()).value,
                    nm)

In [None]:
file = 'results_1p5_coarse.csv'
tiny = 1e-12
try:
    results = ascii.read(file,format='csv')
    mdone = results['M [M_sol]'][-1]
    start = np.min(np.where(np.abs(mdone-mgrid) <= tiny)[0])
except FileNotFoundError:
    start = 0
    with open(file,'w',newline = '') as csvfile:
        writer = csv.writer(csvfile,delimiter=',')
        writer.writerow(['M [M_sol]','a_i [au]','a_f [au]','status'])

In [None]:
run_main = False

In [None]:
if run_main:
    af = np.zeros((nm,na))
    status = np.zeros((nm,na))

    for i in range(nm):
        print('Mass {:8e} ({:4d} of {:4d})'.format(mgrid[i],i+1,nm))
        if i < start:
            print('Already run...')
            continue
        for j in range(na):
            sol = integrate.solve_ivp(eom,(times[0],times[-1]),[agrid[j]],method='DOP853',t_eval=times,
                                      args=[star,mgrid[i]],rtol=rtol,atol=atol,events=inside)
            af[i,j] = sol.y[0][-1]
            status[i,j] = sol.status
            print('SMA {:8e} ({:4d} of {:4d});   status: {:3f}'.format(agrid[j],j+1,na,status[i,j]))
            
        #save every mass step
        for j in range(na):
            with open(file,'a',newline = '') as csvfile:
                writer = csv.writer(csvfile,delimiter=',')
                writer.writerow([mgrid[i],agrid[j],af[i,j],status[i,j]])

In [None]:
if run_main:
    plt.figure()
    plt.plot(agrid,np.transpose(af*(1-status)),'.')
    plt.xlabel('initial a [au]')
    plt.ylabel('final a [au]]')
    plt.savefig('af.pdf')
    #plt.show()

In [None]:
if run_main:
    plt.figure()
    plt.contourf(agrid,mgrid,af*(1-status))
    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel('initial a [au]')
    plt.ylabel('planet mass [MSol]')
    plt.colorbar(label = 'final a [au]')
    #plt.show()
    plt.savefig('af_contour.pdf')

In [None]:
#more detail around boundary

results = ascii.read(file,format='csv')

In [None]:
m = 0
mass = []
boundary = []
status = 1

for i in range(len(results)):
    if np.abs(results['M [M_sol]'][i] - m) >= tiny: #mass changed
        m = results['M [M_sol]'][i]
    else:
        if np.abs(results['status'][i] - status) >= tiny: #reached boundary
            mass.append(m)
            boundary.append([results['a_i [au]'][i-1],results['a_i [au]'][i]])
    status = results['status'][i]
    
mass = np.array(mass)
boundary = np.array(boundary)

In [None]:
plt.plot(mass,boundary,'.')
plt.xscale('log')
plt.yscale('log')
plt.show()

In [None]:
nm = len(mass)
na = 91
da = 0.03
af = np.zeros((nm,na))
status = np.zeros((nm,na))
times = np.linspace(0,1e6,1001)

In [None]:
file = 'results_1p5_fine.csv'
tiny = 1e-12
try:
    results = ascii.read(file,format='csv')
    mdone = results['M [M_sol]'][-1]
    start = np.min(np.where(np.abs(mdone-mgrid) <= tiny)[0])
except FileNotFoundError:
    start = 0
    with open(file,'w',newline = '') as csvfile:
        writer = csv.writer(csvfile,delimiter=',')
        writer.writerow(['M [M_sol]','a_i [au]','a_f [au]','status'])
except IndexError:
    start = 0

In [None]:
run_fine = False

In [None]:
if run_fine:
    for i in range(nm):
        print('Mass {:8e} ({:4d} of {:4d})'.format(mass[i],i+1,nm))
        if i < start:
            print('Already run...')
            continue
        afine = np.linspace(boundary[i,0],boundary[i,1]+2*da,na)
        for j in range(na):
            if j == 0:
                continue
            sol = integrate.solve_ivp(eom,(times[0],times[-1]),[afine[j]],method='DOP853',t_eval=times,
                                      args=[star,mass[i]],rtol=rtol,atol=atol,events=inside)
            af[i,j] = sol.y[0][-1]
            status[i,j] = sol.status
            print('SMA {:8e} ({:4d} of {:4d});   status: {:3f}'.format(afine[j],j+1,na,status[i,j]))

        #save every mass step
        with open(file,'a',newline = '') as csvfile:
            for j in range(na):
                if j == 0:
                    continue
                writer = csv.writer(csvfile,delimiter=',')
                writer.writerow([mass[i],afine[j],af[i,j],status[i,j]])

In [None]:
results = ascii.read(file,format='csv')

tiny = 1e-12

m = 0
mass = []
boundary = []
status = 1

for i in range(len(results)):
    if np.abs(results['M [M_sol]'][i] - m) >= tiny: #mass changed
        m = results['M [M_sol]'][i]
    else:
        if np.abs(results['status'][i] - status) >= tiny: #reached boundary
            mass.append(m)
            boundary.append([results['a_i [au]'][i-1],results['a_i [au]'][i]])
    status = results['status'][i]
    
mass = np.array(mass)
boundary = np.array(boundary)

In [None]:
plt.figure()

plt.loglog(mass,boundary,'.')

plt.show()

In [None]:
plt.figure()
plt.plot(results['a_i [au]'],np.transpose(results['a_f [au]']*(1-results['status'])),'.')
plt.xlabel('initial a [au]')
plt.ylabel('final a [au]]')
plt.show()