# Chandra-ML

This Repository contains all the necessary data tables and routines for the classification of sources in the **Chandra Source Catalog-2.0**

The section **Data** describes all the data tables in and various routines to obtain the required data in proper format from this table in general.

The section **Model Training and validation** Describes our application of LightGBM classification model on this data and the routines developed for it.

The last section **Application** Shows the applicatin of the model on the unclassified sources.

#### Requirements : 
```
astropy
astroquery
pandas
numpy
scickit-learn
lightgbm
```

#### Important Imports

In [20]:
import numpy as np 
import pandas as pd 
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Data

### Directory Details

All the data for this work is in the folder *_data_* folder. 

<small>
Due to size constrains of this github, the data is not included in this repository. It is uploaded to the google drive *drive link* in a zipped file. Download and extract the folder in this directory and do not change the file names.
</small>



#### _data_ folder structure


```
├── data
│   ├── classified
│   │   ├── AGN.csv
│   │   ├── CV.csv
│   │   ├── HMXB.csv
│   │   ├── LMXB.csv
│   │   ├── PULSAR.csv
│   │   ├── STAR.csv
│   │   ├── TRAIN_SRC.csv
│   │   ├── ULX.csv
│   │   └── YSO.csv
│   ├── mw_cat
│   │   ├── 2mass_v2.csv
│   │   ├── chandra_filtered_sources.csv
│   │   ├── gaia.csv
│   │   ├── galex_combined.csv
│   │   ├── MIPS.csv
│   │   ├── sdss.csv
│   │   └── wise_combined.csv
│   ├── new_src_data
│   │   └── new_sources.csv
│   ├── source_info
│   │   └── all_csc_source_info.csv
│   └── training_data
│       ├── id_frame.csv
│       ├── imputed
│       │   ├── x_phot_minmax_10iter_rfimpimp.csv
│       │   ├── x_phot_minmax_constimp.csv
│       │   ├── x_phot_minmax_forestimp.csv
│       │   ├── x_phot_minmax_knnimp.csv
│       │   ├── x_phot_minmax_meanimp.csv
│       │   └── x_phot_minmax_modeimp.csv
│       ├── train_data_minmax.csv
│       └── x_phot_minmax.csv

```



* *classified* : Contains the data table for all the sources idnetified using the LightGBM in this work. The table consists of the class memberhip probabilities alongwith the MW data for all the sources.

* *mw_cat* : Multi-wavelength catalogs for all the sources. use csc names as the identifiers.

* *new_src_data* : normalized data corresponding to all the unclassified sources. This data tale is used in this work to preoduse the classification table and the CMPs available in the _classified_ folder

* *source_info* : The data-table in this folder contains all the necessary information (quality flags, position, observation info) for all the sources in the CSC-2.0

* *training_data* : contains the data-tabl of the sources cross-match and identified in various classes. The data in this folder are normalised and was used for the training of model in this work. All the imputed data are inside the _imputed_ folder.

#### General Data Retrival

We will start with the source list of all the sources in the CSC-2.0.
The data table **all_source_info.csv** in the folder _data/source_info_ contains the information of all the sources. With a minimum Pandas skill, one can select the object of choice from this csv file. But with the routine **get_source_info** in the _choices_ module, the list can be derived eaisly using flags parameter.

In [13]:
flags = {
    'conf_flag' : 0 , 
    'streak_src_flag' : 0 , 
    'extent_flag' : 0 , 
    'pileup_flag' : 0 , 
    }
from choices import get_source_list
sources = get_source_list(flags)
sources

2CXO J003935.9-732725
2CXO J003936.7-731249
2CXO J004028.7-731106
2CXO J004506.3-730056
2CXO J004659.0-731918
...
2CXO J220613.7-495727
2CXO J220614.6-500951
2CXO J220618.4-500554
2CXO J220626.0-500126
2CXO J220642.7-495916


