In [None]:
import os.path
import os
import sys
import numpy as np
import scipy.spatial
import scipy.stats
import time
from itertools import groupby

from sklearn import mixture

import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle

from astroquery.gaia import Gaia
from astroquery.simbad import Simbad
from astropy.io.votable import parse_single_table
from astropy.io import ascii
from astropy.table import Table, vstack, unique
from astropy.coordinates import SkyCoord, Galactic
from astropy import units

In [None]:
plt.rc('xtick',direction='in',top=True)
plt.rc('ytick',direction='in',right=True)
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["font.size"] = 12
plt.rcParams["mathtext.fontset"] = "stix"

from astropy.visualization import quantity_support
quantity_support()

First we need to import the Gaia data (DR2 or eDR3, or a smaller sample centred on Sol for speed), set up the coordinates and add the Sun

In [None]:
# what is our source? DR2 or eDR3? For Solar nbhd (80pc), use plain DR2/3 as they're (much) quicker to load
#source_cat = 'eDR3'
#source_cat = 'DR2'
source_cat = 'eDR3_all'
#source_cat = 'DR2_all'


# get 80 pc RV sample
if source_cat == 'eDR3':
    filename = "eDR3_RV_80pc.vot"
    if os.path.exists(filename):
        data_all = parse_single_table(filename).to_table()
    else:
        job = Gaia.launch_job_async("select * from gaiaedr3.gaia_source where parallax > 12.5 and "
                                        "dr2_radial_velocity IS NOT NULL",dump_to_file=True,output_format = "votable",
                                    output_file=filename)
        r = job.get_results()
        data_all = parse_single_table(filename).to_table()
    data_all.rename_column('dr2_radial_velocity','radial_velocity')

if source_cat == 'DR2':
    filename = "DR2_RV_80pc.vot"
    if os.path.exists(filename):
        data_all = parse_single_table(filename).to_table()
    else:
        job = Gaia.launch_job_async("select * from gaiadr2.gaia_source where parallax > 12.5 and "
                                        "radial_velocity IS NOT NULL",dump_to_file=True,output_format = "votable",
                                    output_file=filename)
        r = job.get_results()
        data_all = parse_single_table(filename).to_table()

#get full RV sample
if source_cat == 'DR2_all':
    for i in range(6):
        filename = "DR2_RV_all_"+"{:1d}".format(i)+".vot"
        if os.path.exists(filename):
            tmp = parse_single_table(filename).to_table()
            print("Read table "+"{:1d}".format(i+1)+" of 6")
        else:
            job = Gaia.launch_job_async("select designation, source_id, ref_epoch, ra, ra_error, dec, dec_error, "
                                            "parallax, parallax_error, pmra, pmra_error, pmdec, pmdec_error, ra_dec_corr, "
                                            "ra_parallax_corr, ra_pmra_corr, ra_pmdec_corr, dec_parallax_corr, dec_pmra_corr, "
                                            "dec_pmdec_corr, parallax_pmra_corr, parallax_pmdec_corr, pmra_pmdec_corr, "
                                            "astrometric_gof_al, astrometric_excess_noise, astrometric_excess_noise_sig, "
                                            "phot_g_mean_flux, phot_g_mean_flux_error, phot_g_mean_mag, phot_bp_mean_flux, "
                                            "phot_bp_mean_flux_error, phot_bp_mean_mag, phot_rp_mean_flux, "
                                            "phot_rp_mean_flux_error, phot_rp_mean_mag, radial_velocity, "
                                            "radial_velocity_error from gaiadr2.gaia_source where "
                                            "radial_velocity IS NOT NULL and ra >= "+str(60*i)+" and "
                                            "ra < "+str(60*(i+1)),dump_to_file=True,output_format = "votable",
                                        output_file=filename)
            print("Downloaded table "+"{:1d}".format(i+1)+" of 6")
            r = job.get_results()
            tmp = parse_single_table(filename).to_table()
            print("Read table "+"{:1d}".format(i+1)+" of 6")
        if i == 0:
            data_all = tmp
        else:
            data_all = vstack([data_all,tmp])

if source_cat == 'eDR3_all':
    for i in range(6):
        filename = "eDR3_RV_all_"+"{:1d}".format(i)+".vot"
        if os.path.exists(filename):
            tmp = parse_single_table(filename).to_table()
            print("Read table "+"{:1d}".format(i+1)+" of 6")
        else:
            job = Gaia.launch_job_async("select designation, source_id, ref_epoch, ra, ra_error, dec, dec_error, "
                                            "parallax, parallax_error, pmra, pmra_error, pmdec, pmdec_error, ra_dec_corr, "
                                            "ra_parallax_corr, ra_pmra_corr, ra_pmdec_corr, dec_parallax_corr, dec_pmra_corr, "
                                            "dec_pmdec_corr, parallax_pmra_corr, parallax_pmdec_corr, pmra_pmdec_corr, "
                                            "astrometric_gof_al, astrometric_excess_noise, astrometric_excess_noise_sig, "
                                            "phot_g_mean_flux, phot_g_mean_flux_error, phot_g_mean_mag, phot_bp_mean_flux, "
                                            "phot_bp_mean_flux_error, phot_bp_mean_mag, phot_rp_mean_flux, "
                                            "phot_rp_mean_flux_error, phot_rp_mean_mag, dr2_radial_velocity, "
                                            "dr2_radial_velocity_error from gaiaedr3.gaia_source where "
                                            "dr2_radial_velocity IS NOT NULL and ra >= "+str(60*i)+" and "
                                            "ra < "+str(60*(i+1)),dump_to_file=True,output_format = "votable",
                                        output_file=filename)
            print("Downloaded table "+"{:1d}".format(i+1)+" of 6")
            r = job.get_results()
            tmp = parse_single_table(filename).to_table()
            print("Read table "+"{:1d}".format(i+1)+" of 6")
        if i == 0:
            data_all = tmp
        else:
            data_all = vstack([data_all,tmp])
    
    data_all.rename_column('dr2_radial_velocity','radial_velocity')
    data_all.rename_column('dr2_radial_velocity_error','radial_velocity_error')        

            
if source_cat != 'eDR3' and source_cat != 'DR2' and source_cat != 'DR2_all' and source_cat != 'eDR3_all':
    print('Specify correct DR')
    assert(False)
    

data_all.add_column(data_all['phot_bp_mean_mag'] - data_all['phot_rp_mean_mag'],name='BP_RP')      #colour BP_RP
data_all.add_column(data_all['phot_g_mean_mag']+5*np.log10(data_all['parallax']/100),name='M_G')   #absolute M_G

# tidy up some units
data_all['pmra'].unit = units.mas/units.yr
data_all['pmdec'].unit = units.mas/units.yr
data_all['radial_velocity'].unit = units.km/units.s

N_stars_all = len(data_all)

print('{:d} stars read'.format(N_stars_all))

In [None]:
data_all = data_all[data_all['parallax'] > 0]
N_stars_all = len(data_all)
print('{:d} stars with positive parallax'.format(N_stars_all))

In [None]:
coord = SkyCoord(data_all['ra'],data_all['dec'],distance=1000*units.pc/np.array(data_all['parallax']),
                 pm_ra_cosdec=data_all['pmra'],pm_dec=data_all['pmdec'],
                 radial_velocity=data_all['radial_velocity'],
                 frame='icrs').transform_to(Galactic)
coord.representation_type = 'cartesian'

# stuff for Mahalanobis distance
data_all.add_column(coord.u,name='u')
data_all.add_column(coord.v,name='v')
data_all.add_column(coord.w,name='w')
data_all.add_column(coord.U,name='U')
data_all.add_column(coord.V,name='V')
data_all.add_column(coord.W,name='W')
data_all.add_column(1000/data_all['parallax']*units.pc,name='d_Sol')

# add Sol at end of table (index -1)
data_all.add_row({'u':0*units.pc,'v':0*units.pc,'w':0*units.pc,'d_Sol':0*units.pc,
                  'U':0*units.km/units.s,'V':0*units.km/units.s,'W':0*units.km/units.s,
                  'designation':'Sol'})

In [None]:
# colour and magnitude for Sol, from Casagrande+18 (DR2):
data_all[-1]['M_G'] = 4.67
data_all[-1]['BP_RP'] = 0.82

#Solar motion wrt LSR, from Schönrich+10:
U_Sol = 11.1 * units.km/units.s
V_Sol = 12.14 * units.km/units.s
W_Sol = 7.25 * units.km/units.s

data_all['U'] += U_Sol.value #???
data_all['V'] += V_Sol.value
data_all['W'] += W_Sol.value

In [None]:
#coord
# position is (u,v,w)
# velocity is (U,V,W)
# this won't be confusing at all...

Now we set up our target list, check if the target is in Gaia DRn, and make sure we handle Sol properly

In [None]:
# some global variables
d_query = 80 * units.pc #radius of sphere to query
N_thresh = 20 # k = N_thresh for k-NN calculation
rho_thr = 50 # if rescaled rho above this, cut from Gaussian mixture model and class as high-rho
N_models = 10 # number models for GMM
v_factor = 1.25 # check stars with v \in (v_target/v_factor,v_target*v_factor)
N_stars_min = 100 # min number of neighbours within d_query

#rough thin disc, thick disc, halo boundaries from Bensby+14
v_thin = 50
v_thick_min = 70
v_thick_max = 180
v_halo = 200

In [None]:
if not 'd_target' in data_all.colnames:
    data_all.add_column(np.zeros(len(data_all))*units.pc,name='d_target')

data_M_G_9 = data_all[np.where(data_all['M_G'] <= 9.0)]
data_M_G_8 = data_all[np.where(data_all['M_G'] <= 8.0)]

In [None]:
class Target:
    
    def __init__(self,name_short,gaia_id):
        
        self.gaia_id = gaia_id
        self.name_short = name_short
        self.N_sample = 0
        self.P_1comp = (np.nan,np.nan)
        self.P_1comp_v = (np.nan,np.nan)
        return

    def get_neighbours(self,data_all=data_all):
        
        if self.gaia_id is None:
            print(self.name_short+": no Gaia id")
            self.data = None
            self.N_stars = 0
            return
    
        self.target = data_all[data_all['designation'] == self.gaia_id]

        if len(self.target) == 0:
            print(self.name_short+" not found")
            self.data = None
            self.N_stars = 0
            return

        self.folder = 'results/' + self.name_short.replace(' ','') + '/' + source_cat + '/'

        if not os.path.exists(self.folder):
            os.makedirs(self.folder)

        self.folder = self.folder+self.name_short.replace(' ','')+'_'+source_cat+'_'



        self.u = self.target['u']
        self.v = self.target['v']
        self.w = self.target['w']

        d = np.sqrt((self.u-data_all['u'])**2 + 
                    (self.v-data_all['v'])**2 + 
                    (self.w-data_all['w'])**2) * units.pc

        if 'd_target' in data_all.colnames:
            data_all['d_target'] = d
        else:
            data_all.add_column(d,name='d_target')

        self.data = data_all[d <= d_query]
        self.i_target = np.where(self.gaia_id == self.data['designation'])[0][0]
        self.N_stars = len(self.data)
        print(self.name_short+":     Sample: "+str(self.N_stars)+" of "+str(N_stars_all)+" stars")

        return

    def distance_histograms(self):
        fig, (ax1, ax2) = plt.subplots(1,2,figsize=[10,4])

        ax1.hist(self.data['d_Sol'])
        ax1.plot([self.target['d_Sol']]*2,ax1.get_ylim(),'--k',label=self.name_short)
        ax1.set_xlabel('distance to Sol [pc]')
        ax1.set_ylabel('# stars')
        ax1.legend()

        ax2.hist(self.data['d_target'])
        ax2.set_xlabel('distance to '+self.name_short+' [pc]')
        ax2.set_ylabel('# stars')

        plt.savefig(self.folder+'distance_histograms.pdf')
        plt.close()
        
        return

    def distance_histograms_fine(self,M_G_lim=None):
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2,2,figsize=[10,10])
        if self.target['parallax'] > 0:
            min_bin = np.max([0,self.target['d_Sol']-d_query.value])
            max_bin = self.target['d_Sol']+d_query.value
        else:
            min_bin = 0.
            max_bin = d_query.value
        n_bins = int(np.floor(max_bin-min_bin)+1)
        d = np.array(np.linspace(min_bin,max_bin,n_bins)).reshape(n_bins)
        V = np.zeros(n_bins-1)
        V_Sol = np.zeros(n_bins-1)

