In [None]:
import xml.etree.ElementTree as ET
import numpy as np
import networkx as nx
import plotly.graph_objects as go
import math

# 定义 ReferenceLine 类，用于保存参考线信息
class ReferenceLine:
    def __init__(self):
        self.sampled_points = []  # 每个采样点(x, y)
        self.headings = []        # 对应采样点处的局部切向角（弧度）
        self.geometries = []      # 按序存储原始 geometry 参数字典，字典包含：s, x, y, hdg, length, type，以及 arc/spiral 时的特有参数
        self.s_values = []        # 每个采样点对应的 s 值
        
# 定义 Lane 类
class Lane:
    def __init__(self, lane_id, lane_type, sampled_points, in_range=False):
        self.lane_id = str(lane_id)      # 车道编号
        self.lane_type = lane_type       # "left" 或 "right"
        self.sampled_points = sampled_points  # 采样点列表，用于绘制车道中心线
        self.in_range = in_range         # 默认 False
        self.predecessor = []  # [(道路ID，轨道ID)]
        self.successor = []    # [(道路ID，轨道ID)]
        
# 定义 Road 类
class Road:
    def __init__(self, road_id, predecessor, successor, junction, length, on_route=False):
        self.road_id = str(road_id)
        self.predecessor = predecessor  # 例如 (elementType, elementId) 或 None
        self.successor = successor      # 例如 (elementType, elementId, contactPoint) 或 None
        self.junction = str(junction)   # "-1" 表示无 junction，否则为 junction id
        self.on_route = on_route        # 全局路径规划用，默认 False
        self.lanes = []                 # 存放 Lane 对象（仅解析 type="driving" 的车道）
        self.length = length            # 路段长度（float）
        self.midpoint = (0, 0)          # 稍后计算各车道中点的平均值
        self.reference_line = ReferenceLine()  # 保存该 road 的参考线信息

    def compute_midpoint(self):
        pts = []
        for lane in self.lanes:
            if lane.sampled_points:
                pts.append(lane.sampled_points[len(lane.sampled_points) // 2])
        if pts:
            self.midpoint = tuple(np.mean(np.array(pts), axis=0))
        else:
            self.midpoint = (0, 0)
            
def composite_simpson(f, a, b, n):
    """
    使用复化辛普森方法近似计算函数 f 在区间 [a, b] 上的积分。
    参数 n 必须为偶数，此处若 n 为奇数则自动加 1。
    """
    if n % 2 == 1:
        n += 1
    h = (b - a) / n
    x_vals = np.linspace(a, b, n+1)
    y_vals = f(x_vals)
    S = y_vals[0] + y_vals[-1] + 4 * np.sum(y_vals[1:-1:2]) + 2 * np.sum(y_vals[2:-2:2])
    return h/3 * S

def offset_poly(s,s_offset,a,b,c,d):
    return a + b*(s-s_offset) + c*(s-s_offset)**2 + d*(s-s_offset)**3

def compute_reference_line(road_elem, ref_line):
    """
    解析 road 的 planView 中所有 geometry 节点，按 s 排序，
    并使用每段的原始参数对全局 s 进行插值计算，
    将计算得到的参考线采样点、局部切向角以及原始 geometry 参数保存到 ref_line 成员中。
    
    ref_line.sampled_points: 全局采样点列表 [(x,y), ...]
    ref_line.headings: 每个采样点对应的局部切向角列表
    ref_line.geometries: 按顺序排列的 geometry 参数字典列表，每个字典包含：s, x, y, hdg, length, type,（及 arc/spiral 参数）
    """
    plan_view = road_elem.find('planView')
    ref_line.sampled_points = []
    ref_line.headings = []
    ref_line.geometries = []
    if plan_view is None:
        # 默认直线参考线
        length = float(road_elem.get('length', '0'))
        num = max(int(length/0.1), 2)
        s_vals = np.linspace(0, length, num=num)
        for s in s_vals:
            ref_line.sampled_points.append((s, 0))
            ref_line.headings.append(0.0)
        ref_line.geometries.append({'s': 0, 'x': 0, 'y': 0, 'hdg': 0.0, 'length': length, 'type': 'line'})
        return

    geoms = plan_view.findall('geometry')
    geoms.sort(key=lambda g: float(g.get('s', '0')))
    for g in geoms:
        seg = {
            's': float(g.get('s', '0')),
            'x': float(g.get('x', '0')),
            'y': float(g.get('y', '0')),
            'hdg': float(g.get('hdg', '0')),
            'length': float(g.get('length', '0'))
        }
        arc_tag = g.find('arc')
        spiral_tag = g.find('spiral')
        if arc_tag is not None:
            seg['type'] = 'arc'
            seg['k'] = float(arc_tag.get('curvature'))
        elif spiral_tag is not None:
            seg['type'] = 'spiral'
            seg['curvStart'] = float(spiral_tag.get('curvStart'))
            seg['curvEnd'] = float(spiral_tag.get('curvEnd'))
        else:
            seg['type'] = 'line'
        ref_line.geometries.append(seg)
    # 遍历全局 s 值，插值计算参考线点和局部切向角
    road_length = ref_line.geometries[-1]['s'] + ref_line.geometries[-1]['length']
    num = max(int(road_length/0.1), 2)
    s_vals = np.linspace(0, road_length, num=num)
    for s in s_vals:
        # 找到 s 所在的 geometry
        seg = None
        for g in ref_line.geometries:
            if s >= g['s'] and s <= g['s'] + g['length']:
                seg = g
                break
        if seg is None:
            seg = ref_line.geometries[-1]
            ds = seg['length']
        else:
            ds = s - seg['s']
        if seg['type'] == 'line':
            x_val = seg['x'] + ds * np.cos(seg['hdg'])
            y_val = seg['y'] + ds * np.sin(seg['hdg'])
            hdg_val = seg['hdg']
        elif seg['type'] == 'arc':
            k = seg['k']
            if abs(k) > 1e-8:
                x_val = seg['x'] + (np.sin(seg['hdg'] + k*ds) - np.sin(seg['hdg']))/k
                y_val = seg['y'] - (np.cos(seg['hdg'] + k*ds) - np.cos(seg['hdg']))/k
            else:
                x_val = seg['x'] + ds*np.cos(seg['hdg'])
                y_val = seg['y'] + ds*np.sin(seg['hdg'])
            hdg_val = seg['hdg'] + k*ds
        elif seg['type'] == 'spiral':
            curvStart = seg['curvStart']
            curvEnd = seg['curvEnd']
            theta = lambda u: seg['hdg'] + curvStart*u + 0.5*(curvEnd-curvStart)/seg['length'] * u**2
            n_sub = 50
            I_x = composite_simpson(lambda u: np.cos(theta(u)), 0, ds, n_sub)
            I_y = composite_simpson(lambda u: np.sin(theta(u)), 0, ds, n_sub)
            x_val = seg['x'] + I_x
            y_val = seg['y'] + I_y
            hdg_val = theta(ds)
        ref_line.sampled_points.append((x_val, y_val))
        ref_line.headings.append(hdg_val)
        ref_line.s_values.append(s)
        
# 从 road 的 <lanes> 部分解析 driving 类型车道，并返回包含 Lane 对象的列表  
def parse_driving_lanes(road_elem, junction_dict, road_obj):
    """
    从 road 的 <lanes> 部分解析车道信息，
    累积所有车道宽度（包括类型不为 "driving" 的，如 shoulder、median 等），
    但仅对类型为 "driving" 的车道生成 Lane 对象以供可视化。

    利用 road_obj.reference_line 中的 sampled_points、headings 与 s_values，
    以及全局 laneOffset 多项式（a, b, c, d）计算基础偏移，
    对于左侧车道：
         offset_left(s) = global_offset(s) + cum_width_left(s) + [w_current(s)]/2,
    对于右侧车道：
         offset_right(s) = global_offset(s) - (cum_width_right(s) + [w_current(s)]/2),
    其中：
        global_offset(s) = offset_poly(s, sl_offset, a, b, c, d)
        w_current(s) = offset_poly(s, sw_offset, a_w, b_w, c_w, d_w)
        
    本函数在解析 lane 宽度信息的同时，也读取各 lane 内的连接关系信息，
    如果 lane 元素包含 <link> 节点，则提取其中的 <predecessor> 和 <successor> 节点，
    将其 id 属性记录到 Lane 对象的对应字段中（初步存为字符串）。
    
    针对非junction的 Road，如果 Road 的 predecessor/successor 的 elementType 为 "junction"，
    则通过查询 junction 节点，查找对应 connection 中 laneLink 的信息来更新 Lane 的连接关系，
    例如：若 road1 的 predecessor 为 ("junction", "841")，且当前 lane 的 lane_id 为 "3"，
    则在 junction id="841" 中查找 connection 中 incomingRoad 为 road1 且 contactPoint 为 "end"，
    从中查找 laneLink whose from=="3"，将其 to 属性作为 lane1 的 predecessor。
    
    结果直接更新 road_obj.lanes（仅保留 type 为 "driving" 的车道）。
    """
    # 将 s_values 转为 numpy 数组
    s_arr = np.array(road_obj.reference_line.s_values)
    baseline = road_obj.reference_line.sampled_points
    baseline_headings = road_obj.reference_line.headings

    driving_lanes = []
    lanes_elem = road_elem.find('lanes')
    if lanes_elem is None:
        road_obj.lanes = driving_lanes
        return

    # 提取全局 laneOffset 多项式系数
    lane_offset_elem = lanes_elem.find('laneOffset')
    if lane_offset_elem is not None:
        sl_offset = float(lane_offset_elem.get('s', '0'))
        a_val = float(lane_offset_elem.get('a', '0'))
        b_val = float(lane_offset_elem.get('b', '0'))
        c_val = float(lane_offset_elem.get('c', '0'))
        d_val = float(lane_offset_elem.get('d', '0'))
    else:
        sl_offset = 0.0
        a_val = b_val = c_val = d_val = 0.0

    # 计算全局偏移数组
    global_offset = offset_poly(s_arr, sl_offset, a_val, b_val, c_val, d_val)

    lane_section = lanes_elem.find('laneSection')
    if lane_section is None:
        road_obj.lanes = driving_lanes
        return


    # -------------------------
    # 解析左侧车道（left 节点），lane id 为正，按升序排列
    # -------------------------
    left_elem = lane_section.find('left')
    left_all = []
    if left_elem is not None:
        for lane in left_elem.findall('lane'):
            lane_id = int(lane.get('id'))
            lane_type = lane.get('type')
            width_elem = lane.find('width')
            if width_elem is not None:
                sw_offset = float(width_elem.get('sOffset', '0'))
                a_w = float(width_elem.get('a', '0'))
                b_w = float(width_elem.get('b', '0'))
                c_w = float(width_elem.get('c', '0'))
                d_w = float(width_elem.get('d', '0'))
            else:
                sw_offset = 0.0; a_w = b_w = c_w = d_w = 0.0
            
            
            # print(f"解析 R {road_obj.road_id} L {lane_id}")
            road_pred = road_obj.predecessor
            road_succ = road_obj.successor
            # print(f"road_pred id {road_pred[1]} type {road_pred[0]}")
            # print(f"road_succ id {road_succ[1]} type {road_succ[0]}")
            
            lane_pred = []
            lane_succ = []
            if road_pred[0]=='road':
                # 读取 lane 的链接关系（若存在）
                link_elem = lane.find('link')
                if link_elem is not None:
                    pred_elem = link_elem.find('predecessor')
                    if pred_elem is not None:
                        lane_pred.append((road_pred[1],pred_elem.get('id')))
                        # print(f"找到前续 {(road_pred[1],pred_elem.get('id'))}")
            elif road_pred[0]=='junction':
                junction = junction_dict.get(road_pred[1])
                for conn in junction.findall('connection'):
                    if conn.get('incomingRoad')==road_obj.road_id:
                        for link in conn.findall('laneLink'):
                            if lane.get('id') == link.get('from'):
                                lane_pred.append((conn.get('connectingRoad'),link.get('to')))
                                # print(f"找到前续 {(conn.get('connectingRoad'),link.get('to'))}")
                                break
                                
            if road_succ[0]=='road':
                # 读取 lane 的链接关系（若存在）
                link_elem = lane.find('link')
                if link_elem is not None:
                    succ_elem = link_elem.find('successor')
                    if succ_elem is not None:
                        lane_succ.append((road_succ[1],succ_elem.get('id')))
                        # print(f"找到后继 {(road_succ[1],succ_elem.get('id'))}")
            elif road_pred[0]=='junction':
                junction = junction_dict.get(road_pred[1])
                for conn in junction.findall('connection'):
                   if conn.get('incomingRoad')==road_obj.road_id:
                       for link in conn.findall('laneLink'):
                           if lane.get('id') == link.get('from'):
                               lane_succ.append((conn.get('connectingRoad'),link.get('to')))
                               # print(f"找到后继 {(conn.get('connectingRoad'),link.get('to'))}")
                               break

            left_all.append((lane_id, sw_offset, a_w, b_w, c_w, d_w, lane_type, lane_pred, lane_succ))
        left_all.sort(key=lambda x: x[0])
    cum_width_left = np.zeros_like(s_arr)

    for info in left_all:
        lane_id, sw_offset, a_w, b_w, c_w, d_w, lane_type, lane_pred, lane_succ = info
        w_current = offset_poly(s_arr, sw_offset, a_w, b_w, c_w, d_w)
        if lane_type == "driving":
            current_offset = global_offset + cum_width_left + w_current/2.0
            sampled_points = []
            for (pt, local_hdg, offset_val) in zip(baseline, baseline_headings, current_offset):
                x_ref, y_ref = pt
                x_lane = x_ref - offset_val * np.sin(local_hdg)
                y_lane = y_ref + offset_val * np.cos(local_hdg)
                sampled_points.append((x_lane, y_lane))
            lane_obj = Lane(lane_id, "left", sampled_points, in_range=False)
            # 记录从 lane 自身获取的链接信息（可能只有 lane id，没有 Road 信息）
            if lane_pred is not None:
                lane_obj.predecessor = lane_pred
            if lane_succ is not None:
                lane_obj.successor = lane_succ
            driving_lanes.append(lane_obj)
        cum_width_left = cum_width_left + w_current

    # -------------------------
    # 解析右侧车道（right 节点），lane id 为负，按降序排列
    # -------------------------
    right_elem = lane_section.find('right')
    right_all = []
    if right_elem is not None:
        for lane in right_elem.findall('lane'):
            lane_id = int(lane.get('id'))
            lane_type = lane.get('type')
            width_elem = lane.find('width')
            if width_elem is not None:
                sw_offset = float(width_elem.get('sOffset', '0'))
                a_w = float(width_elem.get('a', '0'))
                b_w = float(width_elem.get('b', '0'))
                c_w = float(width_elem.get('c', '0'))
                d_w = float(width_elem.get('d', '0'))
            else:
                sw_offset = 0.0; a_w = b_w = c_w = d_w = 0.0
                
            
            # print(f"解析 R {road_obj.road_id} L {lane_id}")
            road_pred = road_obj.predecessor
            road_succ = road_obj.successor
            # print(f"road_pred id {road_pred[1]} type {road_pred[0]}")
            # print(f"road_succ id {road_succ[1]} type {road_succ[0]}")
            
            lane_pred = []
            lane_succ = []
            if road_pred[0]=='road':
                # 读取 lane 的链接关系（若存在）
                link_elem = lane.find('link')
                if link_elem is not None:
                    pred_elem = link_elem.find('predecessor')
                    if pred_elem is not None:
                        lane_pred.append((road_pred[1],pred_elem.get('id')))
                        # print(f"找到前续 {(road_pred[1],pred_elem.get('id'))}")
            elif road_pred[0]=='junction':
                junction = junction_dict.get(road_pred[1])
                for conn in junction.findall('connection'):
                    if conn.get('incomingRoad')==road_obj.road_id:
                        for link in conn.findall('laneLink'):
                            if lane.get('id') == link.get('from'):
                                lane_pred.append((conn.get('connectingRoad'),link.get('to')))
                                # print(f"找到前续 {(conn.get('connectingRoad'),link.get('to'))}")
                                break
                                
            if road_succ[0]=='road':
                # 读取 lane 的链接关系（若存在）
                link_elem = lane.find('link')
                if link_elem is not None:
                    succ_elem = link_elem.find('successor')
                    if succ_elem is not None:
                        lane_succ.append((road_succ[1],succ_elem.get('id')))
                        # print(f"找到后继 {(road_succ[1],succ_elem.get('id'))}")
            elif road_pred[0]=='junction':
                junction = junction_dict.get(road_pred[1])
                for conn in junction.findall('connection'):
                   if conn.get('incomingRoad')==road_obj.road_id:
                       for link in conn.findall('laneLink'):
                           if lane.get('id') == link.get('from'):
                               lane_succ.append((conn.get('connectingRoad'),link.get('to')))
                               # print(f"找到后继 {(conn.get('connectingRoad'),link.get('to'))}")
                               break
                               
            right_all.append((lane_id, sw_offset, a_w, b_w, c_w, d_w, lane_type, lane_pred, lane_succ))
        right_all.sort(key=lambda x: x[0], reverse=True)
    cum_width_right = np.zeros_like(s_arr)

    for info in right_all:
        lane_id, sw_offset, a_w, b_w, c_w, d_w, lane_type, lane_pred, lane_succ = info
        w_current = offset_poly(s_arr, sw_offset, a_w, b_w, c_w, d_w)
        if lane_type == "driving":
            current_offset = global_offset - (cum_width_right + w_current/2.0)
            sampled_points = []
            for (pt, local_hdg, offset_val) in zip(baseline, baseline_headings, current_offset):
                x_ref, y_ref = pt
                x_lane = x_ref - offset_val * np.sin(local_hdg)
                y_lane = y_ref + offset_val * np.cos(local_hdg)
                sampled_points.append((x_lane, y_lane))
            lane_obj = Lane(lane_id, "right", sampled_points, in_range=False)
            if lane_pred is not None:
                lane_obj.predecessor = lane_pred
            if lane_succ is not None:
                lane_obj.successor = lane_succ
            driving_lanes.append(lane_obj)
        cum_width_right = cum_width_right + w_current

    road_obj.lanes = driving_lanes


# 修改后的 parse_oxdr_all 将参考线提取集成到 Road 类中
def parse_oxdr_all(file_path):
    """
    解析 xodr 文件，提取参考线、车道及链路信息，并构造 Road 对象（包含 Lane 列表）。
    对于 lane 链接，直接采用 xodr 文件中 lane 链接信息进行赋值，不作默认假设。
    """
    tree = ET.parse(file_path)
    root = tree.getroot()
    roads = {}

    junction_dict = {junc.get('id'): junc for junc in root.findall('.//junction')}
    # 遍历所有 road 元素，构造 Road 对象
    for road_elem in root.findall('.//road'):
        road_id = road_elem.get('id')
        length = float(road_elem.get('length', '0'))
        link = road_elem.find('link')
        predecessor = None
        successor = None
        if link is not None:
            pred = link.find('predecessor')
            succ = link.find('successor')
            if pred is not None:
                predecessor = (pred.get('elementType'), pred.get('elementId'))
            if succ is not None:
                successor = (succ.get('elementType'), succ.get('elementId'), succ.get('contactPoint'))
        junction = road_elem.get('junction', "-1")
        road_obj = Road(road_id, predecessor, successor, junction, length, on_route=False)
        compute_reference_line(road_elem, road_obj.reference_line)
        parse_driving_lanes(road_elem, junction_dict, road_obj)
        
        road_obj.compute_midpoint()
        roads[road_id] = road_obj

    return roads


def extract_lanes_in_range(roads, current_pos, sensing_range):
    cx, cy = current_pos
    for road in roads.values():
        road_in_range = False
        for lane in road.lanes:
            for (x, y) in lane.sampled_points:
                if np.linalg.norm([x-cx, y-cy]) <= sensing_range:
                    lane.in_range = True
                    road_in_range = True
                    break
        road.on_route = road_in_range

def generate_circle_points(center, radius, num_points=50):
    cx, cy = center
    theta = np.linspace(0, 2*np.pi, num_points)
    x_circle = (cx + radius * np.cos(theta)).tolist()
    y_circle = (cy + radius * np.sin(theta)).tolist()
    return x_circle, y_circle

def visualize_lanes(roads, current_pos, sensing_range):
    fig = go.Figure()
    for road in roads.values():
        for lane in road.lanes:
            if not lane.sampled_points:
                continue
            xs, ys = zip(*lane.sampled_points)
            if lane.in_range:
                color = 'blue'
            elif road.junction != "-1":
                color = '#f27a0d'
            else:
                color = 'grey'
            # label = f'Road {road.road_id} Lane {lane.lane_id} ({lane.lane_type})'
            if road.junction =='-1':
                label = f'R{road.road_id}L{lane.lane_id}'
            else:
                label = f'R{road.road_id}L{lane.lane_id}J{road.junction}'
            fig.add_trace(go.Scatter(
                x=xs, y=ys,
                mode='lines',
                line=dict(color=color, width=2),
                name=label
            ))
        
    cx, cy = current_pos
    fig.add_trace(go.Scatter(
        x=[cx], y=[cy],
        mode='markers',
        marker=dict(color='red', size=12, symbol='x'),
        name='Current Position'
    ))
    circle_x, circle_y = generate_circle_points(current_pos, sensing_range)
    fig.add_trace(go.Scatter(
        x=circle_x, y=circle_y,
        mode='lines',
        line=dict(color='red', dash='dash'),
        name='Sensing Range'
    ))
    fig.update_layout(
        title='Driving Lanes Visualization',
        xaxis_title='X',
        yaxis_title='Y',
        legend_title='Legend',
        xaxis=dict(scaleanchor="y", scaleratio=1),
        template="plotly_white",
        width=700,
        height=500,
    )
    fig.show()
    
def build_topology_graph_lanes(roads):
    """
    构建 lane 级别的拓扑图：
      - 节点为每个包含采样点的 Lane，节点ID 格式为 "roadID_laneID"；
      - 节点位置为 Lane 采样点列表中位于中点处的点；
      - 边基于 Lane.predecessor 和 Lane.successor 建立：
          如果当前 Lane 的 predecessor 非 None，则认为其目标 Lane 位于所属 Road 的前驱 Road 中，
          即构造新键为 (road.predecessor[1], lane.predecessor)；
          同理，successor 使用 (road.successor[1], lane.successor)；
      - 附带属性：road_id, lane_id, lane_type, junction.
    """
    G = nx.DiGraph()
    # 添加 Lane 节点；节点ID 以 "roadID_lane_laneID" 命名
    for road in roads.values():
        for lane in road.lanes:
            if not lane.sampled_points:
                continue
            mid_index = len(lane.sampled_points) // 2
            pos = lane.sampled_points[mid_index]
            node_id = f"{road.road_id}_lane_{lane.lane_id}"
            # 保存所属 road 的 junction 信息（用于后续着色）
            G.add_node(node_id, pos=pos, road_id=road.road_id, lane_id=lane.lane_id,
                       lane_type=lane.lane_type, junction=road.junction)
    
    # 构建辅助字典，key 为 (road_id, lane_id)，value 为节点ID
    lane_node_dict = {}
    for road in roads.values():
        for lane in road.lanes:
            if not lane.sampled_points:
                continue
            node_id = f"{road.road_id}_lane_{lane.lane_id}"
            lane_node_dict[(road.road_id, lane.lane_id)] = node_id
    # print(f"lane_node_dict{lane_node_dict}")
    # 添加边：检查每个 Lane 的 predecessor 和 successor
    for road in roads.values():
        for lane in road.lanes:
            # print(f"road{road.road_id} lane{lane.lane_id}")
            if not lane.sampled_points:
                continue
            current_node = lane_node_dict.get((road.road_id, lane.lane_id))
            if not current_node:
                continue

            # 处理 predecessor
            if lane.predecessor is not None:
                for pred in lane.predecessor:
                    if road.predecessor is not None:
                        if pred in lane_node_dict:
                            pred_node = lane_node_dict[pred]
                            pos1 = np.array(G.nodes[pred_node]['pos'])
                            pos2 = np.array(G.nodes[current_node]['pos'])
                            weight = 0
                            # print(f"添加前续边 from {pred_node} to {current_node}")
                            G.add_edge(pred_node, current_node, weight=weight)

            # 处理 successor
            if lane.successor is not None:
                for succ in lane.successor:
                    if road.successor is not None:
                        if succ in lane_node_dict:
                            succ_node = lane_node_dict[succ]
                            pos1 = np.array(G.nodes[current_node]['pos'])
                            pos2 = np.array(G.nodes[succ_node]['pos'])
                            weight = 0
                            # print(f"添加后继边 from {current_node} to {succ_node}")
                            G.add_edge(current_node, succ_node, weight=weight)
    return G

def build_topology_graph_roads(roads):
    """
    构建 Road 级别的拓扑图，然后将所有 junction != "-1" 的 Road 节点合并为 aggregated 节点。
    
    具体步骤：
      1. 构建 Road 节点，其位置取 Road.midpoint，边根据 predecessor/successor 关系构建，
         权重为两节点的欧氏距离；
      2. 将所有 junction != "-1" 的 Road 节点分组，每组生成一个 aggregated 节点，
         该节点位置取组内所有节点的中点平均值，并重构与外部节点的边关系；
         
    返回：合并后的 networkx.DiGraph 图。
    """
    # 构建 Road 拓扑图
    G = nx.DiGraph()
    for road in roads.values():
        # 仅当 road 内存在 driving 车道时才加入图中
        if not road.lanes:
            continue
        is_junction_road = (road.junction != "-1")
        G.add_node(road.road_id,
                   pos=road.midpoint,
                   junction=road.junction,
                   road_id=road.road_id,
                   is_junction_road=is_junction_road)
    # 添加边：基于 predecessor/successor 关系
    for road in roads.values():
        if road.road_id not in G.nodes:
            continue
        pred = road.predecessor
        succ = road.successor
        if pred is not None:
            etype, eid = pred
            if etype == "road" and eid in G.nodes:
                pos1 = np.array(G.nodes[eid]['pos'])
                pos2 = np.array(G.nodes[road.road_id]['pos'])
                weight = 0
                G.add_edge(eid, road.road_id, weight=weight)
        if succ is not None:
            etype, eid, cp = succ
            if etype == "road" and eid in G.nodes:
                pos1 = np.array(G.nodes[road.road_id]['pos'])
                pos2 = np.array(G.nodes[eid]['pos'])
                weight = 0
                G.add_edge(road.road_id, eid, weight=weight)
    
    # 简化图：合并所有 junction != "-1" 的 Road 节点
    groups = {}
    for node, data in G.nodes(data=True):
        junction = data.get("junction", "-1")
        if junction != "-1":
            groups.setdefault(junction, []).append(node)
    
    simplified_G = nx.DiGraph()
    # 添加所有非 junction 节点（junction == "-1"）
    non_group_nodes = [node for node, data in G.nodes(data=True) if data.get("junction", "-1") == "-1"]
    for node in non_group_nodes:
        simplified_G.add_node(node, **G.nodes[node])
    # 添加非分组节点间的边
    for u, v, d in G.edges(data=True):
        if u in non_group_nodes and v in non_group_nodes:
            simplified_G.add_edge(u, v, **d)
    # 对每个 junction 分组生成 aggregated 节点
    for junc, nodes in groups.items():
        positions = [np.array(G.nodes[n]['pos']) for n in nodes if 'pos' in G.nodes[n]]
        center = tuple(np.mean(positions, axis=0)) if positions else (0, 0)
        agg_node = f"Junction_{junc}"
        # 注意：aggregated 节点也使用橙色显示，junction 信息保留
        simplified_G.add_node(agg_node, pos=center, junction=junc, is_junction_road=True)
        # 对该组内的所有节点，将它们与外部的边重写到 aggregated 节点
        for n in nodes:
            for u, v, d in G.out_edges(n, data=True):
                if v in nodes:
                    continue
                else:
                    simplified_G.add_edge(agg_node, v, **d)
            for u, v, d in G.in_edges(n, data=True):
                if u in nodes:
                    continue
                else:
                    simplified_G.add_edge(u, agg_node, **d)
    return simplified_G

def visualize_topology_combined(roads, detailed=True):
    """
    使用 Plotly 可视化拓扑图：
      - detailed=True：显示每个 lane 的节点，节点位置为 lane 采样点的中点，
                         边基于 lane 的 predecessor/successor 关系；
      - detailed=False：显示 Road 之间的拓扑关系，节点位置取 Road.midpoint，
                         并对 junction 路段进行聚合显示。
      对于属于 junction（junction != "-1"）的节点，均显示为橙黄色。
    """
    if detailed:
        G = build_topology_graph_lanes(roads)
    else:
        G = build_topology_graph_roads(roads)
        
    pos = nx.get_node_attributes(G, 'pos')
    
    # 构建边集
    edge_x, edge_y = [], []
    for u, v in G.edges():
        x0, y0 = pos[u]
        x1, y1 = pos[v]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])
    edge_trace = go.Scatter(
        x=edge_x,
        y=edge_y,
        line=dict(width=1, color='#888'),
        hoverinfo='none',
        mode='lines'
    )
    
    node_x, node_y, node_text, node_color = [], [], [], []
    for node, data in G.nodes(data=True):
        x, y = pos[node]
        node_x.append(x)
        node_y.append(y)
        # 在 detailed 模式下显示 lane 信息，同时判断所属 road 的 junction 状态
        if detailed:
            # text = f"{data['road_id']}\nLane {data['lane_id']} ({data['lane_type']})"
            text = f"{data['road_id']}_{data['lane_id']}"
        else:
            text = str(node)
        node_text.append(text)
        # 若节点关联的 junction 不为 "-1"（包括aggregated节点），则使用橙色显示，否则使用天蓝色
        if data.get('junction', "-1") != "-1":
            node_color.append('orange')
        else:
            node_color.append('skyblue')
    
    node_trace = go.Scatter(
        x=node_x,
        y=node_y,
        mode='markers+text',
        text=node_text,
        textposition="top center",
        hoverinfo='text',
        marker=dict(
            showscale=False,
            color=node_color,
            size=20,
            line_width=2
        )
    )
    
    fig = go.Figure(data=[edge_trace, node_trace],
                    layout=go.Layout(
                        title=dict(text='Topology Graph'if detailed else 'Topology Graph (Simplify)', font=dict(size=16)),
                        showlegend=False,
                        hovermode='closest',
                        margin=dict(b=20, l=5, r=5, t=40),
                        xaxis=dict(scaleanchor="y", scaleratio=1, showgrid=False, zeroline=False, showticklabels=False),
                        yaxis=dict(scaleanchor="x", scaleratio=1, showgrid=False, zeroline=False, showticklabels=False),
                        width=700,
                        height=500,
                    ))
    fig.show()

In [None]:
#%%
# ## 示例使用
#
# 这里我们先使用模拟数据构造一个 oxdr 地图数据。
# 若有实际的文件，可使用 load_oxdr_map 加载文件内容。
#
# 模拟数据格式参考上述假设

# 模拟的 oxdr 数据
# file_path = "./road_map/parking1.xodr"
# file_path = "./road_map/parking2.xodr"
# file_path = "./road_map/parking3.xodr"
file_path = "./road_map/Town10HD_Opt.xodr"
roads = parse_oxdr_all(file_path)

# 设置当前位置与感知范围
current_position = (0.0, 0.0)  # 例如：x=100米, y=200米
sensing_range = 30.0  # 例如：50米感知半径

# 提取感知范围内的车道
extract_lanes_in_range(roads, current_position, sensing_range)

# 可视化结果
visualize_lanes(roads, current_position, sensing_range)


# 调用可视化拓扑关系（可选）
visualize_topology_combined(roads, detailed=True)
visualize_topology_combined(roads, detailed=False)