In this tutorial, we'll demonstrate the Argoverse 2.0 map API, and visualize some of the map data.

In [85]:
from argparse import Namespace
from pathlib import Path

In [86]:
# path to where the logs live
dataroot = f"/data1/av2/train"

# unique log identifier
# log_id = "adcf7d18-0510-35b0-a2fa-b4cea13a6d76"
log_id = "00a6ffc1-6ce9-3bc3-a060-6006e9893a1a"

# Find a log that is in PAO : Palo Alto


### Function to plot (lat, lng) on Google Satellite Map
* gmplot

In [87]:
import gmplot

import os
import time
from selenium import webdriver


def plot_coordinates_on_map(latitude_list, longitude_list, output_file='map.html', color='red', save_png=True):
    # Set the center of the map based on the average of the provided coordinates
    center_lat = sum(latitude_list) / len(latitude_list)
    center_lng = sum(longitude_list) / len(longitude_list)

    # Create a Google Map Plotter object
    gmap = gmplot.GoogleMapPlotter(center_lat, center_lng, 13, map_type='satellite')  # Zoom level: 1=World, 20=Building

    # Plot the coordinates on the map
    gmap.scatter(latitude_list, longitude_list, color, size=1, marker=False)

    # Draw the map to an HTML file
    gmap.draw(output_file)

    print(f"Map saved to {output_file}")


def plot_route_and_correct(rough_lats, rough_lngs, correct_lats, correct_lngs, output_file='map.html', color='red', save_png=True):
    # Set the center of the map based on the average of the provided coordinates
    center_lat = sum(correct_lats) / len(correct_lats)
    center_lng = sum(correct_lngs) / len(correct_lngs)

    # Create a Google Map Plotter object
    gmap = gmplot.GoogleMapPlotter(center_lat, center_lng, 13, map_type='satellite')  # Zoom level: 1=World, 20=Building

    # Plot the rough routes
    gmap.scatter(rough_lats, rough_lngs, "green", size=1, marker=False)

    # Plot the coordinates on the map
    gmap.scatter(correct_lats, correct_lngs, color, size=1, marker=False)
    

    # Draw the map to an HTML file
    gmap.draw(output_file)

    print(f"Map saved to {output_file}")


    # if save_png:
    #     png_file = "test.png"
    #     # Use Selenium to take a screenshot of the HTML map
    #     options = webdriver.ChromeOptions()
    #     options.add_argument('--headless')  # Run Chrome in headless mode (no GUI)
    #     driver = webdriver.Chrome(options=options)
    #     driver.get(f'file://{os.path.abspath(output_file)}')
    #     time.sleep(2)  # Wait for the page to load (adjust as needed)
    #     driver.save_screenshot(png_file)
    #     driver.quit()


import matplotlib.pyplot as plt
def plot_and_save_plt(lats, lngs, color="red", png_file="map.png", title="Map in city coordinates", xlabel='Longitude', ylabel='Latitude'):
    # Create a scatter plot
    # plt.figure(figsize=(10, 8))  # Set the size of the plot
    plt.figure
    plt.scatter(lngs, lats, color=color, s=5)  # Plot the coordinates
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.grid(True)  # Add grid lines
    # Save the plot as a PNG image
    plt.savefig(png_file)
    print(f"Map image saved to {png_file}")    

# # Example usage:
# if __name__ == "__main__":
#     # Example coordinates (New York City)
#     latitude_list = [40.7128, 40.7210, 40.7484]
#     longitude_list = [-74.0060, -73.9886, -73.9857]

#     plot_coordinates_on_map(latitude_list, longitude_list)

In [88]:

"""Unit tests on utilities for converting AV2 city coordinates to UTM or WGS84 coordinate systems."""

import numpy as np

import av2.geometry.utm as geo_utils
from av2.geometry.utm import CityName
from av2.utils.typing import NDArrayFloat
from pathlib import Path
from typing import List

import numpy as np

from av2.map.map_api import ArgoverseStaticMap, LaneSegment



### Converting City coordinates (x, y) to WGS84 (lat, lng)
1. Get LaneSegments from PaloAlto
2. Get center lines coordinates from 1 LaneSegment
3. Convert (x,y) to (lat, lng)
4. Plot in Google Map

