In [1]:
import sys
sys.path.append("path/to/utils")
import utils.congestion_track
import utils.data_loader
import utils.visualization
import utils.process_func
import utils.data_save

import json
import ast
import re
from datetime import datetime

from langchain.tools import Tool
from langchain.chat_models import ChatOpenAI
from langchain.agents import initialize_agent, AgentType


import pandas as pd
import geopandas as gpd
from shapely import wkt
from shapely.geometry import Point, Polygon, LineString, MultiLineString, MultiPoint, GeometryCollection
import statistics

import importlib
importlib.reload(utils.congestion_track)  # 重新加载修改后的模块
importlib.reload(utils.data_loader)  # 重新加载修改后的模块
importlib.reload(utils.visualization)  # 重新加载修改后的模块
importlib.reload(utils.process_func)  # 重新加载修改后的模块
importlib.reload(utils.data_save)  # 重新加载修改后的模块



<module 'utils.data_save' from 'D:\\myJupyter\\paper2\\application\\utils\\data_save.py'>

In [None]:
from py2neo import Node, Relationship, Graph, NodeMatcher, RelationshipMatcher, Subgraph

g = Graph('bolt://localhost:7687',auth=('neo4j','your password'), name = 'trkg')

In [None]:
# 初始化 LLM
llm = ChatOpenAI(model="gpt-4o", 
    api_key="your api",
    base_url="your url",
    temperature=0)

  llm = ChatOpenAI(model="gpt-4o",


In [4]:
# 提取目标类型
def extract_key_information(task_description):
    """
    提取用户查询中的目标类型，避免返回 None
    """
    target_types = ["道路查询", "区域拥堵状况", "道路拥堵状况"]#"POI查询", "道路拥堵位置", "拥堵原因"
    response = llm.predict(
        f"""
        用户的问题是：{task_description}
        请从以下目标类型中必须选择最符合的一个：{target_types}
        确定后，只返回：{{"target_type": ...}}。
        """
    )
    try:
        # 解析输入
        query_results = utils.process_func.parse_query_results(response)
    
        # 确保 `query_results` 是列表
        if isinstance(query_results, list):
            target_type = query_results
        elif isinstance(query_results, dict):
            target_type = query_results.get("target_type", [])
        else:
            return {"error": "query_results 不是有效的格式"}
        return target_type
    except json.JSONDecodeError:
        return {"target_type": "未知"}


# POI查询
def query_poi_fulltext_index(task_description):
    """
    提取地点信息并使用 Neo4j Fulltext 索引查询指定关键词的节点信息
    """
    response = llm.predict(
        f"""
        用户的问题是：{task_description}
        请提取问题中具体的地点信息，必须是具体的建筑、商场、地铁站、景点、或街道名称，不要使用'上海'。只返回： {{"location": ...}}。
        """
    )
    # print(f"response：{response}")
    try:
        # 解析输入
        query_results = utils.process_func.parse_query_results(response)
    
        # 确保 `query_results` 是列表
        if isinstance(query_results, list):
            location = query_results
        elif isinstance(query_results, dict):
            location = query_results.get("location", "无地点信息")
        else:
            return {"error": "query_results 不是有效的格式"}
        if location == "无地点信息" or location == "上海":
            return {"location": location, "poi_results": []}

        # Neo4j 查询
        query = f"""
        CALL db.index.fulltext.queryNodes("poi_fulltext_index", "{location}") 
        YIELD node, score
        RETURN node.name AS poi_id, node.poi_name AS poi_name, score
        ORDER BY score DESC
        LIMIT 10
        """
        result = g.run(query).data()
        return {"location": location, "poi_results": result}

    except json.JSONDecodeError:
        return {"error": "解析地点信息失败"}

# 查询 POI 的最佳匹配
def find_best_match(query_results):
    """
    在 POI 结果中找到最符合关键词的 POI
    """
    # 解析输入
    query_results = utils.process_func.parse_query_results(query_results)

    # 确保 `query_results` 是列表
    if isinstance(query_results, list):
        keyword = "未知关键词"
        poi_results = query_results  # 直接作为 POI 列表
    elif isinstance(query_results, dict):
        keyword = query_results.get("location", "未知地点")
        poi_results = query_results.get("poi_results", [])
    else:
        return {"error": "query_results 不是有效的格式"}

    # 确保 `poi_results` 不是空列表
    if not poi_results:
        return {"error": "未找到相关 POI 信息"}

    # 让 LLM 选择最佳 POI（确保 JSON 格式输出）
    response = llm.predict(
        f"""
        用户的关键词是：{keyword}
        下面是查询到的 POI 结果：
        {json.dumps(poi_results, ensure_ascii=False)}

        请综合 `score` 分数和 POI 与问题的相关性，选出最优 POI。
        只返回 JSON 格式：{{"poi_id": ..., "poi_name": ..., "score": ...}}
        """
    )

    try:
        best_poi = json.loads(response)

        # 确保 `best_poi` 结构正确
        if isinstance(best_poi, dict) and "poi_id" in best_poi and "poi_name" in best_poi:
            return best_poi

        # 如果 `LLM` 返回错误，默认选 `score` 最高的 POI
        return max(poi_results, key=lambda x: x["score"])
    except json.JSONDecodeError:
        return max(poi_results, key=lambda x: x["score"])  # 兜底策略

# 查询 honeycomb 关系
def query_honeycomb_contains(best_poi):
    """
    查询包含指定 POI 的 honeycomb
    """
    # print(f"best_poi：{best_poi}")
    # 解析输入
    query_results = utils.process_func.parse_query_results(best_poi)

    # 确保 `query_results` 是列表
    if isinstance(query_results, list):
        poi_id = query_results
    elif isinstance(query_results, dict):
        poi_id = query_results.get("poi_id", "未知poi_id")
    else:
        return {"error": "query_results 不是有效的格式"}
    
    try:
        query = f"""
        MATCH (n:honeycomb)-[r:contains]->(m:POI)
        WHERE m.name = {poi_id}
        RETURN n.name
        """
        result = g.run(query).data()
        return result
    except Exception as e:
        return {"error": str(e)}

# 通过 honeycomb 查询相关道路
def query_honeycomb_within_and_neighbors_roads(honeycomb_id, depth=2):
    """
    查询 honeycomb 关联的道路 + 指定深度的邻域道路，并去重
    :param honeycomb_id: 目标 honeycomb 的 ID
    :param depth: 查询邻域的深度 (默认 1)
    :return: 去重后的道路信息列表
    """
    try:
        # 查询自身关联的道路
        query_self = f"""
        MATCH (n:honeycomb)<-[:within]-(r:road)
        WHERE n.name = {honeycomb_id}
        RETURN n.name AS honeycomb_id, r.roadname AS road_name, r.name AS road_id
        """
        self_roads = g.run(query_self).data()

        # 查询邻域及其关联的道路
        query_neighbors = f"""
        MATCH (n:honeycomb)-[:adjacency*1..{depth}]-(neighbor:honeycomb)
        WHERE n.name = {honeycomb_id}
        MATCH (neighbor)<-[:within]-(r:road)
        RETURN neighbor.name AS honeycomb_id, r.roadname AS road_name, r.name AS road_id
        """
        neighbor_roads = g.run(query_neighbors).data()

        # 合并数据
        final_result = self_roads + neighbor_roads

        # 去重
        df = pd.DataFrame(final_result)
        df = df.drop_duplicates(subset=["road_id"])  # 以 road_id 为准去重

        return df.to_dict(orient="records")  # 转换回字典列表

    except Exception as e:
        return {"error": str(e)}

# 提取时间范围
def extract_time_range(task_description):
    """
    从用户输入中提取时间范围，返回 JSON 格式的时间信息
    """
    prompt = f"""
    你是一个专业的信息提取助手。
    用户的问题是：{task_description}

    请从用户的问题中提取时间范围，包括 'start_time' 和 'end_time'。
    时间范围的格式为 仅时间（如 17:00:00），或带日期（如 2025-03-24 17:00:00）。
    如果未明确指定时间范围，则返回 {{"start_time": "无时间范围", "end_time": "无时间范围"}}。

    只需要返回 JSON 格式：
    {{"start_time": "<开始时间>", "end_time": "<结束时间>"}}
    """

    # 调用 LangChain LLM 处理
    response = llm.predict(prompt)

    # 解析 LLM 结果
    try:
        match = re.search(r"{.*}", response, re.DOTALL)
        if not match:
            raise ValueError(f"未找到合法的 JSON 内容：{response}")

        time_info = json.loads(match.group(0))
        return time_info
    except json.JSONDecodeError:
        return {"start_time": "无时间范围", "end_time": "无时间范围"}

# 筛选符合条件的 batches
def filter_batches_by_time(task_description):
    """
    根据用户描述的时间范围筛选符合条件的批次
    Args:
        batch_data (pd.DataFrame): 包含批次信息的 DataFrame，必须包含 'start_datetime' 和 'end_datetime' 列
        task_description (str): 用户任务描述，包括时间范围的信息
    Returns:
        list: 符合条件的批次列表
    """
    try:
        # 先提取时间范围
        time_info = extract_time_range(task_description)
        start_time_str = time_info.get("start_time", "无时间范围")
        end_time_str = time_info.get("end_time", "无时间范围")

        if start_time_str == "无时间范围" or end_time_str == "无时间范围":
            print("未指定有效的时间范围，无法筛选批次。")
            return []
    # 读取 batch_time DataFrame
        # 读取 CSV 文件
        file_path = r'F:\paper2\result\batch_time(..].csv'
        batch_data = utils.data_loader.CSV_load(file_path)
        # 判断是否包含日期
        has_date = utils.process_func.is_datetime_with_date(start_time_str) and utils.process_func.is_datetime_with_date(end_time_str)

        if has_date:
            # 如果包含日期，则精确按 datetime 筛选
            start_time = datetime.strptime(start_time_str, "%Y-%m-%d %H:%M:%S")
            # ✅ 特殊处理 00:00:00，将其视为当日 23:59:59
            if end_time_str.endswith("00:00:00"):
                base_date = end_time_str.split()[0]
                end_time = datetime.strptime(base_date + " 23:59:59", "%Y-%m-%d %H:%M:%S")
            else:
                end_time = datetime.strptime(end_time_str, "%Y-%m-%d %H:%M:%S")
            filtered_batches = batch_data[
                (batch_data['start_datetime'] >= start_time) & (batch_data['end_datetime'] <= end_time)
            ]
        else:
            # 只处理时间范围（不含日期）
            start_time = datetime.strptime(start_time_str, "%H:%M:%S").time()
            end_time = datetime.strptime(end_time_str, "%H:%M:%S").time()
            # 00:00:00 特殊处理
            if end_time == datetime.strptime("00:00:00", "%H:%M:%S").time():
                end_time = datetime.strptime("23:59:59", "%H:%M:%S").time()

        # 用于对比，可注释
            # ❌ 排除特定日期/区间
            # 👉 可配置过滤逻辑：以下任选一段打开
            
            ## 方式 1：排除指定日期（单个）
            # excluded_date = datetime.strptime("2015-04-21", "%Y-%m-%d").date()
            # batch_data = batch_data[batch_data['start_date_only'] != excluded_date]
            
            ## 方式 2：排除大于某日期（如只保留 4月21日之前的）
            # cutoff_date = datetime.strptime("2015-04-21", "%Y-%m-%d").date()
            # batch_data = batch_data[batch_data['start_date_only'] <= cutoff_date]
            
            ## 方式 3：只选择某个日期区间（保留该区间的数据）
            # select_start_date = datetime.strptime("2015-04-22", "%Y-%m-%d").date()
            # select_end_date = datetime.strptime("2015-05-01", "%Y-%m-%d").date()
            # batch_data = batch_data[
            #     batch_data['start_date_only'].between(select_start_date, select_end_date)
            # ]
            
            if start_time > end_time:  # 跨午夜
                filtered_batches = batch_data[
                    (batch_data['start_time_only'] >= start_time) |
                    (batch_data['end_time_only'] <= end_time)
                ]
            else:
                filtered_batches = batch_data[
                    (batch_data['start_time_only'] >= start_time) &
                    (batch_data['end_time_only'] <= end_time)
                ]


        print(f"✅ [DEBUG] start_time: {start_time}, end_time: {end_time}")
        print(f"✅ [DEBUG] filtered_batches:\n{filtered_batches}")

        # 获取符合条件的批次列表
        batch_list = filtered_batches['batch'].to_list()
        return batch_list
    except Exception as e:
        print(f"❌ [ERROR] 筛选批次时出错：{e}")
        return []
    
def query_trajectory_points_by_honeycomb(args):
    """
    查询 `trajectory_point` 数据，并输出格网级拥堵结果。
    """
    print(f"args：{args}")
    try:
    #数据读取
        # 读取 CSV 文件
        file_path = r'F:\paper2\result\batch_time(..].csv'
        batch_data_df = utils.data_loader.CSV_load(file_path)
        # 读取 CRS 文件
        crs_path = r"F:/paper2/result/roadcrs.txt"
        roadcrs = utils.data_loader.CRS_load(crs_path)
        
        # 解析输入
        query_results = utils.process_func.parse_query_results(args)
    
        # 确保 `query_results` 是字典
        if isinstance(query_results, dict):
            honeycomb_ids_ls = query_results.get("honeycomb_ids_ls", [])
            filtered_batches_ls = query_results.get("filtered_batches_ls", [])
            road_ids_ls = query_results.get("road_ids_ls", [])
            poi_ids_ls = query_results.get("poi_ids_ls", [])
        else:
            return {"error": "query_results 不是有效的格式"}
        print(f"honeycomb_ids_ls：{honeycomb_ids_ls}")
        print(f"filtered_batches_ls：{filtered_batches_ls}")
        print(f"road_ids_ls：{road_ids_ls}")
        print(f"poi_ids_ls：{poi_ids_ls}")
    
        # 如果没有 `honeycomb_id`，返回空 DataFrame
        if not honeycomb_ids_ls:
            return pd.DataFrame(columns=['batch', 'grid_congestion_level', 'datetime_range'])
    
        # 转换 `filtered_batches_ls` 为 Cypher 查询格式
        filtered_batches_str = ', '.join(f"'{batch}'" for batch in filtered_batches_ls)
        all_results = []
    
        for honeycomb_name in honeycomb_ids_ls:
            # 构建 Cypher 查询
            query = f"""
            MATCH (tp:trajectory_point)
            WHERE tp.batch IN [{filtered_batches_str}]
            AND tp.grid_congestion_level IS NOT NULL
            AND tp.honeycomb_name = {honeycomb_name}
            RETURN DISTINCT tp.batch AS batch, 
                            tp.grid_congestion_level AS grid_congestion_level, 
                            tp.datetime AS datetime, tp.geometry AS geometry,
                            tp.vehicle AS vehicle, tp.passenger AS passenger_state, tp.honeycomb_name AS honeycomb_name
            """
    
            result = g.run(query).data()
            all_results.extend(result)
    
        # 如果没有查询结果，返回空 DataFrame
        if not all_results:
            return pd.DataFrame(columns=['batch', 'grid_congestion_level', 'datetime_range'])
    
        # 转换结果为 DataFrame
        df = pd.DataFrame(all_results)
        
        # 处理时间格式（转换为 HH:MM:SS）
        if 'datetime' in df.columns:
            df['datetime'] = pd.to_datetime(df['datetime'], errors='coerce').dt.strftime('%H:%M:%S')
            # 按 'vehicle' 和 'batch' 分组，并在每组内按 'datetime' 升序排序
        df_sorted = df.sort_values(by=['vehicle', 'batch', 'datetime'])
        df_sorted['index_helper'] = range(len(df_sorted))
        df_summary = utils.process_func.calculate_df_summary(df_sorted)

        # 初始化一个空列表，用于存储扁平化的数据
        honeycomb_list = []
        if not df_summary.empty:
            # 将 'geometry' 列中的 WKT 字符串转换为 Shapely 几何对象
            df_sorted['geometry'] = df_sorted['geometry'].apply(wkt.loads)
            # 转换为 GeoDataFrame
            df_sorted = gpd.GeoDataFrame(df_sorted, geometry='geometry', crs=roadcrs)
            # 设置 'trajectory_name' 为索引
            df_sorted.set_index('index_helper', inplace=True)
        
            # 遍历 fid_time_summary 的每一行
            for index, row in df_summary.iterrows():
                batch = row['batch']
                vehicle = row['vehicle']
                grid_congestion_level = row['grid_congestion_level']
                    
                min_name = row['min_name']
                max_name = row['max_name']
                # 直接使用 min_name 和 max_name 作为行索引进行切片
                between_rows = df_sorted.loc[min_name:max_name]
                    
                # 提取几何点并构建 LineString
                line = LineString(between_rows['geometry'].tolist())
                
                # 计算总长度
                total_length = line.length
                                        
                # 如果总长度小于 10 米，视为泊车状态，废弃该数据;如果总长度大于 600 米，视为错误数据，废弃该数据
                if total_length <= 10
                    continue
                
                # 只取 between_rows 的第一行
                first_row = between_rows.iloc[0]
                honeycomb_id = first_row['honeycomb_name']  # 从首行取 honeycomb_name
                
                # 更新 batch 键
                batch_row = batch_data_df[batch_data_df['batch'] == batch]
                if not batch_row.empty:
                    start_datetime = batch_row.iloc[0]['start_time_only']
                    end_datetime = batch_row.iloc[0]['end_time_only']
                    batch = f"{start_datetime}-{end_datetime}"
        
                # 检查 `honeycomb_id` 和 `batch` 是否已存在，若存在则合并 `grid_congestion_level`
                found = False
                for item in honeycomb_list:
                    if item["honeycomb_id"] == honeycomb_id and item["batch"] == batch:
                        item["grid_congestion_level"].append(grid_congestion_level)
                        found = True
                        break
                
                # 如果没找到，则创建新条目
                if not found:
                    honeycomb_list.append({
                        "honeycomb_id": honeycomb_id,
                        "batch": batch,
                        "grid_congestion_level": [grid_congestion_level]  # 初始化为列表
                    })
        # 可注释      
        # utils.data_save.save_honeycomb_list_to_file(honeycomb_list,
        #                 output_path=r"F:\paper2\result\excel\grid_congestion\19_20-20_20 4.21",
        #                 file_type="xlsx")
                    
        # 自定义优先级（值越小优先级越高）
        priority = {
            "smooth": 0,
            "light_congestion": 1,
            "congestion": 2,
            "severe_congestion": 3
        }
        
        # 计算每个 batch 的 grid_congestion_level 众数（考虑多个众数时按优先级取）
        for item in honeycomb_list:
            grid_levels = item["grid_congestion_level"]
        
            if grid_levels:
                try:
                    # 正常情况：有唯一众数
                    most_common_value = statistics.mode(grid_levels)
                except statistics.StatisticsError:
                    # 多个众数时：按自定义优先级选择
                    counter = Counter(grid_levels)
                    max_freq = max(counter.values())
                    candidates = [k for k, v in counter.items() if v == max_freq]
        
                    # 使用优先级表选择最优项
                    most_common_value = sorted(candidates, key=lambda x: priority.get(x, float('inf')))[0]
        
                # 更新为最终选择的众数
                item["grid_congestion_level"] = most_common_value
            else:
                item["grid_congestion_level"] = None
        
        # 打印最终存储的结果
        print("\nFinal structured data:\n")
        print(honeycomb_list)
        response = llm.predict(f"""
        Here is a list of congestion data:

        {json.dumps(honeycomb_list, indent=4, default=str)}

        Please format this into a natural language summary, describing the congestion situation by road name (if available) and honeycomb ID, including congestion levels and time periods.
        The response should be structured like this:

        'The congestion situation around the [poi_name] in Shanghai from [start_time] to [end_time] is as follows:
        - On Grid [honeycomb_id], there was congestion from [time_range], and it was smooth otherwise.
        - ...'

        Ensure clarity and readability, and present it as the Final Answer.
        """)
        plot_type = 'grid congestion'
        
        utils.visualization.plot_traffic_congestion(honeycomb_ids_ls, road_ids_ls, poi_ids_ls, honeycomb_list, roadcrs, g, plot_type)
        
        return response
    except Exception as e:
        print(f"❌ [ERROR] 执行时出错：{e}")
        return pd.DataFrame(columns=['batch', 'grid_congestion_level', 'datetime_range'])


def query_trajectory_points_by_batch(args):
    """
    查询 `trajectory_point` 数据，并输出道路级拥堵结果。
    """
    print(f"args：{args}")
    try:
    #数据读取
        # 读取 CSV 文件
        file_path = r'F:\paper2\result\batch_time(..].csv'
        batch_data_df = utils.data_loader.CSV_load(file_path)
        # 读取 CRS 文件
        crs_path = r"F:/paper2/result/roadcrs.txt"
        roadcrs = utils.data_loader.CRS_load(crs_path)
        # 从本地加载 honeycomb_cache
        pkl_path_1 = r"F:/paper2/result/honeycomb_cache.pkl"
        honeycomb_cache = utils.data_loader.pkl_load(pkl_path_1)
        # 读取path字典
        pkl_path_2 = r"F:/paper2/result/honeycomb_paths_dict.pkl"
        honeycomb_paths_dict = utils.data_loader.pkl_load(pkl_path_2)
        
        # 解析输入
        query_results = utils.process_func.parse_query_results(args)
    
        # 确保 `query_results` 是字典
        if isinstance(query_results, dict):
            honeycomb_ids_ls = query_results.get("honeycomb_ids_ls", [])
            filtered_batches_ls = query_results.get("filtered_batches_ls", [])
            road_ids_ls = query_results.get("road_ids_ls", [])
            poi_ids_ls = query_results.get("poi_ids_ls", [])
        else:
            return {"error": "query_results 不是有效的格式"}
        print(f"honeycomb_ids_ls：{honeycomb_ids_ls}")
        print(f"filtered_batches_ls：{filtered_batches_ls}")
        print(f"road_ids_ls：{road_ids_ls}")
        print(f"poi_ids_ls：{poi_ids_ls}")
    
        # 如果没有 `honeycomb_id`，返回空 DataFrame
        if not honeycomb_ids_ls:
            return pd.DataFrame(columns=['batch', 'grid_congestion_level', 'datetime_range'])
    
        # 转换 `filtered_batches_ls` 为 Cypher 查询格式
        filtered_batches_str = ', '.join(f"'{batch}'" for batch in filtered_batches_ls)
        all_results = []
    
        for honeycomb_name in honeycomb_ids_ls:
            # 构建 Cypher 查询
            query = f"""
            MATCH (tp:trajectory_point)
            WHERE tp.batch IN [{filtered_batches_str}]
            AND tp.grid_congestion_level IS NOT NULL
            AND tp.honeycomb_name = {honeycomb_name}
            RETURN tp.batch AS batch, tp.vehicle AS vehicle, tp.honeycomb_name AS honeycomb_name
            """
    
            # 执行查询
            result = g.run(query).data()
            all_results.extend(result)
        
        # 如果没有查询结果，直接返回空 DataFrame
        if not all_results:
            pd.DataFrame(columns=['batch', 'grid_congestion_level', 'datetime_range'])
        
        # 转换结果为 DataFrame
        df = pd.DataFrame(all_results)
        
        # 去重操作，依据 'batch' 和 'vehicle' 列
        vehicle_batch_df = df.drop_duplicates(subset=['batch', 'vehicle']).reset_index(drop=True)
        
        # 初始化一个空列表，用于存储扁平化的数据
        road_list = []
        for _, row in vehicle_batch_df.iterrows():
            vehicle = row['vehicle']
            batch = row['batch']
            # print(f"\nbatch: {batch}, vehicle: {vehicle}")
            cypher = f"""
                MATCH (tp:trajectory_point)-[:located_in]->(hc:honeycomb)
                WHERE tp.vehicle = {vehicle} AND tp.batch = '{batch}'
                AND (tp)-[:next]-()
                RETURN tp.batch AS batch, tp.geometry AS geometry, tp.name AS trajectory_name, tp.vehicle AS vehicle, tp.datetime AS datetime,
                       hc.name AS honeycomb_name
                ORDER BY tp.name
            """
            # 执行查询并获取结果
            result = g.run(cypher).data()
        
            # 检查是否有返回数据
            if result:
                # 将结果转换为 DataFrame
                node_df = pd.DataFrame(result)
                # 确保 datetime 列是 datetime 格式
                node_df['datetime'] = pd.to_datetime(node_df['datetime'], format="%Y-%m-%d %H:%M:%S")
            
                # trajectory_point 节点数大于 1 才进行拥堵分析
                if len(node_df) > 1:
                    fid_time_summary = utils.congestion_track.calculate_fid_time_summary(node_df)
                    # print(f"fid_time_summary:\n {fid_time_summary}")
                    
                    if not fid_time_summary.empty:
                        # 将 'geometry' 列中的 WKT 字符串转换为 Shapely 几何对象
                        node_df['geometry'] = node_df['geometry'].apply(wkt.loads)
                        # 转换为 GeoDataFrame
                        node_df = gpd.GeoDataFrame(node_df, geometry='geometry', crs=roadcrs)
                        # 设置 'trajectory_name' 为索引
                        node_df.set_index('trajectory_name', inplace=True)
                        # 初始化 last_point_road
                        last_point_road = None
                        prev_honeycomb_name = None
        
                        # 遍历 fid_time_summary 的每一行
                        for index, row in fid_time_summary.iterrows():
                            honeycomb_name = row['honeycomb_name']
                            # print(f"\nRow {index} with FID = {honeycomb_name}")
                            duration = row['duration']
                            # 检查 duration 是否为 0
                            if duration == 0:
                                continue  # 跳过当前行
                            min_name = row['min_name']
                            max_name = row['max_name']
                            # 直接使用 min_name 和 max_name 作为行索引进行切片
                            between_rows = node_df.loc[min_name:max_name]
                            
                            # 提取几何点并构建 LineString
                            line = LineString(between_rows['geometry'].tolist())
                            
                            # 计算总长度
                            total_length = line.length
                            # print(f"Total length of the point set: {total_length}")
                                        
                            # 如果总长度小于 10 米，视为泊车状态，废弃该数据;如果总长度大于 600 米，视为错误数据，废弃该数据
                            if total_length <= 10 or total_length >= 600:
                                continue
                                
                            road_gdf = utils.congestion_track.get_road_gdf(honeycomb_name, roadcrs, honeycomb_cache)
                            if road_gdf is None:
                                continue
                            paths = honeycomb_paths_dict.get(honeycomb_name)
        
                            # 检查is_adjacent
                            if prev_honeycomb_name is None:
                                is_adjacent = False
                                # print(f"Previous honeycomb_name is None. Setting is_adjacent to {is_adjacent}.")
                            else:
                                is_adjacent = utils.congestion_track.get_honeycomb_adjacency(g, prev_honeycomb_name, honeycomb_name)
                                # print(f"Is {prev_honeycomb_name} adjacent to {honeycomb_name}? {is_adjacent}")
        
                            # 判断是否需要使用 trkg_matching_begin
                            if last_point_road is None:
                                # print(f"\n {1}\n")
                                shortest_path, last_point_road, projected_lines = utils.congestion_track.trkg_matching_begin(paths, road_gdf, between_rows)
                                # 检查 shortest_path 是否为 None
                                if shortest_path is None:
                                    # print(f"\n {2}\n")
                                    prev_honeycomb_name = honeycomb_name  # 更新 prev_honeycomb_name
                                    continue  # 跳过当前行
                            else:
                                # print(f"\n {3}\n")
                                shortest_path, last_point_road, projected_lines = utils.congestion_track.trkg_matching_forward(g, paths, road_gdf, between_rows, last_point_road, 
                                                                                                               is_adjacent, prev_honeycomb_name, honeycomb_name)
                                if shortest_path is None:
                                    # print(f"\n {4}\n")
                                    shortest_path, last_point_road, projected_lines = utils.congestion_track.trkg_matching_begin(paths, road_gdf, between_rows)
                                    if shortest_path is None:
                                        # print(f"\n {5}\n")
                                        prev_honeycomb_name = honeycomb_name  # 更新 prev_honeycomb_name
                                        continue  # 跳过当前行
                                        
                            # 计算总长度
                            projected_lines_length = projected_lines.length
                            # 计算平均速度 v_acc
                            v_acc = projected_lines_length / duration  # 确保单位匹配
                            road_congestion_level = utils.congestion_track.classify_congestion_road(v_acc)
                            
                            # 更新 batch 键
                            batch_row = batch_data_df[batch_data_df['batch'] == batch]
                            if not batch_row.empty:
                                start_datetime = batch_row.iloc[0]['start_time_only']
                                end_datetime = batch_row.iloc[0]['end_time_only']
                                batch = f"{start_datetime}-{end_datetime}"
                
                            # 检查 `honeycomb_name` 和 `batch` 是否已存在，若存在则合并 `road_congestion_level`
                            found = False
                            for item in road_list:
                                if item["honeycomb_id"] == honeycomb_name and item["batch"] == batch and item["shortest_path"] == shortest_path:
                                    item["road_congestion_level"].append(road_congestion_level)
                                    found = True
                                    break
                            
                            # 如果没找到，则创建新条目
                            if not found:
                                road_list.append({
                                    "honeycomb_id": honeycomb_name,
                                    # "vehicle": vehicle,
                                    "batch": batch,
                                    "shortest_path": shortest_path,
                                    # "last_road_segment": last_point_road,
                                    # "shortest_path_geometry": projected_lines,
                                    "road_congestion_level": [road_congestion_level]  # 初始化为列表
                                })
        
                            # 更新 prev_honeycomb_name
                            prev_honeycomb_name = honeycomb_name
                            
        # 筛选符合条件的数据
        road_list = [road for road in road_list if road["honeycomb_id"] in honeycomb_ids_ls]
        # 优先级顺序（从高到低）
        priority = {
            "smooth": 0,
            "light_congestion": 1,
            "congestion": 2,
            "severe_congestion": 3
        }
        
        # 计算每个 batch 的 road_congestion_level 众数
        for item in road_list:
            road_levels = item["road_congestion_level"]
        
            if road_levels:
                try:
                    # 正常情况：只有一个众数
                    most_common_value = statistics.mode(road_levels)
                except statistics.StatisticsError:
                    # 有多个众数，使用自定义优先级处理
                    counter = Counter(road_levels)
                    max_freq = max(counter.values())
        
                    # 所有出现频率最高的值（并列众数）
                    candidates = [k for k, v in counter.items() if v == max_freq]
        
                    # 按照优先级排序，选择优先级最高的众数
                    most_common_value = sorted(candidates, key=lambda x: priority.get(x, float('inf')))[0]
        
                # 更新为最终选择的众数
                item["road_congestion_level"] = most_common_value
            else:
                # 若无数据
                item["road_congestion_level"] = None
                
        # 打印最终存储的结果
        print("\nFinal structured data:\n")
        print(road_list)
        response = llm.predict(f"""
        Here is a list of congestion data:

        {json.dumps(road_list, indent=4, default=str)}

        Please format this into a natural language summary, describing the congestion situation by road name (if available) and honeycomb ID, including congestion levels and time periods.
        The response should be structured like this:

        'The congestion situation around the [poi_name] in Shanghai from [start_time] to [end_time] is as follows:
        - In Grid [honeycomb_id], roads [shortest_path] experienced [road_congestion_level] congestion from [batch]
        - If a road segment was congested across multiple time periods, summarize its overall trend.
        - Highlight any severe congestion events.
        - If a road was generally smooth outside congestion periods, mention it.
        - ...'

        Ensure clarity and readability, and present it as the Final Answer.
        """)
        plot_type = 'road congestion'
        
        utils.visualization.plot_traffic_congestion(honeycomb_ids_ls, road_ids_ls, poi_ids_ls, road_list, roadcrs, g, plot_type)
        
        return response
        
    except Exception as e:
        print(f"❌ [ERROR] 执行时出错：{e}")
        return pd.DataFrame(columns=['batch', 'grid_congestion_level', 'datetime_range'])




# 初始化 LangChain Agent
tools = [
    Tool(name="extract_key_information", func=extract_key_information, description="提取用户查询的目标类型"),
    Tool(name="query_poi_fulltext_index", func=query_poi_fulltext_index, description="查询 POI 信息"),
    Tool(name="find_best_match", func=find_best_match, description="从 POI 结果中找到最佳匹配"),
    Tool(name="query_honeycomb_contains", func=query_honeycomb_contains, description="查询 POI 所属 honeycomb,使用{{'poi_id': ...}}作为输入"),
    Tool(name="query_honeycomb_within_and_neighbors_roads", func=query_honeycomb_within_and_neighbors_roads, description="查询 honeycomb 关联的道路 以及 honeycomb 关联邻域的道路"),
    Tool(name="filter_batches_by_time", func=filter_batches_by_time, description="用户问题中如果提到时间范围，筛选符合条件的批次数据"),
    Tool(name="query_trajectory_points_by_honeycomb", func=query_trajectory_points_by_honeycomb, 
         description="""
         该函数的输入需要根据道路查询结果提取'honeycomb_id'的唯一值列表放入'honeycomb_ids_ls','road_id'的唯一值列表放入'road_ids_ls';
         根据 Tool-find_best_match 的结果提取'poi_id'的唯一值列表放入'poi_ids_ls';
         以及根据 Tool-filter_batches_by_time 的查询结果生成filtered_batches_ls;
         使用字典表示。该函数用于处理"区域拥堵状况"。"""),
    Tool(name="query_trajectory_points_by_batch", func=query_trajectory_points_by_batch, 
         description="""
         该函数的输入需要根据道路查询结果提取'honeycomb_id'的完整唯一值列表放入'honeycomb_ids_ls','road_id'的完整唯一值列表放入'road_ids_ls';
         根据 Tool-find_best_match 的结果提取'poi_id'的完整唯一值列表放入'poi_ids_ls';
         以及根据 Tool-filter_batches_by_time 的查询结果生成完整列表filtered_batches_ls;
         使用字典表示。该函数用于处理"道路拥堵状况"。""")
]

agent = initialize_agent(
    tools, 
    llm, 
    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, 
    verbose=True,
    handle_parsing_errors=True,
    system_message=f"""
    你是一个智能地理查询专家，接收用户的自然语言提问并使用工具链进行多步推理，最终得出结论。
    
    请遵循以下原则：
    (1). 推理开始前，必须先输出整体思路（Thought），描述你将如何分步骤回答问题。
       格式应如下：
       Thought: To answer the question about the..., I need to follow these steps:
       1. ...
       2. ...
       3. ...
       ...
       Let's start with the first step.
    (2). 始终优先使用 extract_key_information 工具，以识别用户问题的目标类型，再调用其他工具。
    (3). 最后的回答请以 Final Answer: 开头，引用最后一个 tool 返回的结果，不需要你自行总结或重复内容。
"""
)



  agent = initialize_agent(


In [5]:
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.stdout import StdOutCallbackHandler

# ✅ 1️⃣ 初始化 CallbackManager
callback_manager = CallbackManager(handlers=[StdOutCallbackHandler()])

def run_query(task_description):
    """
    运行 LangChain 代理查询
    Args:
        task_description (str): 用户输入的查询任务
    Returns:
        str: 代理的回复
    """
    inputs = {"task_description": task_description}  # ✅ 确保 inputs 不为 None

    # ✅ 2️⃣ 触发 `on_chain_start()` 但不存储返回值
    callback_manager.on_chain_start(
        {"name": "query_agent"}, inputs
    )

    # ✅ 3️⃣ 运行代理
    result = agent.run(task_description)

    return result

In [6]:
#输出的name 使用XXX_id进行参数定义，方便LLM学习。

In [11]:
if __name__ == "__main__":
    user_task = "上海市Mercedes-Benz Arena周围的道路在下午7时20分到8时20分会在哪些路段拥堵?"
    result = run_query(user_task)
    print("回复：", result)

Error in StdOutCallbackHandler.on_chain_start callback: AttributeError("'NoneType' object has no attribute 'get'")




[1m> Entering new query_agent chain...[0m
[32;1m[1;3mTo answer this question, I need to first determine the type of POI being queried and the relevant time for the congestion analysis.

Action: extract_key_information
Action Input: 上海市Mercedes-Benz Arena周围的道路在下午7时20分到8时20分会在哪些路段拥堵?[0m
✅ [DEBUG] JSON 格式解析成功

Observation: [36;1m[1;3m道路拥堵状况[0m
Thought:[32;1m[1;3mThe task is to find out which road segments around the Mercedes-Benz Arena in Shanghai will be congested between 7:20 PM and 8:20 PM. First, I need to query the POI fulltext index for Mercedes-Benz Arena.

Action: query_poi_fulltext_index
Action Input: Mercedes-Benz Arena, Shanghai[0m
✅ [DEBUG] JSON 格式解析成功

Observation: [33;1m[1;3m{'location': 'Mercedes-Benz Arena', 'poi_results': [{'poi_id': 129312, 'poi_name': 'Mercedes-Benz Arena', 'score': 21.770397186279297}, {'poi_id': 933110, 'poi_name': 'Mercedes-Benz', 'score': 14.85614013671875}, {'poi_id': 925797, 'poi_name': 'Mercedes-Benz', 'score': 14.85614013671875}, 

Map for batch: 19:25:00-19:30:00


Map for batch: 19:30:00-19:35:00


Map for batch: 19:35:00-19:40:00


Map for batch: 19:40:00-19:45:00


Map for batch: 19:45:00-19:50:00


Map for batch: 19:50:00-19:55:00


Map for batch: 19:55:00-20:00:00


Map for batch: 20:00:00-20:05:00


Map for batch: 20:05:00-20:10:00


Map for batch: 20:10:00-20:15:00


Map for batch: 20:15:00-20:20:00



Observation: [33;1m[1;3mThe congestion situation around the area monitored in Shanghai from 19:20:00 to 20:20:00 is as follows:

- In Grid 92672:
  - Roads 61208 experienced smooth conditions from 19:20 to 19:25, and again from 19:30 to 19:40, 19:45 to 19:50, and 20:00 to 20:05.
  - Roads 61204 and 61206 showed smooth conditions from 20:05 to 20:10, while experiencing severe congestion from 19:25 to 19:30 and 19:40 to 19:45.
  - Generally, roads were smooth outside peak congestion periods.
- In Grid 92673:
  - Roads 61220, 61209, and 61215 experienced severe congestion in multiple batches. Specifically, 19:25 to 19:30, 19:45 to 19:50, 19:55 to 20:00, and 20:00 to 20:05.
  - Roads 61209, 61215, 61213, and 61219 showed recurring periods of severe congestion across the reporting times.
  - Other roads in this grid showed alternating periods of smooth and light congestion conditions.
- In Grid 93578:
  - Roads 63176 and 63177 had mixed conditions, experiencing mostly smooth conditions, 