# Use an existing Random Forest model to classify new BASE-9 data

This will likely be the main notebook for users (after creating the model in `create_model.ipynb`, if necessary).

This notebook performs the following tasks:
- reads in posterior data from BASE-9
- generates features from these data
- uses the model created in the previous notebook to classify the posterior sampling as "Good" vs. "Bad".  

In this example, we apply the model to data for NGC 6819, generated by Elizabeth Jeffery.

Most of the "heavy lifting" is done by the code in the `base9_ml_utils.py` file.  See the comments and markdown in that code for more details.

___
*Authors:* Justyce Watson, Aaron Geller\
*Date:* August 2025

## Import all functions from the `base9_ml_utils.py` file

In [1]:
# import functions from .py file
from base9_ml_utils import *

# The lines below are useful if you plan to make changes to the base9_ml_utils.py file.
# They will allow the notebook to refresh when you save changes to the .py file.
#
# %load_ext autoreload
# %autoreload 2


## Read in `.res` files and creates the features 

The user should specify the data directory on their own computer.  The code assumes that this directory contains one `.res` file for each star with the filename containing the star ID.  (If there is additional text in the file name, the user can specify this in the code, using the `file_prefix` and/or `file_suffix` args so that the code can identify the star ID from the filename properly.)  

We will use the `create_features` function imported from `base9_ml_utils.py`.

In [2]:
# run this cell to see information about this function
create_features?

[31mSignature:[39m
create_features(
    directory,
    column=[32m0[39m,
    max_nfiles=inf,
    file_prefix=[33m'NGC_2682_'[39m,
    file_suffix=[33m''[39m,
    ess_num_samples=[32m10000[39m,
    random_seed=[32m42[39m,
)
[31mDocstring:[39m
function that will calculate all the features needed for the ML model 
Note that the file names for the res files must contain the ids (and can include a prefix and suffix)

inputs:
- directory : (string) path to the data directory that contains the res files from BASE-9
- column : (string) column number to use from the res file to use to calculate features
- max_nfiles : (int) maximum number of files to use
- file_prefix : (string) prefix in the res file names before the id 
- file_suffix : (string) suffix in the res file names after the id
- ess_num_samples : (int) number of samples to use in ess normal distribution
- random_seed : (int) random seed used for calculating ess

outputs:
- pandas DataFrame with the calculated features (

In [3]:
# directory on your computer where the .res data files are stored
directory = 'data/NGC6819/ngc6819_single_resfiles' 

# create a DataFrame with features for each star using the 'create_features'
ngc6819_statistic = create_features(directory, file_prefix='gaia_', file_suffix='_sin2')
 
# display the resulting DataFrame in the notebook
ngc6819_statistic


Unnamed: 0,source_id,Width,Upper_bound,Lower_bound,Stdev,SnR,Dip_p,Dip_value,KS_value,KS_p,ESS
0,2076377646922516096,1.310742,0.371344,0.17444,0.705594,12.625670,0.0,0.055795,0.202753,5.028740e-223,9808.068893
1,2076269108813963776,1.465552,0.371344,0.17444,0.688275,12.954116,0.0,0.026590,0.148421,3.345375e-119,9483.299078
2,2076395170383622016,1.609451,0.371344,0.17444,0.793903,11.146351,0.0,0.013034,0.155102,3.832169e-129,10075.385992
3,2076390192531116544,1.370839,0.371344,0.17444,0.657100,13.391178,0.0,0.009949,0.157269,6.046880e-135,9881.741093
4,2076479596566965376,1.188421,0.371344,0.17444,0.610059,14.706439,0.0,0.068436,0.188594,2.449605e-194,9817.439956
...,...,...,...,...,...,...,...,...,...,...,...
1693,2076286593616203520,0.948054,0.371344,0.17444,0.563526,16.162504,0.0,0.012569,0.214150,4.636188e-255,9840.642432
1694,2076394109541237248,1.592687,0.371344,0.17444,0.727018,12.184470,0.0,0.016283,0.197510,2.859512e-213,9968.341079
1695,2076490213726294528,1.546252,0.371344,0.17444,0.728562,12.026299,0.0,0.017071,0.168514,5.394290e-155,9575.281437
1696,2076490041927622016,1.496974,0.371344,0.17444,0.702897,12.574082,0.0,0.016518,0.173834,1.285783e-167,10257.561780


# Load saved model 

We assume that you have a model generated as in the `create_model.ipynb` notebook.  Here we saved the model as `my_model.pkl`.  Note that in order to use a saved model, you will need to be working with the same version of scipy (and possibly other dependencies).    

In [4]:
# Read in the saved model (as a scipy pipeline object)
pipe = load_model('my_model.pkl')

## Use the model to generate labels
Here we use the `make_preds` function imported from `base9_ml_utils.py`.  In this function we read in above and data to be labeled.  Here we will use the data from NGC 6819 that we created above.  Note that the `make_preds` function can accept either a pandas DataFrame (as we have here) or a numpy array.  If passed a DataFrame it will ensure that the columns are in the correct order and convert it to a numpy array.  If passed a numpy array, the user must ensure that the columns are already in the correct order.


*Note:* If you encounter an error here related to `scipy` it may mean that the model (read in above) was created using a different version of `scipy`.  There may be other dependency issues that I am not currently aware of as well.  If you see an error, you should remake the model on your machine using the same environment that you intend to run this notebook in.  If you need access to the data we used to create the model, please contact Aaron Geller.

In [5]:
# use the model to label the NGC 6819 data
y_pred_6819 = make_preds(pipe, ngc6819_statistic)

# print the outpt
y_pred_6819

array(['Bad', 'Bad', 'Bad', ..., 'Bad', 'Bad', 'Good'],
      shape=(1698,), dtype=object)

In [6]:
# As a quick sanity check, you may want to print the number of labels that are "Good" vs "Bad"
# (You would expect a non-zero amount of each label type.)

print(
f'The model predicts {len(np.where(y_pred_6819 == "Good")[0])} "Good" \
and {len(np.where(y_pred_6819 == "Bad")[0])} "Bad" labels for these data.'
)

The model predicts 392 "Good" and 1306 "Bad" labels for these data.
