## Necessary imports

In [15]:
import numpy as np

from dojo.tree import ClassificationTree
from sklearn.datasets import load_iris

## Loading data

In [16]:
iris = load_iris()
data = np.column_stack((iris["data"], iris["target"]))
np.random.shuffle(data)

X, y = data[:, :-1], data[:, -1]

train_size = int(150*0.7)
X_train, X_test, y_train, y_test = X[:train_size], X[train_size:], y[:train_size], y[train_size:]

## Building the model

In [17]:
tree = ClassificationTree(max_depth=3)
tree.fit(X_train, y_train)

In [18]:
tree

ClassificationTree(
    criterion='gini',
    max_depth=3,
    root=<dojo.tree.utils.structure.Node object at 0x10fa94ba8>,
)

## You can now easily visualize the tree itself with a simple call of a method!

In [19]:
tree.visualize()

 Is feature[3] >= 1.0?
 --> True:
   Is feature[3] >= 1.8?
   --> True:
     Prediction: 2.0
   --> False:
     Prediction: 1.0
 --> False:
   Prediction: 0.0


## Let's predict some probabilities

In [20]:
tree.predict_proba(X_test)

[{2.0: 0.10256410256410256, 1.0: 0.8974358974358975},
 {2.0: 0.10256410256410256, 1.0: 0.8974358974358975},
 {2.0: 1.0},
 {0.0: 1.0},
 {0.0: 1.0},
 {2.0: 0.10256410256410256, 1.0: 0.8974358974358975},
 {0.0: 1.0},
 {2.0: 0.10256410256410256, 1.0: 0.8974358974358975},
 {2.0: 0.10256410256410256, 1.0: 0.8974358974358975},
 {2.0: 0.10256410256410256, 1.0: 0.8974358974358975},
 {0.0: 1.0},
 {0.0: 1.0},
 {0.0: 1.0},
 {2.0: 0.10256410256410256, 1.0: 0.8974358974358975},
 {2.0: 1.0},
 {2.0: 1.0},
 {0.0: 1.0},
 {2.0: 0.10256410256410256, 1.0: 0.8974358974358975},
 {2.0: 1.0},
 {2.0: 0.10256410256410256, 1.0: 0.8974358974358975},
 {2.0: 1.0},
 {2.0: 1.0},
 {2.0: 1.0},
 {2.0: 0.10256410256410256, 1.0: 0.8974358974358975},
 {0.0: 1.0},
 {2.0: 0.10256410256410256, 1.0: 0.8974358974358975},
 {2.0: 0.10256410256410256, 1.0: 0.8974358974358975},
 {2.0: 1.0},
 {0.0: 1.0},
 {0.0: 1.0},
 {0.0: 1.0},
 {2.0: 1.0},
 {2.0: 1.0},
 {0.0: 1.0},
 {2.0: 1.0},
 {2.0: 1.0},
 {0.0: 1.0},
 {2.0: 0.10256410256410256,

## Evaluating the model

In [21]:
tree.evaluate(X_train, y_train)

Accuracy score: 0.9619047619047619
