#### Imports

In [1]:
import pyAgrum as gum
import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import sys, os


from meanForDTinBN import *
from classesSizes.digraph import *


from sklearn.model_selection import train_test_split
from sklearn import tree
from sklearn.preprocessing import LabelEncoder

['/usr/lib/python310.zip', '/usr/lib/python3.10', '/usr/lib/python3.10/lib-dynload', '', '/home/echu/.local/lib/python3.10/site-packages', '/usr/local/lib/python3.10/dist-packages', '/usr/lib/python3/dist-packages', '/home/echu/tesis/pasantia-BICC/asvFormula/bayesianNetworks/..']


### Create a dataset from a bayesian network

In [26]:
def decisionTreeFromDataset(dataset : pd.DataFrame, target_feature, maximum_depth=3):
    encodeCategoricalColumns(dataset)
    X = dataset.drop(target_feature, axis=1)
    y = dataset[target_feature]

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    rf_model = tree.DecisionTreeClassifier(random_state=42, max_depth=maximum_depth)
    rf_model.fit(X_train, y_train)

    return rf_model

def encodeCategoricalColumns(dataset):
    le = LabelEncoder()
    categorical_columns = dataset.select_dtypes(include=['object', 'category', 'bool']).columns
    for column in categorical_columns:
        dataset[column] = le.fit_transform(dataset[column])

def datasetFromBayesianNetwork(biffBNFilename, n):
    model = gum.loadBN(biffBNFilename)
    g = gum.BNDatabaseGenerator(model)
    g.setRandomVarOrder()
    g.drawSamples(n)
    return g.to_pandas()

# Create a BNDatabaseGenerator object from the model
cancerData = datasetFromBayesianNetwork("cancer.bif", 10000)
#print(cancerData.head())

# Encode categorical columns

dtCancer = decisionTreeFromDataset(cancerData, 'Cancer', 2)
dtCancerTree = dtCancer.tree_
#tree.plot_tree(dtCancer, feature_names=cancerData.columns)
plt.show()


### Get the underlying structure of tree, converted to network X

In [34]:
def obtainTreeStructure(decisionTree : tree.DecisionTreeClassifier, featureNames : list[str]) -> nx.DiGraph:
    # Extract the necessary details from the tree
    G = nx.DiGraph()
    children_left = decisionTree.tree_.children_left
    children_right = decisionTree.tree_.children_right
    feature = decisionTree.tree_.feature
    threshold = decisionTree.tree_.threshold
    values = decisionTree.tree_.value

    # Stack to keep track of nodes to visit
    stack = [(0, 0)]  # start with the root node id (0) and its depth (0)
    # Assuming you have a list of feature names

    
    while len(stack) > 0:
        node_id, depth = stack.pop()
        # Modify the label to show the actual feature name
        G.add_node(node_id, label=f"{featureNames[feature[node_id]]}")

        # Check if it's a split node (non-leaf)
        is_split_node = children_left[node_id] != children_right[node_id]
        
        if is_split_node:
            # Add edges to left and right children
            G.add_edge(node_id, children_left[node_id], label=f"X[{feature[node_id]}] <= {threshold[node_id]:.2f}")
            G.add_edge(node_id, children_right[node_id], label=f"X[{feature[node_id]}] > {threshold[node_id]:.2f}")
            
            # Add children nodes to the stack
            stack.append((children_left[node_id], depth + 1))
            stack.append((children_right[node_id], depth + 1))
        else:
            # Leaf node, add node with its value
            G.add_node(node_id, label=f"Leaf {node_id}: {np.around(values[node_id], 3)}")
    return G

def drawDecisionTree(decisionTree : nx.DiGraph):
    plt.figure(figsize=(12, 8))
    pos = nx.nx_agraph.graphviz_layout(decisionTree, prog="dot") #To have a tree layout
    nx.draw(decisionTree, pos, with_labels=True, labels={n: decisionTree.nodes[n].get('label', str(n)) for n in decisionTree.nodes()})
    nx.draw_networkx_edge_labels(decisionTree, pos, edge_labels=nx.get_edge_attributes(decisionTree, 'label'))
    plt.show()

networkTree = obtainTreeStructure(dtCancer, cancerData.columns)
#drawDecisionTree(networkTree)

#graph.nodes[node_id]['label']
