In [1]:
import numpy as np 
import matplotlib.pyplot as plt 
import seaborn as sns
sns.set_theme()

import torch 
import torch.nn as nn 
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.set_default_dtype(torch.float64)

%matplotlib inline

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

Using cpu device


In [4]:
from bernstein_torch import bernstein_coeff_order10_new

# Generating P matrix
t_fin = 6.0
num = 60
tot_time = torch.linspace(0, t_fin, num)
tot_time_copy = tot_time.reshape(num, 1)
P, Pdot, Pddot = bernstein_coeff_order10_new(10, tot_time_copy[0], tot_time_copy[-1], tot_time_copy)
print(P.shape, Pdot.shape, Pddot.shape)

torch.Size([60, 11]) torch.Size([60, 11]) torch.Size([60, 11])


In [6]:
def project_to_frenet_frame(traj, ref_line):
    distance_to_ref = torch.cdist(traj[:, :, :2], ref_line[:, :, :2])
    k = torch.argmin(distance_to_ref, dim=-1).view(-1, traj.shape[1], 1).expand(-1, -1, 3)
    ref_points = torch.gather(ref_line, 1, k)
    x_r, y_r, theta_r = ref_points[:, :, 0], ref_points[:, :, 1], ref_points[:, :, 2] 
    x, y = traj[:, :, 0], traj[:, :, 1]
    s = 0.1 * (k[:, :, 0] - 200)
    l = torch.sign((y-y_r)*torch.cos(theta_r)-(x-x_r)*torch.sin(theta_r)) * torch.sqrt(torch.square(x-x_r)+torch.square(y-y_r))
    sl = torch.stack([s, l], dim=-1)
    return sl

In [7]:
def extract_tracks(static_map_path, scenario_path):
    static_map = ArgoverseStaticMap.from_json(Path(static_map_path))
    tracks_df = pd.read_parquet(scenario_path)

    scenario_id = tracks_df["scenario_id"][0]
    focal_track_id = tracks_df["focal_track_id"][0]
    city_name = tracks_df["city"][0]
    map_id = tracks_df["map_id"][0] if "map_id" in tracks_df.columns else None
    slice_id = tracks_df["slice_id"][0] if "slice_id" in tracks_df.columns else None

    # Interpolate scenario timestamps based on the saved start and end timestamps
    timestamps_ns = np.linspace(
        tracks_df["start_timestamp"][0], tracks_df["end_timestamp"][0], num=tracks_df["num_timestamps"][0]
    )


    tracks: List[Track] = []
    for track_id, track_df in tracks_df.groupby("track_id"):

        observed_states: List[bool] = track_df.loc[:, "observed"].values.tolist()
        object_type: ObjectType = ObjectType(track_df["object_type"].iloc[0])
        object_category: TrackCategory = TrackCategory(track_df["object_category"].iloc[0])
        timesteps: List[int] = track_df.loc[:, "timestep"].values.tolist()
        positions: List[Tuple[float, float]] = list(
            zip(
                track_df.loc[:, "position_x"].values.tolist(),
                track_df.loc[:, "position_y"].values.tolist(),
            )
        )
        headings: List[float] = track_df.loc[:, "heading"].values.tolist()
        velocities: List[Tuple[float, float]] = list(
            zip(
                track_df.loc[:, "velocity_x"].values.tolist(),
                track_df.loc[:, "velocity_y"].values.tolist(),
            )
        )

        object_states: List[ObjectState] = []
        for idx in range(len(timesteps)):
            object_states.append(
                ObjectState(
                    observed=observed_states[idx],
                    timestep=timesteps[idx],
                    position=positions[idx],
                    heading=headings[idx],
                    velocity=velocities[idx],
                )
            )

        tracks.append(
            Track(track_id=track_id, object_states=object_states, object_type=object_type, category=object_category)
        )

    return tracks

