In [1]:
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
import dtreeviz
import warnings
import json
import matplotlib.pyplot as plt
from pypinyin import pinyin, lazy_pinyin, Style

In [2]:
def word_to_py(word):
    temp = pinyin(word,style=Style.FIRST_LETTER)
    result = ""
    for one_word in [_[0] for _ in temp]:
        result+=str(one_word).upper()
    return result

In [3]:
# 加载数据
data = pd.read_csv(r"F:\cache_data\zone_ana\ky\train_data\pca_soil_type_train_point.csv")

In [4]:
# 加载字典
with open(r"D:\worker_code\Terrain_Test\data\soil_dict\soil_type_dict.json", "r", encoding="utf-8") as f:
    soil_dict = json.load(f)

In [5]:
# 使用 replace 方法替换特定值
data['NEW_TZ'] = data['NEW_TZ'].replace("腐厚层硅黄壤", "腐厚层硅质黄壤")
data['NEW_TZ'] = data['NEW_TZ'].replace("腐厚层硅质黄壤土", "腐厚层硅质黄壤性土")

In [6]:
data["NEW_TL"] = data['NEW_TZ'].apply(lambda x: soil_dict[x]['土类'])

In [7]:
# dataset["土类"] = word_to_py(dataset["土类"])
data["NEW_TL"] =data['NEW_TL'].apply(word_to_py)

In [8]:
label = 'NEW_TL'
features = ['analyticalhillshading', 'aspect', 'channelnetworkbaselevel', 'channelnetworkdistance', 'convergenceindex', 'dem', 'dz', 'dl', 'etp2022_1', 'etp2022_10', 'etp2022_11', 'etp2022_12', 'etp2022_2', 'etp2022_3', 'etp2022_4', 'etp2022_5', 'etp2022_6', 'etp2022_7', 'etp2022_8', 'etp2022_9', 'etp2022_mean', 'evi', 'lat', 'lon', 'lswi', 'lsfactor', 'mndwi', 'mrttf', 'mrvbf', 'ndvi', 'ndwi', 'night2022', 'pc1', 'pc2', 'plancurvature', 'pre2022_1', 'pre2022_10', 'pre2022_11', 'pre2022_12', 'pre2022_2', 'pre2022_3', 'pre2022_4', 'pre2022_5', 'pre2022_6', 'pre2022_7', 'pre2022_8', 'pre2022_9', 'pre2022_mean', 'profilecurvature', 'relativeslopeposition', 'savi', 'slope', 'slope_postion', 'tmp2022_1', 'tmp2022_10', 'tmp2022_11', 'tmp2022_12', 'tmp2022_2', 'tmp2022_3', 'tmp2022_4', 'tmp2022_5', 'tmp2022_6', 'tmp2022_7', 'tmp2022_8', 'tmp2022_9', 'tmp2022_mean', 'topographicwetnessindex', 'totalcatchmentarea', 'valleydepth', 'vari']

In [9]:
result_df = data[[label]+features]

In [12]:
def plot_decision_tree(data, label_column, max_depth=None, save_path=None):
   """
   绘制决策树并保存为SVG文件
   
   参数:
   data (pd.DataFrame): 输入数据框
   label_column (str): 标签列名
   max_depth (int, optional): 决策树最大深度，默认None
   save_path (str, optional): SVG文件保存路径，默认None
   
   返回:
   dtreeviz.model: 决策树可视化对象
   """
   # 复制数据以避免修改原始数据
   data = data.copy()
   
   # 编码分类变量
   label_encoders = {}
   for column in data.select_dtypes(include=['object','category']).columns:
       le = LabelEncoder()
       data[column] = le.fit_transform(data[column])
       label_encoders[column] = le
    # 处理缺失值
   num_imputer = SimpleImputer(strategy='median')
   data[data.select_dtypes(include=['float64', 'int64']).columns] = num_imputer.fit_transform(
       data.select_dtypes(include=['float64', 'int64'])
   )
    # 准备特征和标签
   X = data.drop(label_column, axis=1)
   y = data[label_column]
    # 训练决策树
   dt_model = DecisionTreeClassifier(random_state=42, max_depth=max_depth)
   dt_model.fit(X, y)
    # 创建可视化
   viz = dtreeviz.model(
       dt_model, 
       X, 
       y,
       target_name=label_column,
       feature_names=X.columns,
       class_names=list(label_encoders[label_column].classes_)
   )
    # 保存SVG文件
   if save_path:
       viz_svg = viz.view()
       with open(save_path, 'w', encoding='utf-8') as f:
           f.write(viz_svg.svg())
   
   return viz


In [None]:
# 使用示例:
viz = plot_decision_tree(
    data=result_df,
    label_column='NEW_TL',
    save_path=r'F:\cache_data\zone_ana\ky\tree_view\decision_tree.svg',
    #max_depth=5
)
# viz.view()  # 显示可视化结果

In [None]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.impute import SimpleImputer
from sklearn.tree import DecisionTreeClassifier
import plotly.graph_objects as go
import networkx as nx
import warnings
from plotly.offline import plot
import plotly.io as pio

# Ignore specific warnings
warnings.filterwarnings("ignore", message="Glyph .* missing from current font.")