#don't calculate first bin to avoid some numerical artefacts
        for i in range(n_bins-2):
            V[i+1] = (4*np.pi/3)*d[i+2]**3 - (4*np.pi/3)*d[i+1]**3
            if (d_query.value > self.target['d_Sol']) and (d[i] < d_query.value - self.target['d_Sol']):
                V_Sol[i+1] = 4*np.pi*d[i+1]**2*(d[i+2]-d[i+1])
            else:
                V_Sol[i+1] = 2*np.pi*d[i+1]**2 * (1 - (d[i+1]**2+self.target['d_Sol']**2-d_query.value**2)/
                                                  (2*d[i+1]*self.target['d_Sol']))*(d[i+2]-d[i+1])
        
        if M_G_lim is None:
            index = [True]*len(self.data)
            filesuf = ''
        else:
            index = np.where(self.data['M_G'] <= M_G_lim)
            filesuf = '_MGlim'+'{:04.1f}'.format(M_G_lim)
            
        N = ax1.hist(self.data['d_Sol'][index],bins=d)
        ax1.plot([self.target['d_Sol']]*2,ax1.get_ylim(),'--k',label=self.name_short)
        ax1.set_xlabel('distance to Sol [pc]')
        ax1.set_ylabel('# stars')
        ax1.legend()

        ax2.plot(d[1:],N[0]/V_Sol)
        ax2.plot([self.target['d_Sol']]*2,ax2.get_ylim(),'--k',label=self.name_short)
        ax2.set_xlabel('distance to Sol [pc]')
        ax2.set_ylabel('stellar density [pc^-3]')

        n_bins = int(np.floor(d_query.value)+1)
        d = np.linspace(0,80,n_bins)
        V = np.zeros(n_bins-1)
        for i in range(n_bins-1):
            V[i] = (4*np.pi/3)*d[i+1]**3 - (4*np.pi/3)*d[i]**3

        N = ax3.hist(self.data['d_target'][index],bins=d)
        ax3.set_xlabel('distance to '+self.name_short+' [pc]')
        ax3.set_ylabel('# stars')

        ax4.plot(d[1:],N[0]/V)
        ax4.set_xlabel('distance to '+self.name_short+' [pc]')
        ax4.set_ylabel('stellar density [pc$^{-3}$]')
        ax4.set_yscale('log')

        plt.savefig(self.folder+'distance_histograms_fine'+filesuf+'.pdf')
        plt.close()
        
        return

    def magnitude_histograms(self):
        plt.figure(figsize=[5,4])
        plt.hist(self.data['phot_g_mean_mag'])
        plt.plot([self.target['phot_g_mean_mag']]*2,plt.ylim(),'--k',label=self.name_short)
        plt.xlabel('G magnitude')
        plt.ylabel('# stars')
        plt.legend()
        plt.savefig(self.folder+'magnitude_histograms.pdf')
        plt.close()
        
        return

    def parallax_error_histograms(self):

        fig, (ax1,ax2,ax3) = plt.subplots(1,3,figsize=[15,4])

        ax1.hist(self.data['parallax']/self.data['parallax_error'])
        ax1.plot([self.target['parallax']/self.target['parallax_error']]*2,ax1.get_ylim(),'--k',label=self.name_short)
        ax1.set_xlabel('parallax over error')
        ax1.set_ylabel('# stars')
        ax1.legend()

        ax2.hist(self.data['parallax']/self.data['parallax_error'],bins=np.linspace(0,1000,11))
        ax2.plot([self.target['parallax']/self.target['parallax_error']]*2,ax2.get_ylim(),'--k',label=self.name_short)
        ax2.set_xlabel('parallax over error')
        ax2.set_ylabel('# stars')
        ax2.set_xlim([0,600])

        ax3.hist(self.data['parallax']/self.data['parallax_error'],bins=np.linspace(0,100,11))
        ax3.plot([self.target['parallax']/self.target['parallax_error']]*2,ax3.get_ylim(),'--k',label=self.name_short)
        ax3.set_xlabel('parallax over error')
        ax3.set_ylabel('# stars')
        ax3.set_xlim([0,60])

        plt.savefig(self.folder+'parallax_error_histograms.pdf')
        plt.close()
        
        return

    def distance_Gmag(self):

        plt.figure(figsize=[5,4])

        plt.scatter(self.data['d_Sol'],self.data['phot_g_mean_mag'],alpha=0.1)
        plt.scatter(self.target['d_Sol'],self.target['phot_g_mean_mag'],c='k',marker='*',label=self.name_short)
        plt.xlabel('distance to Sol [pc]')
        plt.ylabel('G mag')
        plt.legend()

        plt.savefig(self.folder+'distance_Gmag.pdf')
        plt.close()

        return

    def distance_M_G(self):

        plt.figure(figsize=[5,4])

        plt.scatter(self.data['d_Sol'],self.data['M_G'],alpha=0.1)
        plt.scatter(self.target['d_Sol'],self.target['M_G'],c='k',marker='*',label=self.name_short)
        plt.xlabel('distance to Sol [pc]')
        plt.ylabel('absolute $M_G$ [mag]')
        plt.legend()

        plt.savefig(self.folder+'distance_M_G.pdf')
        plt.close()

        return

    def CMD(self):
        plt.figure(figsize=[5,4])

        plt.scatter(self.data['BP_RP'],self.data['M_G'],alpha=0.1)
        plt.scatter(self.target['BP_RP'],self.target['M_G'],c='k',marker='*',label=self.name_short)

        plt.xlabel('BP-RP')
        plt.ylabel('absolute $M_G$ [mag]')
        plt.gca().invert_yaxis()

        plt.legend()

        plt.savefig(self.folder+'CMD.pdf')
        plt.close()

        return

    def CMD_hist(self):
        
        plt.figure(figsize=[5,4])

        good = np.logical_and(~ np.isnan(self.data['BP_RP']),~ np.isnan(self.data['M_G']))
    
        hist, xedge, yedge, pcm = plt.hist2d(self.data['BP_RP'][good],self.data['M_G'][good],
                                             bins=100,norm=mpl.colors.LogNorm())
        plt.scatter(self.target['BP_RP'],self.target['M_G'],c='k',marker='*',label=self.name_short)

        plt.xlabel('BP-RP')
        plt.ylabel('absolute $M_G$ [mag]')
        plt.gca().invert_yaxis()

        plt.legend()

        plt.colorbar(pcm,label='stars per bin')

        plt.savefig(self.folder+'CMD_hist.pdf')
        plt.close()
        
        return

    def total_PM_histograms(self):
        if source_cat == 'eDR3' or source_cat == 'eDR3_all':
            fig, (ax1,ax2,ax3) = plt.subplots(1,3,figsize=[15,4])

            ax1.hist(self.data['pm'])
            ax1.set_xlabel('total PM [mas/yr]')
            ax1.set_ylabel('# stars')

            ax2.hist(self.data['pm'],bins=np.linspace(0,1000,11))
            ax2.set_xlabel('total PM [mas/yr]')
            ax2.set_ylabel('# stars')
            ax2.set_xlim([0,1000])

            ax3.hist(self.data['pm'],bins=np.linspace(0,100,11))
            ax3.set_xlabel('total PM [mas/yr]')
            ax3.set_ylabel('# stars')
            ax3.set_xlim([0,100])

            plt.savefig(self.folder+'total_PM_histograms.pdf')
            plt.close()

        return

    def RV_histograms(self):

        fig, (ax1,ax2) = plt.subplots(1,2,figsize=[10,4])

        ax1.hist(self.data['radial_velocity'])
        ax1.plot([self.target['radial_velocity']]*2,ax1.get_ylim(),'--k',label=self.name_short)
        ax1.set_xlabel('RV [km/s]')
        ax1.set_ylabel('# stars')
        ax1.legend()

        ax2.hist(self.data['radial_velocity'],bins=np.linspace(-100,100,11))
        ax2.plot([self.target['radial_velocity']]*2,ax2.get_ylim(),'--k',label=self.name_short)
        ax2.set_xlabel('RV [km/s]')
        ax2.set_ylabel('# stars')
        ax2.set_xlim([-100,100])

        plt.savefig(self.folder+'RV_histograms.pdf')
        plt.close()

        return

    def X_Y(self):
        plt.figure(figsize=[5,4])

        plt.plot(self.data['u'],self.data['v'],'.',alpha=0.02)
        plt.scatter(self.target['u'],self.target['v'],c='k',marker='*',label=self.name_short)
        plt.xlabel('X [pc]')
        plt.ylabel('Y [pc]')
        plt.axis('equal')
        plt.legend()

        plt.savefig(self.folder+'X_Y.pdf')
        plt.close()

        return

    def Toomre(self):

        plt.figure(figsize=[5,4])

        x = np.linspace(-400,400,1001)
        plt.plot(x,np.sqrt(100**2-x**2),'k')
        plt.plot(x,np.sqrt(200**2-x**2),'k')
        plt.plot(x,np.sqrt(300**2-x**2),'k')
        plt.plot(x,np.sqrt(400**2-x**2),'k')

        plt.plot(self.data['U'],np.sqrt(self.data['V']**2+self.data['W']**2),'.',alpha=0.1)
        plt.scatter(self.target['U'],np.sqrt(self.target['V']**2+self.target['W']**2),c='k',marker='*',
                    label=self.name_short,zorder=9)

        plt.xlabel('$U$ [km/s]')
        plt.ylabel('\sqrt{V^2+W^2}')
        plt.axis('equal')
        plt.xlim(np.min(self.data['U']),np.max(self.data['U']))
        plt.legend()

        plt.savefig(self.folder+'Toomre.pdf')
        plt.close()

        return

    def get_pos_6D(self):
        self.pos_6D = np.array([self.data['u'],self.data['v'],self.data['w'],
                                self.data['U'],self.data['V'],self.data['W']])

        self.Cov = np.cov(self.pos_6D)
    
        return

    def distances(self,i):
        
        if self.pos_6D is None:
            self.get_pos_6D()
        
        D_M = np.zeros(self.N_stars)
    
        for j in range(self.N_stars):
#        if i % N_stars == -1 % N_stars: # we are Sol
            D_M[j] = scipy.spatial.distance.mahalanobis(self.pos_6D[:,i],self.pos_6D[:,j],
                                                        np.linalg.inv(self.Cov))
    
        D_u = self.data['u'] - self.data['u'][i]
        D_v = self.data['v'] - self.data['v'][i]
        D_w = self.data['w'] - self.data['w'][i]
        D_U = self.data['U'] - self.data['U'][i]
        D_V = self.data['V'] - self.data['V'][i]
        D_W = self.data['W'] - self.data['W'][i]
    
        D_phys = np.sqrt(D_u**2 + D_v**2 + D_w**2)
        D_vel = np.sqrt(D_U**2 + D_V**2 + D_W**2)
    
        dist = {'D_M':D_M,'D_phys':D_phys,'D_vel':D_vel,'D_u':D_u,'D_v':D_v,'D_w':D_w,'D_U':D_U,'D_V':D_V,'D_W':D_W}
    
        return dist

    def get_dist_target(self):
        
        self.dist_target = self.distances(self.i_target)
        
        return
    
    def D_M_histograms(self):

        fig, (ax1, ax2, ax3) = plt.subplots(1,3,figsize=[15,4])

        ax1.hist(self.dist_target['D_M'])
        ax1.set_xlabel('Mahalanobis distance to '+self.name_short)
        ax1.set_ylabel('# stars')

        ax2.hist(self.dist_target['D_M'],bins=np.linspace(0,5,11))
        ax2.set_xlabel('Mahalanobis distance to '+self.name_short)
        ax2.set_ylabel('# stars')

        ax3.hist(self.dist_target['D_M'],bins=np.logspace(-1,2,51))
        ax3.set_xlabel('Mahalanobis distance to '+self.name_short)
        ax3.set_ylabel('# stars')
        ax3.set_xscale('log')
        ax3.set_yscale('log')

        plt.savefig(self.folder+'D_M_histograms.pdf')
        plt.close()
        
        return

    def Delta_v_histograms(self):

        fig, (ax1,ax2) = plt.subplots(1,2,figsize=(10,4))

        ax1.hist(self.dist_target['D_vel'])
        ax1.set_xlabel('Delta V from '+self.name_short+' [km/s]')
        ax1.set_ylabel('# stars')

        ax2.hist(self.dist_target['D_vel'],bins=np.linspace(0,100,11))
        ax2.set_xlabel('Delta V from '+self.name_short+' [km/s]')
        ax2.set_ylabel('# stars')

        plt.savefig(self.folder+'Delta_v_histograms.pdf')
        plt.close()

        return

    def D_phys_D_M(self):
        
        fig, (ax1,ax2,ax3) = plt.subplots(1,3,figsize=[15,4])

        ax1.scatter(self.dist_target['D_phys'],self.dist_target['D_M'],alpha=0.03)
        ax1.set_xlabel('physical distance to '+self.name_short+' [pc]')
        ax1.set_ylabel('Mahalanobis distance to '+self.name_short)

        ax2.scatter(self.dist_target['D_phys'],self.dist_target['D_M'],alpha=0.01)
        ax2.set_xlabel('physical distance to '+self.name_short+' [pc]')
        ax2.set_ylabel('Mahalanobis distance to '+self.name_short)
        ax2.set_ylim([0,5])

        ax3.scatter(self.dist_target['D_phys'],self.dist_target['D_M'],alpha=0.2)
        ax3.set_xlabel('physical distance to '+self.name_short+' [pc]')
        ax3.set_ylabel('Mahalanobis distance to '+self.name_short)
        ax3.set_ylim([0,1.5])

        plt.savefig(self.folder+'D_phys_D_M.pdf')
        plt.close()
        
        return

    def Delta_v_D_M(self):
        
        fig, (ax1,ax2,ax3) = plt.subplots(1,3,figsize=[15,4])

        ax1.scatter(self.dist_target['D_vel'],self.dist_target['D_M'],alpha=0.03)
        ax1.set_xlabel('Delta v from '+self.name_short+' [km/s]')
        ax1.set_ylabel('Mahalanobis distance to '+self.name_short)

        ax2.scatter(self.dist_target['D_vel'],self.dist_target['D_M'],alpha=0.01)
        ax2.set_xlabel('Delta v from '+self.name_short+' [km/s]')
        ax2.set_ylabel('Mahalanobis distance to '+self.name_short)
        ax2.set_ylim([0,5])
        ax2.set_xlim([0,150])

        ax3.scatter(self.dist_target['D_vel'],self.dist_target['D_M'],alpha=0.2)
        ax3.set_xlabel('Delta v from '+self.name_short+' [km/s]')
        ax3.set_ylabel('Mahalanobis distance to '+self.name_short)
        ax3.set_ylim([0,1.5])
        ax3.set_xlim([0,40])

        plt.savefig(self.folder+'Delta_v_D_M.pdf')
        plt.close()

        return

    def D_phys_Delta_v(self):

        fig, (ax1,ax2) = plt.subplots(1,2,figsize=[10,4])

        ax1.scatter(self.dist_target['D_phys'],self.dist_target['D_vel'],alpha=0.03)
        ax1.set_xlabel('physical distance to '+self.name_short+' [pc]')
        ax1.set_ylabel('Delta v from '+self.name_short+' [km/s]')

        ax2.scatter(self.dist_target['D_phys'],self.dist_target['D_vel'],alpha=0.01)
        ax2.set_xlabel('physical distance to '+self.name_short+' [pc]')
        ax2.set_ylabel('Delta v from '+self.name_short+' [km/s]')
        ax2.set_ylim([0,150])

        plt.savefig(self.folder+'D_phys_Delta_v.pdf')
        plt.close()

        return

    def get_close(self,dist,j,N=20,dump_to_file=False):
        closest = np.argsort(dist['D_M'])
    
        filename = self.folder+'20_closest.txt'
    
        if dump_to_file:
            with open(filename,'w') as f:

                print('Star: '+self.data['designation'][j],file=f)
                print('(u,v,w) =' + (' {:8.3f}'*3).format(self.data['u'][j],self.data['v'][j],self.data['w'][j]) + 
                      '  [pc]',file=f)
                print('(U,V,W) =' + (' {:8.3f}'*3).format(self.data['U'][j],self.data['V'][j],self.data['W'][j]) + 
                      '  [km/s]',file=f)
                print('\n',file=f)
                print(("{:^6s} {:^29s}" + " {:>8s}"*9).format("id","Gaia id","D_M","D_phys",
                                                              "D_u","D_v","D_w","D_vel","D_U","D_V","D_W"),file=f)
                print(("{:^6s} {:^29s}" + " {:>8s}"*9).format("","","","pc","pc","pc","pc",
                                                              "km/s","km/s","km/s","km/s"),file=f)
                print("-"*90,file=f)
                for i in closest[1:1+N]:
                    print(("{:06d} {:29s}"+" {:8.3f}"*9).format(i,self.data['designation'][i],dist['D_M'][i],
                                                                dist['D_phys'][i],
                                                                dist['D_u'][i],
                                                                dist['D_v'][i],
                                                                dist['D_w'][i],
                                                                dist['D_vel'][i],
                                                                dist['D_U'][i],
                                                                dist['D_V'][i],
                                                                dist['D_W'][i]),file=f)

        return closest

    def get_close_target(self,dump_to_file=False):
        
        self.closest_target = self.get_close(self.dist_target,self.i_target,dump_to_file=dump_to_file)
        
        return

    def get_lt_40pc(self):

        self.lt_40pc = (np.where(np.logical_and(self.dist_target['D_phys'] < d_query/2,
                                                self.data['designation'] != self.target['designation'])))[0]
    
        return

    def set_seed(self):
        
        self.seed_file = self.folder+'seed'
        if os.path.exists(self.seed_file):
            with open(self.seed_file,'r+') as f:
                self.timestamp = int(f.read())
            self.rng = np.random.default_rng(self.timestamp)
            self.restore_from_save = True
        else:
    # use timestamp in ms
            self.timestamp = int(time.time() * 1000)
            self.rng = np.random.default_rng(self.timestamp)
            with open(self.seed_file,'w') as f:
                f.write(str(self.timestamp))
            self.restore_from_save = False

        return

    def get_random_sample(self):
        self.N_sample = min([600,len(self.lt_40pc)])

        self.sample = self.rng.choice(self.lt_40pc,self.N_sample,replace=False)
        
        self.sample_v = np.sqrt(self.data['U'][self.sample]**2 + self.data['V'][self.sample]**2 +
                                self.data['W'][self.sample]**2)
        self.target_v = np.sqrt(self.target['U']**2 + self.target['V']**2 + self.target['W']**2)
    
        return

    def get_sample_distances(self):

        dist = []
        close = []

        self.savefile = self.folder+'densities.txt'
        if self.restore_from_save and os.path.exists(self.savefile):
    #restore from folder+'densities.txt'
            tmp = ascii.read(self.savefile,format='fixed_width_no_header',data_start=6,delimiter='|',
                             names=('Gaia id','rho','d','u','v','w','U','V','W'))
            self.rho_20_target = tmp[0]['rho']
            self.rho_20_t = tmp[1:]['rho']
            self.d_20_target = tmp[0]['d']
            self.d_20 = tmp[1:]['d']
            print(self.name_short+': restored sample from save')
        else:
            print(self.name_short+': generating MC sample')
            for i in range(self.N_sample):
                dist.append(self.distances(self.sample[i]))
#                close.append(self.get_close(dist[i],self.sample[i]))

            self.d_20 = np.array([np.sort(d['D_M'])[N_thresh] if len(d['D_M']) >= N_thresh else np.inf for d in dist])
            self.rho_20 = N_thresh * self.d_20**(-6)
            self.rho_20_t = self.rho_20/np.median(self.rho_20)
            if len(self.dist_target['D_M']) >= N_thresh:
                self.d_20_target = np.sort(self.dist_target['D_M'])[N_thresh]
            else:
                self.d_20_target = np.inf
            self.rho_20_target = N_thresh * self.d_20_target**(-6) / np.median(self.rho_20)


        return

    def sample_D_M_rho_histograms(self):

        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2,2,figsize=(10,10))

        ax1.hist(self.d_20)
        lim = ax1.get_ylim()
        ax1.plot([self.d_20_target,self.d_20_target],lim,'--k',label=self.name_short+' value')
        ax1.set_xlabel('$D_\mathrm{M}$ to 20th nearest neighbour')
        ax1.set_ylabel('# stars')
        ax1.legend()

        ax2.hist(self.d_20,bins=np.logspace(np.log10(np.min(self.d_20)),np.log10(np.max(self.d_20)),26))
        lim = ax2.get_ylim()
        ax2.plot([self.d_20_target,self.d_20_target],lim,'--k',label=self.name_short+' value')
        ax2.set_xscale('log')
        ax2.set_xlabel('$D_\mathrm{M}$ to 20th nearest neighbour')
        ax2.set_ylabel('# stars')

        ax3.hist(self.rho_20_t)
        lim = ax3.get_ylim()
        ax3.plot([self.rho_20_target,self.rho_20_target],lim,'--k',label=self.name_short+' value')
        ax3.set_xlabel(r'$\rho_{20}$ [rescaled]')
        ax3.set_ylabel('# stars')
        ax3.legend()

        ax4.hist(self.rho_20_t,bins=np.logspace(np.log10(np.min(self.rho_20_t)),np.log10(np.max(self.rho_20_t)),26))
        lim = ax4.get_ylim()
        ax4.plot([self.rho_20_target,self.rho_20_target],lim,'--k',label=self.name_short+' value')
        ax4.set_xscale('log')
        ax4.set_xlabel(r'$\rho_{20}$ [rescaled]')
        ax4.set_ylabel('# stars')

        plt.savefig(self.folder+'sample_D_M_rho_histograms.pdf')
        plt.close()
        
        return

