In [29]:
from astropy.io import fits
from astropy.table import Table
import matplotlib.pyplot as plt
import numpy as np
from astroquery.gaia import Gaia
import logging
from tqdm import tqdm
from astropy.table import vstack
logging.basicConfig(level=logging.INFO)

In [10]:
# Import Apogee Data sets
apogee_data_file_NN = '../data/Apogee_DR17_vac_NN/apogee_astroNN-DR17.fits'
apogee_data_file_starhorse = '../data/Apogee_DR17_vac_starhorse/APOGEE_DR17_EDR3_STARHORSE_v2.fits'
apogee_data_file_allstar = '../data/Apogee_DR17_Allstar/allStar-dr17-synspec_rev1.fits'

#Importing data from file
with fits.open(apogee_data_file_allstar) as file:
    apogee_data = file[1].data
    # Cleaning the data set using mask statments for initial inputs
    # Filter for Main Red stars
    mrs_filter = apogee_data['EXTRATARG']==0

    # Filter bad star data
    bs_filter = apogee_data['ASPCAPFLAG'] != 'STAR_BAD'
    prog_filter = apogee_data['PROGRAMNAME'] != 'magclouds'
    rg_filter = apogee_data['LOGG'] < 3.0

    # Filter for valid element abundances
    # FE/ H
    # Filter flags
    fe_h_flag_filter = apogee_data['FE_H_FLAG'] == 0
    # Filter errors
    fe_h_err_filter = apogee_data['FE_H_ERR'] < 0.1
    # Combined filter
    fe_h_filter = fe_h_flag_filter & fe_h_err_filter

    # AL/FE
    # Filter flags
    al_fe_flag_filter = apogee_data['AL_FE_FLAG'] == 0
    # Filter errors
    al_fe_err_filter = apogee_data['AL_FE_ERR'] < 0.1
    # Combined filter
    al_fe_filter = al_fe_flag_filter & al_fe_err_filter

    # CE/FE
    # Filter flags
    ce_fe_flag_filter = apogee_data['CE_FE_FLAG'] == 0
    # Filter errors
    ce_fe_err_filter = apogee_data['CE_FE_ERR'] < 0.15
    # Combined filter
    ce_fe_filter = ce_fe_flag_filter & ce_fe_err_filter


    # Generating the Mg/Mn filter
    # filter flags  
    mg_fe_flag_filter = apogee_data['MG_FE_FLAG'] == 0
    mn_fe_flag_filter = apogee_data['MN_FE_FLAG'] == 0
    mg_mn_flag_filter = mg_fe_flag_filter & mn_fe_flag_filter

    # filter errors
    MG_MN_ERR = np.sqrt(apogee_data['MG_FE_ERR']**2 + apogee_data['MN_FE_ERR']**2)
    mg_mn_err_filter = MG_MN_ERR < 0.1

    # Combined filter
    mg_mn_filter = mg_mn_flag_filter & mg_mn_err_filter

    # Generating Alpha/Fe filter
    # filter flags
    o_fe_flag_filter = apogee_data['O_FE_FLAG'] == 0
    mg_fe_flag_filter = apogee_data['MG_FE_FLAG'] == 0
    si_fe_flag_filter = apogee_data['SI_FE_FLAG'] == 0
    ca_fe_flag_filter = apogee_data['CA_FE_FLAG'] == 0
    ti_fe_flag_filter = apogee_data['TI_FE_FLAG'] == 0
    alpha_fe_flag_filter = o_fe_flag_filter & mg_fe_flag_filter & si_fe_flag_filter & ca_fe_flag_filter & ti_fe_flag_filter

    # filter errors
    alpha_fe_flag_filter = apogee_data['ALPHA_M_ERR'] < 0.1 

    # Combined filter
    alpha_fe_filter = alpha_fe_flag_filter & alpha_fe_flag_filter


    # All Main Red Stars
    apogee_data_red = apogee_data[mrs_filter]
    # All stars remaining based on APOGEE filters 
    # Note this is currently missign alpha/Fe and Mg/Mn filters
    apogee_data_filtered = apogee_data[mrs_filter & bs_filter & prog_filter & rg_filter & fe_h_filter & al_fe_filter & ce_fe_filter & mg_mn_filter & alpha_fe_filter]


In [11]:
print(f'Number of stars in the red giant sample: {len(apogee_data_red)}')
print(f'Number of stars in the (APOGEE) filtered sample: {len(apogee_data_filtered)}')

Number of stars in the red giant sample: 372458
Number of stars in the (APOGEE) filtered sample: 164040


In [12]:
# Extract GAIA ID fron remaining stars
gaia_ids =  np.array(apogee_data_filtered['GAIAEDR3_SOURCE_ID'])

# Set size for SQL query and split up GAIA IDs
query_size = 750
indiv_queries = np.array_split(gaia_ids, np.ceil(len(gaia_ids) / query_size))

# Empty list to store the results of each query
list_query_results = []
# Track missing GAIA IDs
missing_ids_set = set() 