# Load the data
# file_path = r"C:\Users\Runker\Desktop\TEST.csv"
data = result_df

# Encode text columns to numerical values
label_encoders = {}
for column in data.select_dtypes(include=['object', 'category']).columns:
    le = LabelEncoder()
    data[column] = le.fit_transform(data[column])
    label_encoders[column] = le

# Impute missing values in numerical columns with the median
num_imputer = SimpleImputer(strategy='median')
data[data.select_dtypes(include=['float64', 'int64']).columns] = num_imputer.fit_transform(data.select_dtypes(include=['float64', 'int64']))

# Split data into features and target variable
X = data.drop('NEW_TL', axis=1)
y = data['NEW_TL']

# Train the decision tree model
dt_model = DecisionTreeClassifier(random_state=42)
dt_model.fit(X, y)

# Extract tree structure
n_nodes = dt_model.tree_.node_count
children_left = dt_model.tree_.children_left
children_right = dt_model.tree_.children_right
feature = dt_model.tree_.feature
threshold = dt_model.tree_.threshold
value = dt_model.tree_.value

# Create node labels
feature_names = X.columns.tolist()
class_names = [str(cls) for cls in label_encoders['NEW_TL'].classes_]

def create_node_label(node, feature, threshold, value):
    if feature[node] != -2:  # not a leaf node
        return f"{feature_names[feature[node]]} <= {threshold[node]:.2f}"
    else:
        class_counts = value[node][0]
        majority_class = class_names[np.argmax(class_counts)]
        return f"Class: {majority_class}"

node_labels = [create_node_label(i, feature, threshold, value) for i in range(n_nodes)]

# Create edges
edges = []
for i in range(n_nodes):
    if children_left[i] != children_right[i]:
        edges.extend([(i, children_left[i]), (i, children_right[i])])

# Function to compute node depths
def compute_node_depths(n_nodes, children_left, children_right):
    node_depth = np.zeros(n_nodes, dtype=np.int64)
    stack = [(0, 0)]  # start with the root node id and its depth
    while len(stack) > 0:
        node_id, depth = stack.pop()
        node_depth[node_id] = depth
        if children_left[node_id] != children_right[node_id]:
            stack.append((children_left[node_id], depth + 1))
            stack.append((children_right[node_id], depth + 1))
    return node_depth

# Compute node depths
node_depths = compute_node_depths(n_nodes, children_left, children_right)

# Create node coordinates
def hierarchy_pos(G, root, width=1., vert_gap = 0.2, vert_loc = 0, xcenter = 0.5):
    def _hierarchy_pos(G, root, width=1., vert_gap = 0.2, vert_loc = 0, xcenter = 0.5, pos = None, parent = None, parsed = []):
        if pos is None:
            pos = {root:(xcenter,vert_loc)}
        else:
            pos[root] = (xcenter, vert_loc)
        children = list(G.neighbors(root))
        if not isinstance(G, nx.DiGraph) and parent is not None:
            children.remove(parent)  
        if len(children)!=0:
            dx = width/len(children) 
            nextx = xcenter - width/2 - dx/2
            for child in children:
                nextx += dx
                pos = _hierarchy_pos(G,child, width = dx, vert_gap = vert_gap, 
                                    vert_loc = vert_loc-vert_gap, xcenter=nextx,
                                    pos=pos, parent = root, parsed = parsed)
        return pos

    return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter)

G = nx.Graph()
G.add_edges_from(edges)
pos = hierarchy_pos(G, 0)

# Create Plotly figure
edge_trace = go.Scatter(
    x=[], y=[], line=dict(width=0.5, color='#888'), hoverinfo='none', mode='lines')

for edge in edges:
    x0, y0 = pos[edge[0]]
    x1, y1 = pos[edge[1]]
    edge_trace['x'] += tuple([x0, x1, None])
    edge_trace['y'] += tuple([y0, y1, None])

node_trace = go.Scatter(
    x=[], y=[], text=[], mode='markers+text', textposition="top center",
    hoverinfo='text', marker=dict(
        showscale=True,
        colorscale='YlGnBu',
        reversescale=True,
        color=[],
        size=10,
        colorbar=dict(
            thickness=15,
            title='节点深度',
            xanchor='left',
            titleside='right'
        ),
        line_width=2))

for node in G.nodes():
    x, y = pos[node]
    node_trace['x'] += tuple([x])
    node_trace['y'] += tuple([y])

for node, adjacencies in enumerate(G.adjacency()):
    node_trace['marker']['color'] += tuple([node_depths[node]])
    node_info = f'节点 {node}<br>{node_labels[node]}<br>深度: {node_depths[node]}'
    node_trace['text'] += tuple([node_info])

fig = go.Figure(data=[edge_trace, node_trace],
             layout=go.Layout(
                title='<br>决策树可视化',
                titlefont_size=16,
                showlegend=False,
                hovermode='closest',
                margin=dict(b=20,l=5,r=5,t=40),
                annotations=[ dict(
                    text="基于用户数据的决策树",
                    showarrow=False,
                    xref="paper", yref="paper",
                    x=0.005, y=-0.002 ) ],
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                )
# 保存为交互式 HTML 文件
plot(fig, filename='decision_tree_interactive.html', auto_open=False)
# Display the visualization
fig.show()