# Gaussian mixture stuff adapted from https://www.astroml.org/book_figures/chapter4/fig_GMM_1D.html

    def gauss(self):

        self.log_rho = np.log10(self.rho_20_t).reshape(-1,1)
        self.log_rho_target = np.log10(self.rho_20_target)

        clean = np.logical_and(np.abs(self.log_rho - np.mean(self.log_rho)) <= 2*np.std(self.log_rho),
                               self.log_rho <= np.log10(rho_thr))

        models = [None] * N_models
        self.x_rho = np.linspace(np.min(self.log_rho),np.max(self.log_rho),101).reshape(-1,1)
        self.pdf = [None] * N_models
        self.AIC = np.zeros(N_models) * np.nan
        self.BIC = np.zeros(N_models) * np.nan

        self.max_comp_rho = min([N_models,np.sum(clean)])
        
        for i in range(self.max_comp_rho):
            models[i] = mixture.GaussianMixture(n_components=i+1,
                                                random_state=self.timestamp%(int(2**32))).fit(self.log_rho[clean].reshape(-1,1))
            self.pdf[i] = np.exp(models[i].score_samples(self.x_rho)).reshape(-1,1)
            self.AIC[i] = models[i].aic(self.log_rho[clean].reshape(-1,1))
            self.BIC[i] = models[i].bic(self.log_rho[clean].reshape(-1,1))

        if models[1] is not None:
            order = np.argsort(models[1].means_[:,0])

            responsibilities_smooth = (models[1].predict_proba(self.x_rho.reshape(-1, 1)))[:,order]
            self.pdf_individual = responsibilities_smooth * self.pdf[1]
            responsibilities_data = (models[1].predict_proba(self.log_rho.reshape(-1, 1)))[:,order]
            responsibilities_target = (models[1].predict_proba(self.log_rho_target.reshape(-1, 1)))[:,order]

            self.P_high = np.array(responsibilities_data[:,1]/(responsibilities_data[:,0]+responsibilities_data[:,1]))
            self.P_high[self.rho_20_t > rho_thr] = 1 # if rho>50 auto in high pop
            self.P_target = np.array(responsibilities_target[:,1]/(responsibilities_target[:,0]+
                                                                   responsibilities_target[:,1]))
            if self.rho_20_target > rho_thr:
                self.P_target = 1.0

            self.is_high = self.P_high > 0.84
            self.is_low = self.P_high < 0.16
            self.is_ind = np.logical_and(self.P_high <= 0.84,self.P_high >= 0.16)

            self.P_1comp = scipy.stats.kstest(self.log_rho[clean],'norm',
                                              args=(models[0].means_[0,0],models[0].covariances_[0,0,0]))
        else:
            self.P_high = np.nan
            self.P_target = np.nan
            self.is_high = np.nan
            self.is_ind = np.nan
            self.is_low = np.nan
            self.P_1comp = np.nan
            
            
        return

    def plot_gaussian_mixture(self,N_comps_to_plot=4):


        if self.P_target != np.nan:
    
            fig, (ax1, ax2, ax3) = plt.subplots(1,3,figsize=[15,5])

            for i in range(min([self.max_comp_rho,N_comps_to_plot])):
                label = 'N={:2d}, ΔAIC = {:>8.2f},ΔBIC = {:>8.2f}'.format(i+1,self.AIC[i]-np.min(self.AIC),
                                                                          self.BIC[i]-np.min(self.BIC))
                ax1.plot(self.x_rho,self.pdf[i],label=label)
            ax1.hist(self.log_rho,density=True,color='k',alpha=0.5,bins=np.linspace(np.min(self.log_rho),
                                                                                    np.max(self.log_rho),26))
            ax1.plot([self.log_rho_target]*2,ax1.get_ylim(),'--k')
            ax1.legend(loc='upper left',fontsize='small')
            ax1.set_xlabel(r'$\log_{10} \rho$ [rescaled]')
            ax1.set_ylabel('pdf')

            ax2.plot(self.x_rho,self.pdf[0],label='N=1')
            ax2.plot(self.x_rho,self.pdf[1],label='N=2')
            ax2.plot(self.x_rho,self.pdf_individual[:,0],'b:',label='N=2 components')
            ax2.plot(self.x_rho,self.pdf_individual[:,1],'b:')
            ax2.hist(self.log_rho,density=True,color='k',alpha=0.5,bins=np.linspace(np.min(self.log_rho),
                                                                                    np.max(self.log_rho),26))
            ax2.plot([self.log_rho_target]*2,ax2.get_ylim(),'--k',label=self.name_short)
            ax2.set_xlabel(r'$\log_{10} \rho$ [rescaled]')
            ax2.set_ylabel('pdf')
            ax2.legend(fontsize='small')

            ax3.plot(self.log_rho,self.P_high,'o')
            ax3.plot([self.log_rho_target]*2,ax3.get_ylim(),'--k',label=self.name_short)
            ax3.set_xlabel(r'$\log_{10} \rho$ [rescaled]')
            ax3.set_ylabel('$P_\mathrm{high}$')

            plt.savefig(self.folder+'gaussian_mixture.pdf')
            plt.close()
        
        return

    def gauss_v(self):

        
        self.log_v = np.log10(self.sample_v).reshape(-1,1)
        self.log_v_target = np.log10(self.target_v)

        clean = np.abs(self.log_v - np.mean(self.log_v)) <= 2*np.std(self.log_v)

        models = [None] * N_models
        self.x_vel = np.linspace(np.min(self.log_v),np.max(self.log_v),101).reshape(-1,1)
        self.pdf_v = [None] * N_models
        self.AIC_v = np.zeros(N_models) * np.nan
        self.BIC_v = np.zeros(N_models) * np.nan

        self.max_comp_v = min([N_models,np.sum(clean)])
        
        for i in range(self.max_comp_v):
            models[i] = mixture.GaussianMixture(n_components=i+1,
                                                random_state=self.timestamp%(int(2**32))).fit(self.log_v[clean].reshape(-1,1))
            self.pdf_v[i] = np.exp(models[i].score_samples(self.x_vel)).reshape(-1,1)
            self.AIC_v[i] = models[i].aic(self.log_v[clean].reshape(-1,1))
            self.BIC_v[i] = models[i].bic(self.log_v[clean].reshape(-1,1))

        if models[1] is not None:

            order = np.argsort(models[1].means_[:,0])

            responsibilities_smooth = (models[1].predict_proba(self.x_vel.reshape(-1, 1)))[:,order]
            self.pdf_individual_v = responsibilities_smooth * self.pdf_v[1]
            responsibilities_data = (models[1].predict_proba(self.log_v.reshape(-1, 1)))[:,order]
            responsibilities_target = (models[1].predict_proba(self.log_v_target.reshape(-1, 1)))[:,order]

            self.P_high_v = np.array(responsibilities_data[:,1]/(responsibilities_data[:,0]+responsibilities_data[:,1]))
            self.P_target_v = np.array(responsibilities_target[:,1]/(responsibilities_target[:,0]+responsibilities_target[:,1]))

            self.is_high_v = self.P_high_v > 0.84
            self.is_low_v = self.P_high_v < 0.16
            self.is_ind_v = np.logical_and(self.P_high_v <= 0.84,self.P_high_v >= 0.16)
            self.P_1comp_v = scipy.stats.kstest(self.log_v[clean],'norm',
                                                args=(models[0].means_[0,0],models[0].covariances_[0,0,0]))

        else:
            self.P_high_v = np.nan
            self.P_target_v = np.nan
            self.is_high_v = np.nan
            self.is_ind_v = np.nan
            self.is_low_v = np.nan
            self.P_1comp_v = np.nan
 
        return

    def plot_gaussian_mixture_v(self,N_comps_to_plot=4):

        if N_comps_to_plot is None:
            N_comps_to_plot = 4
            
        if self.P_target != np.nan:
            
            fig, (ax1, ax2, ax3) = plt.subplots(1,3,figsize=[15,5])

            for i in range(min([self.max_comp_rho,N_comps_to_plot])):
                label = 'N={:2d}, ΔAIC = {:>8.2f}, ΔBIC = {:>8.2f}'.format(i+1,self.AIC_v[i]-np.min(self.AIC_v),
                                                                           self.BIC_v[i]-np.min(self.BIC_v))
                ax1.plot(self.x_vel,self.pdf_v[i],label=label)
            ax1.hist(self.log_v,density=True,color='k',alpha=0.5,bins=np.linspace(np.min(self.log_v),
                                                                                  np.max(self.log_v),26))
            ax1.plot([self.log_v_target]*2,ax1.get_ylim(),'--k')
            ax1.legend(loc='upper left',fontsize='small')
            ax1.set_xlabel(r'$\log_{10} |\mathbf{v}|$ [km/s]')
            ax1.set_ylabel('pdf')

            ax2.plot(self.x_vel,self.pdf_v[0],label='N=1')
            ax2.plot(self.x_vel,self.pdf_v[1],label='N=2')
            ax2.plot(self.x_vel,self.pdf_individual_v[:,0],'b:',label='N=2 components')
            ax2.plot(self.x_vel,self.pdf_individual_v[:,1],'b:')
            ax2.hist(self.log_v,density=True,color='k',alpha=0.5,bins=np.linspace(np.min(self.log_v),
                                                                                  np.max(self.log_v),26))
            ax2.plot([self.log_v_target]*2,ax2.get_ylim(),'--k',label=self.name_short)
            ax2.set_xlabel(r'$\log_{10} |\mathbf{v}|$ [km/s]')
            ax2.set_ylabel('pdf')
            ax2.legend(fontsize='small')

            ax3.plot(self.log_v,self.P_high_v,'o')
            ax3.plot([self.log_v_target]*2,ax3.get_ylim(),'--k',label=self.name_short)
            ax3.set_xlabel(r'$\log_{10} |\mathbf{v}|$ [km/s]')
            ax3.set_ylabel('$P_\mathrm{high}$')

            plt.savefig(self.folder+'gaussian_mixture_v.pdf')
            plt.close()
        
        return

    def sample_D_phys_D_M_rho(self):

        x = self.dist_target['D_phys'][self.sample]
        y = self.d_20
        total_bins = 20

        bins = np.linspace(x.min(),x.max(), total_bins)
        delta = bins[1]-bins[0]
        idx  = np.digitize(x,bins)
        running_median = [np.median(y[idx==k]) for k in range(total_bins)]

        fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(2,3,figsize=[15,8])

        ax1.scatter(x[self.is_high],y[self.is_high],c='r',alpha=0.2,label='P_high>0.84')
        ax1.scatter(x[self.is_ind],y[self.is_ind],c='k',alpha=0.2,label='0.16<=P_high<=0.84')
        ax1.scatter(x[self.is_low],y[self.is_low],c='b',alpha=0.2,label='P_high<0.16')
        ax1.plot(bins-delta/2,running_median,c='k',label='running median')
        ax1.plot([0,40],[self.d_20_target,self.d_20_target],'k--',label=self.name_short+' value')
        ax1.set_xlabel('distance to '+self.name_short+' [pc]')
        ax1.set_ylabel('$D_\mathrm{M}$ to 20th nearest neighbour')
        ax1.legend()

        ax2.scatter(x[self.is_high],y[self.is_high],c='r',alpha=0.2,label='P_high>0.84')
        ax2.scatter(x[self.is_ind],y[self.is_ind],c='k',alpha=0.2,label='0.16<=P_high<=0.84')
        ax2.scatter(x[self.is_low],y[self.is_low],c='b',alpha=0.2,label='P_high<0.16')
        ax2.set_ylim([0.4,1.1])
        ax2.plot(bins-delta/2,running_median,c='k')
        ax2.plot([0,40],[self.d_20_target,self.d_20_target],'k--')
        ax2.set_xlabel('distance to '+self.name_short+' [pc]')
        ax2.set_ylabel('$D_\mathrm{M}$ to 20th nearest neighbour')

        ax3.scatter(x[self.is_high],y[self.is_high],c='r',alpha=0.2,label='P_high>0.84')
        ax3.scatter(x[self.is_ind],y[self.is_ind],c='k',alpha=0.2,label='0.16<=P_high<=0.84')
        ax3.scatter(x[self.is_low],y[self.is_low],c='b',alpha=0.2,label='P_high<0.16')
        ax3.set_yscale('log')
        ax3.plot(bins-delta/2,running_median,c='k')
        ax3.plot([0,40],[self.d_20_target,self.d_20_target],'k--')
        ax3.set_xlabel('distance to '+self.name_short+' [pc]')
        ax3.set_ylabel('$D_\mathrm{M}$ to 20th nearest neighbour')

        y = self.rho_20_t
        running_median = [np.median(y[idx==k]) for k in range(total_bins)]

        ax4.scatter(x[self.is_high],y[self.is_high],c='r',alpha=0.2,label='P_high>0.84')
        ax4.scatter(x[self.is_ind],y[self.is_ind],c='k',alpha=0.2,label='0.16<=P_high<=0.84')
        ax4.scatter(x[self.is_low],y[self.is_low],c='b',alpha=0.2,label='P_high<0.16')
        ax4.plot(bins-delta/2,running_median,c='k',label='running median')
        ax4.plot([0,40],[self.rho_20_target,self.rho_20_target],'k--',label=self.name_short+' value')
        ax4.set_xlabel('distance to '+self.name_short+' [pc]')
        ax4.set_ylabel(r'$\rho_{20}$ [rescaled]')
        ax4.legend()

        ax5.scatter(x[self.is_high],y[self.is_high],c='r',alpha=0.2,label='P_high>0.84')
        ax5.scatter(x[self.is_ind],y[self.is_ind],c='k',alpha=0.2,label='0.16<=P_high<=0.84')
        ax5.scatter(x[self.is_low],y[self.is_low],c='b',alpha=0.2,label='P_high<0.16')
        ax5.set_ylim([0.,10.])
        ax5.plot(bins-delta/2,running_median,c='k')
        ax5.plot([0,40],[self.rho_20_target,self.rho_20_target],'k--')
        ax5.set_xlabel('distance to '+self.name_short+' [pc]')
        ax5.set_ylabel(r'$\rho_{20}$ [rescaled]')

        ax6.scatter(x[self.is_high],y[self.is_high],c='r',alpha=0.2,label='P_high>0.84')
        ax6.scatter(x[self.is_ind],y[self.is_ind],c='k',alpha=0.2,label='0.16<=P_high<=0.84')
        ax6.scatter(x[self.is_low],y[self.is_low],c='b',alpha=0.2,label='P_high<0.16')
        ax6.set_yscale('log')
        ax6.plot(bins-delta/2,running_median,c='k')
        ax6.plot([0,40],[self.rho_20_target,self.rho_20_target],'k--')
        ax6.set_xlabel('distance to '+self.name_short+' [pc]')
        ax6.set_ylabel(r'$\rho_{20}$ [rescaled]')

        plt.savefig(self.folder+'sample_D_phys_D_M_rho.pdf')
        plt.close()
        
        return

    def sample_D_phys_D_M_rho_1panel(self):
        
        if self.dist_target is not None:
            x = self.dist_target['D_phys'][self.sample]
        else:
            self.get_dist_target()
            x = self.dist_target['D_phys'][self.sample]
            
        total_bins = 20

        bins = np.linspace(x.min(),x.max(), total_bins)
        delta = bins[1]-bins[0]
        idx  = np.digitize(x,bins)

        plt.figure(figsize=[5,4])

        y = self.rho_20_t
        running_median = [np.median(y[idx==k]) for k in range(total_bins)]

        plt.scatter(x[self.is_high],y[self.is_high],c='r',alpha=0.2,label='$P_\mathrm{high}>0.84$')
        plt.scatter(x[self.is_ind],y[self.is_ind],c='k',alpha=0.2,label='$0.16\leq P_\mathrm{high}\leq0.84$')
        plt.scatter(x[self.is_low],y[self.is_low],c='b',alpha=0.2,label='$P_\mathrm{high}<0.16$')
        plt.yscale('log')
        plt.plot(bins-delta/2,running_median,c='k',label='Running median')
        plt.plot([0,40],[self.rho_20_target,self.rho_20_target],'k--',label=self.name_short+' value')
        plt.xlabel('distance to '+self.name_short+' [pc]')
        plt.ylabel(r'$\rho_{20}$ [rescaled]')
        plt.legend()

        plt.savefig(self.folder+'sample_D_phys_D_M_rho_1panel.pdf')
        plt.close()
        
        return

    def sample_Delta_v_D_M_rho(self):
        
        x = self.dist_target['D_vel'][self.sample]
        y = self.d_20
        total_bins = 40

        bins = np.linspace(x.min(),x.max(), total_bins)
        delta = bins[1]-bins[0]
        idx  = np.digitize(x,bins)
        running_median = [np.median(y[idx==k]) for k in range(total_bins)]

        fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(2,3,figsize=[15,8])

        ax1.scatter(x[self.is_high],y[self.is_high],c='r',alpha=0.2,label='P_high>0.84')
        ax1.scatter(x[self.is_ind],y[self.is_ind],c='k',alpha=0.2,label='0.16<=P_high<=0.84')
        ax1.scatter(x[self.is_low],y[self.is_low],c='b',alpha=0.2,label='P_high<0.16')
        ax1.plot(bins-delta/2,running_median,c='k',label='running median')
        ax1.plot([0,300],[self.d_20_target,self.d_20_target],'k--',label=self.name_short+' value')
        ax1.set_xlabel('$|\Delta v|$ from '+self.name_short+' [km/s]')
        ax1.set_ylabel('$D_\mathrm{M}$ to 20th nearest neighbour')
        ax1.legend()

        ax2.scatter(x[self.is_high],y[self.is_high],c='r',alpha=0.2,label='P_high>0.84')
        ax2.scatter(x[self.is_ind],y[self.is_ind],c='k',alpha=0.2,label='0.16<=P_high<=0.84')
        ax2.scatter(x[self.is_low],y[self.is_low],c='b',alpha=0.2,label='P_high<0.16')
        ax2.set_xlim([0,100])
        ax2.set_ylim([0.4,1.1])
        ax2.plot(bins-delta/2,running_median,c='k')
        ax2.plot([0,100],[self.d_20_target,self.d_20_target],'k--')
        ax2.set_xlabel('$|\Delta v|$ from '+self.name_short+' [km/s]')
        ax2.set_ylabel('$D_\mathrm{M}$ to 20th nearest neighbour')

        ax3.scatter(x[self.is_high],y[self.is_high],c='r',alpha=0.2,label='P_high>0.84')
        ax3.scatter(x[self.is_ind],y[self.is_ind],c='k',alpha=0.2,label='0.16<=P_high<=0.84')
        ax3.scatter(x[self.is_low],y[self.is_low],c='b',alpha=0.2,label='P_high<0.16')
        ax3.set_xscale('log')
        ax3.set_yscale('log')
        ax3.plot(bins-delta/2,running_median,c='k')
        ax3.plot([0,300],[self.d_20_target,self.d_20_target],'k--')
        ax3.set_xlabel('$|\Delta v|$ from '+self.name_short+' [km/s]')
        ax3.set_ylabel('$D_\mathrm{M}$ to 20th nearest neighbour')

        y = self.rho_20_t
        running_median = [np.median(y[idx==k]) for k in range(total_bins)]

        ax4.scatter(x[self.is_high],y[self.is_high],c='r',alpha=0.2,label='P_high>0.84')
        ax4.scatter(x[self.is_ind],y[self.is_ind],c='k',alpha=0.2,label='0.16<=P_high<=0.84')
        ax4.scatter(x[self.is_low],y[self.is_low],c='b',alpha=0.2,label='P_high<0.16')
        ax4.plot(bins-delta/2,running_median,c='k',label='running median')
        ax4.plot([0,300],[self.rho_20_target,self.rho_20_target],'k--',label=self.name_short+' value')
        ax4.set_xlabel('$|\Delta v|$ from '+self.name_short+' [km/s]')
        ax4.set_ylabel(r'$\rho_{20}$ [rescaled]')
        ax4.legend()

        ax5.scatter(x[self.is_high],y[self.is_high],c='r',alpha=0.2,label='P_high>0.84')
        ax5.scatter(x[self.is_ind],y[self.is_ind],c='k',alpha=0.2,label='0.16<=P_high<=0.84')
        ax5.scatter(x[self.is_low],y[self.is_low],c='b',alpha=0.2,label='P_high<0.16')
        ax5.set_xlim([0,100])
        ax5.set_ylim([0,10])
        ax5.plot(bins-delta/2,running_median,c='k')
        ax5.plot([0,100],[self.rho_20_target,self.rho_20_target],'k--')
        ax5.set_xlabel('$|\Delta v|$ from '+self.name_short+' [km/s]')
        ax5.set_ylabel(r'$\rho_{20}$ [rescaled]')

        ax6.scatter(x[self.is_high],y[self.is_high],c='r',alpha=0.2,label='P_high>0.84')
        ax6.scatter(x[self.is_ind],y[self.is_ind],c='k',alpha=0.2,label='0.16<=P_high<=0.84')
        ax6.scatter(x[self.is_low],y[self.is_low],c='b',alpha=0.2,label='P_high<0.16')
        ax6.set_xscale('log')
        ax6.set_yscale('log')
        ax6.plot(bins-delta/2,running_median,c='k')
        ax6.plot([0,300],[self.rho_20_target,self.rho_20_target],'k--')
        ax6.set_xlabel('$|\Delta v|$ from '+self.name_short+' [km/s]')
        ax6.set_ylabel(r'$\rho_{20}$ [rescaled]')

        plt.savefig(self.folder+'sample_Delta_v_D_M_rho.pdf')
        plt.close()
        
        return

    def sample_abs_v_D_M_rho(self):

        x = self.sample_v
        x_t = self.target_v
        y = self.d_20
        total_bins = 40
        if self.P_target < 0.16:
            c_t = 'b'
        if self.P_target > 0.84:
            c_t = 'r'
        if self.P_target >= 0.16 and self.P_target <= 0.84:
            c_t = 'k'

        bins = np.linspace(x.min(),x.max(), total_bins)
        delta = bins[1]-bins[0]
        idx  = np.digitize(x,bins)
        running_median = [np.median(y[idx==k]) for k in range(total_bins)]

        fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(2,3,figsize=[15,8])

        ax1.scatter(x[self.is_high],y[self.is_high],c='r',alpha=0.2,label='P_high>0.84')
        ax1.scatter(x[self.is_ind],y[self.is_ind],c='k',alpha=0.2,label='0.16<=P_high<=0.84')
        ax1.scatter(x[self.is_low],y[self.is_low],c='b',alpha=0.2,label='P_high<0.16')
        ax1.scatter([x_t],[self.d_20_target],c=c_t,edgecolor='yellow',marker='*',label=self.name_short,zorder=9)
        ax1.plot(bins-delta/2,running_median,c='k',label='running median')
        ax1.plot([0,300],[self.d_20_target,self.d_20_target],'k--',label=self.name_short+' value')
        ax1.set_xlabel('$\sqrt{U^2 + V^2 + W^2}$ [km/s]')
        ax1.set_ylabel('$D_\mathrm{M}$ to 20th nearest neighbour')

        xmin = np.max((ax1.get_xlim()[0],0))
        xmax = ax1.get_xlim()[1]
        ymin = ax1.get_ylim()[0]
        ymax = ax1.get_ylim()[1]
        thin = Rectangle((xmin,ymin),v_thin-xmin,ymax-ymin,facecolor='g',alpha = 0.3,zorder=-10)
        thick = Rectangle((v_thick_min,ymin),v_thick_max-v_thick_min,ymax-ymin,facecolor='g',alpha=0.2,zorder=-10)
        halo = Rectangle((v_halo,ymin),xmax-v_halo,ymax-ymin,facecolor='g',alpha=0.1,zorder=-10)
        ax1.add_patch(thin)
        ax1.add_patch(thick)
        ax1.add_patch(halo)

        ax1.legend()

        ax2.scatter(x[self.is_high],y[self.is_high],c='r',alpha=0.2,label='P_high>0.84')
        ax2.scatter(x[self.is_ind],y[self.is_ind],c='k',alpha=0.2,label='0.16<=P_high<=0.84')
        ax2.scatter(x[self.is_low],y[self.is_low],c='b',alpha=0.2,label='P_high<0.16')
        ax2.scatter([x_t],[self.d_20_target],c=c_t,edgecolor='yellow',marker='*',label=self.name_short,zorder=9)
        ax2.set_xlim([0,100])
        ax2.set_ylim([0.4,1.1])
        ax2.plot(bins-delta/2,running_median,c='k')
        ax2.plot([0,100],[self.d_20_target,self.d_20_target],'k--')
        ax2.set_xlabel('$\sqrt{U^2 + V^2 + W^2}$ [km/s]')
        ax2.set_ylabel('$D_\mathrm{M}$ to 20th nearest neighbour')

        xmin = np.max((ax2.get_xlim()[0],0))
        xmax = ax2.get_xlim()[1]
        ymin = ax2.get_ylim()[0]
        ymax = ax2.get_ylim()[1]
        thin = Rectangle((xmin,ymin),v_thin-xmin,ymax-ymin,facecolor='g',alpha = 0.3,zorder=-10)
        thick = Rectangle((v_thick_min,ymin),xmax-v_thick_min,ymax-ymin,facecolor='g',alpha=0.2,zorder=-10)
        ax2.add_patch(thin)
        ax2.add_patch(thick)

        ax3.scatter(x[self.is_high],y[self.is_high],c='r',alpha=0.2,label='P_high>0.84')
        ax3.scatter(x[self.is_ind],y[self.is_ind],c='k',alpha=0.2,label='0.16<=P_high<=0.84')
        ax3.scatter(x[self.is_low],y[self.is_low],c='b',alpha=0.2,label='P_high<0.16')
        ax3.scatter([x_t],[self.d_20_target],c=c_t,edgecolor='yellow',marker='*',label=self.name_short,zorder=9)
        ax3.set_xscale('log')
        ax3.set_yscale('log')
        ax3.plot(bins-delta/2,running_median,c='k')
        ax3.plot([0,300],[self.d_20_target,self.d_20_target],'k--')
        ax3.set_xlabel('$\sqrt{U^2 + V^2 + W^2}$ [km/s]')
        ax3.set_ylabel('$D_\mathrm{M}$ to 20th nearest neighbour')

        xmin = np.max((ax3.get_xlim()[0],0))
        xmax = ax3.get_xlim()[1]
        ymin = ax3.get_ylim()[0]
        ymax = ax3.get_ylim()[1]
        thin = Rectangle((xmin,ymin),v_thin-xmin,ymax-ymin,facecolor='g',alpha = 0.3,zorder=-10)
        thick = Rectangle((v_thick_min,ymin),v_thick_max-v_thick_min,ymax-ymin,facecolor='g',alpha=0.2,zorder=-10)
        halo = Rectangle((v_halo,ymin),xmax-v_halo,ymax-ymin,facecolor='g',alpha=0.1,zorder=-10)
        ax3.add_patch(thin)
        ax3.add_patch(thick)
        ax3.add_patch(halo)

        y = self.rho_20_t
        running_median = [np.median(y[idx==k]) for k in range(total_bins)]

        ax4.scatter(x[self.is_high],y[self.is_high],c='r',alpha=0.2,label='P_high>0.84')
        ax4.scatter(x[self.is_ind],y[self.is_ind],c='k',alpha=0.2,label='0.16<=P_high<=0.84')
        ax4.scatter(x[self.is_low],y[self.is_low],c='b',alpha=0.2,label='P_high<0.16')
        ax4.scatter([x_t],[self.rho_20_target],c=c_t,edgecolor='yellow',marker='*',label=self.name_short,zorder=9)
        ax4.plot(bins-delta/2,running_median,c='k',label='running median')
        ax4.plot([0,300],[self.rho_20_target,self.rho_20_target],'k--',label=self.name_short+' value')
        ax4.set_xlabel('$\sqrt{U^2 + V^2 + W^2}$ [km/s]')
        ax4.set_ylabel(r'$\rho_{20}$ [rescaled]')


        xmin = np.max((ax4.get_xlim()[0],0))
        xmax = ax4.get_xlim()[1]
        ymin = ax4.get_ylim()[0]
        ymax = ax4.get_ylim()[1]
        thin = Rectangle((xmin,ymin),v_thin-xmin,ymax-ymin,facecolor='g',alpha = 0.3,zorder=-10)
        thick = Rectangle((v_thick_min,ymin),v_thick_max-v_thick_min,ymax-ymin,facecolor='g',alpha=0.2,zorder=-10)
        halo = Rectangle((v_halo,ymin),xmax-v_halo,ymax-ymin,facecolor='g',alpha=0.1,zorder=-10)
        ax4.add_patch(thin)
        ax4.add_patch(thick)
        ax4.add_patch(halo)

        ax4.legend()

        ax5.scatter(x[self.is_high],y[self.is_high],c='r',alpha=0.2,label='P_high>0.84')
        ax5.scatter(x[self.is_ind],y[self.is_ind],c='k',alpha=0.2,label='0.16<=P_high<=0.84')
        ax5.scatter(x[self.is_low],y[self.is_low],c='b',alpha=0.2,label='P_high<0.16')
        ax5.scatter([x_t],[self.rho_20_target],c=c_t,edgecolor='yellow',marker='*',label=self.name_short,zorder=9)
        ax5.set_xlim([0,100])
        ax5.set_ylim([0,10])
        ax5.plot(bins-delta/2,running_median,c='k')
        ax5.plot([0,100],[self.rho_20_target,self.rho_20_target],'k--')
        ax5.set_xlabel('$\sqrt{U^2 + V^2 + W^2}$ [km/s]')
        ax5.set_ylabel(r'$\rho_{20}$ [rescaled]')

        xmin = np.max((ax5.get_xlim()[0],0))
        xmax = ax5.get_xlim()[1]
        ymin = ax5.get_ylim()[0]
        ymax = ax5.get_ylim()[1]
        thin = Rectangle((xmin,ymin),v_thin-xmin,ymax-ymin,facecolor='g',alpha = 0.3,zorder=-10)
        thick = Rectangle((v_thick_min,ymin),xmax-v_thick_min,ymax-ymin,facecolor='g',alpha=0.2,zorder=-10)
        ax5.add_patch(thin)
        ax5.add_patch(thick)

        ax6.scatter(x[self.is_high],y[self.is_high],c='r',alpha=0.2,label='P_high>0.84')
        ax6.scatter(x[self.is_ind],y[self.is_ind],c='k',alpha=0.2,label='0.16<=P_high<=0.84')
        ax6.scatter(x[self.is_low],y[self.is_low],c='b',alpha=0.2,label='P_high<0.16')
        ax6.scatter([x_t],[self.rho_20_target],c=c_t,edgecolor='yellow',marker='*',label=self.name_short,zorder=9)
        ax6.set_xscale('log')
        ax6.set_yscale('log')
        ax6.plot(bins-delta/2,running_median,c='k')
        ax6.plot([0,300],[self.rho_20_target,self.rho_20_target],'k--')
        ax6.set_xlabel('$\sqrt{U^2 + V^2 + W^2}$ [km/s]')
        ax6.set_ylabel(r'$\rho_{20}$ [rescaled]')

        xmin = np.max((ax6.get_xlim()[0],0))
        xmax = ax6.get_xlim()[1]
        ymin = ax6.get_ylim()[0]
        ymax = ax6.get_ylim()[1]
        thin = Rectangle((xmin,ymin),v_thin-xmin,ymax-ymin,facecolor='g',alpha = 0.3,zorder=-10)
        thick = Rectangle((v_thick_min,ymin),v_thick_max-v_thick_min,ymax-ymin,facecolor='g',alpha=0.2,zorder=-10)
        halo = Rectangle((v_halo,ymin),xmax-v_halo,ymax-ymin,facecolor='g',alpha=0.1,zorder=-10)
        ax6.add_patch(thin)
        ax6.add_patch(thick)
        ax6.add_patch(halo)

        plt.savefig(self.folder+'sample_abs_v_D_M_rho.pdf')
        plt.close()
        
        return

    def sample_abs_v_D_M_rho_1panel(self):


        x = self.sample_v
        x_t = self.target_v
        total_bins = 40
        if self.P_target < 0.16:
            c_t = 'b'
        if self.P_target > 0.84:
            c_t = 'r'
        if self.P_target >= 0.16 and self.P_target <= 0.84:
            c_t = 'k'

        bins = np.linspace(x.min(),x.max(), total_bins)
        delta = bins[1]-bins[0]
        idx  = np.digitize(x,bins)

        fig = plt.figure(figsize=[5,4])

        ax = fig.gca()

        y = self.rho_20_t
        running_median = [np.median(y[idx==k]) for k in range(total_bins)]

        ax.scatter(x[self.is_high],y[self.is_high],c='r',alpha=0.2,label='$P_\mathrm{high}>0.84$')
        ax.scatter(x[self.is_ind],y[self.is_ind],c='k',alpha=0.2,label='$0.16\leq P_\mathrm{high}\leq0.84$')
        ax.scatter(x[self.is_low],y[self.is_low],c='b',alpha=0.2,label='$P_\mathrm{high}<0.16$')
        ax.scatter([x_t],[self.rho_20_target],c=c_t,edgecolor='yellow',marker='*',label=self.name_short,zorder=9)
        ax.set_xscale('log')
        ax.set_yscale('log')
        ax.plot(bins-delta/2,running_median,c='k',label='Running median')
        ax.plot([0,300],[self.rho_20_target,self.rho_20_target],'k--',label=self.name_short+' value')
        ax.set_xlabel('$\sqrt{U^2 + V^2 + W^2}$ [km/s]')
        ax.set_ylabel(r'$\rho_{20}$ [rescaled]')

        xmin = np.max((ax.get_xlim()[0],0))
        xmax = ax.get_xlim()[1]
        ymin = ax.get_ylim()[0]
        ymax = ax.get_ylim()[1]
        thin = Rectangle((xmin,ymin),v_thin-xmin,ymax-ymin,facecolor='g',alpha = 0.3,zorder=-10)
        thick = Rectangle((v_thick_min,ymin),v_thick_max-v_thick_min,ymax-ymin,facecolor='g',alpha=0.2,zorder=-10)
        halo = Rectangle((v_halo,ymin),xmax-v_halo,ymax-ymin,facecolor='g',alpha=0.1,zorder=-10)
        ax.add_patch(thin)
        ax.add_patch(thick)
        ax.add_patch(halo)
        ax.text(xmin+5,ymax/4,'thin disc',color='g')
        ax.text(v_thick_min,ymax/4,'thick disc',color='g')
        ax.text(v_halo,ymax/4,'halo',color='g')

        ax.legend(fontsize='small')

        plt.savefig(self.folder+'sample_abs_v_D_M_rho_1panel.pdf')
        plt.close()
        
        return

    def UVW_rho(self):

        fig, (ax1, ax2, ax3) = plt.subplots(1,3,figsize=[15,4])

        if self.P_target < 0.16:
            c_t = 'b'
        if self.P_target > 0.84:
            c_t = 'r'
        if self.P_target >= 0.16 and self.P_target <= 0.84:
            c_t = 'k'

        x = np.abs(self.data['U'][self.sample])
        x_t = np.abs(self.target['U'])
        y = self.rho_20_t
        total_bins = 40

        bins = np.linspace(x.min(),x.max(), total_bins)
        delta = bins[1]-bins[0]
        idx  = np.digitize(x,bins)
        running_median = [np.median(y[idx==k]) for k in range(total_bins)]

        ax1.scatter(x[self.is_high],y[self.is_high],c='r',alpha=0.2,label='P_high>0.84')
        ax1.scatter(x[self.is_ind],y[self.is_ind],c='k',alpha=0.2,label='0.16<=P_high<=0.84')
        ax1.scatter(x[self.is_low],y[self.is_low],c='b',alpha=0.2,label='P_high<0.16')
        ax1.scatter([x_t],[self.rho_20_target],c=c_t,edgecolor='yellow',marker='*',label=self.name_short,zorder=9)
        ax1.plot(bins-delta/2,running_median,c='k',label='running median')
        ax1.plot([0,300],[self.rho_20_target,self.rho_20_target],'k--',label=self.name_short+' value')
        ax1.set_xlabel('|U| [km/s]')
        ax1.set_ylabel('rho_20 [rescaled]')
        ax1.set_xscale('log')
        ax1.set_yscale('log')
        ax1.legend()

        x = np.abs(self.data['V'][self.sample])
        x_t = np.abs(self.target['V'])
        y = self.rho_20_t
        total_bins = 40

        bins = np.linspace(x.min(),x.max(), total_bins)
        delta = bins[1]-bins[0]
        idx  = np.digitize(x,bins)
        running_median = [np.median(y[idx==k]) for k in range(total_bins)]

        ax2.scatter(x[self.is_high],y[self.is_high],c='r',alpha=0.2,label='P_high>0.84')
        ax2.scatter(x[self.is_ind],y[self.is_ind],c='k',alpha=0.2,label='0.16<=P_high<=0.84')
        ax2.scatter(x[self.is_low],y[self.is_low],c='b',alpha=0.2,label='P_high<0.16')
        ax2.scatter([x_t],[self.rho_20_target],c=c_t,edgecolor='yellow',marker='*',label=self.name_short,zorder=9)
        ax2.plot(bins-delta/2,running_median,c='k',label='running median')
        ax2.plot([0,300],[self.rho_20_target,self.rho_20_target],'k--',label=self.name_short+' value')
        ax2.set_xlabel('|V| [km/s]')
        ax2.set_ylabel('rho_20 [rescaled]')
        ax2.set_xscale('log')
        ax2.set_yscale('log')

        x = np.abs(self.data['W'][self.sample])
        x_t = np.abs(self.target['W'])
        y = self.rho_20_t
        total_bins = 40

        bins = np.linspace(x.min(),x.max(), total_bins)
        delta = bins[1]-bins[0]
        idx  = np.digitize(x,bins)
        running_median = [np.median(y[idx==k]) for k in range(total_bins)]

        ax3.scatter(x[self.is_high],y[self.is_high],c='r',alpha=0.2,label='P_high>0.84')
        ax3.scatter(x[self.is_ind],y[self.is_ind],c='k',alpha=0.2,label='0.16<=P_high<=0.84')
        ax3.scatter(x[self.is_low],y[self.is_low],c='b',alpha=0.2,label='P_high<0.16')
        ax3.scatter([x_t],[self.rho_20_target],c=c_t,edgecolor='yellow',marker='*',label=self.name_short,zorder=9)
        ax3.plot(bins-delta/2,running_median,c='k',label='running median')
        ax3.plot([0,300],[self.rho_20_target,self.rho_20_target],'k--',label=self.name_short+' value')
        ax3.set_xlabel('|W| [km/s]')
        ax3.set_ylabel('rho_20 [rescaled]')
        ax3.set_xscale('log')
        ax3.set_yscale('log')

        plt.savefig(self.folder+'_sample_UVW_rho.pdf')
        plt.close()
        
        return


    def sample_position_D_M_rho(self):
        fig, (ax1, ax2, ax3) = plt.subplots(1,3,figsize=[15,4])

        if self.P_target < 0.16:
            c_t = 'b'
        if self.P_target > 0.84:
            c_t = 'r'
        if self.P_target >= 0.16 and self.P_target <= 0.84:
            c_t = 'k'
        
        points = ax1.scatter(self.data['u'][self.sample],self.data['v'][self.sample],c=np.log10(self.d_20),alpha=0.5)
        ax1.scatter(self.target['u'],self.target['v'],c=np.log10(self.d_20_target),edgecolor='k',marker='*')
        ax1.axis('equal')
        ax1.set_xlabel('u [pc]')
        ax1.set_ylabel('v [pc]')
        cbar = fig.colorbar(points,ax=ax1)
        cbar.ax.set_ylabel('log D_M to 20th nearest neighbour')

        points = ax2.scatter(self.data['u'][self.sample],self.data['v'][self.sample],c=np.log10(self.rho_20_t),
                             alpha=0.5)
        ax2.scatter(self.target['u'],self.target['v'],c=np.log10(self.rho_20_target),edgecolor='k',marker='*')
        ax2.axis('equal')
        ax2.set_xlabel('u [pc]')
        ax2.set_ylabel('v [pc]')
        cbar = fig.colorbar(points,ax=ax2)
        cbar.ax.set_ylabel('log rho_20 [rescaled]')

        ax3.scatter((self.data['u'][self.sample])[self.is_high],(self.data['v'][self.sample])[self.is_high],
                    c='r',label='P_high>0.84',alpha=0.5)
        ax3.scatter((self.data['u'][self.sample])[self.is_ind],(self.data['v'][self.sample])[self.is_ind],
                    c='k',label='0.16<=P_high<=0.84',alpha=0.5)
        ax3.scatter((self.data['u'][self.sample])[self.is_low],(self.data['v'][self.sample])[self.is_low],
                    c='b',label='P_high<0.16',alpha=0.5)
        ax3.scatter(self.target['u'],self.target['v'],c=c_t,edgecolor='yellow',marker='*')
        ax3.axis('equal')
        ax3.set_xlabel('u [pc]')
        ax3.set_ylabel('v [pc]')
        ax3.legend()

        plt.savefig(self.folder+'sample_position_D_M_rho.pdf')
        plt.close()
        
        return

    def sample_Toomre(self):
        fig, (ax1, ax2, ax3) = plt.subplots(1,3,figsize=[15,4])


        if self.P_target < 0.16:
            c_t = 'b'
        if self.P_target > 0.84:
            c_t = 'r'
        if self.P_target >= 0.16 and self.P_target <= 0.84:
            c_t = 'k'

        x = np.linspace(-400,400,1001)
        ax1.plot(x,np.sqrt(100**2-x**2),'k')
        ax1.plot(x,np.sqrt(200**2-x**2),'k')
        ax1.plot(x,np.sqrt(300**2-x**2),'k')
        ax1.plot(x,np.sqrt(400**2-x**2),'k')

        points = ax1.scatter(self.data['U'][self.sample],np.sqrt(self.data['V'][self.sample]**2+
                                                                 self.data['W'][self.sample]**2),
                             c=np.log10(self.d_20),alpha=0.5)
        ax1.scatter(self.target['U'],np.sqrt(self.target['V']**2+self.target['W']**2),
                    c=np.log10(self.d_20_target),edgecolor='yellow',
                    marker='*')
        ax1.set_xlabel('$U$ [km/s]')
        ax1.set_ylabel('$\sqrt{V^2+W^2} [km/s]')
        ax1.set_xlim(np.min(self.data['U'][self.sample])-10,np.max(self.data['U'][self.sample])+10)
        ax1.set_ylim([0,np.max(np.abs(ax1.get_xlim()))+50])
        ax1.set_aspect('equal')

        cbar = plt.colorbar(points, ax=ax1)
        cbar.ax.set_ylabel('$\log_{10} D_\mathrm{M}$ to 20th nearest neighbour')

        ax2.plot(x,np.sqrt(100**2-x**2),'k')
        ax2.plot(x,np.sqrt(200**2-x**2),'k')
        ax2.plot(x,np.sqrt(300**2-x**2),'k')
        ax2.plot(x,np.sqrt(400**2-x**2),'k')

        points = ax2.scatter(self.data['U'][self.sample],np.sqrt(self.data['V'][self.sample]**2+
                                                                 self.data['W'][self.sample]**2),
                             c=np.log10(self.rho_20_t),alpha=0.5)
        ax2.scatter(self.target['U'],np.sqrt(self.target['V']**2+self.target['W']**2),
                    c=np.log10(self.rho_20_target),edgecolor='yellow',
                    marker='*')
        ax2.set_xlabel('$U$ [km/s]')
        ax2.set_ylabel('$\sqrt{V^2+W^2}$ [km/s]')
        ax2.set_xlim(np.min(self.data['U'][self.sample])-10,np.max(self.data['U'][self.sample])+10)
        ax2.set_ylim([0,np.max(np.abs(ax2.get_xlim()))+50])
        ax2.set_aspect('equal')

        cbar = plt.colorbar(points, ax=ax2)
        cbar.ax.set_ylabel(r'$\log_{10} \rho_{20}$ [rescaled]')

        ax3.plot(x,np.sqrt(100**2-x**2),'k')
        ax3.plot(x,np.sqrt(200**2-x**2),'k')
        ax3.plot(x,np.sqrt(300**2-x**2),'k')
        ax3.plot(x,np.sqrt(400**2-x**2),'k')

        ax3.scatter((self.data['U'][self.sample])[self.is_low],
                    np.sqrt((self.data['V'][self.sample])[self.is_low]**2+
                            (self.data['W'][self.sample])[self.is_low]**2),
                    c='b',alpha=0.2,label='$P_\mathrm{high}<0.16$')
        ax3.scatter((self.data['U'][self.sample])[self.is_ind],
                    np.sqrt((self.data['V'][self.sample])[self.is_ind]**2+
                            (self.data['W'][self.sample])[self.is_ind]**2),
                    c='k',alpha=0.2,label='$0.16\leq P_\mathrm{high}\leq0.84$')
        ax3.scatter((self.data['U'][self.sample])[self.is_high],
                    np.sqrt((self.data['V'][self.sample])[self.is_high]**2+
                            (self.data['W'][self.sample])[self.is_high]**2),
                    c='r',alpha=0.2,label='$P_\mathrm{high}>0.84$')
        ax3.scatter(self.target['U'],np.sqrt(self.target['V']**2+self.target['W']**2),
                    c=c_t,edgecolor='yellow',marker='*')
        ax3.set_xlabel('$U$ [km/s]')
        ax3.set_ylabel('$\sqrt{V^2+W^2}$ [km/s]')
        ax3.set_xlim(np.min(self.data['U'][self.sample])-10,np.max(self.data['U'][self.sample])+10)
        ax3.set_ylim([0,np.max(np.abs(ax3.get_xlim()))+50])
        ax3.set_aspect('equal')
        ax3.legend()


        plt.savefig(self.folder+'sample_Toomre.pdf')
        plt.close()
        
        return

    def detrend_v(self):

        try:
            degree = 4
            x = np.log10(self.sample_v)
            x_t = np.log10(self.target_v)

            log_rho = np.log10(self.rho_20_t)
            log_rho_t = np.log10(self.rho_20_target)

        # ignore rho_t > 50 in the fit, so not biased by clusters    
            self.fit = np.polynomial.Polynomial.fit(x[self.rho_20_t < 50],log_rho[self.rho_20_t < 50],degree)
            self.residuals = log_rho - self.fit.__call__(x)

            self.residuals_t = log_rho_t - self.fit.__call__(x_t)
        except:
            self.residuals = np.zeros(self.N_sample) * np.nan
            self.residuals_t = np.nan
            
        return

    def plot_trend(self):
        
        plt.figure(figsize=[5,4])
        if self.P_target < 0.16:
            c_t = 'b'
        if self.P_target > 0.84:
            c_t = 'r'
        if self.P_target >= 0.16 and self.P_target <= 0.84:
            c_t = 'k'

        x = np.log10(self.sample_v)
        x_t = np.log10(self.target_v)
        log_rho = np.log10(self.rho_20_t)
        log_rho_t = np.log10(self.rho_20_target)

        xfit,yfit = self.fit.linspace(101,[np.min(x),np.max(x)])

        plt.plot(xfit,yfit,label='quartic trend',c='k')
        plt.scatter(x[self.is_high],log_rho[self.is_high],c='r',alpha=0.2,label='$P_\mathrm{high}>0.84$')
        plt.scatter(x[self.is_ind],log_rho[self.is_ind],c='k',alpha=0.2,
                     label='$0.16\leq P_\mathrm{high}\leq0.84$')
        plt.scatter(x[self.is_low],log_rho[self.is_low],c='b',alpha=0.2,label='$P_\mathrm{high}<0.16$')
        plt.scatter(x_t,log_rho_t,c=c_t,edgecolor='yellow',marker='*',zorder=9,label=self.name_short)
        plt.xlabel('$\log_{10} |\mathbf{v}|$ [km/s]')
        plt.ylabel(r'$\log_{10} \rho_{20}$')

        ax1 = plt.gca()
        xmin = np.max((ax1.get_xlim()[0],0))
        xmax = ax1.get_xlim()[1]
        ymin = ax1.get_ylim()[0]
        ymax = ax1.get_ylim()[1]
        thin = Rectangle((xmin,ymin),np.log10(v_thin)-xmin,ymax-ymin,facecolor='g',alpha = 0.3,zorder=-10)
        thick = Rectangle((np.log10(v_thick_min),ymin),np.log10(v_thick_max)-np.log10(v_thick_min),
                          ymax-ymin,facecolor='g',alpha=0.2,zorder=-10)
        halo = Rectangle((np.log10(v_halo),ymin),xmax-np.log10(v_halo),ymax-ymin,facecolor='g',alpha=0.1,zorder=-10)
        ax1.add_patch(thin)
        ax1.add_patch(thick)
        ax1.add_patch(halo)
        ax1.text(xmin+(xmax-xmin)*0.1,ymin+(ymax-ymin)*0.9,'thin disc',color='g')
        ax1.text(np.log10(v_thick_min),ymin+(ymax-ymin)*0.9,'thick disc',color='g')
        ax1.text(np.log10(v_halo),ymin+(ymax-ymin)*0.9,'halo',color='g')
        ax1.legend(fontsize='small',loc='lower left')

        plt.savefig(self.folder+'trend.pdf')
        plt.close()
        
    def plot_residuals(self):

        try:
            fig, (ax1, ax2) = plt.subplots(1,2,figsize=[10,4])

            if self.P_target < 0.16:
                c_t = 'b'
            if self.P_target > 0.84:
                c_t = 'r'
            if self.P_target >= 0.16 and self.P_target <= 0.84:
                c_t = 'k'

            x = np.log10(self.sample_v)
            x_t = np.log10(self.target_v)
            log_rho = np.log10(self.rho_20_t)
            log_rho_t = np.log10(self.rho_20_target)

            xfit,yfit = self.fit.linspace(101,[np.min(x),np.max(x)])

            ax1.plot(xfit,yfit,label='quartic trend',c='k')
            ax1.scatter(x[self.is_high],log_rho[self.is_high],c='r',alpha=0.2,label='$P_\mathrm{high}>0.84$')
            ax1.scatter(x[self.is_ind],log_rho[self.is_ind],c='k',alpha=0.2,
                        label='$0.16\leq P_\mathrm{high}\leq0.84$')
            ax1.scatter(x[self.is_low],log_rho[self.is_low],c='b',alpha=0.2,label='$P_\mathrm{high}<0.16$')
            ax1.scatter(x_t,log_rho_t,c=c_t,edgecolor='yellow',marker='*',zorder=9,label=self.name_short)
            ax1.set_xlabel('$\log_{10} |\mathbf{v}|$ [km/s]')
            ax1.set_ylabel(r'$\log_{10} \rho_{20}$')

            xmin = np.max((ax1.get_xlim()[0],0))
            xmax = ax1.get_xlim()[1]
            ymin = ax1.get_ylim()[0]
            ymax = ax1.get_ylim()[1]
            thin = Rectangle((xmin,ymin),np.log10(v_thin)-xmin,ymax-ymin,facecolor='g',alpha = 0.3,zorder=-10)
            thick = Rectangle((np.log10(v_thick_min),ymin),np.log10(v_thick_max)-np.log10(v_thick_min),
                              ymax-ymin,facecolor='g',alpha=0.2,zorder=-10)
            halo = Rectangle((np.log10(v_halo),ymin),xmax-np.log10(v_halo),ymax-ymin,facecolor='g',alpha=0.1,
                             zorder=-10)
            ax1.add_patch(thin)
            ax1.add_patch(thick)
            ax1.add_patch(halo)
            ax1.text(xmin+(xmax-xmin)*0.1,ymin+(ymax-ymin)*0.9,'thin disc',color='g')
            ax1.text(np.log10(v_thick_min),ymin+(ymax-ymin)*0.9,'thick disc',color='g')
            ax1.text(np.log10(v_halo),ymin+(ymax-ymin)*0.9,'halo',color='g')
            ax1.legend(fontsize='small')

            ax2.scatter(x[self.is_high],self.residuals[self.is_high],c='r',alpha=0.2)
            ax2.scatter(x[self.is_ind],self.residuals[self.is_ind],c='k',alpha=0.2)
            ax2.scatter(x[self.is_low],self.residuals[self.is_low],c='b',alpha=0.2)
            ax2.scatter(x_t,self.residuals_t,c=c_t,edgecolor='yellow',marker='*',label=self.name_short)
            ax2.set_xlabel('$\log_{10} |\mathbf{v}|$ [km/s]')
            ax2.set_ylabel('residuals')

            xmin = np.max((ax2.get_xlim()[0],0))
            xmax = ax2.get_xlim()[1]
            ymin = ax2.get_ylim()[0]
            ymax = ax2.get_ylim()[1]
            thin = Rectangle((xmin,ymin),np.log10(v_thin)-xmin,ymax-ymin,facecolor='g',alpha = 0.3,zorder=-10)
            thick = Rectangle((np.log10(v_thick_min),ymin),np.log10(v_thick_max)-np.log10(v_thick_min),
                              ymax-ymin,facecolor='g',alpha=0.2,zorder=-10)
            halo = Rectangle((np.log10(v_halo),ymin),xmax-np.log10(v_halo),ymax-ymin,facecolor='g',
                             alpha=0.1,zorder=-10)
            ax2.add_patch(thin)
            ax2.add_patch(thick)
            ax2.add_patch(halo)
            ax2.text(xmin+(xmax-xmin)*0.1,ymin+(ymax-ymin)*0.9,'thin disc',color='g')
            ax2.text(np.log10(v_thick_min),ymin+(ymax-ymin)*0.9,'thick disc',color='g')
            ax2.text(np.log10(v_halo),ymin+(ymax-ymin)*0.9,'halo',color='g')
            ax2.legend(fontsize='small',loc='lower left')

            plt.savefig(self.folder+'residuals.pdf')
            plt.close()

        except:
            pass
        
        return

    def get_ranks(self):

        v_abs = self.sample_v
        v_t_abs = self.target_v

        similar = np.logical_and(v_abs < v_t_abs*v_factor,v_abs > v_t_abs/v_factor)
        self.N_sim = sum(similar)

        self.rank = sum(self.rho_20_t > self.rho_20_target) + 1
        self.rank_sim = sum(self.rho_20_t[similar] > self.rho_20_target) + 1
        self.rank_detrended = sum(self.residuals > self.residuals_t) + 1
        
        
        temp = np.argsort(self.rho_20_t)[::-1]
        self.rank_all = np.empty_like(temp)
        self.rank_all[temp] = np.arange(len(self.rho_20_t))

        temp = np.argsort(self.residuals)[::-1]
        self.rank_detrended_all = np.empty_like(temp)
        self.rank_detrended_all[temp] = np.arange(len(self.rho_20_t))

        return

    def plot_ranks(self):
        
        plt.figure(figsize=[5,4])
        plt.scatter(self.rank_all/self.N_sample,self.rank_detrended_all/self.N_sample)
        plt.xlabel('density fractional rank')
        plt.ylabel('residuals fractional rank')
        plt.title('Neighbours of '+self.name_short)
        plt.savefig(self.folder+'ranks.pdf')
        plt.close()
        
        return

    def write_densities(self):

        print("{:<20s} ranks {:>4d} of {:>4d} stars in decreasing density".format(self.name_short,
                                                                                  self.rank,self.N_sample+1))
        print("{:<20s} ranks {:>4d} of {:>4d} stars in decreasing residuals".format(self.name_short,
                                                                                    self.rank_detrended,
                                                                                    self.N_sample+1))
        print("{:<20s} ranks {:>4d} of {:>4d} stars with |v| within {:5f}".format(self.name_short,
                                                                                  self.rank_sim,self.N_sim+1,v_factor))

        with open(self.folder+'densities.txt','w') as f:
            print("Saving...")
            print("{:<20s} ranks {:>4d} of {:>4d} stars in decreasing density".format(self.name_short,
                                                                                      self.rank,self.N_sample+1),
                  file=f)
            print("{:<20s} ranks {:>4d} of {:>4d} stars in decreasing residuals".format(self.name_short,
                                                                                        self.rank_detrended,
                                                                                        self.N_sample+1),file=f)

            print("{:<20s} ranks {:>4d} of {:>4d} stars with |v| within {:5f}".format(self.name_short,
                                                                                      self.rank_sim,self.N_sim+1,
                                                                                      v_factor),file=f)
            print("\n",file=f)
            print(("{:^30s}|{:^9s}|{:^9s}"+"|{:^10s}"*6).format("Gaia id","rho","D","u","v","w","U","V","W"),file=f)
            print(("{:^30s}|{:^9s}|{:^9s}"+"|{:^10s}"*6).format("","","","[pc]","[pc]","[pc]",
                                                                 "[km/s]","[km/s]","[km/s]"),file=f)
            print('-'*120,file=f)
            print(("{:<30s}|{:^9.3e}|{:^9.3e}"+"|{:>10.3e}"*6).format(self.data["designation"][self.i_target],
                                                                      self.rho_20_target,self.d_20_target,
                                                                      self.data["u"][self.i_target],
                                                                      self.data["v"][self.i_target],
                                                                      self.data["w"][self.i_target],
                                                                      self.data["U"][self.i_target],
                                                                      self.data["V"][self.i_target],
                                                                      self.data["W"][self.i_target]),
                  file=f)
            for i in range(self.N_sample):
                print(("{:<30s}|{:^9.3e}|{:^9.3e}"+"|{:>10.3e}"*6).format(self.data["designation"][self.sample[i]],
                                                                          self.rho_20_t[i],self.d_20[i],
                                                                          self.data["u"][self.sample[i]],
                                                                          self.data["v"][self.sample[i]],
                                                                          self.data["w"][self.sample[i]],
                                                                          self.data["U"][self.sample[i]],
                                                                          self.data["V"][self.sample[i]],
                                                                          self.data["W"][self.sample[i]]),
                      file=f)
                
        return
    
    def free_mem(self):
        self.dist_target = None
        self.pos_6D = None
        self.closest_target = None

