In [4]:
from catboost import CatBoostClassifier
import pandas as pd

In [6]:
df = pd.read_csv('star_classification.csv')

In [7]:
df.head()

Unnamed: 0,obj_ID,alpha,delta,u,g,r,i,z,run_ID,rerun_ID,cam_col,field_ID,spec_obj_ID,class,redshift,plate,MJD,fiber_ID
0,1.237661e+18,135.689107,32.494632,23.87882,22.2753,20.39501,19.16573,18.79371,3606,301,2,79,6.543777e+18,GALAXY,0.634794,5812,56354,171
1,1.237665e+18,144.826101,31.274185,24.77759,22.83188,22.58444,21.16812,21.61427,4518,301,5,119,1.176014e+19,GALAXY,0.779136,10445,58158,427
2,1.237661e+18,142.18879,35.582444,25.26307,22.66389,20.60976,19.34857,18.94827,3606,301,2,120,5.1522e+18,GALAXY,0.644195,4576,55592,299
3,1.237663e+18,338.741038,-0.402828,22.13682,23.77656,21.61162,20.50454,19.2501,4192,301,3,214,1.030107e+19,GALAXY,0.932346,9149,58039,775
4,1.23768e+18,345.282593,21.183866,19.43718,17.58028,16.49747,15.97711,15.54461,8102,301,3,137,6.891865e+18,GALAXY,0.116123,6121,56187,842


In [23]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 100000 entries, 0 to 99999
Data columns (total 18 columns):
 #   Column       Non-Null Count   Dtype  
---  ------       --------------   -----  
 0   obj_ID       100000 non-null  float64
 1   alpha        100000 non-null  float64
 2   delta        100000 non-null  float64
 3   u            100000 non-null  float64
 4   g            100000 non-null  float64
 5   r            100000 non-null  float64
 6   i            100000 non-null  float64
 7   z            100000 non-null  float64
 8   run_ID       100000 non-null  int64  
 9   rerun_ID     100000 non-null  int64  
 10  cam_col      100000 non-null  int64  
 11  field_ID     100000 non-null  int64  
 12  spec_obj_ID  100000 non-null  float64
 13  class        100000 non-null  object 
 14  redshift     100000 non-null  float64
 15  plate        100000 non-null  int64  
 16  MJD          100000 non-null  int64  
 17  fiber_ID     100000 non-null  int64  
dtypes: float64(10), int64(7),

## Background

This dataset contains over 100,000 observations separated into three class types: Galaxies, Quasars, and Stars. In addition to the class,
each observation has 17 other defining features, several of which are different identification types that were irrelevant for the purposes
of this analysis. For completeness, I will list them all where now:

* obj_ID = Object Identifier, the unique value that identifies the object in the image catalog used by the CAS.

* alpha = Right Ascension angle (at J2000 epoch). This is the angle (in degrees) between the Vernal Equinox 
and the desired point on the celestial sphere.

* delta = Declination angle (at J2000 epoch). The is the angle (in degrees) between the celestial equator and
the desired point on the celestial sphere.

* u = Ultraviolet filter in the photometric system. 3543 Angstroms (354.3 nm)

* g = Green filter in the photometric system. 4770 Angstroms (477.0 nm)

* r = Red filter in the photometric system. 6231 Angstroms (623.1 nm)

* i = Near Infrared filter in the photometric system. 7625 Angstroms (762.5 nm)

* z = Infrared filter in the photometric system. 9134 Angstroms (913.4 nm)

* run_ID = Run Number used to identify the specific scan.

* rereun_ID = Rerun Number to specify how the image was processed.

* cam_col = Camera column to identify the scanline within the run.

* field_ID = Field number to identify each field.

* spec_obj_ID = Unique ID used for optical spectroscopic objects (this means that 2 different observations with the same spec_obj_ID must share the output class).

* class = Object class (galaxy, star, or quasar object).

* redshift = Redshift value based on the increase in wavelength. The more red shifted light is, the further it has traveled from its
point of origin

* plate = Plate ID, identifies each plate in SDSS.

* MJD = Modified Julian Date, used to indicate when a given piece of SDSS data was taken.

* fiber_ID = Fiber ID that identifies the fiber that pointed the light at the focal plane in each observation.

As explained earlier, the ID attributes were removed from the table before I began looking for trends.


In [26]:
df.describe()

