In [7]:
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

This code loads the astronomy dataset from a CSV file, selects a set of features to use as input to a decision tree model, and splits the data into training and testing sets. It then trains a decision tree classifier on the training data and evaluates its performance on the test data by computing accuracy and a confusion matrix.

Finally, the code visualizes the decision tree using the plot_tree() function from scikit-learn, and displays a heatmap of the confusion matrix using sns.heatmap() and plt.show().

Overall, this code demonstrates how to train a decision tree classifier on an astronomy dataset, evaluate its performance, and visualize the resulting tree and confusion matrix.

In [8]:
# Load the data
data = pd.read_csv('Datos/Skyserver_SQL2_27_2018 6_51_39 PM.csv')
print(data.head())

# Select the features and target variable
features = ['ra', 'dec', 'u', 'g', 'r', 'i', 'z']
target = 'class'

X = data[features]
y = data[target]


          objid          ra       dec         u         g         r         i  \
0  1.237650e+18  183.531326  0.089693  19.47406  17.04240  15.94699  15.50342   
1  1.237650e+18  183.598370  0.135285  18.66280  17.21449  16.67637  16.48922   
2  1.237650e+18  183.680207  0.126185  19.38298  18.19169  17.47428  17.08732   
3  1.237650e+18  183.870529  0.049911  17.76536  16.60272  16.16116  15.98233   
4  1.237650e+18  183.883288  0.102557  17.55025  16.26342  16.43869  16.55492   

          z  run  rerun  camcol  field     specobjid   class  redshift  plate  \
0  15.22531  752    301       4    267  3.722360e+18    STAR -0.000009   3306   
1  16.39150  752    301       4    267  3.638140e+17    STAR -0.000055    323   
2  16.80125  752    301       4    268  3.232740e+17  GALAXY  0.123111    287   
3  15.90438  752    301       4    269  3.722370e+18    STAR -0.000111   3306   
4  16.61326  752    301       4    269  3.722370e+18    STAR  0.000590   3306   

     mjd  fiberid  
0  549

The data consists of 10,000 observations of space taken by the SDSS. Every observation is described by 17 feature columns and 1 class column which identifies it to be either a star, galaxy or quasar.

In [9]:
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# Train a decision tree classifier
clf = DecisionTreeClassifier()
clf.fit(X_train, y_train)

# Make predictions on the test set
y_pred = clf.predict(X_test)

# Compute accuracy and print confusion matrix
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy}')
cm = confusion_matrix(y_test, y_pred, labels=['STAR', 'GALAXY', 'QSO'])
print('Confusion matrix:')
print(cm)

Accuracy: 0.896
Confusion matrix:
[[1095  107   15]
 [ 131 1348   18]
 [  21   20  245]]


In [None]:
# Plot confusion matrix
sns.set()
sns.heatmap(cm, annot=True, fmt='g', cmap='Blues', xticklabels=['STAR', 'GALAXY', 'QSO'], yticklabels=['STAR', 'GALAXY', 'QSO'])
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()

# Visualize decision tree
plt.figure(figsize=(20,10))
plot_tree(clf, feature_names=features, class_names=['STAR', 'GALAXY', 'QSO'], filled=True)
plt.show()