In [None]:
# run this cell if you just want Sol
#targets = ['Sol']
#dr2id = ['Sol']

In [None]:
include_Sol = True

#planets_table = ascii.read('PS_2021.02.24_07.17.03.csv')
#planets_table = ascii.read('PS_2021.03.08_07.22.24.csv')
planets_table = ascii.read('PS_2021.03.11_04.56.35.csv')

#fix "Qatar-n" -> "Qatar n"
planets_table['hostname'] = [t.replace("Qatar-","Qatar ") for t in planets_table['hostname']]
#Praesepe: "Prnnnn" isn't a catalogue in Simbad, and I can't find it's published anywhere
#exoplanet.eu ids Pr0201 as BD+20 2184 but no id for Pr0211...
planets_table['hostname'] = [t.replace("Pr0201","BD+20 2184") for t in planets_table['hostname']]
#HIP 65A is just HIP 65 in Simbad
planets_table['hostname'] = [t.replace("HIP 65 A","HIP 65") for t in planets_table['hostname']]

targets = [p['hostname'] for p in planets_table]

#remove duplicate hosts
#might change order at this point
targets = list(set(targets))

restored_from_file = False
if len(targets) > 0:
    try:
        Simbad.reset_votable_fields()
        Simbad.add_votable_fields('typed_id')
        Simbad.add_votable_fields('ids')
        result = Simbad.query_objects(targets)
        dr2id = []
        for i in range(len(targets)):
            match = [r for r in result if r['TYPED_ID'] == targets[i]]
            if len(match) == 1:
                ids = result[i]['IDS'].split('|')
                match2 = [s for s in ids if 'Gaia DR2' in s]
                if len(match2) == 1:
                    dr2id.append(match2[0])
                if len(match2) == 0:
                    dr2id.append(None)
            if len(match) == 0:
                dr2id.append(None)
    except:
#if Simbad is down/inaccessible, try restoring ids from an old log file
        logfile = 'log/xmatch_ids_1616414978.3437288.txt'
        print('Access Simbad failed: restoring from '+logfile)
        xmatch = ascii.read(logfile,delimiter=',',data_start=1,format='csv')
        restored_from_file = True
        dr2id = []
        dr3id = []
        for t in targets:
            ind = np.where(xmatch['target'] == t)
            dr2id.append(xmatch['dr2id'][ind][0])
            dr3id.append(xmatch['dr3id'][ind][0])
        dr2id[dr2id == 'None'] = None
        dr3id[dr3id == 'None'] = None

if include_Sol:
    if targets[-1] != 'Sol':
        targets.append('Sol')
        dr2id.append('Sol')

for i in range(len(targets)):
    print(targets[i],dr2id[i])


In [None]:
if not restored_from_file:
    dr2table = Table([[int(d[9:]) for d in dr2id[:-1] if d is not None]],names=['dr2_source_id'])

    job = Gaia.launch_job("SELECT gaia.dr2_source_id, gaia.dr3_source_id "
                             "FROM gaiaedr3.dr2_neighbourhood AS gaia "
                             "INNER JOIN tap_upload.table "
                             "ON gaia.dr2_source_id = tap_upload.table.dr2_source_id",
                          upload_resource=dr2table[0:2000],upload_table_name='table')

    tmp = job.get_results()

    job = Gaia.launch_job("SELECT gaia.dr2_source_id, gaia.dr3_source_id "
                             "FROM gaiaedr3.dr2_neighbourhood AS gaia "
                             "INNER JOIN tap_upload.table "
                             "ON gaia.dr2_source_id = tap_upload.table.dr2_source_id",
                          upload_resource=dr2table[2000:],upload_table_name='table')

    tmp2 = job.get_results()