Unnamed: 0,obj_ID,alpha,delta,u,g,r,i,z,run_ID,rerun_ID,cam_col,field_ID,spec_obj_ID,redshift,plate,MJD,fiber_ID
count,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0
mean,1.237665e+18,177.629117,24.135305,21.980468,20.531387,19.645762,19.084854,18.66881,4481.36606,301.0,3.51161,186.13052,5.783882e+18,0.576661,5137.00966,55588.6475,449.31274
std,8438560000000.0,96.502241,19.644665,31.769291,31.750292,1.85476,1.757895,31.728152,1964.764593,0.0,1.586912,149.011073,3.324016e+18,0.730707,2952.303351,1808.484233,272.498404
min,1.237646e+18,0.005528,-18.785328,-9999.0,-9999.0,9.82207,9.469903,-9999.0,109.0,301.0,1.0,11.0,2.995191e+17,-0.009971,266.0,51608.0,1.0
25%,1.237659e+18,127.518222,5.146771,20.352353,18.96523,18.135828,17.732285,17.460677,3187.0,301.0,2.0,82.0,2.844138e+18,0.054517,2526.0,54234.0,221.0
50%,1.237663e+18,180.9007,23.645922,22.179135,21.099835,20.12529,19.405145,19.004595,4188.0,301.0,4.0,146.0,5.614883e+18,0.424173,4987.0,55868.5,433.0
75%,1.237668e+18,233.895005,39.90155,23.68744,22.123767,21.044785,20.396495,19.92112,5326.0,301.0,5.0,241.0,8.332144e+18,0.704154,7400.25,56777.0,645.0
max,1.237681e+18,359.99981,83.000519,32.78139,31.60224,29.57186,32.14147,29.38374,8162.0,301.0,6.0,989.0,1.412694e+19,7.011245,12547.0,58932.0,1000.0


In [46]:
# уничтожаем так как одно и тоже число
X = df.drop(['rerun_ID'], axis=1)

In [47]:
from sklearn.model_selection import train_test_split

y = df['class']
X_train, X_test, y_train, y_test = train_test_split(X, y, 
                                                    train_size=0.75, 
                                                    random_state=42)

In [48]:
# X_train.head()
X_test2 = X_test.drop('class', axis=1)
X_test2.columns

Index(['obj_ID', 'alpha', 'delta', 'u', 'g', 'r', 'i', 'z', 'run_ID',
       'cam_col', 'field_ID', 'spec_obj_ID', 'redshift', 'plate', 'MJD',
       'fiber_ID'],
      dtype='object')

In [None]:
from catboost import FeaturesData



In [50]:
# model = CatBoostClassifier( loss_function='MultiLogloss',
#                             eval_metric='HammingLoss',
#                            random_seed=21,
#                             iterations=500,
#                             class_names=['GALAXY', 'QSO', 'STAR'])
# # Fit model
train_label = ['obj_ID', 'alpha', 'delta', 'u', 'g', 'r', 'i', 'z', 'run_ID',  
               'cam_col', 'field_ID', 'spec_obj_ID', 'redshift', 'plate', 'MJD', 'fiber_ID']
cat_features = [ 'run_ID',  'cam_col', 'field_ID', 'plate', 'MJD', 'fiber_ID']

# Get predicted probabilities for each class
#preds_proba = model.predict_proba(eval_data)
# Get predicted RawFormulaVal
#preds_raw = model.predict(eval_data, prediction_type='RawFormulaVal')

In [54]:
model = CatBoostClassifier(cat_features=['class'] ) # классификатор
model.fit(X_train, y_train) # обучение классификатора

Learning rate set to 0.098617
0:	learn: 0.9122489	total: 147ms	remaining: 2m 26s
1:	learn: 0.7732497	total: 242ms	remaining: 2m
2:	learn: 0.6639896	total: 308ms	remaining: 1m 42s
3:	learn: 0.5759337	total: 383ms	remaining: 1m 35s
4:	learn: 0.5031165	total: 466ms	remaining: 1m 32s
5:	learn: 0.4419333	total: 554ms	remaining: 1m 31s
6:	learn: 0.3899735	total: 648ms	remaining: 1m 31s
7:	learn: 0.3454694	total: 697ms	remaining: 1m 26s
8:	learn: 0.3067920	total: 789ms	remaining: 1m 26s
9:	learn: 0.2732069	total: 837ms	remaining: 1m 22s
10:	learn: 0.2437343	total: 861ms	remaining: 1m 17s
11:	learn: 0.2179067	total: 907ms	remaining: 1m 14s
12:	learn: 0.1951059	total: 968ms	remaining: 1m 13s
13:	learn: 0.1749122	total: 1s	remaining: 1m 10s
14:	learn: 0.1570169	total: 1.04s	remaining: 1m 8s
15:	learn: 0.1410932	total: 1.08s	remaining: 1m 6s
16:	learn: 0.1269432	total: 1.13s	remaining: 1m 5s
17:	learn: 0.1142942	total: 1.18s	remaining: 1m 4s
18:	learn: 0.1029872	total: 1.23s	remaining: 1m 3s
19:	