In [89]:
def create_argoverse_static_map(dataroot, log_id):
    # args = Namespace(**{"dataroot": Path(dataroot), "log_id": Path(log_id)})
    log_map_dirpath = Path(dataroot) / log_id / "map"
    avm = ArgoverseStaticMap.from_map_dir(log_map_dirpath, build_raster=False)
    return avm

def get_center_lines_from_lane_segments(avm):
    """Get center-lines from given LaneSegments
    
    Return:
        centerlines: NDArray of shape (N,2)
    """
    lane_segment_ids = avm.get_scenario_lane_segment_ids()
    # print(f"Number of lane segments in this log: {len(lane_segment_ids)}")

    centerlines = []
    for l_id in lane_segment_ids:
        centerline = avm.get_lane_segment_centerline(l_id) # shape: (10, 3)
        # Save html for this centerline
        # latlngs = test_convert_city_coords_to_wgs84_pao(centerline[:,:2])
        # lats, lngs = two_d_list_to_two_lists(latlngs)
        # plot_coordinates_on_map(lats, lngs, "centerline-"+str(l_id) +".html")
        centerlines.append(centerline)
        
    centerlines = np.vstack(centerlines)[:,:2]
    return centerlines

def two_d_list_to_two_lists(coordinates):
    """
    Convert [[x,y]] to [x1, x2, ...], [y1, y2, ...]
    """
    xs, ys = zip(*coordinates)
    return xs, ys    

def convert_city_coords_to_wgs84(points_city, city_name=CityName.PAO) -> None:
    """Convert city coordinates from 'city name' coordinates."""
    wgs84_coords = geo_utils.convert_city_coords_to_wgs84(
        points_city, city_name
    )
    lats, lngs = two_d_list_to_two_lists(wgs84_coords)
    return lats, lngs


def write_txt(lats, lngs, filename='latlngs.txt'):
    file = open(filename,'w')
    for lat, lng in zip(lats, lngs):
        file.write(str(lat) + " " + str(lng) +"\n")
    file.close()    

dataroot = f"/data1/av2/train"

# -------- Palo Alto log ------------- 6d3bfbc9-45dc-316e-a94c-a441371d0571 #
log_id = "6d3bfbc9-45dc-316e-a94c-a441371d0571" # log within Palo Alto (PAO)
avm = create_argoverse_static_map(dataroot, log_id)
centerlines = get_center_lines_from_lane_segments(avm)
lats, lngs = convert_city_coords_to_wgs84(centerlines, CityName.PAO)
write_txt(lats, lngs, filename='pao.txt')
plot_coordinates_on_map(lats, lngs, "test-pao.html")


# -------- Miami log ------------- 0b1b993a-68b3-3232-9afa-fc9942b5b79b#
# log_id = "0b1b993a-68b3-3232-9afa-fc9942b5b79b" 
log_id = "c049334b-5568-3ca0-9b28-0c09d00b7bb3"
avm = create_argoverse_static_map(dataroot, log_id)
centerlines = get_center_lines_from_lane_segments(avm)
lats, lngs = convert_city_coords_to_wgs84(centerlines, CityName.MIA)
write_txt(lats, lngs, filename='mia.txt')
plot_coordinates_on_map(lats, lngs, "test-mia.html")



"""
Writ a class: AV2HDMap
members:
- centerlines - city coordinates (x, y)
              - wgs84 coords: (lat, lng)

- lane_boundary - city coordinates (x, y)
              - wgs84 coords: (lat, lng)

- centerlines - city coordinates (x, y)
              - wgs84 coords: (lat, lng)

                            
"""

def extract_centerlines_from_log(log_id, cityname, frame='wgs84',render_html=False):
    """Extract centerlines from log
    Arguments:
        - log_id of this log data
        - cityname: CityName
    Returns
        - lats: list of latitudes of centerline waypoints
        - lngs: list of longitudes of centerline waypoints
    Additional Output
        - HTML file: (optional) visualize centerlines with Google Satellite Map
    
    """
    avm = create_argoverse_static_map(dataroot, log_id)
    centerlines = get_center_lines_from_lane_segments(avm)

    if cityname == "ATX":
        city_enum = CityName.ATX
    elif cityname =="PAO":
        city_enum = CityName.PAO
    elif cityname =="MIA":
        city_enum = CityName.MIA
    elif cityname =="DTW":
        city_enum = CityName.DTW    
    elif cityname =="PIT":
        city_enum = CityName.PIT
    elif cityname =="WDC":
        city_enum = CityName.WDC    
    else:
        print(f"ERROR! {cityname} NOT defined in av2.geometry.utm")  
        raise NotImplementedError

    if frame=='city':
        xs = centerlines[:,0]
        ys = centerlines[:,1]
        return xs, ys 
    elif frame=='wgs84':
        lats, lngs = convert_city_coords_to_wgs84(centerlines, city_enum)
        if render_html:
            plot_coordinates_on_map(lats, lngs, f"{log_id}.html")
        return lats, lngs
    # write_txt(lats, lngs, filename='mia.txt')


