筛选意图预测数据idx，场景：交叉路口

In [1]:
import argparse
import pickle
import copy
from argoverse.data_loading.argoverse_forecasting_loader import ArgoverseForecastingLoader
from argoverse.map_representation.map_api import ArgoverseMap
import matplotlib.pyplot as plt
from skimage.transform import rotate
from tqdm import tqdm
import numpy as np
import pandas as pd
import torch
from shapely.ops import unary_union
from shapely.geometry import Point
from shapely.geometry import Polygon

torch.cuda.set_device(0)

am = ArgoverseMap()

data_path = '/data/fyy/lanegcn/dataset/dataset/train/data'
avl = ArgoverseForecastingLoader(data_path)
avl.seq_list = sorted(avl.seq_list)

def get_traj_and_lane(idx): 
    city_name = copy.deepcopy(avl[idx].city)
    data_seq = copy.deepcopy(avl[idx].seq_df)   # (len, 6)
    timestamp = np.sort(np.unique(data_seq['TIMESTAMP'].values))
    
    mapping = dict()
    for i, ts in enumerate(timestamp):
        mapping[ts] = i
        
    # 某个场景下的所有轨迹，5s内的轨迹，(len, 2)
    trajs = np.concatenate((
            data_seq.X.to_numpy().reshape(-1, 1),
            data_seq.Y.to_numpy().reshape(-1, 1)), 1)

    steps = [mapping[x] for x in data_seq['TIMESTAMP'].values]
    steps = np.asarray(steps, np.int64)

    objs = data_seq.groupby(['TRACK_ID', 'OBJECT_TYPE']).groups
    keys = list(objs.keys())
    obj_type = [x[1] for x in keys]

    av_idx = obj_type.index('AV')  # av_index = 0，获取AV的索引
    idcs = objs[keys[av_idx]]

    av_traj = trajs[idcs]  # av_traj.shape = (50, 2)
    av_step = steps[idcs]  # av_step.shape = (50,)
    
    # 获取周围车辆的轨迹
    del keys[av_idx]
    ctx_trajs, ctx_steps = [], []  
    for key in keys:
        idcs = objs[key]
        ctx_trajs.append(trajs[idcs])
        ctx_steps.append(steps[idcs])
        
    data = dict()
    data['city'] = city_name
    data['trajs'] = [av_traj] + ctx_trajs
    data['steps'] = [av_step] + ctx_steps
    data['argo_id'] = int(avl.seq_list[idx].name[:-4])
    
    av_lane = am.get_lane_ids_in_xy_bbox(av_traj[0][0], av_traj[0][1], city_name, 5)
    return av_traj, av_lane, timestamp[0], city_name

# 读取PIT和MIA的十字路口的lane_id
intersection_data_path = [
    'intersection_data/PIT/intersection_PIT_id.pickle',
    'intersection_data/MIA/intersection_MIA_id.pickle'
]
intersection_data = {}
for path in intersection_data_path:
    with open(path, 'rb') as file:
        data = pickle.load(file)
    
    file_name = path.split('/')[-2]
    intersection_data[file_name] = data


# 统计意图预测的数据集，前20帧在停止先前，且至少15帧的轨迹在十字路口内（能够标注意图）
intention_data = []
for idx in tqdm(range(len(avl.seq_list))):
    av_traj, av_lane, _, city_name = get_traj_and_lane(idx)
    
    # 统计这条轨迹的周围路段
    traj_start, traj_end = av_traj[0], av_traj[-1]
    lane_nearest = am.get_lane_ids_in_xy_bbox(traj_start[0], traj_start[1], city_name, 5)
    lane_nearest += am.get_lane_ids_in_xy_bbox(traj_end[0], traj_end[1], city_name, 5)

    # 获取交叉路口的多边形
    intersection_polygon = []
    for lane_id in lane_nearest:
        if lane_id in intersection_data[city_name]:
            polygon = Polygon(am.get_lane_segment_polygon(lane_id, city_name))
            if polygon.is_valid:
                intersection_polygon.append(polygon)
    merged_polygon_intersection = unary_union(intersection_polygon) if intersection_polygon else None  
    
    # 获取直行道路的多边形
    straight_road_polygon = []
    for lane_id in lane_nearest:
        if lane_id in intersection_data[city_name]: continue
        polygon = Polygon(am.get_lane_segment_polygon(lane_id, city_name))
        if polygon.is_valid:
            straight_road_polygon.append(polygon)
    merged_polygon_straight_road = unary_union(straight_road_polygon) if straight_road_polygon else None
           
    # 统计轨迹在直行道路内的帧数       
    frames_inside_straight = 0
    for x, y in av_traj[:20]:
        point = Point(x, y)
        if merged_polygon_straight_road and merged_polygon_straight_road.contains(point):
            frames_inside_straight += 1
    
    # 统计轨迹在交叉路口内的帧数
    frames_inside_intersection = 0
    for x, y in av_traj[20:]:
        point = Point(x, y)
        if merged_polygon_intersection and merged_polygon_intersection.contains(point):
            frames_inside_intersection += 1

    if frames_inside_straight == 20 and frames_inside_intersection > 15:
        intention_data.append(idx)


  3%|▎         | 6399/205942 [05:08<2:40:32, 20.72it/s]


ParserError: Error tokenizing data. C error: Calling read(nbytes) on source failed. Try engine='python'.

In [2]:
file_name = '/data/fyy/new_prediction/argoverse/intersection_data/intention_train_av_idx.pkl'

# with open(file_name, 'wb') as file:
#     pickle.dump(intention_data, file)
    
with open(file_name, 'rb') as file:
    intention_data = pickle.load(file)

print("意图数据的数量", len(intention_data))

意图数据的数量 8540


提取对应idx轨迹, 生成测试集和训练集

In [3]:
def turn_direction(traj):
    # 计算第20帧的角度
    initial_vector = traj[0] - traj[3]
    initial_angle = np.arctan2(initial_vector[1], initial_vector[0])

    # 计算最后一帧的角度
    final_vector = traj[-5] - traj[-1]
    final_angle = np.arctan2(final_vector[1], final_vector[0])

    # 计算角度差异
    angle_threshold = np.pi / 12

    angle_diff = final_angle - initial_angle

    # 判断是否转向
    if abs(angle_diff) > angle_threshold:
        if angle_diff > 0:
            return [0, 1, 0]
        else:
            return [0, 0, 1]
    else:
        return [1, 0, 0]

In [6]:
traj_label = []
left, right, through = 0, 0, 0  # 左转 1056, 右转 922, 直行 6562(1000), 总数 8540
for idx in intention_data:
    av_traj, av_lane, _, city_name = get_traj_and_lane(idx)
    label = turn_direction(av_traj[20:])
    if label == [1, 0, 0] and through >= 1000:
        continue
    if label == [0, 1, 0]:
        left += 1
    elif label == [0, 0, 1]:
        right += 1
    elif label == [1, 0, 0]:
        through += 1
    traj_label.append([av_traj[:20], label])


In [7]:
print("左转", left)
print("右转", right)
print("直行", through)
print("总数", len(traj_label))

左转 1056
右转 922
直行 1000
总数 2978


In [8]:
file_name = '/data/fyy/new_prediction/argoverse/intersection_data/intention_train_av_traj_and_label.pkl'

with open(file_name, 'wb') as file:
    pickle.dump(traj_label, file)