# Stellar Classification

**Author:** Esteban Duran

**Description:** The classification of cosmic entities is a fundamental problem
in astronomy. In this project we will be using machine learning techniques to
create a classification model that can reliably classify the cosmic entities of
galaxy, quasar, and star from the [Stellar Classification Dataset - SDSS17](https://www.kaggle.com/datasets/fedesoriano/stellar-classification-dataset-sdss17)
dataset taken by the [ Sloan Digital Sky Survey (SDSS)](https://www.sdss.org).
The dataset contains 100,000 observations where each observation consists of
17 feature columns and 1 class column. The algorithms we will be using include:

- Logistic Regression (baseline)
- Decision Tree
- Random Forest
- Neural Network

I will be comparing these algorithms against each other and selecting the one
that ends up performing the best.

The following describes our dataset features:
1. obj_ID = Object Identifier, the unique value that identifies the object in the image catalog used by the CAS
2. alpha = Right Ascension angle (at J2000 epoch)
3. delta = Declination angle (at J2000 epoch)
4. u = Ultraviolet filter in the photometric system
5. g = Green filter in the photometric system
6. r = Red filter in the photometric system
7. i = Near Infrared filter in the photometric system
8. z = Infrared filter in the photometric system
9. run_ID = Run Number used to identify the specific scan
10. rereun_ID = Rerun Number to specify how the image was processed
11. cam_col = Camera column to identify the scanline within the run
12. field_ID = Field number to identify each field
13. spec_obj_ID = Unique ID used for optical spectroscopic objects (this means that 2 different observations with the same spec_obj_ID must share the output class)
14. class = object class (galaxy, star or quasar object)
15. redshift = redshift value based on the increase in wavelength
16. plate = plate ID, identifies each plate in SDSS
17. MJD = Modified Julian Date, used to indicate when a given piece of SDSS data was taken
18. fiber_ID = fiber ID that identifies the fiber that pointed the light at the focal plane in each observation

**Dataset:** [Stellar Classification Dataset - SDSS17](https://www.kaggle.com/datasets/fedesoriano/stellar-classification-dataset-sdss17)

## Setup

In [1]:
# Import all the modules we will need
import fastbook
from fastbook import *
from fastai.vision.all import *
from fastcore.all import *
from fastai.tabular.all import *

from imblearn.over_sampling import SMOTE

from utils.draw import draw_tree
from utils import parks_ranger

import seaborn as sns

from numpy import random
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier

import torch, numpy as np, pandas as pd

# Set print options to conform to the notebook width
np.set_printoptions(linewidth=140)
torch.set_printoptions(linewidth=140, sci_mode=False, edgeitems=7)
pd.set_option("display.width", 140)

# Set the seaborn theme
sns.set_theme(style="whitegrid")

# ! We need this to get some of the training output to work. This will be fixed
# in a future release of Jupyter for VS Code.
# https://github.com/microsoft/vscode-jupyter/pull/13442#issuecomment-1541584881
from IPython.display import clear_output, DisplayHandle


def update_patch(self, obj):
    clear_output(wait=True)
    self.display(obj)


DisplayHandle.update = update_patch

random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7f42ce601d30>

## Load Data

In [2]:
dataset_name = "fedesoriano/stellar-classification-dataset-sdss17"
dataset_path = URLs.path(dataset_name)

dataset_path

Path('/root/.fastai/archive/stellar-classification-dataset-sdss17')

In [3]:
Path.BASE_PATH = dataset_path

In [4]:
# Download the dataset to a hidden folder and extract it from kaggle
if not dataset_path.exists() or not any(Path(dataset_path).iterdir()):
    import kaggle

    dataset_path.mkdir(parents=True, exist_ok=True)
    kaggle.api.dataset_download_cli(dataset_name, path=dataset_path, unzip=True)

dataset_path.ls()

(#1) [Path('star_classification.csv')]

In [5]:
stellar_df = pd.read_csv(dataset_path / "star_classification.csv", low_memory=False)
stellar_df.head()

Unnamed: 0,obj_ID,alpha,delta,u,g,r,i,z,run_ID,rerun_ID,cam_col,field_ID,spec_obj_ID,class,redshift,plate,MJD,fiber_ID
0,1.237661e+18,135.689107,32.494632,23.87882,22.2753,20.39501,19.16573,18.79371,3606,301,2,79,6.543777e+18,GALAXY,0.634794,5812,56354,171
1,1.237665e+18,144.826101,31.274185,24.77759,22.83188,22.58444,21.16812,21.61427,4518,301,5,119,1.176014e+19,GALAXY,0.779136,10445,58158,427
2,1.237661e+18,142.18879,35.582444,25.26307,22.66389,20.60976,19.34857,18.94827,3606,301,2,120,5.1522e+18,GALAXY,0.644195,4576,55592,299
3,1.237663e+18,338.741038,-0.402828,22.13682,23.77656,21.61162,20.50454,19.2501,4192,301,3,214,1.030107e+19,GALAXY,0.932346,9149,58039,775
4,1.23768e+18,345.282593,21.183866,19.43718,17.58028,16.49747,15.97711,15.54461,8102,301,3,137,6.891865e+18,GALAXY,0.116123,6121,56187,842


In [11]:
# cat_names = ['class']
cont_names = ['obj_ID', 'alpha', 'delta', "u", "g", "r", "i", "z", "run_ID", "rerun_ID", "cam_col", "field_ID", "spec_obj_ID", "redshift", "plate", "MJD", "fiber_ID"]
procs = [Categorify, FillMissing, Normalize]

In [12]:
dls = TabularDataLoaders.from_df(stellar_df, dataset_path, procs=procs, cont_names=cont_names, 
                                 y_names="class", bs=64)

In [13]:
dls.show_batch()

Unnamed: 0,obj_ID,alpha,delta,u,g,r,i,z,run_ID,rerun_ID,cam_col,field_ID,spec_obj_ID,redshift,plate,MJD,fiber_ID,class
0,1.23767e+18,18.754521,7.514645,21.728161,21.56535,20.276541,19.4487,19.07457,5640.999943,301.0,4.0,121.000002,5.12739e+18,0.5930179,4554.000022,56192.999992,151.000005,GALAXY
1,1.237653e+18,6.104825,-9.766508,18.66202,17.547741,17.047541,16.834419,16.73419,1739.999955,301.0,4.0,61.999997,2.152859e+18,-0.0002373091,1912.000101,53293.000041,501.999998,STAR
2,1.237662e+18,229.628801,39.058807,17.260349,15.60575,15.04369,14.86419,14.81529,3926.000004,301.0,6.0,21.000002,3.277565e+18,2.617916e-09,2910.999959,54631.000054,255.999997,GALAXY
3,1.237668e+18,190.13089,21.582037,25.93203,22.281719,20.73608,19.830811,19.523491,5193.999984,301.0,2.0,495.000003,6.738563e+18,0.4942179,5985.000014,56089.000013,189.0,GALAXY
4,1.237679e+18,328.287571,10.622795,25.3353,22.64823,20.898939,19.843399,19.4874,7776.999921,301.0,5.0,115.000001,4.610723e+18,0.581978,4094.999969,55497.000003,593.000007,GALAXY
5,1.237662e+18,230.385649,33.21944,24.371849,21.653219,20.571131,20.212099,19.618521,3926.999976,301.0,4.0,19.000005,1.560503e+18,0.0001722927,1385.9999,53115.999917,20.999993,STAR
6,1.237658e+18,172.501084,51.583419,22.357691,21.48061,20.12575,19.337761,18.721201,2830.000034,301.0,3.0,368.000006,7.53364e+18,0.4905013,6691.000048,56412.999992,887.000012,GALAXY
7,1.237679e+18,349.374606,1.260161,19.34223,18.138651,17.632681,17.24931,17.05443,7716.999953,301.0,1.0,257.999998,4.302156e+17,0.0743958,381.999898,51815.999807,443.0,GALAXY
8,1.237666e+18,18.140198,14.365049,24.60453,21.991779,20.2034,19.01996,18.436279,4828.999994,301.0,3.0,35.999998,5.784995e+18,0.5777519,5138.0,55830.000007,441.0,GALAXY
9,1.237668e+18,190.654296,18.722687,19.892811,18.21567,17.3619,16.982691,16.681431,5313.999978,301.0,1.0,91.0,2.943181e+18,0.05706288,2614.000073,54481.0,288.000003,GALAXY


In [18]:
learn = tabular_learner(dls, metrics=accuracy)
learn.fit(4)

epoch,train_loss,valid_loss,accuracy,time
0,0.161135,0.185651,0.9342,00:14
1,0.130977,0.120452,0.96365,00:11
2,0.11898,0.120772,0.962,00:11
3,0.124708,0.118997,0.96165,00:10


In [16]:
??dls

[0;31mType:[0m        TabularDataLoaders
[0;31mString form:[0m <fastai.tabular.data.TabularDataLoaders object at 0x7f42bbc17910>
[0;31mLength:[0m      2
[0;31mFile:[0m        ~/.conda/envs/stellar/lib/python3.10/site-packages/fastai/tabular/data.py
[0;31mSource:[0m     
[0;32mclass[0m [0mTabularDataLoaders[0m[0;34m([0m[0mDataLoaders[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;34m"Basic wrapper around several `DataLoader`s with factory methods for tabular data"[0m[0;34m[0m
[0;34m[0m    [0;34m@[0m[0mclassmethod[0m[0;34m[0m
[0;34m[0m    [0;34m@[0m[0mdelegates[0m[0;34m([0m[0mTabular[0m[0;34m.[0m[0mdataloaders[0m[0;34m,[0m [0mbut[0m[0;34m=[0m[0;34m[[0m[0;34m"dl_type"[0m[0;34m,[0m [0;34m"dl_kwargs"[0m[0;34m][0m[0;34m)[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0mfrom_df[0m[0;34m([0m[0mcls[0m[0;34m,[0m [0;34m[0m
[0;34m[0m        [0mdf[0m[0;34m:[0m[0mpd[0m[0;34m.[0m[0mDataFrame[0m[0;34m,[0m[0;3