# Loop through each smaller set of GAIA IDs for the SQL query
for i, query in enumerate(tqdm(indiv_queries, desc="Processing Queries")):
    # Convert the chunk to a comma-separated string for SQL syntax
    gaia_id_list = ", ".join(query.astype(str))
    
    # Define the query
    distance_query = f"""
    SELECT source_id, r_med_geo, r_lo_geo, r_hi_geo, r_med_photogeo, r_lo_photogeo, r_hi_photogeo
    FROM external.gaiaedr3_distance
    WHERE source_id IN ({gaia_id_list});
    """
    # Run the query with SQL
    job = Gaia.launch_job(distance_query)
    results = job.get_results()

    # Store missing IDs 
    query_ids = set(query)  
    returned_ids = set(results['source_id'])  
    missing_ids_set.update(query_ids - returned_ids)

    # Append the results to the list
    list_query_results.append(results)

# Combine all results into a single table
all_query_results = vstack(list_query_results)

# Store missing IDs in array
missing_gaia_ids = np.array(list(missing_ids_set))


Processing Queries: 100%|██████████| 219/219 [02:51<00:00,  1.28it/s]


In [13]:
# Determine IDs which do not return SQL information
unique, counts = np.unique(gaia_ids[np.isin(gaia_ids, missing_gaia_ids)], return_counts=True)
print(f'Number of unique missing GAIA IDs: {len(unique)}')
print(f'Number of items in Database: {np.sum(counts)}')
print(f'Number of GAIA IDs given 0: {counts[0]}')

# Remove stars with missing GAIA Data
missing_ids_position = np.isin(gaia_ids, missing_gaia_ids)
apogee_data_filtered_2 = Table(apogee_data_filtered[~missing_ids_position])


Number of unique missing GAIA IDs: 246
Number of items in Database: 821
Number of GAIA IDs given 0: 576


In [30]:
# Remove stars with missing GAIA Data
missing_ids_position = np.isin(gaia_ids, missing_gaia_ids)
apogee_data_filtered_2 = Table(apogee_data_filtered[~missing_ids_position])

In [71]:
# Sort tables by GAIA ID's
all_query_results.sort('source_id')
apogee_data_filtered_2.sort('GAIAEDR3_SOURCE_ID')

# Check if the GAIA ID's match beofre merging
assert np.array_equal(all_query_results['source_id'], apogee_data_filtered_2['GAIAEDR3_SOURCE_ID']), "Mismatch in GAIA IDs!"
print("All GAIA IDs match")

# Calculate symmetrized distance uncertainties 
all_query_results['r_sym_uncert_geo'] = (all_query_results['r_hi_geo'] - all_query_results['r_lo_geo']) / 2
all_query_results['r_sym_uncert_photogeo'] = (all_query_results['r_hi_photogeo'] - all_query_results['r_lo_photogeo']) / 2

# Merge the tables - using the phot geometric distance 
apogee_data_filtered_2['r_sym_uncert_photogeo'] = all_query_results['r_sym_uncert_photogeo']
apogee_data_filtered_2['r_med_photogeo'] = all_query_results['r_med_photogeo']

# Merge the tables - using the geometric only distance
# apogee_data_filtered_2['r_sym_uncert_geo'] = all_query_results['r_sym_uncert_geo']
# apogee_data_filtered_2['r_med_geo'] = all_query_results['r_med_geo']

# Add Mg/MN characteristic
apogee_data_filtered_2['MG_MN'] = apogee_data_filtered_2['MG_FE'] - apogee_data_filtered_2['MN_FE']
apogee_data_filtered_2['MG_MN_ERR'] = np.sqrt(apogee_data_filtered_2['MG_FE_ERR']**2 + apogee_data_filtered_2['MN_FE_ERR']**2)


All GAIA IDs match


In [7]:
# # Print the results
# print(results)


# # Filter for eccentricity
# ecc_filter = apogee_data['ECCENTRICITY'] > 0.85
# # Filter for orbital apocenter
# apo_filter = apogee_data['APOCENTER'] > 5
# # Filter for distance error
# dist_err_filter = apogee_data['DIST_ERR'] < 1.5
# # Filter for orbital energy
# energy_filter = apogee_data['ENERGY'] < 0




# # Plot the HR diagram
# plt.figure(figsize=(10, 8))
# scatter = plt.scatter(filtered_teff, filtered_logg, c=filtered_fe_h, cmap='viridis', s=10, alpha=0.7)
# plt.colorbar(scatter, label='[Fe/H] (Metallicity)')

# # Reverse x-axis (hotter stars on the left)
# plt.gca().invert_xaxis()

# # Label axes
# plt.xlabel('Effective Temperature (K)', fontsize=14)
# plt.ylabel('Surface Gravity (log g)', fontsize=14)
# plt.title('Hertzsprung-Russell Diagram (APOGEE Data)', fontsize=16)
# plt.grid(True)

# plt.show()





# # Extract relevant columns
# teff = apogee_data['TEFF']
# logg = apogee_data['LOGG']
# bp_rp = apogee_data['bp_rp']

# # Apply conditions for red stars (example thresholds)
# red_star_mask = (teff < 5000) & (logg < 3) & (bp_rp > 1.0)

# # Filter the data
# red_stars = apogee_data[red_star_mask]
# print(f'Number of red stars: {len(red_stars)}')