162:	learn: 0.0000415	total: 9.12s	remaining: 46.8s
163:	learn: 0.0000405	total: 9.16s	remaining: 46.7s
164:	learn: 0.0000397	total: 9.2s	remaining: 46.6s
165:	learn: 0.0000388	total: 9.24s	remaining: 46.4s
166:	learn: 0.0000380	total: 9.28s	remaining: 46.3s
167:	learn: 0.0000380	total: 9.32s	remaining: 46.2s
168:	learn: 0.0000374	total: 9.37s	remaining: 46.1s
169:	learn: 0.0000369	total: 9.41s	remaining: 46s
170:	learn: 0.0000361	total: 9.45s	remaining: 45.8s
171:	learn: 0.0000361	total: 9.48s	remaining: 45.7s
172:	learn: 0.0000361	total: 9.54s	remaining: 45.6s
173:	learn: 0.0000354	total: 9.58s	remaining: 45.5s
174:	learn: 0.0000347	total: 9.62s	remaining: 45.4s
175:	learn: 0.0000346	total: 9.67s	remaining: 45.3s
176:	learn: 0.0000346	total: 9.71s	remaining: 45.1s
177:	learn: 0.0000339	total: 9.75s	remaining: 45s
178:	learn: 0.0000334	total: 9.79s	remaining: 44.9s
179:	learn: 0.0000331	total: 9.84s	remaining: 44.8s
180:	learn: 0.0000326	total: 9.88s	remaining: 44.7s
181:	learn: 0.000

321:	learn: 0.0000134	total: 16.3s	remaining: 34.2s
322:	learn: 0.0000133	total: 16.3s	remaining: 34.2s
323:	learn: 0.0000133	total: 16.3s	remaining: 34.1s
324:	learn: 0.0000132	total: 16.4s	remaining: 34s
325:	learn: 0.0000132	total: 16.4s	remaining: 33.9s
326:	learn: 0.0000132	total: 16.5s	remaining: 33.9s
327:	learn: 0.0000131	total: 16.5s	remaining: 33.8s
328:	learn: 0.0000130	total: 16.5s	remaining: 33.7s
329:	learn: 0.0000129	total: 16.6s	remaining: 33.7s
330:	learn: 0.0000129	total: 16.6s	remaining: 33.6s
331:	learn: 0.0000128	total: 16.7s	remaining: 33.6s
332:	learn: 0.0000127	total: 16.7s	remaining: 33.5s
333:	learn: 0.0000127	total: 16.8s	remaining: 33.4s
334:	learn: 0.0000126	total: 16.8s	remaining: 33.4s
335:	learn: 0.0000125	total: 16.9s	remaining: 33.3s
336:	learn: 0.0000125	total: 16.9s	remaining: 33.3s
337:	learn: 0.0000124	total: 16.9s	remaining: 33.2s
338:	learn: 0.0000123	total: 17s	remaining: 33.1s
339:	learn: 0.0000123	total: 17s	remaining: 33.1s
340:	learn: 0.0000

483:	learn: 0.0000083	total: 23.2s	remaining: 24.7s
484:	learn: 0.0000083	total: 23.2s	remaining: 24.6s
485:	learn: 0.0000083	total: 23.3s	remaining: 24.6s
486:	learn: 0.0000083	total: 23.3s	remaining: 24.5s
487:	learn: 0.0000083	total: 23.3s	remaining: 24.5s
488:	learn: 0.0000083	total: 23.4s	remaining: 24.4s
489:	learn: 0.0000082	total: 23.4s	remaining: 24.4s
490:	learn: 0.0000082	total: 23.5s	remaining: 24.3s
491:	learn: 0.0000082	total: 23.5s	remaining: 24.3s
492:	learn: 0.0000082	total: 23.5s	remaining: 24.2s
493:	learn: 0.0000082	total: 23.6s	remaining: 24.2s
494:	learn: 0.0000082	total: 23.6s	remaining: 24.1s
495:	learn: 0.0000081	total: 23.7s	remaining: 24s
496:	learn: 0.0000081	total: 23.7s	remaining: 24s
497:	learn: 0.0000081	total: 23.8s	remaining: 23.9s
498:	learn: 0.0000081	total: 23.8s	remaining: 23.9s
499:	learn: 0.0000080	total: 23.8s	remaining: 23.8s
500:	learn: 0.0000080	total: 23.9s	remaining: 23.8s
501:	learn: 0.0000080	total: 23.9s	remaining: 23.7s
502:	learn: 0.00