In [None]:
if not restored_from_file:
    xmatch = vstack([tmp,tmp2])
    duplicates = [x for x, y in groupby(sorted(xmatch['dr2_source_id'])) if len(list(y)) > 1]

    dr2_list = list(xmatch['dr2_source_id'])
    dr3_list = list(xmatch['dr3_source_id'])

    dr3id = [None] * (len(dr2id) - 1)

    for i in range(len(dr2id)-1):
        if dr2id[i] is None:
            dr3id[i] = None
        else:
            match_3id = xmatch[xmatch['dr2_source_id'] == int(dr2id[i][9:])]['dr3_source_id']
            match = np.where(np.array([data_all['source_id'] == x for x in match_3id]).any(axis=0))
            if len(match[0] == 0):
                dr3id[i] = None
            if len(match[0] >= 1):
    # check the G mags
                mag_thresh = 1.0
                this_mag = data_all[match]['phot_g_mean_mag']
                table_mag = np.mean(planets_table[planets_table['hostname'] == targets[i]]['sy_gaiamag'])
                mag_match = np.abs(this_mag - table_mag) <= mag_thresh
                if np.sum(mag_match) == 0:
                    dr3id[i] = None
                if np.sum(mag_match) == 1:
                    dr3id[i] = 'Gaia EDR3 '+str(data_all['source_id'][match[0][0]])
                if np.sum(mag_match) > 1:
                    print(i,dr2id[i],len(match_3id),len(match[0]))
                    print(this_mag,table_mag,mag_match)


if include_Sol:
    if dr3id[-1] != 'Sol':
        dr3id.append('Sol')

logdir = 'log'
if not os.path.exists(logdir):
    os.mkdir(logdir)


file = logdir+'/xmatch_ids_'+str(time.time())+'.txt'
with open(file,'w') as f:
    print('target,    dr2id,    dr3id',file=f)
    for i in range(len(targets)):
        print(targets[i],',',dr2id[i],',',dr3id[i],file=f)

In [None]:
n_targets = len(targets)

stars = []

# loop over targets. Functions making plots are commented out

for i in range(n_targets):
    print(str(i)+' of '+str(n_targets))
    target = Target(targets[i],dr3id[i])
    target.get_neighbours()
    if target.data is not None and target.N_stars >= N_stars_min:
        #target.distance_histograms()
        #target.distance_histograms_fine()
        #target.magnitude_histograms()
        #target.parallax_error_histograms()
        #target.distance_Gmag()
        #target.distance_M_G()
        #target.CMD()
        #target.CMD_hist()
        #target.RV_histograms()
        #target.X_Y()
        #target.Toomre()
        target.get_pos_6D()
        target.get_dist_target()
        #target.D_M_histograms()
        #target.Delta_v_histograms()
        #target.D_phys_D_M()
        #target.Delta_v_D_M()
        #target.D_phys_Delta_v()
        target.get_close_target()
        target.get_lt_40pc()
        target.set_seed()
        target.get_random_sample()
        target.get_sample_distances()
        #target.sample_D_M_rho_histograms()
        target.gauss()
#        target.plot_gaussian_mixture()
        target.gauss_v()
        #target.plot_gaussian_mixture_v()
        #target.sample_D_phys_D_M_rho()
        #target.sample_D_phys_D_M_rho_1panel()
        #target.sample_Delta_v_D_M_rho()
        #target.sample_abs_v_D_M_rho()
        #target.sample_abs_v_D_M_rho_1panel()
        #target.UVW_rho()
        #target.sample_position_D_M_rho()
        #target.sample_Toomre()
        target.detrend_v()
        #target.plot_residuals()
        target.get_ranks()
        target.write_densities()
        target.free_mem()
    stars.append(target)

