In [1]:
import numpy as np
import polars as pl
import pandas as pd
import matplotlib.pyplot as plt

from astropy.coordinates import SkyCoord
import astropy.units as u
from galpy.orbit import Orbit

In [2]:
data = pl.read_csv(
    "../data/raw/plato_targets_LOPS2.csv",
    columns=[
        "ra",
        "dec",
        "parallax",
        "pmra",
        "pmdec",
        "radial_velocity",
    ],
).to_pandas()

In [3]:
coord = SkyCoord(
    ra=data["ra"].to_numpy() * u.degree,
    dec=data["dec"].to_numpy() * u.degree,
    distance=(1 / data["parallax"].to_numpy()) * u.pc,
    pm_ra_cosdec=data["pmra"].to_numpy() * u.mas / u.yr,
    pm_dec=data["pmdec"].to_numpy() * u.mas / u.yr,
    radial_velocity=data["radial_velocity"].to_numpy() * u.km / u.s,
)

In [4]:
o = Orbit(
    vxvv=[
        coord.ra,
        coord.dec,
        coord.distance,
        coord.pm_ra_cosdec,
        coord.pm_dec,
        coord.radial_velocity,
    ],
    radec=True,
)

u, v, w = o.U(), o.V(), o.W()

In [5]:
comp = {
    "thin disk": {
        "X": 0.94,
        "sigma_U": 35,
        "sigma_V": 20,
        "sigma_W": 16,
        "V_asym": -15,
    },
    "thick disk": {
        "X": 0.06,
        "sigma_U": 67,
        "sigma_V": 38,
        "sigma_W": 35,
        "V_asym": -46,
    },
    "halo": {
        "X": 0.0015,
        "sigma_U": 160,
        "sigma_V": 90,
        "sigma_W": 90,
        "V_asym": -220,
    },
}

In [6]:
from scipy.stats import norm


def P(U, V, W, parameter):
    for key in ["sigma_U", "sigma_V", "sigma_W", "V_asym"]:
        if not key in parameter.keys():
            raise ValueError(f"Parameter dict must contain {key!r}.")

    U_prob = norm.pdf(U, loc=0, scale=parameter["sigma_U"])
    V_prob = norm.pdf(V, loc=parameter["V_asym"], scale=parameter["sigma_V"])
    W_prob = norm.pdf(W, loc=0, scale=parameter["sigma_W"])
    return U_prob * V_prob * W_prob

In [7]:
def relative_prob(U, V, W, comp_1, comp_2):
    for parameter_dict in [comp_1, comp_2]:
        for key in ["X", "sigma_U", "sigma_V", "sigma_W", "V_asym"]:
            if not key in parameter_dict.keys():
                raise ValueError(
                    f"Parameter dict {parameter_dict!r} must contain {key!r}."
                )

    prob_1 = P(U, V, W, comp_1)
    prob_2 = P(U, V, W, comp_2)
    return comp_1["X"] * prob_1 / (comp_2["X"] * prob_2)

In [8]:
td_d = relative_prob(u, v, w, comp["thick disk"], comp["thin disk"])
td_h = relative_prob(u, v, w, comp["thick disk"], comp["halo"])

In [11]:
(td_h < 0.1).sum()

211

In [16]:
from plato.classify import relative_probability

ImportError: cannot import name 'relative_probability' from 'plato.classify' (/home/chris/Documents/Projects/plato/plato/classify.py)