# Extract 'lane_boundaries':
def get_lane_bounds_from_lane_segments(avm):
    """Get center-lines from given LaneSegments
    
    Return:
        centerlines: NDArray of shape (N,2)
    """
    lane_segment_ids = avm.get_scenario_lane_segment_ids()
    # print(f"Number of lane segments in this log: {len(lane_segment_ids)}")

    polygons = []
    for l_id in lane_segment_ids:
        polygon = avm.get_lane_segment_polygon(l_id) # shape: (10, 3)
        # Save html for this centerline
        # latlngs = test_convert_city_coords_to_wgs84_pao(centerline[:,:2])
        # lats, lngs = two_d_list_to_two_lists(latlngs)
        # plot_coordinates_on_map(lats, lngs, "centerline-"+str(l_id) +".html")
        polygons.append(polygon)
        
    polygons = np.vstack(polygons)[:,:2]
    return polygons

def extract_lane_polygons_from_log(log_id, cityname, frame='wgs84',render_html=False):
    """Extract centerlines from log
    Arguments:
        - log_id of this log data
        - cityname: CityName
    Returns
        - lats: list of latitudes of centerline waypoints
        - lngs: list of longitudes of centerline waypoints
    Additional Output
        - HTML file: (optional) visualize centerlines with Google Satellite Map
    
    """
    avm = create_argoverse_static_map(dataroot, log_id)
    centerlines = get_lane_bounds_from_lane_segments(avm)

    if cityname == "ATX":
        city_enum = CityName.ATX
    elif cityname =="PAO":
        city_enum = CityName.PAO
    elif cityname =="MIA":
        city_enum = CityName.MIA
    elif cityname =="DTW":
        city_enum = CityName.DTW    
    elif cityname =="PIT":
        city_enum = CityName.PIT
    elif cityname =="WDC":
        city_enum = CityName.WDC    
    else:
        print(f"ERROR! {cityname} NOT defined in av2.geometry.utm")  
        raise NotImplementedError

    if frame=='city':
        xs = centerlines[:,0]
        ys = centerlines[:,1]
        return xs, ys 
    elif frame=='wgs84':
        lats, lngs = convert_city_coords_to_wgs84(centerlines, city_enum)
        if render_html:
            plot_coordinates_on_map(lats, lngs, f"{log_id}.html")
        return lats, lngs

Map saved to test-pao.html
Map saved to test-mia.html


### Read `/data1/av2-datasets/train` and extract `log-ids`
1. Build hash table: **key**: `log-id` -> **Value**: `CityName`
2. Save 6 lists of `log-ids`


In [90]:
import json
import os
from pathlib import Path

def get_log_ids(directory):
    log_ids = []
    # Iterate over the contents of the directory
    for entry in os.listdir(directory):
        # Join the directory path with the entry name to get the full path
        full_path = os.path.join(directory, entry)
        # Check if the entry is a directory
        if os.path.isdir(full_path):
            log_ids.append(entry)
    return log_ids

# /data1/av2-datasets/train/00a6ffc1-6ce9-3bc3-a060-6006e9893a1a/map/0a8a4cfa-4902-3a76-8301-08698d6290a2_ground_height_surface____PIT.npy
def extract_cityname_from_log(dataroot, log_id):
    log_map_dirpath = Path(dataroot) / log_id / "map" 
    npy_files = list(log_map_dirpath.glob("*.npy"))
    if len(npy_files) >1:
        print(f"ERROR! log_id {log_id} contains multiple npy file. Can't extract CityName")
    for npy_file in npy_files:
        cityname = str(npy_file)[-7:-4]
    return cityname