In [None]:
def extract_lanes(static_map_path, scenario_path, debug=False, sample_len=100):
    static_map = ArgoverseStaticMap.from_json(Path(static_map_path))
    tracks_df = pd.read_parquet(scenario_path)

    tracks = extract_tracks(static_map_path, scenario_path)

    agx, agy = None, None

    for track in tracks:
        if track.category == TrackCategory.FOCAL_TRACK:
            states = track.object_states
            agx = [state.position[0] for state in states]
            agy = [state.position[1] for state in states]
            if debug: plt.scatter(agx[:50], agy[:50], color="blue", zorder=10)
            if debug: plt.scatter(agx[50:], agy[50:], color="orange", zorder=10)
            traj = np.array([agx, agy])
            traj = traj.reshape(traj.shape[1], traj.shape[0])
            pass
        elif track.category == TrackCategory.SCORED_TRACK:
            states = track.object_states
            otx = [state.position[0] for state in states]
            oty = [state.position[1] for state in states]
            # plt.scatter(otx[:50], oty[:50], color="red")
            # plt.scatter(otx[50:], oty[50:], color="grey")
            pass    

    # getting the backbone
    backbones = []
    for tm in range(0, 50):
        min_ind = get_nearest_centerline(agx, agy, static_map, tm=tm)
        cx = get_centerline(static_map, min_ind)
        lane_seg = static_map.vector_lane_segments[min_ind]
        if tm == 0:
            backbones.append([min_ind])
            continue
        last_lane = backbones[-1][-1]
        if min_ind == last_lane:
            continue
        fg = True
        for backbone in backbones:
            if min_ind in backbone:
                fg = False
            if min_ind in static_map.vector_lane_segments[backbone[-1]].successors:
                backbone.append(min_ind)
                fg = False
        if fg:
            backbones.append([min_ind])
        # plt.plot(cx[:, 0], cx[:, 1])

    MAX_DIS = 300

    hold_lanes = []
    import copy
    
    qry_pt = [agx[50], agy[50]]

    def BFS(queue, lane_id, visited, hold, static_map, qry_pt=[]):
        while len(queue):
            [cur, distance, hold] = queue.pop(0)
            if distance > MAX_DIS:
                hold_lanes.append(hold)
                continue
            successors = static_map.vector_lane_segments[cur].successors
            done = False
            for successor in successors:
                if visited.get(successor) == None:
                    # not visited
                    try:
                        static_map.vector_lane_segments[successor]
                        hold_ = copy.deepcopy(hold)
                        hold_.append(successor)
                        cx = get_centerline(static_map, successor)
                        distance1 = get_distance(qry_pt, cx[-1])
                        distance2 = get_distance(qry_pt, cx[0])
                        queue.append([successor, max(distance1, distance2), hold_])
                        visited[successor] = True
                    except:
                        done = True
            if done:
                hold_lanes.append(hold)
                pass
            if len(successors) == 0:
                # end of the line
                hold_lanes.append(hold)
                pass

    # hold
    visited = {}
    for lane in backbone:
        queue = []
        if visited.get(lane) == None:
            visited[lane] = True
            cx = get_centerline(static_map, lane)
            distance1 = get_distance(qry_pt, cx[-1])
            distance2 = get_distance(qry_pt, cx[0])
            hold = []
            hold.append(lane)
            queue.append([lane, max(distance1, distance2), hold])
            BFS(queue, lane, visited, hold, static_map, qry_pt)
            cnt = 0
            for lanes in hold_lanes:
                color = "#"+''.join([random.choice('0123456789ABCDEF') for j in range(6)])
                for lane_ in lanes:
                    # if cnt == 1: print("BAZINGA", lane_)
                    cx = get_centerline(static_map, lane_)
                    if debug: plt.plot(cx[:, 0], cx[:, 1], color=color)
                cnt = cnt + 1
                if debug: plt.scatter(agx[:50], agy[:50], color="blue", zorder=10)
                if debug: plt.scatter(agx[50:], agy[50:], color="orange", zorder=10)                    
                # if debug: plt.axis('equal')
                if debug: plt.xlim([np.min(agx) - 100, np.max(agx) + 100])
                if debug: plt.ylim([np.min(agy) - 150, np.max(agy) + 150])                
                if debug: plt.show()
                if debug: plt.clf()
    hold_array = copy.deepcopy(hold_lanes)

    hold_lanes = []
    # left change
    visited = {}
    for lane_hold in backbone:
        queue = []
        lane = static_map.vector_lane_segments[lane_hold].left_neighbor_id
        if visited.get(lane) == None:
            try:
                # to check if available in lst
                static_map.vector_lane_segments[lane].successors
            except:
                continue
            visited[lane] = True
            cx = get_centerline(static_map, lane)
            distance1 = get_distance(qry_pt, cx[-1])
            distance2 = get_distance(qry_pt, cx[0])            
            hold = []
            hold.append(lane)
            queue.append([lane, max(distance1, distance2), hold])
            BFS(queue, lane, visited, hold, static_map, qry_pt)
            for lanes in hold_lanes:
                for lane_ in lanes:
                    cx = get_centerline(static_map, lane_)
                    # if debug: plt.plot(cx[:, 0], cx[:, 1], color="blue")
    left_change_array = copy.deepcopy(hold_lanes)

    hold_lanes = []                
    # left change
    visited = {}
    for lane_hold in backbone:
        queue = []
        lane = static_map.vector_lane_segments[lane_hold].right_neighbor_id
        if visited.get(lane) == None:
            try:
                # to check if available in lst
                static_map.vector_lane_segments[lane].successors
            except:
                continue
            visited[lane] = True
            cx = get_centerline(static_map, lane)
            distance1 = get_distance(qry_pt, cx[-1])
            distance2 = get_distance(qry_pt, cx[0])                        
            hold = []
            hold.append(lane)
            queue.append([lane, max(distance1, distance2), hold])
            BFS(queue, lane, visited, hold, static_map, qry_pt)
            for lanes in hold_lanes:
                for lane_ in lanes:
                    cx = get_centerline(static_map, lane_)
                    # if debug: plt.plot(cx[:, 0], cx[:, 1], color="yellow")
    right_change_array = copy.deepcopy(hold_lanes)
    if debug: plt.xlim([np.min(agx) - 50, np.max(agx) + 50])
    if debug: plt.ylim([np.min(agy) - 50, np.max(agy) + 50])
    lanes = []
    cnt = 0
    for lane in hold_array:
        lane_information = []
        for lane_id in lane:
            # if cnt == 1: print("LAUWA LAUWA", lane_id)
            wx = [waypt.x for waypt in static_map.vector_lane_segments[lane_id].left_lane_boundary.waypoints]
            wy = [waypt.y for waypt in static_map.vector_lane_segments[lane_id].left_lane_boundary.waypoints]
            wwx = [waypt.x for waypt in static_map.vector_lane_segments[lane_id].right_lane_boundary.waypoints]
            wwy = [waypt.y for waypt in static_map.vector_lane_segments[lane_id].right_lane_boundary.waypoints]
            left_boundary = np.dstack((wx, wy))[0]
            right_boundary = np.dstack((wwx, wwy))[0]
            left_boundary = interp_arc(t=100, points=left_boundary)
            right_boundary = interp_arc(t=100, points=right_boundary)
            cx = get_centerline(static_map, lane_id)
            cx = interp_arc(t=sample_len, points=cx)
            is_intersection = static_map.vector_lane_segments[lane_id].is_intersection
            lane_type = static_map.vector_lane_segments[lane_id].lane_type.name
            left_mark_type = static_map.vector_lane_segments[lane_id].left_mark_type.name
            right_mark_type = static_map.vector_lane_segments[lane_id].right_mark_type.name
            for ind in range(sample_len):
                lane_information.append([is_intersection, lane_type, cx[ind],
                                        left_boundary[ind][0], left_boundary[ind][1], left_mark_type, 
                                         right_boundary[ind][0], right_boundary[ind][1], right_mark_type, cx[ind][0],  cx[ind][1]])
                pass
        lanes.append(lane_information)
        cnt = cnt + 1
    for lane in left_change_array:
        lane_information = []
        for lane_id in lane:
            wx = [waypt.x for waypt in static_map.vector_lane_segments[lane_id].left_lane_boundary.waypoints]
            wy = [waypt.y for waypt in static_map.vector_lane_segments[lane_id].left_lane_boundary.waypoints]
            wwx = [waypt.x for waypt in static_map.vector_lane_segments[lane_id].right_lane_boundary.waypoints]
            wwy = [waypt.y for waypt in static_map.vector_lane_segments[lane_id].right_lane_boundary.waypoints]
            left_boundary = np.dstack((wx, wy))[0]
            right_boundary = np.dstack((wwx, wwy))[0]
            left_boundary = interp_arc(t=100, points=left_boundary)
            right_boundary = interp_arc(t=100, points=right_boundary)
            cx = get_centerline(static_map, lane_id)
            cx = interp_arc(t=sample_len, points=cx)
            is_intersection = static_map.vector_lane_segments[lane_id].is_intersection
            lane_type = static_map.vector_lane_segments[lane_id].lane_type.name
            left_mark_type = static_map.vector_lane_segments[lane_id].left_mark_type.name
            right_mark_type = static_map.vector_lane_segments[lane_id].right_mark_type.name
            for ind in range(sample_len):
                lane_information.append([is_intersection, lane_type, cx[ind],
                                        left_boundary[ind][0], left_boundary[ind][1], left_mark_type, 
                                         right_boundary[ind][0], right_boundary[ind][1], right_mark_type, cx[ind][0],  cx[ind][1]])
                pass
        lanes.append(lane_information)
    for lane in right_change_array:
        lane_information = []
        for lane_id in lane:
            wx = [waypt.x for waypt in static_map.vector_lane_segments[lane_id].left_lane_boundary.waypoints]
            wy = [waypt.y for waypt in static_map.vector_lane_segments[lane_id].left_lane_boundary.waypoints]
            wwx = [waypt.x for waypt in static_map.vector_lane_segments[lane_id].right_lane_boundary.waypoints]
            wwy = [waypt.y for waypt in static_map.vector_lane_segments[lane_id].right_lane_boundary.waypoints]
            left_boundary = np.dstack((wx, wy))[0]
            right_boundary = np.dstack((wwx, wwy))[0]
            left_boundary = interp_arc(t=100, points=left_boundary)
            right_boundary = interp_arc(t=100, points=right_boundary)
            cx = get_centerline(static_map, lane_id)
            cx = interp_arc(t=sample_len, points=cx)
            is_intersection = static_map.vector_lane_segments[lane_id].is_intersection
            lane_type = static_map.vector_lane_segments[lane_id].lane_type.name
            left_mark_type = static_map.vector_lane_segments[lane_id].left_mark_type.name
            right_mark_type = static_map.vector_lane_segments[lane_id].right_mark_type.name
            for ind in range(sample_len):
                lane_information.append([is_intersection, lane_type, cx[ind],
                                        left_boundary[ind][0], left_boundary[ind][1], left_mark_type, 
                                         right_boundary[ind][0], right_boundary[ind][1], right_mark_type, cx[ind][0],  cx[ind][1]])
                pass
        lanes.append(lane_information)
    return lanes

In [6]:
from av2.map.map_api import ArgoverseStaticMap
from av2.datasets.motion_forecasting.data_schema import ArgoverseScenario, ObjectState, ObjectType, Track, TrackCategory
from av2.geometry.polyline_utils import convert_lane_boundaries_to_polygon, get_polyline_length, interp_polyline_by_fixed_waypt_interval
from av2.geometry.interpolate import compute_midpoint_line, interp_arc
from shapely.geometry import LineString, Point, Polygon

MAX_DIS = 200

def get_centerline(static_map, min_ind):
    wx = [waypt.x for waypt in static_map.vector_lane_segments[min_ind].left_lane_boundary.waypoints]
    wy = [waypt.y for waypt in static_map.vector_lane_segments[min_ind].left_lane_boundary.waypoints]
    wwx = [waypt.x for waypt in static_map.vector_lane_segments[min_ind].right_lane_boundary.waypoints]
    wwy = [waypt.y for waypt in static_map.vector_lane_segments[min_ind].right_lane_boundary.waypoints]
    cx, _ = compute_midpoint_line(np.dstack((wx, wy))[0], np.dstack((wwx, wwy))[0])
    return cx

def get_distance(p1, p2):
    return np.linalg.norm(np.array(p1) - np.array(p2))

def get_nearest_centerline(agx, agy, static_map, tm = 90):
    min_ind = 0
    min_dist = 1e11
    
    lane_ids = list(static_map.vector_lane_segments.keys())
    for lane_id in lane_ids:
        wx = [waypt.x for waypt in static_map.vector_lane_segments[lane_id].left_lane_boundary.waypoints]
        wy = [waypt.y for waypt in static_map.vector_lane_segments[lane_id].left_lane_boundary.waypoints]

        wwx = [waypt.x for waypt in static_map.vector_lane_segments[lane_id].right_lane_boundary.waypoints]
        wwy = [waypt.y for waypt in static_map.vector_lane_segments[lane_id].right_lane_boundary.waypoints]
        # poly = convert_lane_boundaries_to_polygon(np.dstack((wx, wy))[0], np.dstack((wwx, wwy))[0])
        cx, _ = compute_midpoint_line(np.dstack((wx, wy))[0], np.dstack((wwx, wwy))[0])

        centerline = interp_arc(100, cx)
        # plt.plot(centerline[:, 0], centerline[:, 1], "--")
        pos = np.array([agx[tm], agy[tm]])
        dist = np.linalg.norm(centerline - pos, axis=1).min()

        if dist < min_dist:
            min_dist = dist
            min_ind = lane_id
    return min_ind

def give_lane_color(name):
    string = ""
    if "YELLOW" in name:
        string += "y"
    elif "WHITE" in name:
        string += "w"
    else:
        string = "grey"
    if "DASH" in name:
        string += "--"
    return string

