In [None]:
import glob
from os import path
import re

from astropy.table import QTable
from astropy.constants import G
import astropy.units as u
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
from scipy.integrate import simps
from tqdm import tqdm

from twoface.mass import period_at_surface, stellar_radius

import mesa_reader as mr

from helpers import (MESAHelper, MatchedSimulatedSample, 
                     compute_dlne, solve_final_ea)

### Load the MESA models

In [None]:
mesa = MESAHelper('../mesa/')

### Load the data

In [None]:
unimodal = QTable.read('../../twoface/paper/1-catalog/tables/highK-unimodal.fits', 
                       character_as_bytes=False)
clean_uni = unimodal[ (unimodal['clean_flag'] == 0)]
high_logg = clean_uni[clean_uni['LOGG'] > 2]

In [None]:
(np.nanmedian(high_logg['M1']), 
 np.nanmedian(high_logg['M2_min']), 
 np.median(high_logg['LOGG'][high_logg['LOGG'] > -999]))

### Plot the stellar evolution curves

In [None]:
cmap = plt.get_cmap('rainbow_r')
norm = mpl.colors.Normalize(vmin=0.8, vmax=3)

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
for k in sorted(mesa.h.keys()):
    h = mesa.h[k]
    min_idx = h.log_R.argmin()
    slc = slice(min_idx, None)
    ax.plot(h.Teff[slc], h.log_g[slc], 
            marker='', label=r'${0}\,{{\rm M}}_\odot$'.format(k),
            linewidth=2, alpha=0.8, 
            color=cmap(norm(float(k))))
    
ax.set_ylim(5, -0.1)
ax.set_xlim(13000, 3000)
ax.legend(loc='upper left', fontsize=12, borderaxespad=1.)

ax.set_xlabel(r'$T_{\rm eff}$ [K]')
ax.set_ylabel(r'$\log g$')

fig.tight_layout()

fig.savefig('../paper/figures/mesa.pdf')

## Generate a fake population of binaries

In [None]:
s = MatchedSimulatedSample(logg=high_logg['LOGG'],
                           M1=high_logg['M1'],
                           M2=high_logg['M2_min'],
                           mesa_helper=mesa,
                           seed=42)

In [None]:
t = s.generate_sample(size=1024)

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(15, 5))

ax = axes[0]
ax.scatter(t['P'], t['e'], marker='.')
ax.scatter(high_logg['P'], high_logg['e'], marker='.')

ax.set_xscale('log')
ax.set_ylim(0, 1)
ax.set_xlabel('$P$')
ax.set_ylabel('$e$')

_, bins, _ = axes[1].hist(t['M1_orig'], bins='auto', alpha=0.5, normed=True);
axes[1].hist(high_logg['M1'][np.isfinite(high_logg['M1'])], bins='auto', alpha=0.5, normed=True);
axes[1].set_xlabel('$M_1$')

bins = np.linspace(0.01, 1.5, 10)
axes[2].hist(t['M2']/t['M1_orig'], bins=bins, normed=True);
axes[2].hist(high_logg['M2_min'][np.isfinite(high_logg['M2_min'])] / high_logg['M1'][np.isfinite(high_logg['M2_min'])], 
             bins='auto', alpha=0.5, normed=True);
axes[2].set_xlabel('$q$')

_, bins, _ = axes[3].hist(t['logg'], bins='auto', alpha=0.5, normed=True);
axes[3].hist(high_logg['LOGG'], bins='auto', alpha=0.5, normed=True);
axes[3].set_xlabel(r'$\log g$')

fig.tight_layout()

In [None]:
_, bins, _ = plt.hist(t['M1_orig'].value, bins='auto', normed=True) #, h_Mstr[np.abs(h_M - t['M1'].value[i]).argmin()]
plt.hist(high_logg['M1'][np.isfinite(high_logg['M1'])], bins='auto', normed=True);
plt.hist(t['M1'], bins=bins, normed=True);

## Now simulate circularization:

In [None]:
dlnes = []
for row in tqdm(t):
    dlne = compute_dlne(row['logg'], M1=row['M1'], M2=row['M2'], a=row['a'], 
                        mesa_helper=mesa)
    dlnes.append(dlne)
    
t['dlne'] = dlnes
t['e_f'] = np.exp(np.log(t['e']) + t['dlne'])