dataroot = "/data1/av2/train"  # Replace with the path to your directory
json_file = "log-cityname.json"

"""
    ATX = "ATX"  # Austin, Texas
    DTW = "DTW"  # Detroit, Michigan
    MIA = "MIA"  # Miami, Florida
    PAO = "PAO"  # Palo Alto, California
    PIT = "PIT"  # Pittsburgh, PA
    WDC = "WDC"  # Washington, DC

"""

log_to_cityname = {}
city_logids = {} # key: str -> val: [str] list of log ids

log_ids = get_log_ids(dataroot)
for log_id in log_ids:
    cityname = extract_cityname_from_log(dataroot, log_id)
    # print(f"    Map {log_id} -> CityName {cityname}")
    log_to_cityname[log_id] = cityname
    if cityname in city_logids:
        city_logids[cityname].append(log_id)
    else:
        city_logids[cityname] = [log_id]
    
    # if cityname == "MIA":
    #     mia_log_ids.append(log_id)
print(f"Number of cities: {len(city_logids)}")
for key, value in city_logids.items():
    print(f"city: {key} number of log ids {len(value)}")

Number of cities: 6
city: MIA number of log ids 254
city: PIT number of log ids 247
city: ATX number of log ids 25
city: DTW number of log ids 75
city: WDC number of log ids 87
city: PAO number of log ids 12


In [91]:
cityname = "MIA"

# Take the first 10 logs
city_lats = []
city_lngs = []
centerline_xs = []
centerline_ys = []
city_logids[cityname] = sorted(city_logids[cityname])
for idx, log_id in enumerate(city_logids[cityname]):
    # if idx <=2:
    #     continue
    if idx >= 5:
        break
    # print(f"log_id {log_id}")
    # print(f"Get centerline for log id: {log_id}")
    lats, lngs = extract_centerlines_from_log(log_id, cityname, frame='wgs84')
    city_lats.extend(lats)
    city_lngs.extend(lngs)
    xs, ys = extract_centerlines_from_log(log_id, cityname, frame='city')
    centerline_xs.extend(xs)
    centerline_ys.extend(ys)

plot_and_save_plt(centerline_ys, centerline_xs, "r", f"{cityname}-{len(city_logids[cityname])}.png", f"{cityname} (City coordinates)", "x", "y")

# Generate a full HTML map with all cityname lats/lngs
plot_coordinates_on_map(city_lats, city_lngs, f"{cityname}-full.html")

Map image saved to MIA-254.png
Map saved to MIA-full.html


In [92]:

cityname = "MIA"

# Take the first 10 logs
city_lats = []
city_lngs = []
polygon_xs = []
polygon_ys = []
city_logids[cityname] = sorted(city_logids[cityname])
for idx, log_id in enumerate(city_logids[cityname]):
    # if idx <=2:
        # continue
    if idx >= 5:
        break
    # print(f"log_id {log_id}")
    # print(f"Get centerline for log id: {log_id}")
    lats, lngs = extract_lane_polygons_from_log(log_id, cityname, frame='wgs84')
    city_lats.extend(lats)
    city_lngs.extend(lngs)
    xs, ys = extract_lane_polygons_from_log(log_id, cityname, frame='city')
    polygon_xs.extend(xs)
    polygon_ys.extend(ys)

plot_and_save_plt(polygon_ys, polygon_xs, "black", f"{cityname}-{len(city_logids[cityname])}-polygon.png", f"{cityname} lane polygons (City coordinates)", "x", "y")

# Generate a full HTML map with all cityname lats/lngs
plot_coordinates_on_map(city_lats, city_lngs, f"{cityname}-polygon.html")


Map image saved to MIA-254-polygon.png
Map saved to MIA-polygon.html


In [93]:
# Plot centerlines and lane polygons
# %matplotlib qt
# import matplotlib.pyplot as plt

png_file=f'{cityname}-log-overlap.png'
plt.figure(f"{cityname}")
plt.scatter(polygon_xs, polygon_ys, 3, "black")
plt.scatter(centerline_xs, centerline_ys, 3, "red")
# plt.scatter(rough_xs, rough_ys, 3, "green")
plt.xlabel('x(m)')
plt.ylabel('y(m)')
plt.title(f'{cityname} log (City coordinate)')
plt.grid(True)  # Add grid lines
# Save the plot as a PNG image
# plt.show()
plt.savefig(png_file)
print(f"Map image saved to {png_file}")   