def get_backbone(agx, agy, static_map_path, end_tm=40):
    static_map = ArgoverseStaticMap.from_json(Path(static_map_path))
    backbones = []
    for tm in range(0, end_tm):
        min_ind = get_nearest_centerline(agx, agy, static_map, tm=tm)
        cx = get_centerline(static_map, min_ind)
        lane_seg = static_map.vector_lane_segments[min_ind]
        if tm == 0:
            backbones.append([min_ind])
            continue
        last_lane = backbones[-1][-1]
        if min_ind == last_lane:
            continue
        fg = True
        for backbone in backbones:
            if min_ind in backbone:
                fg = False
            if min_ind in static_map.vector_lane_segments[backbone[-1]].successors:
                backbone.append(min_ind)
                fg = False
        if fg:
            backbones.append([min_ind])
    for backbone in backbones:
        import random
        # color = "#"+''.join([random.choice('0123456789ABCDEF') for j in range(6)])
        for min_ind in backbone:
            cx = get_centerline(static_map, min_ind)
            # plt.plot(cx[:, 0], cx[:, 1], color=color)

    max_backbone = None
    max_len = -1e11
    ag_traj, _ = interp_polyline_by_fixed_waypt_interval(np.dstack((agx, agy))[0], 0.5)
    for backbone in backbones:
        total = 0
        color = "#"+''.join([random.choice('0123456789ABCDEF') for j in range(6)])
        cxs = []
        cys = []
        for lane in backbone:
            cx = get_centerline(static_map, lane)
            cxs.append(cx[:, 0])
            cys.append(cx[:, 1])
            # plt.plot(cx[:, 0], cx[:, 1], color=color, marker='o')
        lane_seq_polygon = centerline_to_polygon(np.dstack((cxs, cys))[0])
        total = 0
        for xy in ag_traj:
            point_in_polygon_score = Polygon(lane_seq_polygon).contains(Point(xy))
            total += point_in_polygon_score
        if total > max_len:
            max_len = total
            max_backbone = backbone
    # plt.plot(agx, agy, "b")            
    # plt.xlim([np.min(agx) - 100, np.max(agx) + 100])
    # plt.ylim([np.min(agy) - 150, np.max(agy) + 150])                    
    # plt.show()
    return max_backbone


def BFS(queue, lane_id, visited, hold, static_map, qry_pt=[], hold_lanes = []):
    while len(queue):
        [cur, distance, hold] = queue.pop(0)
        if distance > MAX_DIS:
            hold_lanes.append(hold)
            continue
        successors = static_map.vector_lane_segments[cur].successors
        done = False
        for successor in successors:
            if visited.get(successor) == None:
                # not visited
                try:
                    static_map.vector_lane_segments[successor]
                    import copy
                    hold_ = copy.deepcopy(hold)
                    hold_.append(successor)
                    cx = get_centerline(static_map, successor)
                    distance1 = get_distance(qry_pt, cx[-1])
                    distance2 = get_distance(qry_pt, cx[0])
                    queue.append([successor, max(distance1, distance2), hold_])
                    visited[successor] = True
                except:
                    done = True
        if done:
            hold_lanes.append(hold)
            pass
        if len(successors) == 0:
            # end of the line
            hold_lanes.append(hold)
            pass
    return hold_lanes

In [143]:
def project_to_frenet_frame(traj, oracle_centerline):
    traj, oracle_centerline = traj.clone().detach(), oracle_centerline.clone().detach()
    distance_to_ref = torch.cdist(torch.tensor(traj)[:, :, :2], torch.tensor(oracle_centerline))#.reshape(1, oracle_centerline.shape[0], oracle_centerline.shape[1]))[:, :, :2])
    k = torch.argmin(distance_to_ref, dim=-1).view(-1, traj.shape[1], 1).expand(-1, -1, 2)
    ref_line = oracle_centerline.clone()#torch.tensor(oracle_centerline.reshape(1, oracle_centerline.shape[0], oracle_centerline.shape[1]))
    _ref_points = torch.gather(ref_line, 1, k - torch.ones_like(k))
    ref_points = torch.gather(ref_line, 1, k)
    ref_points_ = torch.gather(ref_line, 1, k + torch.ones_like(k))
    norm = torch.linalg.norm(ref_points_ - _ref_points, dim=2)
    _ref_points_3d = torch.zeros((ref_points.shape[0], ref_points.shape[1], 3)); _ref_points_3d[:, :, 0] = _ref_points[:, :, 0]; _ref_points_3d[:, :, 1] = _ref_points[:, :, 1]
    ref_points_3d = torch.zeros((ref_points.shape[0], ref_points.shape[1], 3)); ref_points_3d[:, :, 0] = ref_points[:, :, 0]; ref_points_3d[:, :, 1] = ref_points[:, :, 1]
    ref_points__3d = torch.zeros((ref_points.shape[0], ref_points.shape[1], 3)); ref_points__3d[:, :, 0] = ref_points_[:, :, 0]; ref_points__3d[:, :, 1] = ref_points_[:, :, 1]
    traj_3d = torch.zeros((traj.shape[0], traj.shape[1], 3)); traj_3d[:, :, 0] = torch.tensor(traj[:, :, 0]); traj_3d[:, :, 1] = torch.tensor(traj[:, :, 1])

    prp = torch.linalg.cross(ref_points__3d - traj_3d, ref_points__3d - _ref_points_3d, dim=2)[:, :, 2]
    norm = torch.linalg.norm(ref_points_ - _ref_points, dim=2)
    d = torch.div(prp, norm)
    s = 0.5 * torch.argmin(distance_to_ref, dim=-1)
    return torch.stack([s, d], dim=-1)

In [189]:
debug = False
from av2.geometry.interpolate import compute_midpoint_line, interp_arc
from av2.geometry.polyline_utils import centerline_to_polygon
from shapely.geometry import LineString, Point, Polygon
from spline import Spline2D
from pathlib import Path
import pandas as pd

def rotate(gt_x, gt_y,theta):
    gt_x_x = [ (gt_x[k] * np.cos(theta) - gt_y[k] * np.sin(theta))  for k in range(len(gt_x))]
    gt_y_y = [ (gt_x[k] * np.sin(theta) + gt_y[k] * np.cos(theta))  for k in range(len(gt_x))]
    gt_x = gt_x_x
    gt_y = gt_y_y
    return gt_x, gt_y