Now let's extract the information for 100 of these sources using the function **get_source_info**.
<small> Note: this function is also based on the file _all_source_info.csv_. The point of having an additional function for this is that at any stage of working with any data-table of N number of sources, we can alway pull out the information about the source.

In [15]:
from choices import get_source_info
source_info = get_source_info(sources.sample(100))
source_info

Unnamed: 0_level_0,ra,dec,gal_l,gal_b,err_ellipse_r0,err_ellipse_r1,err_ellipse_ang,conf_flag,extent_flag,sat_src_flag,var_flag,pileup_flag,streak_src_flag,significance,acis_time,hrc_time
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
2CXO J002046.2-705659,5.192620,-70.949804,306.523591,-45.963984,0.956380,0.804594,44.584318,0.0,0,0.0,0.0,0.0,0,3.588235,22768.771028,
2CXO J101020.3-124106,152.584628,-12.685272,253.231259,34.215853,0.718688,0.717721,79.403259,0.0,0,0.0,0.0,0.0,0,9.390751,51195.761895,
2CXO J165409.8-020219,253.540954,-2.038749,16.632152,24.790265,2.002898,1.185102,104.896836,0.0,0,0.0,0.0,0.0,0,4.722222,20772.938789,
2CXO J100301.2+021934,150.755395,2.326167,237.190828,42.708642,1.161515,1.079687,75.542480,0.0,0,0.0,0.0,0.0,0,2.666667,158905.462446,
2CXO J054244.6-404855,85.686088,-40.815288,246.499422,-29.796437,3.891118,2.875106,178.855720,0.0,0,0.0,0.0,0.0,0,3.875000,50405.925987,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2CXO J180325.0-295432,270.854168,-29.909153,1.117724,-3.835987,0.805650,0.779075,6.289155,0.0,0,0.0,1.0,0.0,0,3.875000,104822.403707,
2CXO J181647.4-162129,274.197777,-16.358283,14.477496,0.010426,0.853008,0.796700,174.514643,0.0,0,0.0,0.0,0.0,0,2.052632,18463.958710,
2CXO J173555.2+570356,263.980184,57.065711,85.279886,32.634090,4.897011,3.619107,102.905240,0.0,0,0.0,0.0,0.0,0,2.111111,10337.817452,
2CXO J134902.7+035728,207.261495,3.957986,336.132887,63.054502,1.876542,1.016800,133.998688,0.0,0,0.0,0.0,0.0,0,9.082235,11005.007980,


Now we will get the raw data for these sources using the function **get_raw_data**. 

In [17]:
from choices import get_raw_data
src_data = get_raw_data(source_info)
src_data