Map image saved to MIA-log-overlap.png


Now, given a trajectory, how do we know which **log HD map** should we take?

# Get a rough route using Google Route API
* Input: a start (lat, lng) and a goal (lat, lng)
*  Output: a list of waypoints (lat, lng) from start to goal.
###  Blocks required:
1. Convert WGS84 to **City** coordinate
2. Given start & goal coordinates, generate a rough route

In [94]:
from typing import Dict, Final, List, Optional, Tuple, Union
from av2.utils.typing import NDArrayBool, NDArrayByte, NDArrayFloat, NDArrayInt
from av2.geometry.utm import convert_gps_to_utm, CITY_ORIGIN_LATLONG_DICT

# Convert (lat, lng) to city coordinates
def convert_wgs84_points_to_city_coords(
    points_wgs84: Union[NDArrayFloat, NDArrayInt], city_name: CityName
) -> NDArrayFloat:
    """Convert WGS84 coordinates to city coordinates.

    Args:
        points_wgs84: Array of shape (N,2), representing points in the WGS84 coordinate system, as (latitude, longitude).
        city_name: Name of city, where query points are located.

    Returns:
        2d points in city coordinates, as (N,2) array.
    """
    latitude, longitude = CITY_ORIGIN_LATLONG_DICT[city_name]
    # Get (easting, northing) of origin.
    origin_utm = convert_gps_to_utm(
        latitude=latitude, longitude=longitude, city_name=city_name
    )
     
    points_city = np.zeros_like(points_wgs84)
    for i, (lat, long) in enumerate(points_wgs84):
      point_utm = convert_gps_to_utm(
          latitude=lat, longitude=long, city_name=city_name
      )
      points_city[i] = np.asarray(point_utm) - np.asarray(origin_utm)

    return points_city


In [95]:
from av2.utils.get_rough_route import get_rough_route
start = ("25.80402823314048", "-80.19421488790299")
goal = ("25.80282667040719", "-80.19514344778023")
rough_lats, rough_lngs = get_rough_route(start, goal, api_key='AIzaSyBmYtO7rXCbqG02eEzLWb2FgexIve6FmvU')
plot_coordinates_on_map(rough_lats, rough_lngs, "rough-route.html", color='green')
plot_and_save_plt(rough_lats, rough_lngs, color="g", png_file="rough-route.png", xlabel='x(m)', ylabel='y(m)')

 Sending POST request...
Getting POST response
write response.json to ./route_response.json
    Instruction: Head west on NE 29th St toward N Miami Ave
    Instruction: Turn left onto N Miami Ave
Destination will be on the right
Distance: 0.1 mi
Number of spline points: 220
Average distance per point: 1.018181818181818 meter/point
Map saved to rough-route.html
Map image saved to rough-route.png


In [96]:
# Get rough route 'steps' 
from av2.utils.get_rough_route import Step, Node, Instr
from av2.utils.get_rough_route import get_rough_route_steps
def get_latlngs_from_steps(steps):
    lats = []
    lngs = []
    for step in steps:
        lats.extend([step.start.lat, step.goal.lat])
        lngs.extend([step.start.lng, step.goal.lng])
    return lats, lngs

start = ("25.80402823314048", "-80.19421488790299")
goal = ("25.80282667040719", "-80.19514344778023")
steps = get_rough_route_steps(start, goal, api_key='AIzaSyBmYtO7rXCbqG02eEzLWb2FgexIve6FmvU')
rough_lats, rough_lngs = get_latlngs_from_steps(steps)
print(f"len rought_lats: {len(rough_lats)}")
print(f"len rought_lngs: {len(rough_lngs)}")
plot_coordinates_on_map(rough_lats, rough_lngs, "rough-route.html", color='green')
plot_and_save_plt(rough_lats, rough_lngs, color="g", png_file="rough-route.png", xlabel='x(m)', ylabel='y(m)')

 Sending POST request...
Getting POST response
write response.json to ./route_response.json
    Instruction: Head west on NE 29th St toward N Miami Ave
    Instruction: Turn left onto N Miami Ave
Destination will be on the right
Number of steps: 2
step: instruction Instr.GO_STRAIGHT
    start node: 25.8040013, -80.1942138
    goal node: 25.8039789, -80.19507320000001