# Custom Dataset Loader 
class TrajDataset(Dataset):
    """Expert Trajectory Dataset."""
    def __init__(self, dataset_path, source_path, debug=False):
        
        self.dataset_path = dataset_path
        self.source_path = source_path
        self.files = os.listdir(DATASET_PATH)
        self.debug = debug
        
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        # idx = 11363
        debug = self.debug
        tracks = np.load(os.path.join(self.dataset_path, self.files[idx], f"{self.files[idx]}.npy"), allow_pickle=True)
        # [object_state.position[0], object_state.position[1], object_state.heading, object_state.velocity[0], object_state.velocity[1], timestamps_ns[ind], tck[track.category], vmp[track.object_type]
        focal_track = None
        file = self.files[idx]
        for track in tracks: 
            if track[0][-2] == 3: focal_track = track

        agx = np.array(focal_track)[10:50, 0] # position-x
        agy = np.array(focal_track)[10:50, 1] # position-y
        ox = np.array(focal_track)[50:, 0] # position-x
        oy = np.array(focal_track)[50:, 1] # position-y        
        
        # obtain centerlines
        static_map_path = os.path.join(self.source_path, file,  f"log_map_archive_{file}.json")
        static_map = ArgoverseStaticMap.from_json(Path(static_map_path))
        scenario_path = os.path.join(self.source_path, file, f"scenario_{file}.parquet")
        backbone = get_backbone(agx, agy, static_map_path)
        for lane in backbone:
            queue = []
            cx = get_centerline(static_map, lane)
        if debug: plt.axis('equal')
        visited = {}
        hold_lanes = []
        qry_pt = [agx[39], agy[39]]
        for lane in backbone:
            queue = []
            cx = get_centerline(static_map, lane)
            if visited.get(lane) == None:
                visited[lane] = True
                cx = get_centerline(static_map, lane)
                distance1 = get_distance(qry_pt, cx[-1])
                distance2 = get_distance(qry_pt, cx[0])
                hold = []
                hold.append(lane)
                queue.append([lane, max(distance1, distance2), hold])
                hold_lanes = BFS(queue, lane, visited, hold, static_map, qry_pt, hold_lanes)

        tx = np.hstack((agx, ox))
        ty = np.hstack((agy, oy))
        traj = np.dstack((tx, ty))[0]
        oracle_centerline = None
        max_len = -1e11
        # print(idx)
        cxs = []
        cys = []
        for lane in backbone:
            cx = get_centerline(static_map, lane)
            cxs.append(cx[:, 0])
            cys.append(cx[:, 1])            
        oracle_ids = None
        for hold_lane in hold_lanes:
            cxs = []
            cys = []
            oracle = []
            for lane in hold_lane:
                cx = get_centerline(static_map, lane)
                cxs.append(cx[:, 0])
                cys.append(cx[:, 1])
                oracle.append(lane)
            lane_seq_polygon = centerline_to_polygon(np.dstack((cxs, cys))[0])
            total = 0
            for xy in traj:
                point_in_polygon_score = Polygon(lane_seq_polygon).contains(Point(xy))
                total += point_in_polygon_score
            if total > max_len:
                max_len = total
                oracle_ids = oracle
                oracle_centerline = np.dstack((cxs, cys))[0]
        cxs = []
        cys = []
        for lane in oracle_ids:
            cx = get_centerline(static_map, lane)
            for pos in cx:            
                cxs.append(pos[0])
                cys.append(pos[1])     
        oracle_centerline = np.dstack((cxs, cys))[0]
        # print(idx)
        if debug:
            plt.plot(agx, agy, 'b', zorder=4)
            plt.plot(ox, oy, 'r', zorder=4)
            plt.plot(oracle_centerline[:, 0], oracle_centerline[:, 1], "ko")
            plt.title("before transformation")
            plt.axis('equal')
            plt.show()        
        
        # lanes = extract_lanes(static_map_path, scenario_path, debug=False, sample_len=100)
        offsetx = agx[-1]
        offsety = agy[-1]
        theta = np.arctan2(agy[-1] - agy[-2], agx[-1] - agx[-2])
        agx, agy = rotate(agx - offsetx, agy - offsety, -theta)
        ox, oy = rotate(ox - offsetx, oy - offsety, -theta)
        oracle_centerline_x, oracle_centerline_y = rotate(np.array(cxs) - offsetx, np.array(cys) - offsety, -theta)
        oracle_centerline = np.dstack((oracle_centerline_x, oracle_centerline_y))[0]
        last_diff = oracle_centerline[-1] - oracle_centerline[-2]
        first_diff = oracle_centerline[0] - oracle_centerline[1]
        last_diff /= (2 * np.linalg.norm(last_diff))
        first_diff /= (2 * np.linalg.norm(first_diff))
        num = 100
        last = np.linspace(0.5, num, num * 2).reshape(num * 2, 1) @ last_diff.reshape(1, 2) + oracle_centerline[-1]
        first = oracle_centerline[0] + np.linspace(num, 0.5, num * 2).reshape(num * 2, 1) @ first_diff.reshape(1, 2)
        long_spline = np.concatenate((first, oracle_centerline, last), axis=0)
        oracle_centerline_, _ = interp_polyline_by_fixed_waypt_interval(long_spline, 0.5)
        if debug:
            plt.plot(oracle_centerline_[:, 0], oracle_centerline_[:, 1], "ko")
            plt.plot(first[:, 0], first[:, 1], "ro")
            plt.plot(last[:, 0], last[:, 1], "bo")
            plt.plot(agx, agy, 'b', zorder=4)
            plt.plot(ox, oy, 'r', zorder=4)        
            # plt.plot(long_spline[:, 0], long_spline[:, 1], "ko")
            plt.title("after transformation")
            plt.axis('equal')
            plt.show()
        
        traj = np.dstack((agx, agy))
        oracle_centerline, _ = interp_polyline_by_fixed_waypt_interval(np.dstack((oracle_centerline_x, oracle_centerline_y))[0], 0.5)
        # oracle_centerline = np.dstack((oracle_centerline_x, oracle_centerline_y))
        # print(len(oracle_centerline))
        total_pad = 1000
        pad_left = (total_pad - len(oracle_centerline))//2
        pad_right = total_pad - pad_left - len(oracle_centerline)
        agent_traj = np.dstack((np.hstack((agx, ox)), np.hstack((agy, oy))))
        last_diff = oracle_centerline[-1] - oracle_centerline[-2]
        first_diff = oracle_centerline[0] - oracle_centerline[1]
        last_diff /= (2 * np.linalg.norm(last_diff))
        first_diff /= (2 * np.linalg.norm(first_diff))
        num = 100
        last = np.linspace(0.5, pad_right//2 + 0.5 * (pad_right % 2), pad_right).reshape(pad_right * 2, 1) @ last_diff.reshape(1, 2) + oracle_centerline[-1]
        first = oracle_centerline[0] + np.linspace(pad_left, 0.5, pad_left * 2).reshape(pad_left * 2, 1) @ first_diff.reshape(1, 2)
        long_spline = np.concatenate((first, oracle_centerline, last), axis=0)
        oracle_centerline_, _ = interp_polyline_by_fixed_waypt_interval(long_spline, 0.5)
        oracle_centerline_plot = interp_arc(t=agent_traj.shape[1], points=oracle_centerline_[5:-10])
        total_traj = np.concatenate((oracle_centerline_plot.reshape(1, 100, 2), agent_traj), axis=1)
        oracle_centerline_, _ = interp_polyline_by_fixed_waypt_interval(long_spline, 0.5)        
        # inp = []
        # out = []
        # print(total_traj.shape, oracle_centerline_.shape)
        print(len(last), len(first), len(oracle_centerline))
        return torch.tensor(total_traj[0]).double(), torch.tensor(oracle_centerline_).double()
        # return torch.tensor(inp).double(), torch.tensor(out).double(), torch.tensor(traj_inp).double(), torch.tensor(traj_out).double(), torch.tensor(b_inp).double(), torch.tensor(c).double(), torch.tensor(traj_c).double()

In [199]:
# Load the dataset
# train_data = np.load("../datasets/toy/train_data.npy", mmap_mode="c")
DATASET_PATH =  "/mnt/e/datasets/argoverse/parsed"
source_path =  "/mnt/e/datasets/argoverse/val"
sample_len = 100
import os
files = os.listdir(DATASET_PATH)
train_dataset = TrajDataset(DATASET_PATH, source_path)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=False, num_workers=0)
for _, data in enumerate(train_loader):
    if _ < 8: continue
    # inp, out, past, future, inputs, c, traj_c = data
    total_traj, oracle_centerline_ = data        

    frenet_coordinates = project_to_frenet_frame(total_traj, oracle_centerline_)
    # print(frenet_coordinates.shape, "OOOHHHHHHHHHHHHHHH")
    frenet_coordinates = frenet_coordinates.detach()
    frenet_coordinates = frenet_coordinates - frenet_coordinates[:, 139:140,:]
    # print(oracle_centerline.shape, traj.shape, k.shape, distance_to_ref.shape, ref_line.shape)    
    vels = frenet_coordinates[:, 1:] - frenet_coordinates[:, :-1]
    
    # print(frenet_coordinates[:, 139:140,:].shape)

    inp = torch.zeros((train_loader.batch_size, 200), requires_grad=False)
    out = torch.zeros((train_loader.batch_size, 120), requires_grad=False)
    inp[:, 0::5] = frenet_coordinates[:, 100:140, 0]
    inp[:, 1::5] = frenet_coordinates[:, 100:140, 1]
    inp[:, 2::5] = vels[:, 99:139, 0]
    inp[:, 3::5] = vels[:, 99:139, 1]
    inp[:, 4::5] = 0
    out[:, :60] = frenet_coordinates[:, 140:, 0]
    out[:, 60:] = frenet_coordinates[:, 140:, 1]
    
    inp = torch.tensor(inp)
    out = torch.tensor(out)
    
    # for num in range(train_loader.batch_size):
    #     plt.plot(frenet_coordinates[num, :100, 0], frenet_coordinates[num, :100, 1], 'ko', zorder=4)        
    #     plt.plot(frenet_coordinates[num, 100:140, 0], frenet_coordinates[num, 100:140, 1], 'bo', zorder=5)
    #     plt.plot(frenet_coordinates[num, 140:, 0], frenet_coordinates[num, 140:, 1], 'ro', zorder=5)
    #     # plt.plot(oss, od, 'r', zorder=4)
    #     # plt.plot(cs, cd, "ko")
    #     plt.title("after torch frenet projection")
    #     plt.axis('equal')
    #     plt.show()

    pass
    # break
# Using PyTorch Dataloader
# train_loader = DataLoader(train_dataset, batch_size=2, shuffle=False, num_workers=0)

RuntimeError: stack expects each tensor to be equal size, but got [1000, 2] at entry 0 and [999, 2] at entry 7

In [None]:
from beta_cvae_aug_ddn import Encoder, Decoder, Beta_cVAE, BatchOpt_DDN, DeclarativeLayer
from gru_cvae_aug_ddn import GRU_cVAE, DecoderGRU, EncoderGRU

# DDN
num_batch = train_loader.batch_size
node = BatchOpt_DDN(P, Pdot, Pddot, num_batch)
opt_layer = DeclarativeLayer(node)

# Beta-cVAE Inputs
enc_inp_dim = 200
enc_out_dim = 120
dec_inp_dim = enc_inp_dim
dec_out_dim = 8
hidden_dim = 1024 * 2
z_dim = 2

# inp_mean, inp_std = 5.1077423, 20.914295
inp_mean, inp_std = 5.1077423, 10.914295

encoder = Encoder(enc_inp_dim, enc_out_dim, hidden_dim, z_dim)
decoder = Decoder(dec_inp_dim, dec_out_dim, hidden_dim, z_dim)
model = Beta_cVAE(encoder, decoder, opt_layer, inp_mean, inp_std).to(device)