Unnamed: 0_level_0,ra_x,dec_x,gal_l,gal_b,err_ellipse_r0,err_ellipse_r1,err_ellipse_ang,conf_flag,extent_flag,sat_src_flag,var_flag_x,pileup_flag,streak_src_flag,significance_x,acis_time,hrc_time,ra_y,dec_y,significance_y,gal_l2,gal_b2,likelihood,var_flag_y,var_inter_hard_flag,b-csc,h-csc,m-csc,s-csc,u-csc,hard_hm,hard_hs,hard_ms,var_intra_index_b,var_intra_prob_b,ks_intra_prob_b,kp_intra_prob_b,var_inter_index_b,var_inter_prob_b,var_inter_sigma_b,u-sdss,g-sdss,r-sdss,i-sdss,z-sdss,24_microns_(MIPS),J,H,K,W1,W2,W3,W4,FUV,NUV,G,Bp,Rp,Bp-R,G-J,G-W2,Bp-H,Bp-W3,Rp-K,J-H,J-W1,W1-W2,u-g,g-r,r-z,i-z,u-z
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1,Unnamed: 62_level_1,Unnamed: 63_level_1,Unnamed: 64_level_1,Unnamed: 65_level_1,Unnamed: 66_level_1,Unnamed: 67_level_1,Unnamed: 68_level_1,Unnamed: 69_level_1,Unnamed: 70_level_1,Unnamed: 71_level_1
2CXO J002046.2-705659,5.192620,-70.949804,306.523591,-45.963984,0.956380,0.804594,44.584318,0.0,0,0.0,0.0,0.0,0,3.588235,22768.771028,,5.192620,-70.949804,3.588235,306.523591,-45.963984,127.049495,0,0,5.067160e-15,3.284658e-15,1.163564e-15,9.246993e-16,,0.206121,0.154903,-0.059963,2.0,0.898197,0.731318,0.860139,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
2CXO J101020.3-124106,152.584628,-12.685272,253.231259,34.215853,0.718688,0.717721,79.403259,0.0,0,0.0,0.0,0.0,0,9.390751,51195.761895,,152.584628,-12.685272,9.390751,253.231259,34.215853,625.852569,0,1,2.397524e-14,1.604152e-14,4.974342e-15,2.974315e-15,,0.113679,0.123673,0.008745,0.0,0.394664,0.806962,0.768943,3.0,0.513402,1.688237e-06,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
2CXO J165409.8-020219,253.540954,-2.038749,16.632152,24.790265,2.002898,1.185102,104.896836,0.0,0,0.0,0.0,0.0,0,4.722222,20772.938789,,253.540954,-2.038749,4.722222,16.632152,24.790265,72.716409,0,0,2.547270e-14,1.946782e-14,3.889532e-15,3.761376e-15,,0.176140,0.016240,-0.167395,0.0,0.174089,0.516371,0.157145,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
2CXO J100301.2+021934,150.755395,2.326167,237.190828,42.708642,1.161515,1.079687,75.542480,0.0,0,0.0,0.0,0.0,0,2.666667,158905.462446,,150.755395,2.326167,2.666667,237.190828,42.708642,81.028272,0,1,2.186402e-15,1.675384e-15,5.853234e-16,0.000000e+00,0.0,0.149906,0.433479,0.151156,0.0,0.444484,0.612317,0.474779,0.0,0.456854,7.386110e-08,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
2CXO J054244.6-404855,85.686088,-40.815288,246.499422,-29.796437,3.891118,2.875106,178.855720,0.0,0,0.0,0.0,0.0,0,3.875000,50405.925987,,85.686088,-40.815288,3.875000,246.499422,-29.796437,26.563678,0,0,5.791593e-15,3.559558e-15,1.471395e-15,9.266752e-16,,-0.108682,-0.039975,0.074953,0.0,0.358657,0.302306,0.460818,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2CXO J180325.0-295432,270.854168,-29.909153,1.117724,-3.835987,0.805650,0.779075,6.289155,0.0,0,0.0,1.0,0.0,0,3.875000,104822.403707,,270.854168,-29.909153,3.875000,1.117724,-3.835987,44.912474,1,0,1.286669e-15,0.000000e+00,5.750358e-16,6.346875e-16,,-0.463460,-0.730793,-0.371018,6.0,0.958639,0.998059,0.993966,5.0,0.716285,3.706027e-08,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
2CXO J181647.4-162129,274.197777,-16.358283,14.477496,0.010426,0.853008,0.796700,174.514643,0.0,0,0.0,0.0,0.0,0,2.052632,18463.958710,,274.197777,-16.358283,2.052632,14.477496,0.010426,26.362438,0,0,1.862584e-14,1.672257e-14,0.000000e+00,0.000000e+00,0.0,0.999375,0.999375,-0.999375,0.0,0.474449,0.587416,0.283705,0.0,0.448714,1.900174e-07,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
2CXO J173555.2+570356,263.980184,57.065711,85.279886,32.634090,4.897011,3.619107,102.905240,0.0,0,0.0,0.0,0.0,0,2.111111,10337.817452,,263.980184,57.065711,2.111111,85.279886,32.634090,16.895662,0,0,0.000000e+00,,0.000000e+00,4.433949e-15,,-0.999375,-0.999375,-0.999375,0.0,0.416042,0.496184,0.243796,,,,19.743,17.738,16.831,16.426,16.143,,,,,,,,,,,,,,,,,,,,,,,2.005,0.907,0.688,0.283,3.600
2CXO J134902.7+035728,207.261495,3.957986,336.132887,63.054502,1.876542,1.016800,133.998688,0.0,0,0.0,0.0,0.0,0,9.082235,11005.007980,,207.261495,3.957986,9.082235,336.132887,63.054502,357.519121,0,0,1.027781e-13,5.846297e-14,1.823370e-14,2.798401e-14,0.0,0.126171,-0.294816,-0.403498,0.0,0.147051,0.218394,0.527038,,,,21.452,21.149,20.913,20.646,20.894,,,,,,,,,,,,,,,,,,,,,,,0.303,0.236,0.019,-0.248,0.558