642:	learn: 0.0000063	total: 31.3s	remaining: 17.4s
643:	learn: 0.0000062	total: 31.4s	remaining: 17.4s
644:	learn: 0.0000062	total: 31.4s	remaining: 17.3s
645:	learn: 0.0000062	total: 31.5s	remaining: 17.2s
646:	learn: 0.0000062	total: 31.5s	remaining: 17.2s
647:	learn: 0.0000062	total: 31.6s	remaining: 17.1s
648:	learn: 0.0000062	total: 31.6s	remaining: 17.1s
649:	learn: 0.0000062	total: 31.6s	remaining: 17s
650:	learn: 0.0000062	total: 31.7s	remaining: 17s
651:	learn: 0.0000062	total: 31.7s	remaining: 16.9s
652:	learn: 0.0000062	total: 31.8s	remaining: 16.9s
653:	learn: 0.0000061	total: 31.8s	remaining: 16.8s
654:	learn: 0.0000061	total: 31.9s	remaining: 16.8s
655:	learn: 0.0000061	total: 31.9s	remaining: 16.7s
656:	learn: 0.0000061	total: 32s	remaining: 16.7s
657:	learn: 0.0000061	total: 32s	remaining: 16.6s
658:	learn: 0.0000061	total: 32.1s	remaining: 16.6s
659:	learn: 0.0000061	total: 32.2s	remaining: 16.6s
660:	learn: 0.0000060	total: 32.2s	remaining: 16.5s
661:	learn: 0.000006

804:	learn: 0.0000048	total: 38.8s	remaining: 9.39s
805:	learn: 0.0000048	total: 38.8s	remaining: 9.34s
806:	learn: 0.0000048	total: 38.8s	remaining: 9.29s
807:	learn: 0.0000048	total: 38.9s	remaining: 9.23s
808:	learn: 0.0000048	total: 38.9s	remaining: 9.19s
809:	learn: 0.0000048	total: 38.9s	remaining: 9.13s
810:	learn: 0.0000048	total: 39s	remaining: 9.09s
811:	learn: 0.0000048	total: 39s	remaining: 9.04s
812:	learn: 0.0000048	total: 39.1s	remaining: 8.99s
813:	learn: 0.0000048	total: 39.1s	remaining: 8.94s
814:	learn: 0.0000048	total: 39.2s	remaining: 8.89s
815:	learn: 0.0000048	total: 39.2s	remaining: 8.84s
816:	learn: 0.0000048	total: 39.3s	remaining: 8.79s
817:	learn: 0.0000048	total: 39.3s	remaining: 8.75s
818:	learn: 0.0000048	total: 39.4s	remaining: 8.7s
819:	learn: 0.0000047	total: 39.4s	remaining: 8.65s
820:	learn: 0.0000047	total: 39.4s	remaining: 8.6s
821:	learn: 0.0000047	total: 39.5s	remaining: 8.55s
822:	learn: 0.0000047	total: 39.5s	remaining: 8.5s
823:	learn: 0.00000

965:	learn: 0.0000040	total: 45.9s	remaining: 1.62s
966:	learn: 0.0000040	total: 46s	remaining: 1.57s
967:	learn: 0.0000040	total: 46s	remaining: 1.52s
968:	learn: 0.0000040	total: 46.1s	remaining: 1.47s
969:	learn: 0.0000040	total: 46.1s	remaining: 1.43s
970:	learn: 0.0000040	total: 46.2s	remaining: 1.38s
971:	learn: 0.0000039	total: 46.2s	remaining: 1.33s
972:	learn: 0.0000039	total: 46.2s	remaining: 1.28s
973:	learn: 0.0000039	total: 46.3s	remaining: 1.24s
974:	learn: 0.0000039	total: 46.3s	remaining: 1.19s
975:	learn: 0.0000039	total: 46.4s	remaining: 1.14s
976:	learn: 0.0000039	total: 46.4s	remaining: 1.09s
977:	learn: 0.0000039	total: 46.5s	remaining: 1.04s
978:	learn: 0.0000039	total: 46.5s	remaining: 998ms
979:	learn: 0.0000039	total: 46.6s	remaining: 950ms
980:	learn: 0.0000039	total: 46.6s	remaining: 903ms
981:	learn: 0.0000039	total: 46.7s	remaining: 855ms
982:	learn: 0.0000039	total: 46.7s	remaining: 808ms
983:	learn: 0.0000039	total: 46.8s	remaining: 760ms
984:	learn: 0.00