encoder = EncoderGRU(enc_inp_dim, enc_out_dim, hidden_dim, z_dim, batch_size=num_batch)
decoder = DecoderGRU(dec_inp_dim, 2, hidden_dim, z_dim, batch_size=num_batch)
model_gru = GRU_cVAE(encoder, decoder, opt_layer, inp_mean, inp_std).to(device)

P_ = torch.block_diag(P, P).to(device)
Pdot_ = Pdot.to(device)
Pddot_ = Pddot.to(device)

In [None]:
epochs = 60
step, beta = 0, 3.5
optimizer = optim.AdamW(model.parameters(), lr = 1e-4, weight_decay=6e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 12, gamma = 0.1)

avg_train_loss, avg_rcl_loss, avg_kl_loss, avg_aug_loss = [], [], [], []
for epoch in range(epochs):
    
    # Train Loop
    losses_train, kl_losses, rcl_losses, aug_losses = [], [], [], []
    model.train()
    from tqdm import tqdm
    it = 0
    for total_traj, oracle_centerline_ in tqdm(train_loader):        
        frenet_coordinates = project_to_frenet_frame(total_traj, oracle_centerline_)
        # print(frenet_coordinates.shape, "OOOHHHHHHHHHHHHHHH")
        frenet_coordinates = frenet_coordinates.detach()
        frenet_coordinates = frenet_coordinates - frenet_coordinates[:, 139:140,:]
        # print(oracle_centerline.shape, traj.shape, k.shape, distance_to_ref.shape, ref_line.shape)    
        vels = frenet_coordinates[:, 1:] - frenet_coordinates[:, :-1]

        # print(frenet_coordinates[:, 139:140,:].shape)

        inp = torch.zeros((train_loader.batch_size, 200), requires_grad=False)
        out = torch.zeros((train_loader.batch_size, 120), requires_grad=False)
        inp[:, 0::5] = frenet_coordinates[:, 100:140, 0]
        inp[:, 1::5] = frenet_coordinates[:, 100:140, 1]
        inp[:, 2::5] = vels[:, 99:139, 0]
        inp[:, 3::5] = vels[:, 99:139, 1]
        inp[:, 4::5] = 0
        out[:, :60] = frenet_coordinates[:, 140:, 0]
        out[:, 60:] = frenet_coordinates[:, 140:, 1]

        # inp = torch.tensor(inp, requires_grad=True)
        # out = torch.tensor(out, requires_grad=True)
        
        it = it + 1
        inp = inp.to(device)
        out = out.to(device)
        

        traj_gt = out#torch.cat((out[:, :50], out[:, 50:]))
        # z = torch.cat([inp, out], dim = 1)
		# Sample from z -> Reparameterized 
		# z = self._sample_z(mean, std)
		
		# Decode y
		# y_star = self._decoder(z, inp_norm, init_state_ego, y_ub, y_lb)
        
        # Ego vehicle states
        initial_state_ego = inp[:, 2:6].clone()
        initial_state_ego[:, 2:4] = initial_state_ego[:, 0:2]
        initial_state_ego[:, 0:2] = 0
        
        mean, std = model._encoder(inp, traj_gt)
        z = model._sample_z(mean, std)
        y_star = model._decoder(z, inp, initial_state_ego, 0, 0) 
        traj_sol = (P_ @ y_star.T).T 

        # Remember to add the Aug Loss
        KL_loss, RCL_loss, loss, _ = model.forward(inp, traj_gt, initial_state_ego, P_, Pdot_, Pddot_, beta, step)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses_train.append(loss.detach().cpu().numpy()) 
        rcl_losses.append(RCL_loss.detach().cpu().numpy())
        kl_losses.append(KL_loss.detach().cpu().numpy())
        
        if it % 5 ==1:
            ags = inp[0, 0::5].detach()
            agd = inp[0, 1::5].detach()
            oss = out[0, :60].detach()
            od = out[0, 60:].detach()
            pss = traj_sol[0, :60].detach()
            pd = traj_sol[0, 60:].detach()
            cs = frenet_coordinates[0, :100, 0].detach()
            cd = frenet_coordinates[0, :100, 1].detach()
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15,5))
            # ax1.plot(ags, agd);
            ax1.plot(oss, od);
            ax1.plot(pss, pd);
            # ax1.plot(cs, cd, "k--");
            # ax1.axis('equal')
            ax2.axis('equal')
            plt.axis('equal')
            plt.show()
            plt.clf()
            from IPython.display import clear_output
            import time
            time.sleep(4)
            clear_output(wait=True)
        
        # aug_losses.append(Aug.detach().cpu().numpy())
        if it % 5 == 1:
            print(f"Epoch: {epoch + 1}, Train Loss: {np.average(losses_train):.3f}, RCL: {np.average(rcl_losses):.3f}, KL: {np.average(kl_losses):.3f}") #, Aug: {np.average(aug_losses):.3f}")

    if epoch % 1 == 0:
        print(f"Epoch: {epoch + 1}, Train Loss: {np.average(losses_train):.3f}, RCL: {np.average(rcl_losses):.3f}, KL: {np.average(kl_losses):.3f}") #, Aug: {np.average(aug_losses):.3f}")

    step += 1.0
    scheduler.step()
    avg_train_loss.append(np.average(losses_train)), avg_rcl_loss.append(np.average(rcl_losses)), \
    avg_kl_loss.append(np.average(kl_losses)) #, avg_aug_loss.append(np.average(aug_losses))

In [None]:
x = np.linspace(0, 10, 20)
y = 2  * x
plt.scatter(x, y)
point = np.array(y) + 4
plt.scatter(x, point)


traj = np.dstack((x, point))
spline = np.dstack((x, y))[0]
print(spline.shape)


