In [1]:
from nuscenes.nuscenes import NuScenes
from nuscenes.map_expansion.map_api import NuScenesMap
from pyquaternion import Quaternion
from ultralytics import YOLO
import json
from tqdm import tqdm
from utils_drivelm import get_option, action_map

In [2]:
nusc = NuScenes(version='v1.0-trainval', dataroot='/data2/common/xuanyang/nuscenes', verbose=True)

Loading NuScenes tables for version v1.0-trainval...


23 category,
8 attribute,
4 visibility,
64386 instance,
12 sensor,
10200 calibrated_sensor,
2631083 ego_pose,
68 log,
850 scene,
34149 sample,
2631083 sample_data,
1166187 sample_annotation,
4 map,
Done loading in 39.013 seconds.
Reverse indexing ...
Done reverse indexing in 11.4 seconds.


In [3]:
map_singapore_onenorth = NuScenesMap(dataroot='/data2/common/xuanyang/nuscenes', map_name='singapore-onenorth')
map_singapore_hollandvillage = NuScenesMap(dataroot='/data2/common/xuanyang/nuscenes', map_name='singapore-hollandvillage')
map_boston_seaport = NuScenesMap(dataroot='/data2/common/xuanyang/nuscenes', map_name='boston-seaport')
map_singapore_queenstown = NuScenesMap(dataroot='/data2/common/xuanyang/nuscenes', map_name='singapore-queenstown')

In [4]:
def get_map_instance_from_frame(scene_token):
    scene_info = nusc.get('scene', scene_token)
    log_info = nusc.get('log', scene_info['log_token'])
    map_name = log_info['location']
    if map_name == 'singapore-onenorth':
        map_instance = map_singapore_onenorth
    elif map_name == 'singapore-hollandvillage':
        map_instance = map_singapore_hollandvillage
    elif map_name == 'boston-seaport':
        map_instance = map_boston_seaport
    elif map_name == 'singapore-queenstown':
        map_instance = map_singapore_queenstown
    else:
        raise ValueError('Unsupported map name')
    return map_instance


def get_ego_pose(frame_token):
    sample_info = nusc.get('sample', frame_token)
    cam_front_data = nusc.get('sample_data', sample_info['data']['CAM_FRONT'])
    ego_pose_info = nusc.get('ego_pose', cam_front_data['ego_pose_token'])
    return ego_pose_info['translation'], ego_pose_info['rotation']


def search_lane(map_instance, lane_token):
    lanes = map_instance.lane
    for lane_info in lanes:
        if lane_info['token'] == lane_token:
            return lane_info
    # print (f"Error: {lane_token} not found")
    return None


def get_nearby_lane_types(map_instance, scene_token, frame_token):
    ego_translation, ego_rotation = get_ego_pose(frame_token)
    ego_x, ego_y, ego_z = ego_translation
    road_on_point = map_instance.layers_on_point(ego_x, ego_y)
    closest_lane = map_instance.get_closest_lane(ego_x, ego_y, radius=3)
    lane_info = search_lane(map_instance, closest_lane)
    return ego_x, ego_y, road_on_point, lane_info


def get_node_info(map_instance, node_token):
    all_node = map_instance.node
    for node in all_node:
        if node['token'] == node_token:
            return node


def distance_cal(x1,y1,x2,y2):
    return ((x1-x2)**2 + (y1-y2)**2)**0.5


def get_divider_type(ego_x, ego_y, map_instance, divider_segment_info):
    min_distance = 100000000
    min_node = None
    for node in divider_segment_info:
        node_info = get_node_info(map_instance, node['node_token'])
        distance = distance_cal(ego_x, ego_y, node_info['x'], node_info['y'])
        if distance < min_distance:
            min_node = node
            min_distance = distance
    return min_node


def condition_predicate_extractor(conv_path, question_path, detect_info_save_path):
    yolo = YOLO('best.pt')
    with open(conv_path, 'r') as f:
        conv = json.load(f)
    with open(question_path, 'r') as f:
        questions = json.load(f)
    all_detect_info = []
    for conversation in tqdm(conv):
        id = conversation['id']
        scene_id = id.split('_')[0]
        frame_id = id.split('_')[1]
        # yolo_detection
        images = conversation['image'][:3] # cam_front cam_front_right cam_front_left
        yolo_results = set()
        yolo_result_list = []
        detected_classes = []
        for img_path in images:
            detections = yolo(img_path, verbose=False)
            for detection in detections:
                for box in detection.boxes:
                    class_name = yolo.names[int(box.cls)]
                    detected_classes.append(class_name)
        if detected_classes:
            yolo_results.update(detected_classes)
        
        # condition_predicate_extractor
        map_instance = get_map_instance_from_frame(scene_id)
        ego_x, ego_y, road_on_point, lane_info = get_nearby_lane_types(map_instance, scene_id, frame_id)
        if road_on_point['ped_crossing'] != '':
            yolo_results.add('pedestrianCrossing')
        if road_on_point['stop_line'] != '':
            yolo_results.add('stopLine')
        if lane_info:
            if lane_info['left_lane_divider_segments']:
                left_min_node = get_divider_type(ego_x, ego_y, map_instance, lane_info['left_lane_divider_segments'])
                yolo_results.add(left_min_node['segment_type']+'_LEFT')
            if lane_info['right_lane_divider_segments']:
                right_min_node = get_divider_type(ego_x, ego_y, map_instance, lane_info['right_lane_divider_segments'])
                yolo_results.add(right_min_node['segment_type']+ '_RIGHT')
        
        # action_predicate
        question_part = questions[scene_id]["key_frames"][frame_id]["QA"]["behavior"][0]["Q"]
        answer = questions[scene_id]["key_frames"][frame_id]["QA"]["behavior"][0]["A"]
        option = get_option(question_part, answer)
        action_list = action_map(option)
        
        # save
        yolo_result_list = list(yolo_results)
        single_detect_info = {
            'image_id': id,
            'classes': yolo_result_list,
            'action': action_list,
        }
        all_detect_info.append(single_detect_info)
        with open(detect_info_save_path, 'w') as f:
            json.dump(all_detect_info, f)  
            

        
    

In [5]:
# conv_path = 'DriveLM_process/conversation_drivelm_train.json'
# question_path = 'DriveLM_process/train_eval.json'
# save_path = 'process_data_drivelm/train/train_detected_classes.json'

conv_path = 'DriveLM_process/conversation_drivelm_val.json'
question_path = 'DriveLM_process/v1_1_val_nus_q_only.json'
save_path = 'process_data_drivelm/test/test_detected_classes.json'

condition_predicate_extractor(conv_path, question_path, save_path)

  ckpt = torch.load(file, map_location="cpu")
  0%|          | 0/799 [00:00<?, ?it/s]

100%|██████████| 799/799 [03:19<00:00,  4.01it/s]