step: instruction Instr.TURN_LEFT
    start node: 25.8039789, -80.19507320000001
    goal node: 25.8028281, -80.1950884
route distance: 0.1 mi
len rought_lats: 4
len rought_lngs: 4
Map saved to rough-route.html
Map image saved to rough-route.png


In [97]:
# A list of step: steps
# For each step -> convert 'start' & end node to city xy
# Find nearby lane segment to start & end



%matplotlib qt
import matplotlib.pyplot as plt

# Put these lane segment as "A* goal nodes"
# If 3 lane segments -> plan 2 path
# 2 steps -> 2 path (3 lane segment)
def calc_ls_position(avmap, ls_id):
    """Calculate mean position for a lane segment    """
    ls_centerline = avmap.get_lane_segment_centerline(ls_id)
    xs = ls_centerline[:,0]
    ys = ls_centerline[:,1]
    ls_pos = (np.mean(xs), np.mean(ys))
    # print(f"ls {ls_id} pos {ls_pos}")
    return ls_pos

class Node:
    def __init__(self, lsid, position, parent_node=None):
        self.lsid = lsid           # lane segment id
        self.position = position   # mean position (x, y)
        self.parent = parent_node  # pointer to parent_node
        self.g = 0 
        self.h = 0
        self.f = 0
    
    def __eq__(self, other):
        return self.position == other.position
    
    def __lt__(self, other):
        return self.f < other.f
        
    def __hash__(self):
        return hash(self.position)        


# ((lat, lng), (lat, lng))
log_id="07e4fccb-eb2d-31e5-bbcb-6550d0860f64"
avm = create_argoverse_static_map(dataroot, log_id)

search_radius = 0.5
nodes = []

for i, step in enumerate(steps):

    lane_polygon_xs = []
    lane_polygon_ys = []
    centerline_xs = []
    centerline_ys = []


    start_latlng = np.array((step.start.lat, step.start.lng))
    goal_latlng = np.array((step.goal.lat, step.goal.lng))
    query_latlng = np.stack((start_latlng, goal_latlng), axis=0)
    print(f"query_latlng.shape {query_latlng.shape}")
    points_xy = convert_wgs84_points_to_city_coords(query_latlng, CityName.MIA)
    print(f"City coordinates: points_xy.shape {points_xy.shape}")
    start_x, goal_x = points_xy[:, 0]
    start_y, goal_y = points_xy[:, 1]   
    print(f"    start x,y {start_x}, {start_y}")
    print(f"    goal x,y {goal_x}, {goal_y}") 
    # Find nearby lane segment:
    query_st = np.array((start_x, start_y))
    lss = avm.get_nearby_lane_segments(query_st, search_radius_m=search_radius)
    print(f"    # ls near start: {len(lss)}")
    # TODO: hack hear. Deal with this later
    # If more than 1 nearby lane segments: pick the first one (lss[0])
    if len(lss)>0:
        ls = lss[0]
        ls_pos = calc_ls_position(avm, ls.id)
        nodes.append(Node(ls.id, ls_pos))
        print(f"    Add node: {ls.id} pos: {ls_pos}")
    else:
        print(f"    Error! No lane segments near {query_st} within {search_radius} meter")
    # for ls in lss:
    #     lane_polygon_xs.extend(ls.polygon_boundary[:,0])
    #     lane_polygon_ys.extend(ls.polygon_boundary[:,1])
    #     centerline = avm.get_lane_segment_centerline(ls.id)
    #     centerline_xs.extend(centerline[:,0])
    #     centerline_ys.extend(centerline[:,1])         

    """Only add the goal node for the last step"""
    if i == len(steps)-1:
        query_gl = np.array((goal_x, goal_y))
        lss = avm.get_nearby_lane_segments(query_gl, search_radius_m=search_radius)    
        print(f"    # ls near goal: {len(lss)}")

        if len(lss)>0:
            ls = lss[0]
            ls_pos = calc_ls_position(avm, ls.id)
            nodes.append(Node(ls.id, ls_pos))
            print(f"    Add node: {ls.id} pos: {ls_pos}")
        else:
            print(f"    Error! No lane segments near {query_gl} within {search_radius} meter")

        # for ls in lss:
        #     lane_polygon_xs.extend(ls.polygon_boundary[:,0])
        #     lane_polygon_ys.extend(ls.polygon_boundary[:,1])   
        #     centerline = avm.get_lane_segment_centerline(ls.id)
        #     centerline_xs.extend(centerline[:,0])
        #     centerline_ys.extend(centerline[:,1]) 
        #     print(f"    ls id: {ls.id}")
            # print(f"    predecessors: {ls.predecessors}")
            # print(f"    successors: {ls.successors}")

    # Plot:
    png_file='mia-log-overlap.png'
    plt.figure('Mia city')
    plt.scatter(lane_polygon_xs, lane_polygon_ys, 3, "black")
    plt.scatter(centerline_xs, centerline_ys, 3, "red")
    plt.scatter(start_x, start_y, 3, "green")
    plt.scatter(goal_x, goal_y, 3, "green")
    plt.xlabel('x(m)')
    plt.ylabel('y(m)')
    plt.title('Mia log (City coordinate)')
    plt.grid(True)  # Add grid lines    