# mask = (t['a'] * (1 - t['e_f'])) < t['R1']
# t['e_f'][mask] = np.nan

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 5))
ax.scatter(-t['dlne'], t['e_f'])
ax.set_xscale('log')
ax.set_xlim(1E6, 1E-12)

In [None]:
def plot_simulated_P_e(t, e_col='e_f', P_col='P'):
    fig, axes = plt.subplots(1, 3, figsize=(14, 5.5), sharey=True)
    
    cmap = plt.get_cmap('inferno')
    style = dict(marker='o', edgecolor='#555555', linewidth=0.5,
                 alpha=0.5, vmin=2, vmax=4, cmap=cmap,
                 s=30, c=t['logg'], rasterized=True)
    
    P_surf = period_at_surface(t['M1'], t['logg'], t[e_col], t['M2'])
    
    # Actually plot the markers
    cs = axes[0].scatter(t['P'], t['e'], **style)
    axes[1].scatter(t[P_col], t[e_col], **style)
    axes[2].scatter(t[P_col]/P_surf, t[e_col], **style)
    
    # Label all the things
    axes[0].set_ylabel(r'$e$')
    axes[0].set_xlabel(r'$P$ [day]')
    axes[1].set_xlabel(r'$P$ [day]')
    axes[2].set_xlabel(r'$P/P_{\rm surface}$')
    
    axes[0].set_title('initial')
    axes[1].set_title('final')
    axes[2].set_title('final')
    
    # Scales, lims, ticks:
    for ax in axes:
        ax.set_xscale('log')
        
    axes[0].set_ylim(-0.05, 1)
    loc = mpl.ticker.LogLocator(numticks=10)
    for ax in axes[:2]:
        ax.xaxis.set_ticks(10**np.arange(-1, 4+0.1))
        ax.xaxis.set_ticks(np.concatenate([x*np.arange(1, 10+1) for x in ax.get_xticks()[:-1]]), minor=True)
        ax.set_xlim(8E-1, 1E4)

    axes[2].xaxis.set_ticks(10**np.arange(-1, 4+0.1))
    axes[2].xaxis.set_ticks(np.concatenate([x*np.arange(1, 10+1) for x in ax.get_xticks()[:-1]]), minor=True)
    axes[2].set_xlim(8E-1, 1.5E3)

    # Colorbar
    cax = fig.add_axes([0.865, 0.165, 0.02, 0.615])
    cb = fig.colorbar(cs, cax=cax)
    cb.ax.xaxis.set_ticks_position('top')
    cb.ax.xaxis.set_label_position('top')
    cb.set_label(r'$\log g$', labelpad=10)
    cb.solids.set_rasterized(True) 
    cb.solids.set_edgecolor('face')
    cb.set_ticks(np.arange(2, 4+0.1, 0.5))
    cb.ax.invert_yaxis()

    fig.tight_layout()
    fig.subplots_adjust(top=0.78, right=0.85, wspace=0.1)
    fig.set_facecolor('w')
    
    fig.suptitle(r'${\bf Simulated\,\,binaries}$', y=0.94, x=0.45, fontsize=26)
    
    return fig

In [None]:
fig = plot_simulated_P_e(t)
fig.savefig('../paper/figures/simulated.pdf', rasterized=True, dpi=250)

---

## Try solving de/dt and da/dt simultaneously

In [None]:
t2 = t.copy()

In [None]:
efs = []
afs = []
for row in tqdm(t2):
    ef, af = solve_final_ea(row['e'], row['a'], row['logg'], 
                            row['M1'], row['M2'], mesa)
    efs.append(ef)
    afs.append(af)

efs = np.array(efs)
afs = u.Quantity(afs).to(u.au)
t2['e_f'] = efs
t2['a_f'] = afs
t2['P_f'] = 2*np.pi * np.sqrt(t2['a_f']**3 / (G * (t2['M1'] + t2['M2']))).to(u.day)

In [None]:
fig = plot_simulated_P_e(t2, P_col='P_f')

Conclusion: solving both $a$ and $e$ looks almost the same as the simpler case!

In [None]:
plt.figure(figsize=(5,5))
plt.scatter(t['P'], t2['P_f'])
plt.xscale('log')
plt.yscale('log')
plt.xlim(1E0, 1E4)
plt.ylim(1E0, 1E4)

In [None]:
plt.figure(figsize=(5,5))
plt.scatter(t['e_f'], t2['e_f'])