Note the above function retrives the data from the various MW catalog we have collected from various MW catalogs and the selected features obtained from the NED. 
<small> Note: The NED contains the information about the source from various other catalogs and other properties not included in this work.</small> 
For the sources we can obtain the data directly from the NED using the function **get_NED_data**. This will create seperate CSV files in the folder _data/NED_data_ .

Example is given for 10 sources from the above table.

In [19]:
from choices import get_NED_data
# No need to provide the entire dataframe, an empty dataframe with name as the index will work.
get_NED_data(src_data.sample(10)[[]])

100%|██████████| 10/10 [00:00<00:00, 48.77it/s]


> All the data used in this work with the selected flag filters and the features are stored in the following files, and will be used in the further section
* Training data : *data/training_data/train_data_minmax.csv*
* Unclassified sources : *data/new_src_data/new_sources.csv*

# Model Training and Validation

### Training Data

In the folder **training_data** all the data table consist of 13882 sources and 57 features. These sources are all the sources matched within a cross match radius of 10 arcses.

The information about the sources (including the source catalog, its name in the parent ctalog and the cross-match offset of its parent catalog wit CSC) is given in the data table *id_frame.csv*. 

Using this *id_frame.csv*, the training data with required constrains can be retrived.

We used the training data with the follwing constrains :

* 'conf_flag' : 0 , 
* 'streak_src_flag' : 0 , 
* 'extent_flag' : 0 , 
* 'pileup_flag' : 0 ,
* 'offset' $\leq 1$

Which gives 7703 sources and we have used 41 out of 57 features.
The training data used in this work is given in *train_data_minmax.csv*

In [None]:
data_train = pd.read_csv('data/training_data/train_data_minmax.csv' , index_col='name')
data_train

### Import _make_model_ class
The class _make_model_ is takes in the training data, a classification model(scickit-learn compatible model). This class is can be used to validate the model using CCV method and to train and save the classifier for implementation on the test data.

In [None]:
from utilities_v2 import make_model

### Build the Model: _make_model_ class

_make_model_ takes in the following components
*   name : user defined name of the model (can be any string)
*   train_data : as pandas dataframe
*   label : class label for the training data (list or pandas series)
*   classifier : classifier model
*   oversamples : Oversampling function like Scickit-Learn's _SMOTE_ object.

#### Data
the class _make_model_ takes in training data and the training label as pandas dataframe

In [None]:
# Example Implementation ####################
# x = data.drop(columns=['class'])
# y = data['class']

x = data_train.drop(columns=['class'])
y = data_train['class']

#### Classifier

Next we will use a classifier from scickit-learn _RandomForestClassifier_ 

The user can supply their own classifier for the _make_model_ object with only condition that the classifier must implement the _fit_ function. (Need not worry, as most of the models in Scickit-Learn always implement the _fit_ function)

<small>Note: the parameters we are giving for the model that we are giving here is optained after hyper-parameter tuning of the model.</small>

##### Random Forest classifier