In [None]:
# based on matplotlib scatter–histogram example
def scatter_hist(x, y, ax, ax_histx, ax_histy, minx, maxx, miny, maxy, col, label, x_c_scale=1, y_c_scale=1):
    # no labels
    ax_histx.tick_params(axis="x", labelbottom=False)
    ax_histy.tick_params(axis="y", labelleft=False)

    # the scatter plot:
    ax.plot(x, y, '.', alpha=1.0, c=col, label=label, markersize=1)

    # now determine nice limits by hand:
#    binwidth = 0.25
#    xymax = max(np.max(np.abs(x)), np.max(np.abs(y)))
#    lim = (int(xymax/binwidth) + 1) * binwidth
    xbinwidth = 0.05
    ybinwidth = 0.1
    xbins = np.arange(minx, maxx + xbinwidth, xbinwidth)
    ybins = np.arange(miny, maxy + ybinwidth, ybinwidth)
    xbinsfine = np.linspace(minx, maxx + xbinwidth, 1001)
    ybinsfine = np.linspace(miny, maxy + ybinwidth, 1001)
    
    #have to generate the cumulative histograms separately in order to rescale them
    x_h = np.histogram(x,xbinsfine,density=True)[0]*(maxx-minx)/(len(xbinsfine)-1) * x_c_scale
    x_c = np.zeros(len(xbinsfine))
    for i in range(len(xbinsfine)-1):
        x_c[i+1] = x_c[i] + x_h[i]
    
    y_h = np.histogram(y,ybinsfine,density=True)[0]*(maxy-miny)/(len(ybinsfine)-1) * y_c_scale
    y_c = np.zeros(len(ybinsfine))
    for i in range(len(ybinsfine)-1):
        y_c[i+1] = y_c[i] + y_h[i]

    ax_histx.hist(x, density=True, alpha=0.5, bins=xbins, color=col)
    ax_histx.hist(x, density=True, alpha=1.0, bins=xbins, color=col, histtype='step')
#    ax_histx.hist(x, density=True, cumulative=True, histtype='step', color=col, bins=xbinsfine)
    ax_histx.step(xbinsfine,x_c,color=col)
    ax_histy.hist(y, orientation='horizontal', density = True, alpha=0.5, bins=ybins, color=col)
    ax_histy.hist(y, orientation='horizontal', density = True, alpha=1.0, bins=ybins, color=col,histtype='step')
#    ax_histy.hist(y, orientation='horizontal', density=True, cumulative=True, histtype='step', color=col,
#                  bins=ybinsfine)
    ax_histy.step(y_c,ybinsfine,color=col)


class Sample:
    
    def __init__(self,name,stars,plot_col='k'):
        
        self.name = name
        self.stars = stars
        self.plot_col = plot_col
        self.get_moments()
        
    def get_moments(self):
        
        self.mean_u = np.nanmean([s.target['U'] for s in self.stars])
        self.mean_v = np.nanmean([s.target['V'] for s in self.stars])
        self.mean_w = np.nanmean([s.target['W'] for s in self.stars])
        self.D_uu = np.nanmean([(s.target['U'] - self.mean_u)**2 for s in self.stars])
        self.D_vv = np.nanmean([(s.target['V'] - self.mean_v)**2 for s in self.stars])
        self.D_ww = np.nanmean([(s.target['W'] - self.mean_w)**2 for s in self.stars])
        self.sigma2 = self.D_uu + self.D_vv + self.D_ww
# Hamer & Schlaufman definition
        self.sig_HS = np.nanmean([np.sqrt((s.target['U'] - self.mean_u)**2 + 
                                          (s.target['V'] - self.mean_v)**2 +
                                          (s.target['W'] - self.mean_w)**2) for s in self.stars])
        self.RMS = np.sqrt(np.nanmean([(s.target['U'])**2 + 
                                       (s.target['V'])**2 +
                                       (s.target['W'])**2 for s in self.stars]))
        self.median_abs_v = np.nanmedian([np.sqrt((s.target['U'] - self.mean_u)**2 + 
                                                  (s.target['V'] - self.mean_v)**2 +
                                                  (s.target['W'] - self.mean_w)**2) for s in self.stars])

In [None]:
file = logdir+'/flags'+str(time.time())+'.txt'
with open(file,'w') as f:
    print('name,   good_mass,   hasHJ,    hasCJ,    P_1comp,  P_1compv',file=f)
    for s in stars:
    
        HJacut = 0.2
        s.planets = planets_table[planets_table['hostname'] == s.name_short]

        s.Ms_in_0720 = np.ma.masked_array(((s.planets['st_mass'] >= 0.7).any() and 
                                          (s.planets['st_mass'] <= 2.0).any()),dtype='bool').filled(False)
        s.t_in_1045 = np.ma.masked_array(((s.planets['st_age'] >= 1.0).any() and 
                                         (s.planets['st_age'] <= 4.5).any()),dtype='bool').filled(False)
        s.HJ = np.ma.masked_array((np.logical_and(s.planets['pl_bmasse'] >= 50,
                                                  np.logical_or(s.planets['pl_orbsmax'] <= HJacut,
                                                                ((s.planets['pl_orbper']/365.25)**2 * 
                                                                 s.planets['st_mass'])**(1/3) <= HJacut))).any(),
                                  dtype='bool').filled(False)
        s.CJ = np.ma.masked_array((np.logical_and(s.planets['pl_bmasse'] >= 50,
                                                  np.logical_and(s.planets['pl_orbsmax'] > HJacut,
                                                                 ((s.planets['pl_orbper']/365.25)**2 * 
                                                                  s.planets['st_mass'])**(1/3) > HJacut))).any(),
                                  dtype='bool').filled(False)
        print(s.name_short,',',s.Ms_in_0720,',',s.HJ,',',s.CJ,',',s.P_1comp[1],',',s.P_1comp_v[1],file=f)


In [None]:
from distutils.util import strtobool
def compare_flags(file):
    comparison = ascii.read(comp_file,delimiter=',',format='csv',data_start=1)

    for s in stars:
        match = comparison[comparison['name'] == s.name_short]
        if strtobool(match['good_mass'][0]) != s.Ms_in_0720:
            print(s.name_short+': mass flag differs')
        if strtobool(match['hasHJ'][0]) != s.HJ:
            print(s.name_short+': HJ flag differs')
        if strtobool(match['hasCJ'][0]) != s.CJ:
            print(s.name_short+': CJ flag differs')
        if (match['P_1comp'][0] > 0.05) != (s.P_1comp[1] > 0.05):
            print(s.name_short+': P_1comp flag differs')
        if (match['P_1compv'][0] > 0.05) != (s.P_1comp_v[1] > 0.05):
            print(s.name_short+': P_1comp_v flag differs')
        
        
    
comp_file = 'log/flags1616429974.3911638.txt'
compare_flags(comp_file)


In [None]:
enough_sample = [s.N_sample >= 400 for s in stars]
not_1comp = [s.P_1comp[1] < 0.05 for s in stars]
sample_good = np.logical_and(enough_sample,not_1comp)

Ms_good = [s.Ms_in_0720 for s in stars]
t_good = [s.t_in_1045 for s in stars]

HJ = [s.HJ for s in stars]

HJ_true = np.logical_and(np.logical_and(HJ,Ms_good),sample_good)
WHJ_true = np.logical_and(HJ_true,t_good)
HJs = Sample('HJs',[s for i,s in enumerate(stars) if HJ_true[i]],plot_col='r')
WHJs = Sample('WHJs',[s for i,s in enumerate(stars) if WHJ_true[i]],plot_col='r')

CJ = [s.CJ for s in stars]

CJ_true = np.logical_and(np.logical_and(CJ,Ms_good),sample_good)
WCJ_true = np.logical_and(CJ_true,t_good)
CJs = Sample('CJs',[s for i,s in enumerate(stars) if CJ_true[i]],plot_col='k')
WCJs = Sample('WCJs',[s for i,s in enumerate(stars) if WCJ_true[i]],plot_col='k')

WAll_true = np.logical_and(np.logical_and(Ms_good,t_good),sample_good)
WAll = Sample('WAll',[s for i,s in enumerate(stars) if WAll_true[i]],plot_col='b')

WInc_true = np.logical_and(np.logical_and(Ms_good,t_good),enough_sample)
WInc = Sample('WIcl1comp',[s for i,s in enumerate(stars) if WInc_true[i]],plot_col='b')

All = Sample('All',[s for i,s in enumerate(stars[:-1]) if sample_good[i]], plot_col='g')

print('HJs mean V:  ',HJs.mean_u,HJs.mean_v,HJs.mean_w)
print('WHJs mean V: ',WHJs.mean_u,WHJs.mean_v,WHJs.mean_w)
print('CJs mean V:  ',CJs.mean_u,CJs.mean_v,CJs.mean_w)
print('WCJs mean V: ',WCJs.mean_u,WCJs.mean_v,WCJs.mean_w)
print('WAll mean V: ',WAll.mean_u,WAll.mean_v,WAll.mean_w)
print('WInc mean V: ',WInc.mean_u,WInc.mean_v,WInc.mean_w)
print()
print('HJs disp:    ',np.sqrt(HJs.D_uu),np.sqrt(HJs.D_vv),np.sqrt(HJs.D_ww))
print('WHJs disp:   ',np.sqrt(WHJs.D_uu),np.sqrt(WHJs.D_vv),np.sqrt(WHJs.D_ww))
print('CJs disp:    ',np.sqrt(CJs.D_uu),np.sqrt(CJs.D_vv),np.sqrt(CJs.D_ww))
print('WCJs disp:   ',np.sqrt(WCJs.D_uu),np.sqrt(WCJs.D_vv),np.sqrt(WCJs.D_ww))
print('WAll disp:   ',np.sqrt(WAll.D_uu),np.sqrt(WAll.D_vv),np.sqrt(WAll.D_ww))
print('WInc disp:   ',np.sqrt(WInc.D_uu),np.sqrt(WInc.D_vv),np.sqrt(WInc.D_ww))
print()
print('HJs sigma:   ',np.sqrt(HJs.sigma2),HJs.sig_HS,HJs.RMS,HJs.median_abs_v)
print('WHJs sigma:  ',np.sqrt(WHJs.sigma2),WHJs.sig_HS,WHJs.RMS,WHJs.median_abs_v)
print('CJs sigma:   ',np.sqrt(CJs.sigma2),CJs.sig_HS,CJs.RMS,CJs.median_abs_v)
print('WCJs sigma:  ',np.sqrt(WCJs.sigma2),WCJs.sig_HS,WCJs.RMS,WCJs.median_abs_v)
print('WAll sigma:  ',np.sqrt(WAll.sigma2),WAll.sig_HS,WAll.RMS,WAll.median_abs_v)
print('WInc sigma:  ',np.sqrt(WInc.sigma2),WInc.sig_HS,WInc.RMS,WInc.median_abs_v)


In [None]:
def logv_residuals(samples,x_c_scale=1.0,y_c_scale=1.0,ksy=0.1):

    fig = plt.figure(figsize=(5,5))
    
    left, width = 0.1, 0.65
    bottom, height = 0.1, 0.65
    spacing = 0.005
    rect_scatter = [left, bottom, width, height]
    rect_histx = [left, bottom + height + spacing, width, 0.2]
    rect_histy = [left + width + spacing, bottom, 0.2, height]

    ax = fig.add_axes(rect_scatter)
    ax_histx = fig.add_axes(rect_histx, sharex=ax)
    ax_histy = fig.add_axes(rect_histy, sharey=ax)

    all_res = [star.residuals_t for star in samples[0].stars]
    all_v = [star.target_v for star in samples[0].stars]
    if len(samples) > 1:
        for i in range(len(samples)-1):
            all_v.extend([star.target_v for star in samples[i+1].stars])
            all_res.extend([star.residuals_t for star in samples[i+1].stars])
    
    minx = np.nanmin(np.log10(all_v))
    maxx = np.nanmax(np.log10(all_v))
    miny = np.nanmin(all_res)
    maxy = np.nanmax(all_res)
    v = []
    res = []
    for i in range(len(samples)):
        v.append(np.array([star.target_v for star in samples[i].stars]).flatten())
        res.append(np.array([star.residuals_t for star in samples[i].stars]).flatten())
        scatter_hist(np.log10(v[i]), res[i],ax, ax_histx, ax_histy,minx,maxx,miny,maxy,samples[i].plot_col,
                     label=samples[i].name+' $N={:3d}$'.format(len(samples[i].stars)),
                     x_c_scale=x_c_scale,y_c_scale=y_c_scale)
    if len(samples) == 2:
        ks_res = scipy.stats.ks_2samp(res[0],res[1])
        ks_v = scipy.stats.ks_2samp(v[0],v[1])
        ax.text(minx+(maxx-minx)*0.1,miny+(maxy-miny)*(ksy+0.1),r'$p_\mathrm{{ks,vel}}={:6.2e}$'.format(ks_v[1]))
        ax.text(minx+(maxx-minx)*0.1,miny+(maxy-miny)*ksy,r'$p_\mathrm{{ks,residuals}}={:6.2e}$'.format(ks_res[1]))

    ax.set_xlabel(r'$\log_{{10}} |\mathbf{v}|$ [km s$^{-1}$]')
    ax.set_ylabel('residuals')

    xmin = np.max((ax.get_xlim()[0],0))
    xmax = ax.get_xlim()[1]
    ymin = ax.get_ylim()[0]
    ymax = ax.get_ylim()[1]
    thin = Rectangle((xmin,ymin),np.log10(v_thin)-xmin,ymax-ymin,facecolor='g',alpha = 0.3)
    thick = Rectangle((np.log10(v_thick_min),ymin),np.log10(v_thick_max)-np.log10(v_thick_min),
                      ymax-ymin,facecolor='g',alpha=0.2)
    halo = Rectangle((np.log10(v_halo),ymin),xmax-np.log10(v_halo),ymax-ymin,facecolor='g',alpha=0.1)
    ax.add_patch(thin)
    ax.add_patch(thick)
    ax.add_patch(halo)
    ax.text(np.log10(v_thin)-0.6,ymin+(ymax-ymin)*0.7,'thin disc',color='g')
    ax.text(np.log10(v_thick_min),ymin+(ymax-ymin)*0.7,'thick disc',color='g')
    ax.text(np.log10(v_halo),ymin+(ymax-ymin)*0.7,'halo',color='g')

    ax.legend()

    plot_samples = ''
    for s in samples:
        plot_samples += (s.name + '_')

    plt.savefig(source_cat+'_'+plot_samples+'logv_residuals.pdf',bbox_inches='tight')
    plt.close()