oracle_centerline = spline
last_diff = oracle_centerline[-1] - oracle_centerline[-2]
first_diff = oracle_centerline[0] - oracle_centerline[1]
last_diff /= (2 * np.linalg.norm(last_diff))
first_diff /= (2 * np.linalg.norm(first_diff))
num = 100
last = np.linspace(0.5, num, num * 2).reshape(num * 2, 1) @ last_diff.reshape(1, 2) + oracle_centerline[-1]
first = oracle_centerline[0] + np.linspace(num, 0.5, num * 2).reshape(num * 2, 1) @ first_diff.reshape(1, 2)
long_spline = np.concatenate((first, oracle_centerline, last), axis=0)
print(get_polyline_length(long_spline)//0.5)
print(len(oracle_centerline_))
oracle_centerline_, _ = interp_polyline_by_fixed_waypt_interval(long_spline, 0.5)

frenet_coordinates = project_to_frenet_frame(traj, oracle_centerline_)

plt.plot(spline[:, 0], spline[:, 1])
plt.show()
print(frenet_coordinates.shape)

plt.plot(frenet_coordinates[0, :, 0], frenet_coordinates[0, :, 1])
plt.show()

In [None]:
# Sanity Check
for batch_num, (datas) in enumerate(train_loader):
    inp, out = datas
    
    print(inp.shape, out.shape)
        
    num = 2
    
    # # Sanity Check
    # x_obs = inp[num].flatten()[5::5]
    # y_obs = inp[num].flatten()[6::5]
     
    # th = np.linspace(0, 2 * np.pi, 100)
    plt.figure(1)
    
    a_obs, b_obs = 5.8, 3.2
    
    # for i in range(0, 10):
    #     x_ell = x_obs[i] + a_obs * np.cos(th)
    #     y_ell = y_obs[i] + b_obs * np.sin(th)
    #     plt.plot(x_ell, y_ell, '-k', linewidth=1.0)

    plt.axis('equal')
        
    traj_gt = out

    # Ego vehicle states
    initial_state_ego = inp[:, 2:6]
    initial_state_ego[:, 2] = inp[:, 2::5][:, -1]
    initial_state_ego[:, 3] = inp[:, 3::5][:, -1]
    initial_state_ego[:, 0:2] = 0

    mean, std = model._encoder(inp, traj_gt)
    z = model._sample_z(mean, std)
    y_star = model._decoder(z, inp, initial_state_ego, 0, 0) 
    traj_sol = (P @ y_star.T).T

    ccx = traj_sol[num].flatten()[0:60]
    ccy = traj_sol[num].flatten()[60:]
    
    cx = out[num].flatten()[0:60]
    cy = out[num].flatten()[60:]
    
    x_gt =  cx
    y_gt =  cy

    x_gt_ =  ccx
    y_gt_ =  ccy    
    
    x_obs = inp[0::5]
    y_obs = inp[1::5]
    
    ag_x = inp[num].flatten()[0::5]
    ag_y = inp[num].flatten()[1::5]    

    # plt.plot(ag_x.numpy(), ag_y.numpy(), label="Observed", color="blue")    
    plt.plot(x_gt.numpy(), y_gt.numpy(), label="Ground Truth", color="red")
    plt.plot(x_gt_.detach().numpy(), y_gt_.detach().numpy(), label="Predicted", color="orange")
    plt.legend()
    plt.savefig(f"argoverse_figs/{batch_num}_{num}.png")
    plt.clf()
    # plt.show()
    
    # break

In [None]:
torch.save(model.state_dict(), './Weights/cvae_aug_mse.pth')

In [None]:
debug = True
from av2.geometry.interpolate import compute_midpoint_line, interp_arc
from av2.geometry.polyline_utils import centerline_to_polygon
from shapely.geometry import LineString, Point, Polygon
from spline import Spline2D
# Custom Dataset Loader 
class TrajDataset(Dataset):
    """Expert Trajectory Dataset."""
    def __init__(self, data):
        
        # Inputs
        self.inp = data[:, 0:55]
        
        # Outputs
        self.out = data[:, 55:]
        
    def __len__(self):
        return len(self.inp)
    
    def __getitem__(self, idx):
        
        # Inputs
        inp = self.inp[idx]
        
        # Outputs
        out = self.out[idx]
                 
        return torch.tensor(inp).double(), torch.tensor(out).double()

def rotate(gt_x, gt_y,theta):
    gt_x_x = [ (gt_x[k] * np.cos(theta) - gt_y[k] * np.sin(theta))  for k in range(len(gt_x))]
    gt_y_y = [ (gt_x[k] * np.sin(theta) + gt_y[k] * np.cos(theta))  for k in range(len(gt_x))]
    gt_x = gt_x_x
    gt_y = gt_y_y
    return gt_x, gt_y
    
# Load the dataset
# train_data = np.load("../datasets/toy/train_data.npy", mmap_mode="c")
DATASET_PATH =  "/mnt/e/datasets/argoverse/parsed"
import os
files = os.listdir(DATASET_PATH)

# Custom Dataset Loader 
class TrajDataset(Dataset):
    """Expert Trajectory Dataset."""
    def __init__(self, dataset_path):
        
        self.dataset_path = dataset_path
        self.files = os.listdir(DATASET_PATH)
        
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        arr = np.load(os.path.join(self.dataset_path, self.files[idx], f"{self.files[idx]}.npy"), allow_pickle=True)
        map_info = np.load(os.path.join(self.dataset_path, self.files[idx], f"lanes_{self.files[idx]}.npy"), allow_pickle=True)
        agent, obstacle_tracks = None, []
        # [object_state.position[0], object_state.position[1], object_state.heading, object_state.velocity[0], object_state.velocity[1], timestamps_ns[ind], tck[track.category], vmp[track.object_type]
        # 5s of input of agent and 4 obstacles, 5s of output
        inp = np.zeros((200))
        out = np.zeros((120))
        if debug: fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
        for tracks in arr:
            if tracks[0][-2] == 3:
                # FOCAL TRACK
                agent = tracks
                last_obs = np.array(tracks)[-1]
                last_obs_x = last_obs[0]
                last_obs_y = last_obs[1]
                last_obs_vel_x = last_obs[3]
                last_obs_vel_y = last_obs[4]
                for i in range(np.array(tracks).shape[0], 110):
                    timestep = (i - np.array(tracks).shape[0] + 1)
                    tracks.append([last_obs_x + last_obs_vel_x * timestep, last_obs_y + last_obs_vel_y * timestep, last_obs[2], last_obs_vel_x, last_obs_vel_y, i, last_obs[-2], last_obs[-1]])
                agx = np.array(tracks)[10:50, 0] # position-x
                agy = np.array(tracks)[10:50, 1] # position-y
                offsetx = agx[-1]
                offsety = agy[-1]
                theta = np.arctan2(agy[-1] - agy[-2], agx[-1] - agx[-2])
                ox = np.array(tracks)[50:, 0] # position-x
                oy = np.array(tracks)[50:, 1] # position-x
                if debug: ax1.plot(agx - offsetx, agy - offsety, color="blue", linewidth=3, zorder=3)
                if debug: ax1.plot(ox - offsetx, oy - offsety, color="orange", linewidth=3, zorder=3)
                agx, agy = rotate(agx - offsetx, agy - offsety, -theta)
                ox, oy = rotate(ox - offsetx, oy - offsety, -theta)
                vx = np.array(tracks)[10:50, 3] # velocity-x
                vy = np.array(tracks)[10:50, 4] # velocity-y
                vx, vy = rotate(vx, vy, -theta)
                
                hold_lanes = np.array(map_info[0])
                
                max_ind = 0
                max_val = -1e11
                colors = ["r", "g", "b", "o", "p", "r", "g", "b", "o", "p", "r", "g", "b", "o", "p", "r", "g", "b", "o", "p"]
                for lane_num in range(len(hold_lanes)//100):
                    wx = hold_lanes[(lane_num * 100):((lane_num + 1)*100), 3].astype(float)
                    wy = hold_lanes[(lane_num * 100):((lane_num + 1)*100), 4].astype(float)
                    wwx = hold_lanes[(lane_num * 100):((lane_num + 1)*100), 6].astype(float)
                    wwy = hold_lanes[(lane_num * 100):((lane_num + 1)*100), 7].astype(float)
                    # print(np.array(c_lane).shape, "len of c_lane")
                    c = np.array(hold_lanes[(lane_num * 100):((lane_num + 1)*100), 2])
                    cx = []
                    cy = []
                    for pt in c:
                        cx.append(pt[0])
                        cy.append(pt[1])   
                    c = np.dstack((cx, cy))[0]
                    print(type(c), c.shape)
                    lane_seq_polygon = centerline_to_polygon(c)
                    point_in_polygon_score = 0
                    for xy in np.dstack((agx, agy))[0]:
                        point_in_polygon_score += Polygon(lane_seq_polygon).contains(Point(xy))
                        if point_in_polygon_score > max_val:
                            max_val = point_in_polygon_score
                            max_ind = lane_num
                    if debug: ax1.plot(cx - offsetx, cy - offsety, f"{colors[lane_num]}--", linewidth=2)
                    cx, cy = rotate(cx - offsetx, cy - offsety, -theta)
                    if debug: ax2.plot(cx, cy, "g--", linewidth=2)
                            

                hold_lanes = np.array(map_info[0])
                lane_num = max_ind
                wx = hold_lanes[(lane_num * 100):((lane_num + 1)*100), 3].astype(float)
                wy = hold_lanes[(lane_num * 100):((lane_num + 1)*100), 4].astype(float)
                wwx = hold_lanes[(lane_num * 100):((lane_num + 1)*100), 6].astype(float)
                wwy = hold_lanes[(lane_num * 100):((lane_num + 1)*100), 7].astype(float)
                # print(np.array(c_lane).shape, "len of c_lane")
                c = np.array(hold_lanes[(lane_num * 100):((lane_num + 1)*100), 2])
                cx = []
                cy = []
                for pt in c:
                    cx.append(pt[0])
                    cy.append(pt[1])   
                c = np.dstack((cx, cy))[0]
                            
                            
                # cx = [:, 0]
                # cy = np.array(hold_lanes[(lane_num * 100):((lane_num + 1)*100), 2])[:, 1]
                # print(len(cx), len(cy), "CHAGGA")
                # plt.plot(wx, wy, "b", linewidth=3)
                # cx = compute_midpoint_line(np.dstack((wx, wy))[0], np.dstack((wwx, wwy))[0])
                # print(np.array(cx).shape, "shape")
                c = np.dstack((cx, cy))
                ag = np.dstack((agx, agy))
                o = np.dstack((ox, oy))

                spline = Spline2D(c[0][:, 0], c[0][:, 1])
                ags = []
                agd = []
                oss = []
                od = []                
                cs = []
                cd = []
                for pos in ag[0]:
                    s, d = spline.calc_frenet_position(pos[0], pos[1])
                    ags.append(s)
                    agd.append(d)
                
                for pos in o[0]:
                    s, d = spline.calc_frenet_position(pos[0], pos[1])
                    oss.append(s)
                    od.append(d)

                for pos in c[0]:
                    s, d = spline.calc_frenet_position(pos[0], pos[1])
                    cs.append(s)
                    cd.append(d)

                offset_s, offset_d = ags[-1], agd[-1]
                ags = np.array(ags) - offset_s
                agd = np.array(agd) - offset_d
                oss = np.array(oss) - offset_s
                od = np.array(od) - offset_d
                cs = np.array(cs) - offset_s
                cd = np.array(cd) - offset_d     
                    
                vs = (np.array(ags)[1:] - np.array(ags)[:-1])/0.1
                vd = (np.array(agd)[1:] - np.array(agd)[:-1])/0.1
                vs_init = (np.array(ags)[-1] - np.array(oss)[0])/0.1
                vd_init = (np.array(agd)[-1] - np.array(od)[0])/0.1                    
                
                if debug: ax3.plot(ags, agd, color="blue")
                if debug: ax3.plot(oss, od, color="orange")
                if debug: ax3.plot(cs, cd, color="green")                

                ref_line = torch.tensor(c)
                traj = torch.tensor(ag)
                
                # project_to_frenet_frame(traj, ref_line)
                if debug: ax1.axis('equal')
                if debug: ax2.axis('equal')
                if debug: ax3.axis('equal')
                
                
                inp[0::5] = ags
                inp[1::5] = agd
                inp[2::5] = vs_init
                inp[3::5] = vd_init
                inp[7::5] = vs
                inp[8::5] = vd
                inp[4::5] = np.array(tracks)[10:50, 2] - theta # heading
                out[:60] = oss
                out[60:] = od
                c = np.dstack((cs, cd))[0]
                traj_inp = np.hstack((agx, agy)).flatten()
                traj_out = np.hstack((ox, oy)).flatten()
                traj_c = np.hstack((cx, cy)).flatten()
                if debug: ax2.plot(agx, agy, color="blue", linewidth=3, zorder=3)
                if debug: ax2.plot(ox, oy, color="orange", linewidth=3, zorder=3)                
                b_inp = np.array([agx[-1], agy[-1], vx[-1], vy[-1], 0, 0])
        return torch.tensor(inp).double(), torch.tensor(out).double(), torch.tensor(traj_inp).double(), torch.tensor(traj_out).double(), torch.tensor(b_inp).double(), torch.tensor(c).double(), torch.tensor(traj_c).double()

train_dataset = TrajDataset(DATASET_PATH)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)
for _, data in enumerate(train_loader):
    inp, out, past, future, inputs, c, traj_c = data
    ags = inp[0, 0::5].detach()
    agd = inp[0, 1::5].detach()
    oss = out[0, :60].detach()
    od = out[0, 60:].detach()
    cs = c[0,: ,0].detach()
    cd = c[0,: ,1].detach()    
    agx = past[0, :40].detach()
    agy = past[0, 40:].detach()
    ox = future[0, :60].detach()
    oy = future[0, 60:].detach()    
    cx = traj_c[0,:100].detach()
    cy = traj_c[0,100:].detach()
    # print(traj_c.shape, c.shape, past.shape, future.shape)
    # print(c.shape)
    # fig, (ax2, ax1) = plt.subplots(1, 2, figsize=(15,5))
    # ax1.plot(ags, agd);
    # ax1.plot(oss, od);
    # ax1.plot(cs, cd, "g--");
    # ax1.axis('equal')
    # ax1.set_title("Frenet Frame")
    # ax2.plot(agx, agy);
    # ax2.plot(ox, oy);
    # ax2.plot(cx, cy, "g--");
    # ax2.axis('equal')
    # ax2.set_title("Global Frame")
    plt.show()
    # plt.savefig(f"debug/{_}.png")
    plt.clf()
    from IPython.display import clear_output
    import time
    time.sleep(3)
    clear_output(wait=True)
    # break
    pass
# Using PyTorch Dataloader
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0)