In [None]:
# Create a new make_model object
from sklearn.ensemble import RandomForestClassifier
clf = RandomForestClassifier(n_estimators=400 , max_depth=30 , random_state=np.random.randint(0,999999))

##### LightGBM classifier

In [None]:
import lightgbm as lgb 
def calc_weight(gamma , y):
    l = len(y)
    cl_weight = {}
    cl_dict = y.value_counts().to_dict()
    for cl , val in zip(cl_dict.keys() , cl_dict.values()):
        w = np.exp((l / val)*gamma)
        cl_weight[cl] = w
    #print(cl_weight)
    return cl_weight

In [None]:
gamma = 0.07
cl_weight = calc_weight(gamma , y)
clf = lgb.LGBMClassifier(n_estimators = 100 ,class_weight = cl_weight , objective= 'multiclass', sparse=True , is_unbalance=True , metric=['auc_mu'] ,verbosity = 0 , random_state=42 , num_class=len(np.unique(y)) ,force_col_wise=True)

#### Oversampler

In [None]:
from imblearn.over_sampling import SMOTE
oversampler = SMOTE(k_neighbors=4)

#### Put everything together 

In [None]:
# rf_model = make_model(model_name = 'test_model', classifier=clf, oversampler = oversampler, train_data = x, label=y)
lgb_model = make_model(model_name = 'lgb_model', classifier=clf, oversampler = None, train_data = x, label=y)

### Validate the Model

the object _make_model_ implements *validate* function ehich performs the Cumultive K fold cross validation for the supplied model and for the given data

In [None]:
lgb_model.train()

In [None]:
lgb_model.validate(save_predictions=True, multiprocessing=False, k_fold=20)

Let us see the validation result

The validation results are stored in the attribute _validation_model_ of the _make_model_ object

In [None]:
# Print validation result
print("Confusion Matrix: ")
print(lgb_model.validation_score['class_labels'])
print(lgb_model.validation_score['confusion_matrix'])
print("Overall Scores: ")
print(lgb_model.validation_score['overall_scores'])
print("Class-Wise scores: ")
display(lgb_model.validation_score['class_wise_scores']*100)

### Train the model

Now the above validation function can be used by varying the classifier parameters and then checking the validation result as per the user requirement, and once the results are satisfactoory, the user call the _train_ function of the _make_model_ object which will train and store the supplied classifier. for training, unlike the cross validation where a fraction of th data is used, here the classifier is trained on the entire dataset.

In [None]:
lgb_model.train()

### Save the Model

Next we will use the _save_ function of the object _make_model_ to save the classifier alongwith the validation scores and predictions on the training data

In [None]:
lgb_model.save('models/lightGBM-example.joblib')

# Application

### Load Data: Unidentified sources

In [None]:
all_new = pd.read_csv('data/new_src_data/new_sources.csv' , index_col='name')

### Load Saved Model

In [None]:
import joblib
from utilities_v2 import make_model
lgb_model = joblib.load('models/lightGBM-example.joblib')
lgb_model

### Predict

In [None]:
clf = lgb_model.clf

In [None]:
# from utilities import softmax , norm_prob
def get_pred_table(u):
    pred_prob = (clf.predict_proba(u))
    pred_prob_df = pd.DataFrame(pred_prob , columns=[f'prob_{el}' for el in clf.classes_] , index = u.index.to_list())
    pred_prob_df
    u_df = pd.DataFrame({
        'name' : u.index.to_list() ,
        'class' : clf.predict(u) , 
        'prob' : [np.amax(el) for el in pred_prob] ,
        'prob_margin' : [el[-1]-el[-2] for el in np.sort(pred_prob , axis=1 ,)]
    }).set_index('name')
    u_df = pd.merge(u_df , pred_prob_df , left_index=True , right_index=True)
    u_df.index.name = 'name'
    u_df 
    return u_df

In [None]:
# u_df_var = get_pred_table(variable_src)
u_df = get_pred_table(all_new)
u_df