# DecisionTree

Below is the dat set.

In [9]:
import plotly.graph_objs as go
from decision_tree import DecisionTree
import pandas as pd
import random
import numpy as np

def loadXY(data, sample):
    X, Y = [], []
    dataX = data.drop(['类别'], axis=1).T
    dataY = data['类别'].T
    for s in sample:
        X.append(np.array(dataX[s]))
        Y.append(dataY[s])
    return X, Y

data = pd.read_csv('loan.csv', header=0,index_col=0)

go.Figure(data=[go.Table(
    header=dict(values=data.keys()),
    cells=dict(values=[x[1] for x in data.items() ]))
]).show()

tag = {}
for col in data:
    tag[col] = {}
    mp = list(set(data[col]))
    mp = {mp[i]:i for i in range(len(mp))}
    tag[col] = {mp[i]:i for i in mp}
    data[col] = data[col].apply(lambda x:mp[x])

size = len(data)
train = random.sample(list(range(1,size+1)), 10)
test = list(set(list(range(1,size+1))) - set(train))

trainX, trainY = loadXY(data, train)
testX, testY = loadXY(data, test)

tree = DecisionTree(trainX, trainY, 0, mode='C45')

print("train")
for x, y in zip(trainX, trainY):
    print(tree.Predict(x), y)

print("test")
for x, y in zip(testX, testY):
    print(tree.Predict(x), y)

train
0 0
1 1
0 0
1 1
0 0
1 0
1 1
1 1
1 0
1 0
test
1 1
1 1
0 0
1 0
0 0


## Visualization

In [10]:
import igraph
from decision_tree import Node
from igraph import EdgeSeq

nr_vert = 0
def dfs(rt:Node, p, edge, g:igraph.Graph):
    global nr_vert
    rt_idx = nr_vert
    nr_vert += 1
    g.add_vertex(name=data.keys()[rt.feature] if rt.child else rt.label)
    if p is not None:
        g.add_edge(p, rt_idx, name=edge)
    if rt.child is not None:
        for val,c in rt.child.items():
            dfs(c, rt_idx, tag[data.keys()[rt.feature]][val], g)

g = igraph.Graph()
dfs(tree.root, None, None,g)
lay = g.layout("rt")

pos = {k:lay[k] for k in range(nr_vert)}

Y = [lay[k][1] for k in range(nr_vert)]
M = max(Y)
es = EdgeSeq(g)
E = [e.tuple for e in g.es]
Xn = [pos[k][0] for k in range(nr_vert)]
Yn = [M-pos[k][1] for k in range(nr_vert)]
Xe = []
Ye = []
label_e = []
for e in E:
    Xe += [pos[e[0]][0], (pos[e[0]][0]+pos[e[1]][0])/2, pos[e[1]][0], None]
    Ye += [M-pos[e[0]][1], (M-pos[e[0]][1]+M-pos[e[1]][1])/2 , M-pos[e[1]][1], None]
    label_e += [None, g.es[g.get_eid(*e)]['name'], None, None]
label = [g.vs[k]['name'] for k in range(nr_vert)]
fig = go.Figure()
fig.add_trace(go.Scatter(
    x=Xe, y=Ye,
    mode="lines+text",
    line=dict(color='rgb(150,150,150)', width=1),
    text=label_e
))
fig.add_trace(go.Scatter(
    x=Xn,
    y=Yn,
    mode="markers+text",
    marker=dict(
        symbol='circle',
        size=18,
        color='#6175c1'),
    text=label,
    textposition="bottom center"
))
fig.show()