In [None]:
debug = False
from av2.geometry.interpolate import compute_midpoint_line, interp_arc
from av2.geometry.polyline_utils import centerline_to_polygon
from shapely.geometry import LineString, Point, Polygon
from spline import Spline2D
from pathlib import Path
import pandas as pd

def rotate(gt_x, gt_y,theta):
    gt_x_x = [ (gt_x[k] * np.cos(theta) - gt_y[k] * np.sin(theta))  for k in range(len(gt_x))]
    gt_y_y = [ (gt_x[k] * np.sin(theta) + gt_y[k] * np.cos(theta))  for k in range(len(gt_x))]
    gt_x = gt_x_x
    gt_y = gt_y_y
    return gt_x, gt_y

# Custom Dataset Loader 
class TrajDataset(Dataset):
    """Expert Trajectory Dataset."""
    def __init__(self, dataset_path, source_path):
        
        self.dataset_path = dataset_path
        self.source_path = source_path
        self.files = os.listdir(DATASET_PATH)
        
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        arr = np.load(os.path.join(self.dataset_path, self.files[idx], f"{self.files[idx]}.npy"), allow_pickle=True)
        map_info = np.load(os.path.join(self.dataset_path, self.files[idx], f"lanes_{self.files[idx]}.npy"), allow_pickle=True)
        agent, obstacle_tracks = None, []
        # [object_state.position[0], object_state.position[1], object_state.heading, object_state.velocity[0], object_state.velocity[1], timestamps_ns[ind], tck[track.category], vmp[track.object_type]
        # 5s of input of agent and 4 obstacles, 5s of output
        inp = np.zeros((200))
        out = np.zeros((120))
        if debug: fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
        file = self.files[idx]
        static_map_path = os.path.join(self.source_path, file,  f"log_map_archive_{file}.json")
        scenario_path = os.path.join(self.source_path, file, f"scenario_{file}.parquet")
        
        # lanes = extract_lanes(static_map_path, scenario_path, debug=False, sample_len=100)
        focal_track = None
        for tracks in arr:
            if tracks[0][-2] == 3:
                # FOCAL TRACK
                agent = tracks
                focal_track = tracks
        tracks = focal_track
        last_obs = np.array(tracks)[-1]
        last_obs_x = last_obs[0]
        last_obs_y = last_obs[1]
        last_obs_vel_x = last_obs[3]
        last_obs_vel_y = last_obs[4]
        for i in range(np.array(tracks).shape[0], 110):
            timestep = (i - np.array(tracks).shape[0] + 1)
            tracks.append([last_obs_x + last_obs_vel_x * timestep, last_obs_y + last_obs_vel_y * timestep, last_obs[2], last_obs_vel_x, last_obs_vel_y, i, last_obs[-2], last_obs[-1]])
        agx = np.array(tracks)[10:50, 0] # position-x
        agy = np.array(tracks)[10:50, 1] # position-y
        
        offsetx = agx[-1]
        offsety = agy[-1]
        theta = np.arctan2(agy[-1] - agy[-2], agx[-1] - agx[-2])
        ox = np.array(tracks)[50:, 0] # position-x
        oy = np.array(tracks)[50:, 1] # position-x
        import copy
        aagx = copy.deepcopy(agx)
        aagy = copy.deepcopy(agy)
        oox = copy.deepcopy(ox)
        ooy = copy.deepcopy(oy)                
        
        if debug: ax1.plot(agx - offsetx, agy - offsety, color="blue", linewidth=3, zorder=3)
        if debug: ax1.plot(ox - offsetx, oy - offsety, color="orange", linewidth=3, zorder=3)
        agx, agy = rotate(agx - offsetx, agy - offsety, -theta)
        ox, oy = rotate(ox - offsetx, oy - offsety, -theta)
        vx = np.array(tracks)[10:50, 3] # velocity-x
        vy = np.array(tracks)[10:50, 4] # velocity-y
        vx, vy = rotate(vx, vy, -theta)

        # hold_lanes = np.array(map_info[0])
        # getting the backbone
        static_map = ArgoverseStaticMap.from_json(Path(static_map_path))
        backbones = []
        for tm in range(0, 40):
            min_ind = get_nearest_centerline(agx, agy, static_map, tm=tm)
            cx = get_centerline(static_map, min_ind)
            lane_seg = static_map.vector_lane_segments[min_ind]
            if tm == 0:
                backbones.append([min_ind])
                continue
            last_lane = backbones[-1][-1]
            if min_ind == last_lane:
                continue
            fg = True
            for backbone in backbones:
                if min_ind in backbone:
                    fg = False
                if min_ind in static_map.vector_lane_segments[backbone[-1]].successors:
                    backbone.append(min_ind)
                    fg = False
            if fg:
                backbones.append([min_ind])
            # plt.plot(cx[:, 0], cx[:, 1])

        for backbone in backbones:
            import random
            # color = "#"+''.join([random.choice('0123456789ABCDEF') for j in range(6)])
            for min_ind in backbone:
                cx = get_centerline(static_map, min_ind)
                # plt.plot(cx[:, 0], cx[:, 1], color=color)

        max_backbone = None
        max_len = -1e11
        for backbone in backbones:
            total = 0
            for lane in backbone:
                cx = get_centerline(static_map, lane)
                total += get_polyline_length(cx)
            if total > max_len:
                max_len = total
                max_backbone = backbone

        MAX_DIS = 300

        hold_lanes = []
        import copy

        qry_pt = [agx[39], agy[39]]

        def BFS(queue, lane_id, visited, hold, static_map):
            while len(queue):
                [cur, distance, hold] = queue.pop(0)
                if distance > MAX_DIS:
                    hold_lanes.append(hold)
                    continue
                successors = static_map.vector_lane_segments[cur].successors
                done = False
                for successor in successors:
                    if visited.get(successor) == None:
                        # not visited
                        try:
                            static_map.vector_lane_segments[successor]
                            hold_ = copy.deepcopy(hold)
                            hold_.append(successor)
                            cx = get_centerline(static_map, successor)
                            distance1 = get_distance(qry_pt, cx[-1])
                            distance2 = get_distance(qry_pt, cx[0])
                            queue.append([successor, max(distance1, distance2), hold_])
                            visited[successor] = True
                        except:
                            done = True
                if done:
                    hold_lanes.append(hold)
                    pass
                if len(successors) == 0:
                    # end of the line
                    hold_lanes.append(hold)
                    pass

        # hold
        visited = {}
        for lane in backbone:
            queue = []
            if visited.get(lane) == None:
                visited[lane] = True
                cx = get_centerline(static_map, lane)
                distance1 = get_distance(qry_pt, cx[-1])
                distance2 = get_distance(qry_pt, cx[0])
                hold = []
                hold.append(lane)
                queue.append([lane, max(distance1, distance2), hold])
                BFS(queue, lane, visited, hold, static_map)
                cnt = 0
                for lanes in hold_lanes:
                    color = "#"+''.join([random.choice('0123456789ABCDEF') for j in range(6)])
                    for lane_ in lanes:
                        # if cnt == 1: print("BAZINGA", lane_)
                        cx = get_centerline(static_map, lane_)
                        # if debug: plt.plot(cx[:, 0], cx[:, 1], color=color)
                    cnt = cnt + 1
                    # if debug: plt.scatter(agx[:50], agy[:50], color="blue", zorder=10)
                    # if debug: plt.scatter(agx[50:], agy[50:], color="orange", zorder=10)                    
                    # # if debug: plt.axis('equal')
                    # if debug: plt.xlim([np.min(agx) - 100, np.max(agx) + 100])
                    # if debug: plt.ylim([np.min(agy) - 150, np.max(agy) + 150])                
                    # if debug: plt.show()
                    # if debug: plt.clf()

        hold_array = hold_lanes
        lanes = []
        for lane in hold_array:
            lane_information = []
            for lane_id in lane:
                # if cnt == 1: print("LAUWA LAUWA", lane_id)
                wx = [waypt.x for waypt in static_map.vector_lane_segments[lane_id].left_lane_boundary.waypoints]
                wy = [waypt.y for waypt in static_map.vector_lane_segments[lane_id].left_lane_boundary.waypoints]
                wwx = [waypt.x for waypt in static_map.vector_lane_segments[lane_id].right_lane_boundary.waypoints]
                wwy = [waypt.y for waypt in static_map.vector_lane_segments[lane_id].right_lane_boundary.waypoints]
                left_boundary = np.dstack((wx, wy))[0]
                right_boundary = np.dstack((wwx, wwy))[0]
                left_boundary = interp_arc(t=100, points=left_boundary)
                right_boundary = interp_arc(t=100, points=right_boundary)
                cx = get_centerline(static_map, lane_id)
                cx = interp_arc(t=sample_len, points=cx)
                is_intersection = static_map.vector_lane_segments[lane_id].is_intersection
                lane_type = static_map.vector_lane_segments[lane_id].lane_type.name
                left_mark_type = static_map.vector_lane_segments[lane_id].left_mark_type.name
                right_mark_type = static_map.vector_lane_segments[lane_id].right_mark_type.name
                for ind in range(sample_len):
                    lane_information.append([is_intersection, lane_type, cx[ind],
                                            left_boundary[ind][0], left_boundary[ind][1], left_mark_type, 
                                             right_boundary[ind][0], right_boundary[ind][1], right_mark_type, cx[ind][0],  cx[ind][1]])
                    pass
            lanes.append(lane_information)


        max_ind = 0
        max_val = -1e11
        hold_lanes = np.array(lanes[0])
        colors = ["r", "g", "r", "y", "b", "r", "g", "r", "y", "b", "r", "g", "r", "y", "b"]
        for lane_num in range(len(hold_lanes)//100):
            wx = hold_lanes[(lane_num * 100):((lane_num + 1)*100), 3].astype(float)
            wy = hold_lanes[(lane_num * 100):((lane_num + 1)*100), 4].astype(float)
            wwx = hold_lanes[(lane_num * 100):((lane_num + 1)*100), 6].astype(float)
            wwy = hold_lanes[(lane_num * 100):((lane_num + 1)*100), 7].astype(float)
            ccx = hold_lanes[(lane_num * 100):((lane_num + 1)*100), 9].astype(float)
            ccy = hold_lanes[(lane_num * 100):((lane_num + 1)*100), 10].astype(float)
            plt.plot(ccx, ccy, color="green")
            # print(np.array(c_lane).shape, "len of c_lane")
            c = np.array(hold_lanes[(lane_num * 100):((lane_num + 1)*100), 2])
            cx = []
            cy = []
            for pt in c:
                cx.append(pt[0])
                cy.append(pt[1])   
            c = np.dstack((cx, cy))[0]
            print(type(c), c.shape)
            lane_seq_polygon = centerline_to_polygon(c)
            point_in_polygon_score = 0
            for xy in np.dstack((agx, agy))[0]:
                point_in_polygon_score += Polygon(lane_seq_polygon).contains(Point(xy))
                if point_in_polygon_score > max_val:
                    max_val = point_in_polygon_score
                    max_ind = lane_num
            if debug: ax1.plot(ccx - offsetx, ccy - offsety, f"{colors[lane_num]}--", linewidth=2)
            cx, cy = rotate(ccx - offsetx, ccy - offsety, -theta)
            if debug: ax2.plot(cx, cy, "g--", linewidth=2)
        plt.plot(aagx, aagy, color="blue")
        plt.plot(oox, ooy, color="orange")
        plt.show()
        hold_lanes = np.array(map_info[0])
        lane_num = 0
        wx = hold_lanes[(lane_num * 100):((lane_num + 1)*100), 3].astype(float)
        wy = hold_lanes[(lane_num * 100):((lane_num + 1)*100), 4].astype(float)
        wwx = hold_lanes[(lane_num * 100):((lane_num + 1)*100), 6].astype(float)
        wwy = hold_lanes[(lane_num * 100):((lane_num + 1)*100), 7].astype(float)
        # print(np.array(c_lane).shape, "len of c_lane")
        c = np.array(hold_lanes[(lane_num * 100):((lane_num + 1)*100), 2])
        cx = []
        cy = []
        for pt in c:
            cx.append(pt[0])
            cy.append(pt[1])   
        c = np.dstack((cx, cy))[0]


        # cx = [:, 0]
        # cy = np.array(hold_lanes[(lane_num * 100):((lane_num + 1)*100), 2])[:, 1]
        # print(len(cx), len(cy), "CHAGGA")
        # plt.plot(wx, wy, "b", linewidth=3)
        # cx = compute_midpoint_line(np.dstack((wx, wy))[0], np.dstack((wwx, wwy))[0])
        # print(np.array(cx).shape, "shape")
        c = np.dstack((cx, cy))
        ag = np.dstack((agx, agy))
        o = np.dstack((ox, oy))

        spline = Spline2D(c[0][:, 0], c[0][:, 1])
        ags = []
        agd = []
        oss = []
        od = []                
        cs = []
        cd = []
        for pos in ag[0]:
            s, d = spline.calc_frenet_position(pos[0], pos[1])
            ags.append(s)
            agd.append(d)

        for pos in o[0]:
            s, d = spline.calc_frenet_position(pos[0], pos[1])
            oss.append(s)
            od.append(d)

        for pos in c[0]:
            s, d = spline.calc_frenet_position(pos[0], pos[1])
            cs.append(s)
            cd.append(d)

        offset_s, offset_d = ags[-1], agd[-1]
        ags = np.array(ags) - offset_s
        agd = np.array(agd) - offset_d
        oss = np.array(oss) - offset_s
        od = np.array(od) - offset_d
        cs = np.array(cs) - offset_s
        cd = np.array(cd) - offset_d     

        vs = (np.array(ags)[1:] - np.array(ags)[:-1])/0.1
        vd = (np.array(agd)[1:] - np.array(agd)[:-1])/0.1
        vs_init = (np.array(ags)[-1] - np.array(oss)[0])/0.1
        vd_init = (np.array(agd)[-1] - np.array(od)[0])/0.1                    

        if debug: ax3.plot(ags, agd, color="blue")
        if debug: ax3.plot(oss, od, color="orange")
        if debug: ax3.plot(cs, cd, color="green")                

        ref_line = torch.tensor(c)
        traj = torch.tensor(ag)

        # project_to_frenet_frame(traj, ref_line)
        if debug: ax1.axis('equal')
        if debug: ax2.axis('equal')
        if debug: ax3.axis('equal')


        inp[0::5] = ags
        inp[1::5] = agd
        inp[2::5] = vs_init
        inp[3::5] = vd_init
        inp[7::5] = vs
        inp[8::5] = vd
        inp[4::5] = np.array(tracks)[10:50, 2] - theta # heading
        out[:60] = oss
        out[60:] = od
        c = np.dstack((cs, cd))[0]
        traj_inp = np.hstack((agx, agy)).flatten()
        traj_out = np.hstack((ox, oy)).flatten()
        traj_c = np.hstack((cx, cy)).flatten()
        if debug: ax2.plot(agx, agy, color="blue", linewidth=3, zorder=3)
        if debug: ax2.plot(ox, oy, color="orange", linewidth=3, zorder=3)                
        b_inp = np.array([agx[-1], agy[-1], vx[-1], vy[-1], 0, 0])
        return torch.tensor(inp).double(), torch.tensor(out).double(), torch.tensor(traj_inp).double(), torch.tensor(traj_out).double(), torch.tensor(b_inp).double(), torch.tensor(c).double(), torch.tensor(traj_c).double()