In [1]:
%load_ext autoreload
%autoreload 2

import sys, os, pickle, time, warnings
import numpy as np, pandas as pd, scipy, scipy.stats as stats, tqdm, h5py
from copy import deepcopy as copy

# Plotting modules
import matplotlib, corner
from pylab import cm
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

In [2]:
# Plotting modules
font = {'family' : 'serif', 'weight' : 'normal',
        'size'   : 16}
legend = {'fontsize': 16}
matplotlib.rc('font', **font)
matplotlib.rc('legend', **legend)
plt.rc('text', usetex=True)
plt.rc('axes', labelsize=16)
plt.rc('xtick',labelsize=16)
plt.rc('ytick',labelsize=16)
plt.rc('legend',fontsize=16)

cm_default = plt.rcParams['axes.prop_cycle'].by_key()['color']

corner_kwargs = {'max_n_ticks':3, 'title_kwargs':{"fontsize": 16}, 'label_kwargs':{'fontsize':16}}

In [3]:
sys.path.extend(['../utilities/', '../models/'])
import samplers, disk_cone_plcut as dcp, plotting, transformations
from transformations import func_inv_jac, func_labels, label_dict
import functions

In [4]:
import disk_halo_mstogap as dh_msto

In [None]:
pi_mu, pi_err, abs_sin_lat, m_mu, log_pi_err, hz=1., alpha1=-1., alpha2=-1., alpha3=-1.,
                                Mto=4., Mms=8., Mms1=9., Mms2=7., fD=0.5, Mx=10., R0=8.27, degree=21):

# Test Disk Likelihood gradient

In [55]:
size = 1

In [65]:
kwargs = {'hz':1., 'alpha1':-1., 'alpha2':-1., 'alpha3':-1.,
          'Mto':4., 'Mms':8., 'Mms1':9., 'Mms2':7., 
          'fD':0.5, 'Mx':10., 'R0':8.27, 'degree':21}

In [80]:
sample = {}; true_pars={}; latent_pars={}; cmpt=0;
filename = '/data/asfe2/Projects/mwtrace_data/mockmodel/sample.h'
with h5py.File(filename, 'r') as hf:
    subset = (hf['sample']['M'][...]>5)
    subsample  = np.sort(np.random.choice(np.arange(np.sum(subset)), size=size, replace=False))
    for key in hf['sample'].keys():
        sample[key]=hf['sample'][key][...][subset][subsample]
    # Get true parameters
    for key in hf['true_pars'].keys():
        if not key in np.arange(3).astype(str):
            true_pars[key]=hf['true_pars'][key][...]
        else:
            true_pars[int(key)]={}
            for par in hf['true_pars'][key].keys():
                true_pars[int(key)][par]=hf['true_pars'][key][par][...]
for j in range(3): true_pars[j]['w']*=size
    
args = (sample['parallax_obs'], sample['parallax_error'], sample['sinb'], \
    sample['m_obs'], np.log(sample['parallax_error']))
sample

{'M': array([8.88118812]),
 'cmpt': array([2]),
 'l': array([1.03700749]),
 'm': array([23.32981651]),
 'm_err': array([0.20543416]),
 'm_obs': array([23.69441515]),
 'parallax_error': array([0.30245281]),
 'parallax_obs': array([0.0331892]),
 's': array([7.75756953]),
 'sinb': array([0.92705025]),
 'source_id': array([53961])}

In [67]:
dh_msto.log_expmodel_perr(*args, **kwargs), dh_msto.log_expmodel_perr_grad(*args, **kwargs, grad=True)

(array([-4.51199479]), (array([-4.51197578]), array([[-1.55221875e+00],
         [ 2.63338940e+00],
         [-2.00000000e+00],
         [            nan],
         [ 6.45475307e-16],
         [-2.16582940e-08]])))

In [81]:
kwargs = {'Mms':8., 'Mms1':9., 'Mms2':7., 
          'Mx':10., 'R0':8.27, 'degree':21}
model = lambda x: dh_msto.log_expmodel_perr_grad(*args, hz=x[0], alpha3=x[1], fD=x[2], alpha1=x[4], alpha2=x[5],
                                                Mto=x[3],  **kwargs, grad=True)[0]
grad = lambda x: dh_msto.log_expmodel_perr_grad(*args, hz=x[0], alpha3=x[1], fD=x[2], alpha1=x[4], alpha2=x[5],
                                                Mto=x[3],  **kwargs, grad=True)[1][:,0]

In [120]:
x0 = np.array([1,-0.5,0.5,4,-1,-1.1])
np.vstack(( scipy.optimize.approx_fprime(x0, model, 1e-10), grad(x0) )).T

array([[ 3.05194092e+00,  3.05193506e+00],
       [ 0.00000000e+00, -4.06759432e-31],
       [ 2.00000017e+00,  2.00000000e+00],
       [ 4.97379915e-04,             nan],
       [-3.56816798e-01, -3.56844716e-01],
       [-1.56852309e-02, -1.56866449e-02]])

In [121]:
%timeit grad(x0)

14.3 ms ± 220 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [36]:
x0 = np.array([1,-0.5,0.5,4,-1,-1])
np.vstack(( scipy.optimize.approx_fprime(x0, model, 1e-12), grad(x0) )).T

array([[ 0.59729999, -1.75979219],
       [ 0.        , -0.08927907],
       [ 0.        ,  1.79244446],
       [ 0.        ,         nan],
       [-2.48645549, -2.4520428 ],
       [ 0.        , -2.59676287]])

In [21]:
x0 = np.array([1,-0.5,0.5,4,-1,-1])
np.vstack(( scipy.optimize.approx_fprime(x0, model, 1e-6), grad(x0) )).T

