## 图可视化

In [None]:
import os
import torch
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx
from sklearn.decomposition import PCA
import numpy as np

# 加载数据
input_dir = "train_data/train_data/"
output_dir = "train_data/graph_vis/"
os.makedirs(output_dir, exist_ok=True)

# 遍历所有.pt文件
for filename in os.listdir(input_dir):
    if filename.endswith(".pt"):
        filepath = os.path.join(input_dir, filename)
        area_name = os.path.splitext(filename)[0]  # 去掉后缀

        # 加载数据
        data = torch.load(filepath, weights_only=False)
        G = to_networkx(data, to_undirected=True)

        # 计算布局
        pos = nx.kamada_kawai_layout(G)

        # 节点大小（PCA降维后归一化）
        x_pca = PCA(n_components=1).fit_transform(data.x.numpy())
        node_size = 1000 * (x_pca - x_pca.min()) / (x_pca.ptp() + 1e-5)

        # 边颜色和宽度（根据标签）
        edge_labels = data.edge_label.tolist()
        edge_colors = ['purple' if l == 1 else 'grey' for l in edge_labels]
        edge_width = [5 if l == 1 else 2 for l in edge_labels]

        # 绘图
        plt.figure(figsize=(15, 15))
        nx.draw(
            G, pos,
            node_size=node_size,
            node_color="#156082",
            edge_color=edge_colors,
            width=edge_width,
            with_labels=False,
            alpha=0.8
        )
        plt.title(f"Graph: {area_name}")
        plt.axis('off')
        plt.tight_layout()

        # 保存图像
        save_path = os.path.join(output_dir, f"{area_name}.png")
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Saved: {save_path}")





## 特征可视化

In [None]:
import os
import geopandas as gpd
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np

base_geo_dir = r'data/organized_data_by_region'
base_attr_dir = r'data/processed_stats'
output_dir = r'output_visualization'
os.makedirs(output_dir, exist_ok=True)

def plot_node_attributes(region_name, node_gdf, node_attrs):
    node_attr_list = [col for col in node_attrs.columns if col not in ['id', 'geometry']]
    n_attr = len(node_attr_list)
    nrows = 2
    ncols = int(np.ceil(n_attr / nrows))

    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(6 * ncols, 6 * nrows))
    axes = axes.flatten()

    # 👉 配色：将亮黄色放在前面
    tab_colors = list(cm.get_cmap('tab10').colors)
    set3_colors = list(cm.get_cmap('Set3').colors)
    bright_yellow = (1.0, 1.0, 0.6)
    remaining_set3 = [c for c in set3_colors if not np.allclose(c, bright_yellow, atol=0.05)]
    color_list = [bright_yellow] + tab_colors + remaining_set3

    for i, attr in enumerate(node_attr_list):
        ax = axes[i]
        merged = node_gdf.merge(node_attrs[['id', attr]], on='id')
        values = merged[attr]

        # 点大小归一化映射
        size = 200 * (values - values.min()) / (values.max() - values.min() + 1e-6) + 20

        merged.plot(
            ax=ax,
            color=color_list[i % len(color_list)],
            markersize=size,
            alpha=0.8
        )
        ax.set_title(f"{attr}", fontsize=16)
        ax.axis('off')

        # 边界扩展，增加点间距
        bounds = merged.total_bounds
        dx = (bounds[2] - bounds[0]) * 0.1
        dy = (bounds[3] - bounds[1]) * 0.1
        ax.set_xlim(bounds[0] - dx, bounds[2] + dx)
        ax.set_ylim(bounds[1] - dy, bounds[3] + dy)

    # 隐藏多余子图
    for j in range(i + 1, len(axes)):
        axes[j].axis('off')

    save_path = os.path.join(output_dir, f"{region_name}_nodes_all_attrs.png")
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✅ 节点属性总图保存: {save_path}")


def plot_edge_attributes(region_name, edge_gdf, edge_attrs):
    edge_attr_list = [col for col in edge_attrs.columns if col not in ['id', 'geometry', 'u', 'v']]
    n_attr = len(edge_attr_list)

    fig, axes = plt.subplots(nrows=1, ncols=n_attr, figsize=(5 * n_attr, 6))
    if n_attr == 1:
        axes = [axes]

    for i, attr in enumerate(edge_attr_list):
        ax = axes[i]
        merged = edge_gdf.merge(edge_attrs[['id', attr]], on='id')

        merged.plot(
            ax=ax,
            column=attr,
            cmap='Spectral',
            legend=True,
            linewidth=1.5,
            alpha=0.8,
            legend_kwds={'shrink': 0.5}
        )
        ax.set_title(f"{attr}", fontsize=16)
        ax.axis('off')

    save_path = os.path.join(output_dir, f"{region_name}_edges_all_attrs.png")
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✅ 边属性总图保存: {save_path}")

