In [37]:
from sklearn.tree import DecisionTreeClassifier
from dataclasses import dataclass
from typing import List, Optional, Tuple
from dataclasses import asdict
import json

@dataclass
class TreeNode:
    """Class to store decision tree node information"""
    # Unique identifier for the node in the tree
    node_id: int
    
    # The name of the feature used for the decision at this node. 
    # If the node is a leaf, this will be `None`.
    feature_name: Optional[str]
    
    # The threshold value for the feature used to split the data at this node. 
    # If the node is a leaf, this will be `None`.
    threshold: Optional[float]
    
    # The node ID of the left child node. If the node is a leaf, this will be `None`.
    left_child: Optional[int]
    
    # The node ID of the right child node. If the node is a leaf, this will be `None`.
    right_child: Optional[int]
    
    # Indicates whether this node is a leaf node (`True` if leaf, `False` if internal).
    is_leaf: bool
    
    # The class label predicted by the leaf node. 
    # Only set if the node is a leaf; otherwise, it is `None`.
    class_label: Optional[str]
    
    # The number of samples (data points) that reached this node during training.
    samples: int

def extract_tree_structure(tree_classifier: DecisionTreeClassifier, feature_names: List[str], target_names: List[str]) -> List[TreeNode]: 
    """
    Extract node information from a trained DecisionTreeClassifier

    Parameters:
    -----------
    tree_classifier : DecisionTreeClassifier
        A trained sklearn DecisionTreeClassifier
    feature_names : List[str]
        A list of feature names
    target_names : List[str]
        A list of target class labels

    Returns:
    --------
    List[TreeNode]
        List of TreeNode objects containing the tree structure
    """
    tree = tree_classifier.tree_
    nodes = []

    for node_id in range(tree.node_count):
        # Check if node is leaf
        is_leaf = tree.children_left[node_id] == -1

        # Get node information
        if is_leaf:
            # Get the class label based on the majority class in the leaf
            class_label_index = int(tree.value[node_id].argmax())
            class_label = target_names[class_label_index]
            
            node = TreeNode(
                node_id=node_id,
                feature_name=None,
                threshold=None,
                left_child=None,
                right_child=None,
                is_leaf=True,
                class_label=class_label,
                samples=int(tree.n_node_samples[node_id])
            )
        else:
            feature_name = feature_names[int(tree.feature[node_id])]
            threshold = float(tree.threshold[node_id])
            left_child = int(tree.children_left[node_id])
            right_child = int(tree.children_right[node_id])

            node = TreeNode(
                node_id=node_id,
                feature_name=feature_name,
                threshold=threshold,
                left_child=left_child,
                right_child=right_child,
                is_leaf=False,
                class_label=None,
                samples=int(tree.n_node_samples[node_id])
            )

        nodes.append(node)

    return nodes

In [38]:
# Import necessary libraries
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier

# Load the Iris dataset
iris = load_iris()
X = iris.data  # Features
y = iris.target  # Labels

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# Create a Decision Tree Classifier
clf = DecisionTreeClassifier(random_state=42)

# Train the classifier
clf = clf.fit(X_train, y_train)


In [39]:

def save_tree_to_json(nodes, filename: str, indent: int = 4):
    """
    Save the tree structure to a JSON file
    
    Parameters:
    -----------
    nodes : List[TreeNode]
        List of TreeNode objects to save
    filename : str
        Path to save the JSON file
    indent : int
        Number of spaces for indentation
    """
    # Convert TreeNodes to dictionaries
    nodes_dict = [asdict(node) for node in nodes]
    
    # Save to file with indentation
    with open(filename, 'w') as f:
        json.dump(nodes_dict, f, indent=indent)

In [40]:
save_tree_to_json(extract_tree_structure(clf, iris.feature_names, iris.target_names), filename="test.json")

the code works for the iris dataset generated tree, let's try a bigger one

In [41]:
import pandas as pd

In [42]:
df = pd.read_csv ("test_dataset.csv")

In [43]:
df.columns

Index(['y_', 'ID', 'engine_age', 'length', 'power', 'month', 'landing',
       'weight', 'value', 'value_cpi', 'price', 'y_month', 'year', 'patch',
       'dist', 'patch_area', 'weight_lym', 'weight_lm', 'val_lm', 'val_lym',
       'nao_index', 'surf_temp'],
      dtype='object')

In [44]:
for col in list (df.columns):
    print (col, ":", df[col].values)

y_ : [1. 1. 0. ... 1. 0. 1.]
ID : [1993001257 1993005128 1996007882 ... 2007039109 2001015121 2008043528]
engine_age : [10. 26. 32. ...  3.  9.  8.]
length : [10.5  21.3  12.13 ...  7.   14.16 35.49]
power : [ 367.  970.  190. ...  144.  291. 1000.]
month : [ 44 131 170 ... 124 121 198]
landing : ['RISØR' 'BÅTSFJORD' 'ØKSNES' ... 'BØMLO' 'RØST' 'MÅSØY']
weight : [5.0000e+00 1.9000e+01 0.0000e+00 ... 1.5000e+00 0.0000e+00 9.4625e+04]
value : [9.5700000e+01 5.6459000e+02 0.0000000e+00 ... 1.9170000e+01 0.0000000e+00
 5.0858288e+05]
value_cpi : [1.1962000e+02 6.1169000e+02 0.0000000e+00 ... 2.0680000e+01 0.0000000e+00
 4.8996424e+05]
price : [19.14       29.71526316  0.         ... 12.78        0.
  5.37472   ]
y_month : [ 8 11  2 ...  4  1  6]
year : [ 3. 10. 14. ... 10. 10. 16.]
patch : ['09-16' '03-03' '04-27' ... '08-15' '07-26' '03-07']
dist : [ 13.6  59.6 286.7 ...  10.6 457.5 220.2]
patch_area : [3229 2122 4596 ... 1418 5377 2768]
weight_lym : [0. 0. 0. ... 1. 0. 0.]
weight_lm : [1

In [45]:
#remove 
# non numerical data 
# value adjusted for inflation (check dataset page for more info)
# other non relevant/not known features 
df.drop(["landing", "patch", "value_cpi", "y_", "ID", "dist", "patch_area", "weight_lym", "weight_lm", "val_lm", "val_lym", "nao_index", "price"], axis=1, inplace=True)

In [46]:
df.columns

Index(['engine_age', 'length', 'power', 'month', 'weight', 'value', 'y_month',
       'year', 'surf_temp'],
      dtype='object')

In [48]:
y = df["value"]
X = df[['engine_age', 'length', 'power', 'month', 'weight', 'y_month',
       'year', 'surf_temp']]

In [57]:
#make y categorical
labels = [
    "Poor Session", "Below Average", "Average Session", "Above Average", 
    "Good Session", "Great Session", "Excellent Session", 
    "Outstanding", "Legendary", "Epic"
]
# Split the values into 10 categories with meaningful labels
y = pd.cut(y, bins=10, labels=labels)

In [58]:
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# Create a Decision Tree Classifier
clf2 = DecisionTreeClassifier(random_state=42)

# Train the classifier
clf2 = clf.fit(X_train, y_train)

In [61]:
save_tree_to_json(extract_tree_structure(clf, feature_names=['engine_age', 'length', 'power', 'month', 'weight', 'y_month',
       'year', 'surf_temp'], target_names=labels), filename="fishingTree.json")