array([[-2.40207012e+00, -1.76098115e+00],
       [ 0.00000000e+00, -8.91171950e-02],
       [ 0.00000000e+00,  1.79300079e+00],
       [ 0.00000000e+00,             nan],
       [-2.48654675e+00, -2.45002985e+00],
       [ 4.44089210e-10, -2.58722967e+00]])

In [236]:
x0 = np.array([1,-0.5,0.5,4,-1,-1])
np.vstack(( scipy.optimize.approx_fprime(x0, model, 1e-5), grad(x0) )).T

array([[  2.28322041,   2.27883744],
       [ -1.45405762,  -1.45775994],
       [ -1.90662084,  -1.90660267],
       [  0.97108056,          nan],
       [  0.04801714,  -0.01253216],
       [  0.02143161, -11.97345119]])

In [180]:
np.exp(model(x0))

0.0056116752797461065

In [170]:
def log_expmodel_perr_grad(pi_mu, pi_err, abs_sin_lat, m_mu, log_pi_err, hz=1., alpha1=-1., alpha2=-1., alpha3=-1.,
                                Mto=4., Mms=8., Mms1=9., Mms2=7., fD=0.5, Mx=10., R0=8.27, degree=21, 
                               n=3, grad=False):

    beta = abs_sin_lat/hz

    # Absolute magnitude not known
    Mag_bounds = [-np.inf, Mto, Mms2, Mms1, Mx]

    ii=2
    a = np.exp((Mag_bounds[ii  ]+10-m_mu)*np.log(10)/5)
    b = np.exp((Mag_bounds[ii+1]+10-m_mu)*np.log(10)/5)
    
    p_model = np.zeros((4, len(pi_mu)))
    # Gauss - Hermite Quadrature
    args = (beta, n*np.ones(len(pi_mu)), pi_mu, pi_err, a, b)
    p_mode = functions.get_fooroots_ridder_hm(dh_msto.expmodel_perr_logit_grad, a=a+1e-15, b=b, args=args)
    curve = dh_msto.expmodel_perr_d2logIJ_dp2(p_mode, *args[:-2], transform='logit_ab', a=a, b=b) / \
                                functions.jac(p_mode, transform='logit_ab', a=a, b=b)**2
    z_mode = functions.trans(p_mode, transform='logit_ab', a=a, b=b)
    sigma = 1/np.sqrt(-curve)
    p_integral = functions.integrate_gh_gap(dh_msto.expmodel_perr_integrand, z_mode, sigma, args[:-2], transform='logit_ab', a=a, b=b, degree=10)
    p_model[ii] = p_integral

    grad_lambda = np.zeros((pi_mu.shape[0], 6)) + np.nan
    # n
    dp_model_dn = np.zeros((4, len(pi_mu)))
    
    p_integral = np.zeros(len(pi_mu))
    # Gauss - Hermite Quadrature
    args = (beta, n*np.ones(len(pi_mu)), pi_mu, pi_err, a, b)
    p_mode = functions.get_fooroots_ridder_hm(dh_msto.expmodel_perr_logit_grad_dn, a=a+1e-15, b=b, args=args)
    curve = dh_msto.expmodel_perr_d2logIJ_dp2_dn(p_mode, *args[:-2], transform='logit_ab', a=a, b=b) / \
                                functions.jac(p_mode, transform='logit_ab', a=a, b=b)**2
    z_mode = functions.trans(p_mode, transform='logit_ab', a=a, b=b)
    sigma = 1/np.sqrt(-curve)
    p_integral = functions.integrate_gh_gap(dh_msto.expmodel_perr_integrand_dn, z_mode, sigma, args[:-2], transform='logit_ab', a=a, b=b, degree=10)
    dp_model_dn[ii] = p_integral
    
    print(dh_msto.expmodel_perr_logit_grad_dn(0.96800718+0.01, args))
    print(a, b, p_mode)
    print(curve)

    return p_model[ii], dp_model_dn[ii]

In [171]:
kwargs = {'hz':1., 'alpha1':-1., 'alpha2':-1., 'alpha3':-1.,
          'Mto':4., 'Mms':8., 'Mms1':9., 'Mms2':7., 
          'fD':0.5, 'Mx':10., 'R0':8.27, 'degree':21}
model = lambda x: log_expmodel_perr_grad(*args, n=x, **kwargs, grad=False)[0][0]
grad = lambda x: log_expmodel_perr_grad(*args, n=x, **kwargs, grad=True)[1][0]

In [173]:
x0 = 1.
scipy.optimize.approx_fprime(np.array([1.]), model, 1e-5), grad(x0)

[ -2.78796639  18.13924131 253.9270348 ]
[0.81950516 0.19239234 0.03317259] [2.05850389 0.48326772 0.08332579] [0.96800718 0.37562137 0.07815277]
[-22.10512494  -0.83222317  -0.91466445]
[ -2.78796471  18.13923751 253.92702653]
[0.81950516 0.19239234 0.03317259] [2.05850389 0.48326772 0.08332579] [0.96800719 0.37562153 0.07815278]
[-22.10513583  -0.83222316  -0.91466447]
[ -2.78796639  18.13924131 253.9270348 ]
[0.81950516 0.19239234 0.03317259] [2.05850389 0.48326772 0.08332579] [0.96800718 0.37562137 0.07815277]
[-22.10512494  -0.83222317  -0.91466445]


(array([0.00509683]), 0.0048680988406353595)

In [154]:
args

(array([1.06975907, 0.01723011, 0.11352768]),
 array([0.06528505, 0.34330224, 0.05635067]),
 array([0.98141136, 0.97770244, 0.92938784]),
 array([17.43224154, 20.57906107, 24.39610278]),
 array([-2.72899214, -1.06914405, -2.87616112]))