<catboost.core.CatBoostClassifier at 0x7f8c5df304c0>

In [55]:
# Get predicted classes
predict = model.predict(X_test)
predict.head()


array([['GALAXY'],
       ['STAR'],
       ['STAR'],
       ...,
       ['STAR'],
       ['GALAXY'],
       ['GALAXY']], dtype=object)

In [62]:
p = pd.DataFrame(predict)
p.head()

Unnamed: 0,0
0,GALAXY
1,STAR
2,STAR
3,STAR
4,STAR


In [59]:
X_test.head()


Unnamed: 0,obj_ID,alpha,delta,u,g,r,i,z,run_ID,cam_col,field_ID,spec_obj_ID,class,redshift,plate,MJD,fiber_ID
75721,1.237679e+18,16.95689,3.64613,23.33542,21.95143,20.48149,19.603,19.13094,7712,6,442,4.855017e+18,GALAXY,0.506237,4312,55511,495
80184,1.237662e+18,240.06324,6.134131,17.86033,16.79228,16.43001,16.30923,16.25873,3894,1,243,2.448928e+18,STAR,0.000345,2175,54612,348
19864,1.237679e+18,30.887222,1.18871,18.18911,16.89469,16.42161,16.24627,16.18549,7717,1,536,8.255357e+18,STAR,4e-06,7332,56683,943
76699,1.237668e+18,247.594401,10.88778,24.99961,21.71203,21.47148,21.30532,21.29109,5323,1,134,4.577999e+18,STAR,-0.000291,4066,55444,326
92991,1.237679e+18,18.896451,-5.26133,23.76648,21.79737,20.69543,20.23403,19.97464,7881,3,148,8.910472e+18,STAR,-0.000136,7914,57331,363


In [61]:
X.iloc[[75721,80184,19864,76699,92991]]

Unnamed: 0,obj_ID,alpha,delta,u,g,r,i,z,run_ID,cam_col,field_ID,spec_obj_ID,class,redshift,plate,MJD,fiber_ID
75721,1.237679e+18,16.95689,3.64613,23.33542,21.95143,20.48149,19.603,19.13094,7712,6,442,4.855017e+18,GALAXY,0.506237,4312,55511,495
80184,1.237662e+18,240.06324,6.134131,17.86033,16.79228,16.43001,16.30923,16.25873,3894,1,243,2.448928e+18,STAR,0.000345,2175,54612,348
19864,1.237679e+18,30.887222,1.18871,18.18911,16.89469,16.42161,16.24627,16.18549,7717,1,536,8.255357e+18,STAR,4e-06,7332,56683,943
76699,1.237668e+18,247.594401,10.88778,24.99961,21.71203,21.47148,21.30532,21.29109,5323,1,134,4.577999e+18,STAR,-0.000291,4066,55444,326
92991,1.237679e+18,18.896451,-5.26133,23.76648,21.79737,20.69543,20.23403,19.97464,7881,3,148,8.910472e+18,STAR,-0.000136,7914,57331,363


In [71]:
from catboost.utils import eval_metric



In [91]:

#res = pd.DataFrame(y_test).join( pd.DataFrame(predict))
test_pool = Pool( y_test, predict)
metric = eval_metric(y_test, predict, 'Precision')
print(metric)

KeyError: 0

In [67]:
for metric in ('Precision', 'Recall', 'F1'):
    print(metric)
    values = eval_metric(y_test, predict, metric)
    for model, value in zip(clf.classes_, values):
        print(f'class={cls}: {value:.4f}')
    print()

Precision


KeyError: 0

In [92]:
model.save_model('astra.cbm')

In [94]:
X_test.to_csv('astra.csv')

In [38]:
ff = model.get_feature_importance()
ff

array([ 0.27624387,  0.27778713,  0.1622283 ,  0.10247422,  0.1581494 ,
        0.25616265,  0.21896545,  0.29039967,  0.27757768,  0.35941504,
        0.16894315, 91.3056414 ,  5.55423445,  0.15600785,  0.17002827,
        0.26574147])

In [39]:
nn = model.get_feature_names()
nn

AttributeError: 'CatBoostClassifier' object has no attribute 'get_feature_names'