def process_region(region_name):
    region_geo_path = os.path.join(base_geo_dir, region_name)
    region_attr_path = os.path.join(base_attr_dir, region_name)

    node_shp = os.path.join(region_geo_path, f"{region_name}_nodes.shp")
    edge_shp = os.path.join(region_geo_path, f"{region_name}_edges.shp")
    node_attr_csv = os.path.join(region_attr_path, f"{region_name}_node_merged_processed.csv")
    edge_attr_csv = os.path.join(region_attr_path, f"{region_name}_edge_stats_processed.csv")

    if not (os.path.exists(node_shp) and os.path.exists(edge_shp) and
            os.path.exists(node_attr_csv) and os.path.exists(edge_attr_csv)):
        print(f"⚠ 缺失文件，跳过: {region_name}")
        return

    print(f"🚀 正在处理: {region_name}")

    node_gdf = gpd.read_file(node_shp)
    edge_gdf = gpd.read_file(edge_shp)
    node_attrs = pd.read_csv(node_attr_csv)
    edge_attrs = pd.read_csv(edge_attr_csv)

    # 确保都有唯一id用于 merge
    if 'id' not in node_attrs.columns:
        node_attrs['id'] = node_attrs.index
    if 'id' not in edge_attrs.columns:
        edge_attrs['id'] = edge_attrs.index
    if 'id' not in node_gdf.columns:
        node_gdf['id'] = node_gdf.index
    if 'id' not in edge_gdf.columns:
        edge_gdf['id'] = edge_gdf.index

    plot_node_attributes(region_name, node_gdf, node_attrs)
    plot_edge_attributes(region_name, edge_gdf, edge_attrs)

# 批量运行
for region_name in os.listdir(base_attr_dir):
    if os.path.isdir(os.path.join(base_attr_dir, region_name)):
        process_region(region_name)


## 测试集可视化

In [None]:
import geopandas as gpd
import pandas as pd
import matplotlib.pyplot as plt
import os

# ========== Step 1: 路径设置 ==========
prediction_file = r'data/pre_CSV/Shaoxing_Cangqian_binary_predictions.csv'
node_shp_path = r'data/road_network_gen/Shaoxing_Cangqian/Shaoxing_Cangqian_nodes.shp'
edge_shp_path = r'data/road_network_gen/Shaoxing_Cangqian/Shaoxing_Cangqian_edges.shp'


# ========== Step 2: 读取数据 ==========
pred_df = pd.read_csv(prediction_file)
node_gdf = gpd.read_file(node_shp_path).reset_index(drop=True)
edge_gdf = gpd.read_file(edge_shp_path).reset_index(drop=True)
buildings=gpd.read_file('data/road_network_cache/Shaoxing_Cangqian/Shaoxing_Cangqian_buildings.shp').to_crs(edge_gdf.crs)
# ========== Step 3: 为 edge_gdf 加上 u, v 字段用于匹配 ==========
# 注意：这里假设 edge_gdf 的顺序与 node_gdf 一致，或你已有 source/target 字段则可直接用
# 否则我们临时构造（用 LineString 的起终点和 node_gdf 的坐标匹配 index）

# 构造 node 坐标索引
node_coords = node_gdf.geometry.apply(lambda p: (round(p.x, 6), round(p.y, 6)))
coord_to_index = {coord: idx for idx, coord in enumerate(node_coords)}

def get_uv_from_line(line):
    start = (round(line.coords[0][0], 6), round(line.coords[0][1], 6))
    end = (round(line.coords[-1][0], 6), round(line.coords[-1][1], 6))
    u = coord_to_index.get(start, -1)
    v = coord_to_index.get(end, -1)
    return pd.Series({'u': u, 'v': v})

edge_gdf[['u', 'v']] = edge_gdf.geometry.apply(get_uv_from_line)

# ========== Step 4: 合并预测结果 ==========
# 确保 (u,v) 和 (v,u) 都能匹配
pred_df['key'] = pred_df.apply(lambda row: tuple(sorted([row['source_node'], row['target_node']])), axis=1)
edge_gdf['key'] = edge_gdf.apply(lambda row: tuple(sorted([row['u'], row['v']])), axis=1)

merged = edge_gdf.merge(pred_df[['key', 'prediction_probability']], on='key', how='left')

# ========== Step 5: 可视化 ==========
fig, ax = plt.subplots(figsize=(10, 10))

# 绘制预测为 1 的边（紫色）
merged[merged['prediction_probability'] == 1].plot(ax=ax, color='black', linewidth=2, label='Predicted: 1')

# 绘制预测为 0 或 NaN 的边（灰色）
merged[merged['prediction_probability'] != 1].plot(ax=ax, color='white', linewidth=1, label='Predicted: 0 or Missing')

# 绘制节点（灰色）
node_gdf.plot(ax=ax, color='blue', markersize=30)
buildings.plot(ax=ax,color="grey",alpha=0.5)
ax.set_axis_off()
ax.set_title("Edge Prediction Visualization from GNN", fontsize=14)
plt.legend()
plt.tight_layout()
plt.savefig('predicted_edge_visualization.png', dpi=300)
plt.show()