query_latlng.shape (2, 2)
City coordinates: points_xy.shape (2, 2)
    start x,y 216.20579440135043, 3265.7748384848237
    goal x,y 130.06667729606852, 3262.766919947695
    # ls near start: 2
    Add node: 84886930 pos: (227.7, 3267.2925000000005)
query_latlng.shape (2, 2)
City coordinates: points_xy.shape (2, 2)
    start x,y 130.06667729606852, 3262.766919947695
    goal x,y 129.32228921551723, 3135.3093665046617
    # ls near start: 1
    Add node: 84896820 pos: (130.26106615780097, 3257.7883530177683)
    # ls near goal: 2
    Add node: 84887523 pos: (127.80250000000001, 3137.56)


### A* find paths


In [98]:
import heapq

"""
A* function
Input: 
    - nodes: a list of nodes to be visit
             for every pair (i, i+1):
                nodes[i] is start
                nodes[i+1] is goal node
Output:
    - lane_segment_ids: a list of [lane_segment_id]
            later we extract and connect centerline of each lane segment

"""

def astar(avmap, start_node, goal_node):
    """
    avmap: ArgoverseStaticMap
    start_node: start node
    goal_node : goal node

    """
    open_list = []
    closed_set = set()
    
    
    heapq.heappush(open_list, start_node)
    
    while open_list:
        current_node = heapq.heappop(open_list)
        
        if current_node == goal_node:
            path = [] # a list of node.lsid
            while current_node:
                path.append(current_node.lsid)
                current_node = current_node.parent
            return path[::-1]
        
        closed_set.add(current_node)
        
        for neighbor in get_neighbors(avmap, current_node):
            if neighbor in closed_set:
                continue
            
            neighbor.g = current_node.g + 1 # 1 -> distance from cur to neighbor
            neighbor.h = heuristic(neighbor, goal_node)
            neighbor.f = neighbor.g + neighbor.h
            
            if neighbor not in open_list:
                heapq.heappush(open_list, neighbor)
            else:
                for open_node in open_list:
                    if open_node == neighbor and open_node.g > neighbor.g:
                        open_node.g = neighbor.g
                        open_node.parent = neighbor.parent

# TODO: Instead of 4 direction, find neighbors from ls.successors/right_neighbor_id/left_neighbor
def get_neighbors(avmap, node):
    neighbors = []

    suc_ids = avmap.get_lane_segment_successor_ids(node.lsid)
    left_id = avmap.get_lane_segment_left_neighbor_id(node.lsid)
    right_id = avmap.get_lane_segment_right_neighbor_id(node.lsid) 

    # ls.successors
    for successor_id in suc_ids:
        pos = calc_ls_position(avmap, successor_id)
        neighbors.append(Node(successor_id, pos, node))

    # left neighbor
    if left_id:
        pos = calc_ls_position(avmap, left_id)
        neighbors.append(Node(left_id, pos, node))

    # right neighbor
    if right_id:
        pos = calc_ls_position(avmap, right_id)
        neighbors.append(Node(right_id, pos, node))                
    
    return neighbors

# TODO: Distance metric might need to be adjusted
def heuristic(node, end):
    return abs(node.position[0] - end.position[0]) + abs(node.position[1] - end.position[1])


