In [None]:
import pandas as pd
import joblib
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, export_text
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report

In [None]:
PATH = '../data/data.csv'
PATH_OVERSAMPLED = '../data/data_oversampled.csv'
MODEL_PATH = '../decision_tree.pkl'

X = 1
O = -1
BLANK = 0

O_WIN   = 0
DRAW    = 1
ONGOING = 2
X_WIN   = 3

In [None]:
df = pd.read_csv(PATH)
df.sample(10)

In [None]:
feature_cols = [str(i) for i in range(9)]
X = df[feature_cols]


y = df['category']

In [None]:
X_train, X_temp, y_train, y_temp = train_test_split(
    X, y, test_size=0.4, random_state=42, stratify=y
)
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp
)

In [None]:
dt = DecisionTreeClassifier(
    criterion='gini',
    max_depth=10,
    min_samples_split=2,
    random_state=42
)

In [None]:
dt.fit(X_train.to_numpy(), y_train.to_numpy())

In [None]:
y_pred_val = dt.predict(X_val)
print("Validation Accuracy :", accuracy_score(y_val, y_pred_val))
print("Validation Precision:", precision_score(y_val, y_pred_val, average='weighted'))
print("Validation Recall   :", recall_score(y_val, y_pred_val, average='weighted'))
print("Validation F1-score :", f1_score(y_val, y_pred_val, average='weighted'))

In [None]:
print("\nClassification Report (Validation):\n",
      classification_report(y_val, y_pred_val, digits=4))

In [None]:
print("\nÁrvore de Decisão (texto):\n", export_text(dt, feature_names=feature_cols))

In [None]:
y_pred_test = dt.predict(X_test)
print("Test  Accuracy :", accuracy_score(y_test, y_pred_test))
print("Test  Precision:", precision_score(y_test, y_pred_test, average='weighted'))
print("Test  Recall   :", recall_score(y_test, y_pred_test, average='weighted'))
print("Test  F1-score :", f1_score(y_test, y_pred_test, average='weighted'))

In [None]:
dt = DecisionTreeClassifier(random_state=42)
dt.fit(X_train, y_train)

joblib.dump(dt, MODEL_PATH)
print(f'Modelo salvo em {MODEL_PATH}')
