In [1]:
import warnings

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score
from sklearn.preprocessing import StandardScaler

from get_data import get_balanced_exoplanet_data

warnings.simplefilter("ignore")

In [None]:
x_train, x_test, y_train, y_test = get_balanced_exoplanet_data()

# Parameters
rng_seed = 2023
depth = 5
boosting_rounds = 100

model = DecisionTreeClassifier(
    max_depth=depth,
    random_state=rng_seed,
    class_weight="balanced",
)
model.fit(x_train, y_train)
pred = model.predict(x_test)

# Confusion matrix
sns.heatmap(
    confusion_matrix(y_test, pred),
    annot=True,
    cmap="Blues",
    cbar=False,
    fmt=".0f",
)
plt.xlabel("Predicted label")
plt.ylabel("True label")

# Metrics
print(precision_score(y_test, pred))
print(recall_score(y_test, pred))
print(accuracy_score(y_test, pred))