def find_nodes_route(avmap, nodes):
    ls_ids = []
    for i, node in enumerate(nodes):
        print(f"Find path from node {i} -> {i+1}")
        if i < len(nodes)-1:
            start_node = nodes[i]
            goal_node = nodes[i+1]
            ls_ids.extend(astar(avmap, start_node, goal_node))
    print(f"lane segments ids: {ls_ids}")
    return ls_ids

def connect_ls_centerlines(avmap, ls_ids):
    """ Given a list of lane segment ids,
        Connect the lane segment centerlines
    
        Returns:
            - route_centerlines: a Numpy array of shape (N, 2): (x, y)
    """
    route_centerlines = []
    for ls_id in ls_ids:
        centerline = avmap.get_lane_segment_centerline(ls_id)
        route_centerlines.extend(centerline[:,:2])
    print(f"Number of waypoints in route: {len(route_centerlines)}")

    return route_centerlines

# ------------ Below is the execution ----------------- #
print(f"Number of nodes: {len(nodes)}")
for node in nodes:
    print(f"    node id: {node.lsid}")
ls_ids = find_nodes_route(avm, nodes)
route_centerlines = connect_ls_centerlines(avm, ls_ids)

# for waypoint in route_centerlines:
#     print(f"type waypoint: {type(waypoint)}")
#     print(f"shape {waypoint.shape}")
route_xs, route_ys = zip(*route_centerlines)


Number of nodes: 3
    node id: 84886930
    node id: 84896820
    node id: 84887523
Find path from node 0 -> 1
Find path from node 1 -> 2
Find path from node 2 -> 3
lane segments ids: [84886930, 84887124, 84887125, 84896656, 84900001, 84900078, 84896820, 84896820, 84896798, 84887862, 84947362, 84888497, 84888651, 84888648, 84887523]
Number of waypoints in route: 150


In [103]:
plot_and_save_plt(route_ys, route_xs, "green", "corrected-route-citycoord.png", xlabel='x(m)', ylabel='y(m)')

Map image saved to corrected-route-citycoord.png


### Convert corrected route back to WGS84 (lat,lng) and superimpose it with Google Map

In [100]:
route_lats = []
route_lngs = []
for waypoint in route_centerlines:
    # print(type(waypoint))
    # print(waypoint.shape)
    waypoint = waypoint.reshape(-1,2)
    lat, lng = convert_city_coords_to_wgs84(waypoint, CityName.MIA)
    route_lats.append(lat[0])
    route_lngs.append(lng[0])
# print(route_lats)
plot_coordinates_on_map(route_lats, route_lngs, "route-mia.html")

plot_route_and_correct(rough_lats, rough_lngs, route_lats, route_lngs, "route-ba-mia.html")

Map saved to route-mia.html
Map saved to route-ba-mia.html


In [101]:
# Convert from WGS84 to City coordinatr
print(np.asarray(rough_lats).shape)
points_wgs84 = np.stack((np.asarray(rough_lats), np.asarray(rough_lngs)), axis=1)
points_mia = convert_wgs84_points_to_city_coords(points_wgs84, CityName.MIA)
print(f"City coordinates: points_mia.shape {points_mia.shape}")
rough_xs = points_mia[:, 0]
rough_ys = points_mia[:, 1]
plot_and_save_plt(rough_ys, rough_xs, "green", "rough-route-citycoord.png", xlabel='x(m)', ylabel='y(m)')

(4,)
City coordinates: points_mia.shape (4, 2)
Map image saved to rough-route-citycoord.png


In [102]:
# Function:
# rough route: -> rough_xs, rough_yx
# lane boundaries: lane_bound_xs, lane_bound_ys
# centerlines: centerline_xs, centerline_ys
# Plot all in a single figure

%matplotlib qt
import matplotlib.pyplot as plt

png_file='mia-log-overlap.png'
plt.figure('Mia city')
plt.scatter(polygon_xs, polygon_ys, 3, "black")
plt.scatter(centerline_xs, centerline_ys, 3, "red")
plt.scatter(rough_xs, rough_ys, 3, "green")
plt.xlabel('x(m)')
plt.ylabel('y(m)')
plt.title('Mia log (City coordinate)')
plt.grid(True)  # Add grid lines
# Save the plot as a PNG image
# plt.show()
plt.savefig(png_file)
print(f"Map image saved to {png_file}")   

Map image saved to mia-log-overlap.png
