# Celestial Object Classification

Training a Decision Tree classifier to classify celestial objects based on orbital and physical parameters.

## Dataset Features
- **orbital_period**: Time to complete one orbit (years)
- **axial_tilt**: Tilt angle relative to orbital plane (degrees)
- **mass**: Object mass relative to Earth
- **type**: Classification (Planet, DwarfPlanet, Asteroid)

In [None]:
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
import pandas as pd
import joblib

# Load your dataset
data = pd.read_csv("../data/celestial_objects.csv")  # Make sure this file exists
X = data[["orbital_period", "axial_tilt", "mass"]]
y = data["type"]

print(f"Dataset shape: {data.shape}")
print(f"\nObject types: {y.value_counts()}")
print(f"\nFeature statistics:\n{X.describe()}")

## Training the Model

In [None]:
# Split and train
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model = DecisionTreeClassifier(random_state=42, max_depth=5)
model.fit(X_train, y_train)

print(f"Training set size: {len(X_train)}")
print(f"Test set size: {len(X_test)}")

## Model Evaluation

In [None]:
# Evaluate
train_accuracy = model.score(X_train, y_train)
test_accuracy = model.score(X_test, y_test)

print(f"Training Accuracy: {train_accuracy:.4f}")
print(f"Test Accuracy: {test_accuracy:.4f}")
print(f"\nFeature Importance:")
for feature, importance in zip(X.columns, model.feature_importances_):
    print(f"  {feature}: {importance:.4f}")

## Save the Model

In [None]:
# Save the trained model
model_path = "../models/celestial_classifier_dt.pkl"
joblib.dump(model, model_path)
print(f"Model saved to {model_path}")

# Verify we can load it
loaded_model = joblib.load(model_path)
print(f"Model loaded successfully. Accuracy: {loaded_model.score(X_test, y_test):.4f}")

## Make Predictions

Example predictions on new data

In [None]:
# Example: Make predictions on test set
predictions = model.predict(X_test[:5])
actual = y_test.iloc[:5].values

print("Sample Predictions:")
for i, (pred, actual_val) in enumerate(zip(predictions, actual)):
    status = "✓" if pred == actual_val else "✗"
    print(f"  {status} Predicted: {pred}, Actual: {actual_val}")