def logv_logrho(samples,x_c_scale=1.0,y_c_scale=1.0,ksy=0.1):

    fig = plt.figure(figsize=(5,5))
    
    left, width = 0.1, 0.65
    bottom, height = 0.1, 0.65
    spacing = 0.005
    rect_scatter = [left, bottom, width, height]
    rect_histx = [left, bottom + height + spacing, width, 0.2]
    rect_histy = [left + width + spacing, bottom, 0.2, height]

    ax = fig.add_axes(rect_scatter)
    ax_histx = fig.add_axes(rect_histx, sharex=ax)
    ax_histy = fig.add_axes(rect_histy, sharey=ax)

    all_rho = [star.rho_20_target for star in samples[0].stars]
    all_v = [star.target_v for star in samples[0].stars]
    if len(samples) > 1:
        for i in range(len(samples)-1):
            all_v.extend([star.target_v for star in samples[i+1].stars])
            all_rho.extend([star.rho_20_target for star in samples[i+1].stars])
    
    minx = np.nanmin(np.log10(all_v))
    maxx = np.nanmax(np.log10(all_v))
    miny = np.nanmin(np.log10(all_rho))
    maxy = np.nanmax(np.log10(all_rho))
    v = []
    rho = []
    for i in range(len(samples)):
        v.append(np.array([star.target_v for star in samples[i].stars]).flatten())
        rho.append(np.array([star.rho_20_target for star in samples[i].stars]).flatten())
        scatter_hist(np.log10(v[i]), np.log10(rho[i]),ax, ax_histx, ax_histy,minx,maxx,miny,maxy,samples[i].plot_col,
                     label=samples[i].name+' $N={:3d}$'.format(len(samples[i].stars))
                     ,x_c_scale=x_c_scale,y_c_scale=y_c_scale)
    if len(samples) == 2:
        ks_rho = scipy.stats.ks_2samp(rho[0],rho[1])
        ks_v = scipy.stats.ks_2samp(v[0],v[1])
        ax.text(minx+(maxx-minx)*0.1,miny+(maxy-miny)*(ksy+0.1),r'$p_\mathrm{{ks,vel}}={:6.2e}$'.format(ks_v[1]))
        ax.text(minx+(maxx-minx)*0.1,miny+(maxy-miny)*ksy,r'$p_\mathrm{{ks,rho}}={:6.2e}$'.format(ks_rho[1]))

    ax.set_xlabel(r'$\log_{{10}} |\mathbf{v}|$ [km s$^{-1}$]')
    ax.set_ylabel(r'$\log_{{10}} \rho$')

    xmin = np.max((ax.get_xlim()[0],0))
    xmax = ax.get_xlim()[1]
    ymin = ax.get_ylim()[0]
    ymax = ax.get_ylim()[1]
    thin = Rectangle((xmin,ymin),np.log10(v_thin)-xmin,ymax-ymin,facecolor='g',alpha = 0.3)
    thick = Rectangle((np.log10(v_thick_min),ymin),np.log10(v_thick_max)-np.log10(v_thick_min),
                      ymax-ymin,facecolor='g',alpha=0.2)
    halo = Rectangle((np.log10(v_halo),ymin),xmax-np.log10(v_halo),ymax-ymin,facecolor='g',alpha=0.1)
    ax.add_patch(thin)
    ax.add_patch(thick)
    ax.add_patch(halo)
    ax.text(np.log10(v_thin)-0.6,ymin+(ymax-ymin)*0.7,'thin disc',color='g')
    ax.text(np.log10(v_thick_min),ymin+(ymax-ymin)*0.7,'thick disc',color='g')
    ax.text(np.log10(v_halo),ymin+(ymax-ymin)*0.7,'halo',color='g')

    ax.legend()

    plot_samples = ''
    for s in samples:
        plot_samples += (s.name + '_')

    plt.savefig(source_cat+'_'+plot_samples+'logv_log_rho.pdf',bbox_inches='tight')
    plt.close()
    
logv_residuals((HJs,CJs),x_c_scale=2,y_c_scale=1.5,ksy=0.4)
logv_logrho((HJs,CJs),x_c_scale=2,y_c_scale=1.5)
logv_residuals((WHJs,WCJs),x_c_scale=2,y_c_scale=2)
logv_logrho((WHJs,WCJs),x_c_scale=2,y_c_scale=2)


In [None]:
def plot_trends(samples):
    
    plt.figure(figsize=[5,4])
    
    for sample in samples:
        for s in sample.stars:
            x = np.log10(s.sample_v)

            xfit,yfit = s.fit.linspace(101,[np.min(x),np.max(x)])

            plt.plot(xfit,yfit,alpha=0.01,color=sample.plot_col)
    
    plt.xlabel('$\log_{{10}}|\mathbf{v}|$ [km/s]')
    plt.ylabel(r'$\log_{{10}}\rho$')
    
    plot_samples = ''
    for s in samples:
        plot_samples += (s.name + '_')

    plt.savefig(source_cat+'_'+plot_samples+'all_trends.pdf')
    plt.close()
    
plot_trends([All])


In [None]:

def plot_age_velocity(sample):
    
    fig = plt.figure(figsize=(5,4))    

    x = np.array([s.planets['st_age'][0] for s in sample.stars])
    y = np.array([s.target_v for s in sample.stars])
    
    plt.plot(x[[s.HJ for s in sample.stars]],y[[s.HJ for s in sample.stars]],'.',color='r',alpha=0.5,label='HJ host')
    plt.plot(x[[not s.HJ for s in sample.stars]],y[[not s.HJ for s in sample.stars]],'.',color='k',alpha=0.5,
             label='not HJ host')

   

    plt.xlabel('Age [Gyr]')
    plt.ylabel('$|\mathbf{v}|$ [km/s]')
    
    plt.legend()
    
    plt.savefig(source_cat+'_'+sample.name+'_age_velocity.pdf')
    plt.close()
    
plot_age_velocity(All)


In [None]:
def high_and_low(sample):
    high = np.nansum([s.P_target > 0.84 for s in sample.stars])
    high2 = 0
    for s in sample.stars:
        if s.P_target > 0.84:
            high2 += 1
    low = np.nansum([s.P_target < 0.16 for s in sample.stars])
    low2 = 0
    for s in sample.stars:
        if s.P_target < 0.16:
            low2 += 1
    
    ind = np.nansum([s.P_target > 0.16 and s.P_target < 0.84 for s in sample.stars])
    ind2 = 0
    for s in sample.stars:
        if s.P_target > 0.16 and s.P_target < 0.84:
            ind2 += 1
   
    highv = np.nansum([s.P_target_v > 0.84 for s in sample.stars])
    high2v = 0
    for s in sample.stars:
        if s.P_target_v > 0.84:
            high2v += 1
    lowv = np.nansum([s.P_target_v < 0.16 for s in sample.stars])
    low2v = 0
    for s in sample.stars:
        if s.P_target_v < 0.16:
            low2v += 1
    
    indv = np.nansum([s.P_target_v > 0.16 and s.P_target_v < 0.84 for s in sample.stars])
    ind2v = 0
    for s in sample.stars:
        if s.P_target_v > 0.16 and s.P_target_v < 0.84:
            ind2v += 1
   
    
    print(sample.name,high,high2,low,low2,ind,ind2,'|',highv,high2v,lowv,low2v,indv,ind2v)
    print(high2/(high2+ind2+low2),low2/(high2+ind2+low2),'|',high2v/(high2v+ind2v+low2v),low2v/(high2v+ind2v+low2v))
    
    print()
    
    return (high2,low2,high2v,low2v)

WHJ_numbers = high_and_low(WHJs)
WCJ_numbers = high_and_low(WCJs)
HJ_numbers = high_and_low(HJs)
WAll_numbers = high_and_low(WAll)

In [None]:
def plot_planets(sample):
    
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2,2,figsize=(10,8),sharex=True,sharey=True)
    
    size=18
    
    x = []
    y = []
    ph = []
    phs = []
    phv = []
    phvs = []
    ishj = []
    iscj = []
    for s in sample.stars:
        count = 0
        hj = []
        cj = []
        for p in s.planets:
            count += 1
            a = p['pl_orbsmax']
            if not np.isfinite(a):
                a = ((p['pl_orbper']/365.25)**2/p['st_mass'])**(1/3)
            x.append(a)
            y.append(p['pl_bmasse'])
            if x[-1] < 0.2 and y[-1] > 50:
                hj.append(True)
            else:
                hj.append(False)
            if x[-1] > 0.2 and y[-1] > 50:
                cj.append(True)
            else:
                cj.append(False)
            ph.append(s.P_target)
            phv.append(s.P_target_v)
            if s.P_1comp_v[1] > 0.05:
                phv[-1] = np.nan
            if (not np.isfinite(x[-1]) or not np.isfinite(y[-1])):
                ph[-1] = np.nan
                phv[-1] = np.nan
                print('Bad pl data: ',s.name_short)
            if phv[-1] < 0.16 and count >1:
# print duplicates
                print('Duplicate: ',s.name_short)
        ishj.append(np.array(hj).any())
        iscj.append(np.array(cj).any())
            
        phs.append(np.nanmin(ph[-count:]))
        phvs.append(np.nanmin(phv[-count:]))
        
    print('HJs: ',sum(ishj))
    print('CJs: ',sum(iscj))
#    print(ishj)
#    print(iscj)

    x = np.array(x)
    y = np.array(y)
    phs = np.array(phs)
    phvs = np.array(phvs)
    ph = np.array(ph)
    phv = np.array(phv).flatten()
    
    ax1.scatter(x[ph<0.16],y[ph<0.16],c='b',zorder=10)

    ax1.set_xscale('log')
    ax1.set_yscale('log')
    ax1.text(0.015,4000,r'$P_{{\mathrm{{high,}}\rho}}<0.16$, $N_\star={:3d}$'.format(sum(phs<0.16)),fontsize=size)
    ax1.text(0.015,55,'HJs',fontsize=size)
    ax1.text(5,55,'CJs',fontsize=size)
    ax1.text(0.5,10,'HJ:CJ = {:2d}:{:2d}'.format(sum(np.logical_and(phs<0.16,ishj)),
                                                 sum(np.logical_and(phs<0.16,iscj))),fontsize=size)
    
    ax2.scatter(x[ph>0.84],y[ph>0.84],c='r',zorder=10)
    
    ax2.set_xscale('log')
    ax2.set_yscale('log')
    ax2.text(0.015,4000,r'$P_{{\mathrm{{high,}}\rho}}>0.84$, $N_\star={:3d}$'.format(sum(phs>0.84)),fontsize=size)
    ax2.text(0.5,10,'HJ:CJ = {:2d}:{:2d}'.format(sum(np.logical_and(phs>0.84,ishj)),
                                                 sum(np.logical_and(phs>0.84,iscj))),fontsize=size)

    ax3.scatter(x[phv>0.84],y[phv>0.84],c='b',zorder=10)

    ax3.set_xscale('log')
    ax3.set_yscale('log')
    ax3.text(0.015,4000,r'$P_\mathrm{{high,v}}>0.84$, $N_\star={:3d}$'.format(sum(phvs>0.84)),fontsize=size)
    ax3.text(0.5,10,'HJ:CJ = {:2d}:{:2d}'.format(sum(np.logical_and(phvs>0.84,ishj)),
                                                 sum(np.logical_and(phvs>0.84,iscj))),fontsize=size)

    ax4.scatter(x[phv<0.16],y[phv<0.16],c='r',zorder=10)
    ax4.set_xscale('log')
    ax4.set_yscale('log')
    ax4.text(0.015,4000,r'$P_\mathrm{{high,v}}<0.16$, $N_\star={:3d}$'.format(sum(phvs<0.16)),fontsize=size)
    ax4.text(0.5,10,'HJ:CJ = {:2d}:{:2d}'.format(sum(np.logical_and(phvs<0.16,ishj)),
                                                 sum(np.logical_and(phvs<0.16,iscj))),fontsize=size)

        
    
    
    for ax in (ax1,ax2,ax3,ax4):
        ax.tick_params(axis='both',labelsize=size-2)
    plt.subplots_adjust(hspace=.0)
    plt.subplots_adjust(wspace=.0)
    
    fig.add_subplot(111, frameon=False)
# hide tick and tick label of the big axis
    plt.tick_params(labelcolor='none', top=False, bottom=False, left=False, right=False)

    plt.xlabel('semimajor axis [au]',fontsize=size)
    plt.ylabel('mass [$\mathrm{{M}}_\oplus$]',fontsize=size)

    xlim = ax1.get_xlim()
    ylim = ax1.get_ylim()
    for ax in (ax1,ax2,ax3,ax4):
        ax.plot([xlim[0],xlim[1]],[50,50],c='grey',zorder=5)
        ax.plot([0.2,0.2],[50,ylim[1]],c='grey',zorder=5)
        
    nx = 101
    ny = 101
    xgrid = np.linspace(np.log10(xlim[0]),np.log10(xlim[1]),nx)
    ygrid = np.linspace(np.log10(ylim[0]),np.log10(ylim[1]),ny)
    meshx, meshy = np.meshgrid(xgrid,ygrid)
    
    xsmooth = (xgrid[-1]-xgrid[0])*0.10
    ysmooth = (ygrid[-1]-ygrid[0])*0.10
    print(xsmooth,ysmooth)
    
    def kernel(meshx,meshy,x,y,xsmooth,ysmooth):
        return np.exp(-0.5*((meshx-x)/(xsmooth))**2 - 0.5*((meshy-y)/(ysmooth))**2)
    
    smooth = np.zeros((nx,ny))
    for s in sample.stars:
        for p in s.planets:
            if s.P_target < 0.16:
                a = p['pl_orbsmax']
                if not np.isfinite(a):
                    a = ((p['pl_orbper']/365.25)**2/p['st_mass'])**(1/3)
                if np.isfinite(a) and np.isfinite(p['pl_bmasse']):
                    smooth += kernel(meshx,meshy,np.log10(a),np.log10(p['pl_bmasse']),xsmooth,ysmooth)
    ax1.contourf(10**meshx,10**meshy,smooth,zorder=-10,cmap='Blues')
                            
    smooth = np.zeros((nx,ny))
    for s in sample.stars:
        for p in s.planets:
            if s.P_target > 0.84:
                a = p['pl_orbsmax']
                if not np.isfinite(a):
                    a = ((p['pl_orbper']/365.25)**2/p['st_mass'])**(1/3)
                if np.isfinite(a) and np.isfinite(p['pl_bmasse']):
                    smooth += kernel(meshx,meshy,np.log10(a),np.log10(p['pl_bmasse']),xsmooth,ysmooth)
    ax2.contourf(10**meshx,10**meshy,smooth,zorder=-10,cmap='Reds')

    smooth = np.zeros((nx,ny))
    for s in sample.stars:
        for p in s.planets:
            if s.P_target_v > 0.84:
                a = p['pl_orbsmax']
                if not np.isfinite(a):
                    a = ((p['pl_orbper']/365.25)**2/p['st_mass'])**(1/3)
                if np.isfinite(a) and np.isfinite(p['pl_bmasse']):
                    smooth += kernel(meshx,meshy,np.log10(a),np.log10(p['pl_bmasse']),xsmooth,ysmooth)
    ax3.contourf(10**meshx,10**meshy,smooth,zorder=-10,cmap='Blues')

    smooth = np.zeros((nx,ny))
    for s in sample.stars:
        for p in s.planets:
            if s.P_target_v < 0.16:
                a = p['pl_orbsmax']
                if not np.isfinite(a):
                    a = ((p['pl_orbper']/365.25)**2/p['st_mass'])**(1/3)
                if np.isfinite(a) and np.isfinite(p['pl_bmasse']):
                    smooth += kernel(meshx,meshy,np.log10(a),np.log10(p['pl_bmasse']),xsmooth,ysmooth)
    ax4.contourf(10**meshx,10**meshy,smooth,zorder=-10,cmap='Reds')

    plt.savefig(source_cat+'_'+sample.name+'_planets_a_M_smooth0.10.pdf',bbox_inches='tight')
    plt.close()
    
plot_planets(WAll)

In [None]:
def plot_ranks(sample):
    plt.figure(figsize=[5,4])

    rgba_cols = np.zeros((len(sample.stars),4))
    N_sample = np.array([s.N_sample + 1 for s in sample.stars])
    rgba_cols[:,3] = N_sample/np.nanmax(N_sample)
    
    f_rank = np.array([s.rank for s in sample.stars])/N_sample
    f_res = np.array([s.rank_detrended for s in sample.stars])/N_sample
    
    plt.scatter(f_rank,f_res,c=rgba_cols)
    plt.xlabel('density fractional rank')
    plt.ylabel('residuals fractional rank')
    plt.title(sample.name)

    #plt.show()
    plt.savefig(source_cat+'_'+sample.name+'_ranks_all.pdf')
    plt.close()
    
plot_ranks(HJs)

In [None]:
def plot_residuals_ranks(sample,control):
    plt.figure(figsize=[5,4])
    N_sample = np.array([s.N_sample + 1 for s in sample.stars])
    f_rank_s = np.array([s.rank for s in sample.stars])/N_sample
    f_res_s = np.array([s.rank_detrended for s in sample.stars])/N_sample
    
    f_rank_c = control.rank_all/control.N_sample
    f_res_c = control.rank_detrended_all/control.N_sample
    
    match = (f_rank_s * control.N_sample+1).astype(int)
    f_res_match = np.sort(f_rank_c)[match]
    
    bins1 = np.linspace(0,1,11)
    bins2 = np.linspace(0,1,1001)
    plt.hist(f_res_s,density=True,bins=bins1,label='Neighbours of '+control.name_short,alpha=0.5)
    plt.hist(f_res_s,cumulative=True,histtype='step',density=True,bins=bins2,label='(cumulative)')
    plt.hist(f_res_match,density=True,bins=bins1,label=sample.name,alpha=0.5)
    plt.hist(f_res_match,cumulative=True,histtype='step',density=True,bins=bins2,label='cumulative')
    dKS, pKS = scipy.stats.ks_2samp(f_res_s,f_res_match)
    plt.text(0.6,0.25,'$p_\mathrm{{KS}}$ = {:4f}'.format(pKS))

    plt.xlabel('residuals fractional rank')
    plt.ylabel('Normalised count')
    plt.legend()
#    plt.show()
    plt.savefig(source_cat+'_'+sample.name+'_residuals_ranks.pdf')
    plt.close()
    
plot_residuals_ranks(HJs,stars[-1])