refs: https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation-batched.ipynb
https://github.com/budzianowski/PyTorch-Beam-Search-Decoding/blob/master/decode_beam.py

In [1]:
import os
import pandas as pd
import numpy as np
from collections import defaultdict
from tqdm.auto import tqdm
import sys
sys.path.insert(0, '../scripts')
from map_traffic_lights_data import * #master_intersection_idx_2_tl_signal_indices, get_lane_point_coordinates, get_lane_len, lane_id_2_idx#get_lane_center_line
# early stopping source: https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py
from pytorchtools import EarlyStopping
from datetime import timedelta
from typing import Dict, List
import torch 
from torch.autograd import Variable
from torch import Tensor
from torch.nn import functional
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch import nn, optim
import heapq
from dataclasses import dataclass, field
import time
import gc
from torchviz import make_dot
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True

HBox(children=(FloatProgress(value=0.0, description='Computing lane adjacency lists', max=528.0, style=Progres…




HBox(children=(FloatProgress(value=0.0, description='Computing lane adjacency lists', max=7977.0, style=Progre…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Lane blocked sets..', layout=Layout(wid…




In [2]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [3]:
MAP_SEGMENT_I = 4
# TRAIN_INPUT_PATHS = [f'../input/agent_lane_seq_df_validate_0.hdf5']#validate_0.hdf5']
# VAL_INPUT_PATHS = [f'../input/agent_lane_seq_df_sample_0.hdf5']
TRAIN_INPUT_PATHS = [f'../input/agent_lane_seq_df_validate_0_{MAP_SEGMENT_I}_preproc.hdf5']#validate_0.hdf5']
VAL_INPUT_PATHS = [f'../input/agent_lane_seq_df_sample_0_{MAP_SEGMENT_I}_preproc.hdf5']
VOCAB_PATH = f'../input/agent_lane_seq_df_validate_0_{MAP_SEGMENT_I}_preproc_vocab.pkl'
HIST_LEN_FRAMES = 100
FUTURE_LEN_FRAMES = 50
MIN_REQUIRED_TOKEN_FREQ = 4
AGENT_SPEED_MAX = 15.
AGENT_YAW_MAX = 2*np.pi

with open(os.path.join(SEGMENTS_OUTPUT_PATH, 'map_segment_2_lanes.pkl'), 'rb') as f:
    map_segment_2_lanes = pickle.load(f)
MAP_SEGMENT_2_X_MIN = [float('inf') for _ in range(len(map_segment_2_lanes))]
MAP_SEGMENT_2_X_MAX = [-float('inf') for _ in range(len(map_segment_2_lanes))]
MAP_SEGMENT_2_Y_MIN = [float('inf') for _ in range(len(map_segment_2_lanes))]
MAP_SEGMENT_2_Y_MAX = [-float('inf') for _ in range(len(map_segment_2_lanes))]
for map_segment in range(len(map_segment_2_lanes)):
    for lane_id in map_segment_2_lanes[map_segment]:
        lane_center = get_lane_center_line(lane_id)
        x_min, x_max = lane_center[:, 0].min(), lane_center[:, 0].max()
        y_min, y_max = lane_center[:, 1].min(), lane_center[:, 1].max()
        if x_min < MAP_SEGMENT_2_X_MIN[map_segment]:
            MAP_SEGMENT_2_X_MIN[map_segment] = x_min
        if y_min < MAP_SEGMENT_2_Y_MIN[map_segment]:
            MAP_SEGMENT_2_Y_MIN[map_segment] = y_min
        if x_max > MAP_SEGMENT_2_X_MAX[map_segment]:
            MAP_SEGMENT_2_X_MAX[map_segment] = x_max
        if y_max > MAP_SEGMENT_2_Y_MAX[map_segment]:
            MAP_SEGMENT_2_Y_MAX[map_segment] = y_max
MAX_DIST_DIFF_M = 10

trn_data_need_preprocessing = False
val_data_need_preprocessing = False
store_preprocessed_results = False

trn_output_name = ''
val_output_name = ''

In [4]:
# agent_lane_df_trn = pd.concat([pd.read_hdf(path, key='data') for path in TRAIN_INPUT_PATHS])
# agent_lane_df_val = pd.concat([pd.read_hdf(path, key='data') for path in VAL_INPUT_PATHS])

agent_lane_df_trn = pd.read_hdf(TRAIN_INPUT_PATHS[0], key='data')
agent_lane_df_val = pd.read_hdf(VAL_INPUT_PATHS[0], key='data')


def filter_map_segment(agent_lane_df, map_segment_i=MAP_SEGMENT_I):
    agent_scene_2_last_map_segment = agent_lane_df.groupby(['agent_track_id', 'scene_idx'])['map_segment_group'].nth(-1)
    map_segment_selection_ = agent_scene_2_last_map_segment[agent_scene_2_last_map_segment == map_segment_i].reset_index()
    track_id__scene_idx__set = set((map_segment_selection_['agent_track_id'].map(lambda x: [x]) +
                                    map_segment_selection_['scene_idx'].map(lambda x: [x])).map(tuple).values)
    track_id__scene_idx__series = (agent_lane_df['agent_track_id'].map(lambda x: [x]) +
                                   agent_lane_df['scene_idx'].map(lambda x: [x])).map(tuple)
    agent_lane_df = agent_lane_df[track_id__scene_idx__series.map(lambda x: x in track_id__scene_idx__set)]
    return agent_lane_df

# def filter_map_segment(agent_lane_df, map_segment_i=MAP_SEGMENT_I):
#     agent_lane_df = agent_lane_df[agent_lane_df['map_segment_group'] == map_segment_i]
#     return agent_lane_df

if trn_data_need_preprocessing:
    agent_lane_df_trn = filter_map_segment(agent_lane_df_trn)[['lane_id', 'map_segment_group', 'lane_point_i', 
                            'agent_speed', 'agent_yaw', 'agent_centroid_shift', 'agent_track_id', 'scene_idx', 'timestamp']]

if val_data_need_preprocessing:
    agent_lane_df_val = filter_map_segment(agent_lane_df_val)[['lane_id', 'map_segment_group', 'lane_point_i',
                        'agent_speed', 'agent_yaw', 'agent_centroid_shift', 'agent_track_id', 'scene_idx', 'timestamp']]

In [5]:
# pd.set_option('display.max_rows', 1000)
# agent_lane_df_trn.head(500)

In [6]:
# temp_ = (agent_lane_df_trn.loc[agent_lane_df_trn['the_same_agent_prev'], 'timestamp']
#          .diff(1)
#          .map(lambda x: x < timedelta(seconds=0.11))
#          .value_counts())
# temp_/temp_.sum()

## True     0.984327
## False    0.015673

In [7]:
# _temp = agent_lane_df_trn.groupby(['agent_track_id', 'scene_idx'])['map_segment_group'].nunique().value_counts()
# _temp/_temp.sum()

# # 1    0.967818
# # 2    0.032179
# # 3    0.000002

In [8]:
def compute_valid_hist_seq_len(agent_lane_df):
    valid_hist_len_list = []
    is_the_same_agent_prev = agent_lane_df['the_same_agent_prev'].values
    for row_i in tqdm(range(len(agent_lane_df)), desc='Last hist valid...'):
        last_valid_idx = row_i
        while (is_the_same_agent_prev[last_valid_idx] and
               row_i - last_valid_idx + 1 < HIST_LEN_FRAMES):
            last_valid_idx -= 1
        valid_hist_len_list.append(row_i - last_valid_idx)
    agent_lane_df['valid_hist_len'] = valid_hist_len_list 
     

def compute_valid_future_seq_len(agent_lane_df):
    is_the_same_agent_next = agent_lane_df['the_same_agent_next'].values
    valid_future_len_list = []
    for row_i in tqdm(range(len(agent_lane_df)), desc='Last future valid...'):
        last_valid_idx = row_i + 1
        if not is_the_same_agent_next[row_i]:
            valid_future_len_list.append(0)
            continue
        while (last_valid_idx + 1 < len(is_the_same_agent_next) and
               is_the_same_agent_next[last_valid_idx] and
               last_valid_idx - row_i < FUTURE_LEN_FRAMES):
            last_valid_idx += 1
        valid_future_len_list.append(last_valid_idx - row_i)
    agent_lane_df['valid_future_len'] = valid_future_len_list
    
if trn_data_need_preprocessing:
    lanes_trn = agent_lane_df_trn['lane_id'].unique()
    lane_2_count = agent_lane_df_trn.groupby(['lane_id'])['lane_id'].count()
    infrequent_lane_points = set(lane_2_count[lane_2_count < MIN_REQUIRED_TOKEN_FREQ].index)
    map_segment_2_train_vocab = defaultdict(dict)
    map_segment_2_vocab_i_2_lane = defaultdict(list)

    for lane_id in lanes_trn:
        if lane_id not in infrequent_lane_points:
            map_segment_2_train_vocab[MAP_SEGMENT_I][lane_id] = len(map_segment_2_train_vocab[MAP_SEGMENT_I])
            map_segment_2_vocab_i_2_lane[MAP_SEGMENT_I].append(lane_id)
    del infrequent_lane_points
else:
    with open(VOCAB_PATH, 'rb') as f:
        map_segment_2_train_vocab = pickle.load(f)

# mapping lane_point_id to vocab idx
def compute_vocab_indices(agent_lane_df, map_segment_idx=MAP_SEGMENT_I):
    def get_vocab_idx(lane_id):
        if lane_id in map_segment_2_train_vocab[map_segment_idx]:
            return map_segment_2_train_vocab[map_segment_idx][lane_id]
        # unknown token
        return len(map_segment_2_train_vocab[map_segment_idx])
    agent_lane_df['lane_vocab_idx'] = agent_lane_df['lane_id'].map(lambda lane_id: get_vocab_idx(lane_id))
                                                                              
    
# computing true agent coordinates
def compute_agent_coord(agent_lane_df):
    agent_lane_df['agent_coord'] = ((agent_lane_df['lane_id'].map(lambda x: [x]) +
                                     agent_lane_df['lane_point_i'].map(lambda x: [x])).map(lambda x: get_lane_point_coordinates(*x)) +
                                    agent_lane_df['agent_centroid_shift'])
    
    
def get_dist_to_lane_end_in_sampled_points(lane_id, lane_point_idx):
    lane_len_points = get_lane_len(lane_id)
    return lane_len_points - lane_point_idx - 1


def compute_agent_lane_features(agent_lane_df):
    agent_lane_df['points_to_lane_end'] = (agent_lane_df['lane_id'].map(lambda x: [x]) +
                                     agent_lane_df['lane_point_i'].map(lambda x: [x])).map(lambda x: get_dist_to_lane_end_in_sampled_points(*x))


def estimate_agent_speed(agent_lane_df):
    coord__coord_next__speed__is_the_same_series = (agent_lane_df['agent_coord'].map(lambda x: [x]) + 
                                                    agent_lane_df['agent_coord'].shift(-1).map(lambda x: [x]) + 
                                                    agent_lane_df['agent_speed'].map(lambda x: [x]) + 
                                                    agent_lane_df['the_same_agent_next'].map(lambda x: [x]))

    def derive_speed_abs(coords, coords_next, speed, is_next_the_same):
        if is_next_the_same:
            coords_diff = coords_next - coords
            return np.hypot(*coords_diff)*10
        return speed
    
    agent_lane_df['agent_speed_derived'] = coord__coord_next__speed__is_the_same_series.map(lambda x: derive_speed_abs(*x))
    

def preprocess_agent_lane_df(agent_lane_df):
#     agent_lane_df.sort_values(by=['scene_idx', 'agent_track_id', 'timestamp'], inplace=True)
    agent_lane_df['the_same_agent_prev'] = ((agent_lane_df['agent_track_id'].shift(1) == agent_lane_df['agent_track_id']) & 
                                           (agent_lane_df['scene_idx'].shift(1) == agent_lane_df['scene_idx']) & 
                                           (agent_lane_df['timestamp'].diff(1) < timedelta(seconds=0.11))) # &
#                                            (agent_lane_df['map_segment_group'].shift(1) == agent_lane_df['map_segment_group']))
    agent_lane_df['the_same_agent_next'] = ((agent_lane_df['agent_track_id'].shift(-1) == agent_lane_df['agent_track_id']) & 
                                           (agent_lane_df['scene_idx'].shift(-1) == agent_lane_df['scene_idx']) & 
                                           (agent_lane_df['timestamp'].diff(1).shift(-1) < timedelta(seconds=0.11)))# &
#                                            (agent_lane_df['map_segment_group'].shift(-1) == agent_lane_df['map_segment_group']))
    
    compute_valid_hist_seq_len(agent_lane_df)
    compute_valid_future_seq_len(agent_lane_df)
    compute_vocab_indices(agent_lane_df)
    compute_agent_coord(agent_lane_df)
    compute_agent_lane_features(agent_lane_df)
#     estimate_agent_speed(agent_lane_df)
    
    
def store_precomputed_df(agent_lane_df, input_paths, output_name, store_vocab=False):
    if len(TRAIN_INPUT_PATHS) == 1:
        output_path = os.path.splitext(input_paths[0])[0]+ f'_{MAP_SEGMENT_I}_preproc.hdf5'
    elif output_name != '':
        output_path = os.path.join('../input', f'{output_name}.hdf5')
    else:
        time_str = time.strftime('%Y%m%d_%H%M%S')
        output_path = os.path.join('../input', f'_{"_".join([os.path.splitext(x)[0] for x in input_paths])}_{MAP_SEGMENT_I}_preprocessed_{time_str}.hdf5')
    print('output_path', output_path)
    agent_lane_df.to_hdf(output_path, key='data')
    if store_vocab:
        vocab_path = os.path.splitext(output_path)[0] + '_vocab.pkl'
        with open(vocab_path, 'wb') as f:
            pickle.dump(map_segment_2_train_vocab, f)
                                   
                                   
if trn_data_need_preprocessing:
    preprocess_agent_lane_df(agent_lane_df_trn)
    if store_preprocessed_results:
        store_precomputed_df(agent_lane_df_trn, TRAIN_INPUT_PATHS, trn_output_name, store_vocab=True)

if val_data_need_preprocessing:
    preprocess_agent_lane_df(agent_lane_df_val)
    if store_preprocessed_results:
        store_precomputed_df(agent_lane_df_val, VAL_INPUT_PATHS, val_output_name)

In [9]:
# check_df = agent_lane_df_trn[(agent_lane_df_trn['agent_track_id'] == 18) & (agent_lane_df_trn['scene_idx'] == 7170)]
# check_df['next_coord'] = check_df[['agent_coord', 'agent_yaw', 'agent_speed']].apply(lambda x: [x['agent_coord'][0] + np.cos(x['agent_yaw'])*x['agent_speed']/10,
#                                                                         x['agent_coord'][1] + np.sin(x['agent_yaw'])*x['agent_speed']/10], axis=1)

In [10]:
# check_df

In [11]:
# agent_lane_df_trn[(agent_lane_df_trn['agent_track_id'] == 18) & (agent_lane_df_trn['scene_idx'] == 7170)]

In [12]:
agent_lane_df_trn['points_to_lane_end'].max(), agent_lane_df_trn['points_to_lane_end'].quantile(0.99)

(147, 24.0)

In [13]:
from torch.autograd import Variable

class LaneSeqModel(nn.Module):

    def __init__(self, vocab,
                 embedding_dim=16,
                 hidden_dim_lane_lstm=64,
                 hidden_dim_speed_lstm=64,
                 hidden_dim_coord_lstm=64,
                 n_layers_lane_lstm=1,
                 n_layers_speed_lstm=1,
                 n_layers_coord_lstm=1,
                 bidirectional_speed_lstm=False,
                 map_segment=4,
                 dropout=0.2,
                 max_agent_points_to_lane_end=agent_lane_df_trn['points_to_lane_end'].max(),
                 device='cuda:0',
                 speed_max=AGENT_SPEED_MAX,
                 yaw_max=AGENT_YAW_MAX,
                 input_hist_max_len=HIST_LEN_FRAMES,
                 prediction_horizon_steps=FUTURE_LEN_FRAMES,
                 max_dist_diff_m=MAX_DIST_DIFF_M,
                 agent_speed_max=AGENT_SPEED_MAX,
                 agent_yaw_max=AGENT_YAW_MAX,
                 beam_width=3):

        self.device = device
        self.vocab_size = len(vocab)

        # Constructor
        super().__init__()

        # embedding layer
        self.lane_embedding = nn.Embedding(self.vocab_size + 2,
                                           embedding_dim)  # including PAD and UNKNOWN tokens, even though PAD must be redundant

        # lstm for lanes seq
        self.lstm_lanes = nn.LSTM(embedding_dim + 3,
                            hidden_dim_lane_lstm,
                            num_layers=n_layers_lane_lstm,
                            bidirectional=False,
                            dropout=dropout,
                            batch_first=True)

        # fc layer to estimate next lane probs
        self.fc_next_lane_logits = nn.Linear(hidden_dim_lane_lstm, self.vocab_size + 2)
        self.fc_next_lane_logits.bias.data.fill_(0.0)
        # fc layer to estimate next dist to lane end
        self.fc_next_dist_to_lane_end = nn.Linear(hidden_dim_lane_lstm, 1)
        
        # lane point coordinates
        lane_point_coordinates = torch.zeros((self.vocab_size + 2, max_agent_points_to_lane_end, 2)).to(device)
        for lane_id, vocab_i in vocab.items():
            for points_to_lane_end in range(max_agent_points_to_lane_end):
                lane_len = get_lane_len(lane_id)
                lane_point_i = max(0, lane_len - 1 - points_to_lane_end)
                point_coords = get_lane_point_coordinates(lane_id, lane_point_i)
                lane_point_coordinates[vocab_i, points_to_lane_end, 0] = point_coords[0]
                lane_point_coordinates[vocab_i, points_to_lane_end, 1] = point_coords[1]
#         lane_point_coordinates_np = np.zeros((self.vocab_size + 2, max_agent_points_to_lane_end, 2))
#         for lane_id, vocab_i in vocab.items():
#             for points_to_lane_end in range(max_agent_points_to_lane_end):
#                 lane_len = get_lane_len(lane_id)
#                 lane_point_i = max(0, lane_len - 1 - points_to_lane_end)
#                 lane_point_coordinates_np[vocab_i, points_to_lane_end] = get_lane_point_coordinates(lane_id, lane_point_i)
#         lane_point_coordinates = Variable(torch.Tensor(lane_point_coordinates_np), requires_grad=False).to(device)
        self.lane_point_x_coord = torch.squeeze(lane_point_coordinates[:, :, 0])
        self.lane_point_y_coord = torch.squeeze(lane_point_coordinates[:, :, 1])

        # lane_point probs activation
        self.softmax_act = nn.Softmax(dim=-1)
        
        # lstm for speed/angle seq
        self.lstm_speed = nn.LSTM(2,
                            hidden_dim_speed_lstm,
                            num_layers=n_layers_speed_lstm,
                            bidirectional=bidirectional_speed_lstm,
                            dropout=dropout,
                            batch_first=True)
        
        # fc layer to estimate next speed abs val
        self.fc_next_speed_yaw = nn.Linear(hidden_dim_speed_lstm * (1 + int(bidirectional_speed_lstm)), 2)
        self.bidirectional_speed_lstm = bidirectional_speed_lstm
        
        # lstm for true coord
        self.lstm_coord = nn.LSTM(7,
                            hidden_dim_coord_lstm,
                            num_layers=n_layers_coord_lstm,
                            bidirectional=False,
                            dropout=dropout,
                            batch_first=True)
        
#         # fc layer to enable model to ignore lane-based coord estimation (e.g. for cars far from lane)
#         self.fc_is_lane_coord_relevant = nn.Linear(6, 1)
#         self.fc_is_lane_coord_relevant.bias.data.fill_(-1)
#         self.fc_is_const_accel_coord_relevant = nn.Linear(6, 1)
#         self.fc_is_const_accel_coord_relevant.bias.data.fill_(0)
#         self.fc_is_const_turn_coord_relevant = nn.Linear(6, 1)
#         self.fc_is_const_turn_coord_relevant.bias.data.fill_(0)
#         self.sigmoid_act = nn.Sigmoid()
        
        # fc to estimate offsets from fixed lane_point coordinates
        # TODO: consider weight decay for combination of coord estimations: https://discuss.pytorch.org/t/simple-l2-regularization/139/2
#         self.fc_coord = nn.Linear(hidden_dim_coord_lstm + hidden_dim_lane_lstm, 2)
#         fc_head_num_features = (hidden_dim_coord_lstm + 
#                                 hidden_dim_lane_lstm + 
#                                 hidden_dim_speed_lstm * (1 + int(bidirectional_speed_lstm)))
        fc_head_num_features = hidden_dim_coord_lstm
        self.bn_1 = nn.BatchNorm1d(fc_head_num_features)
        self.fc_coord_hidden = nn.Linear(fc_head_num_features, 32) # + 
#                                          hidden_dim_lane_lstm + 
#                                          hidden_dim_speed_lstm * (1 + int(bidirectional_speed_lstm)), 64)
        self.fc_coord_hidden_act = nn.ReLU()
        self.bn_2 = nn.BatchNorm1d(32)
        self.fc_coord = nn.Linear(32, 2)
#         self.fc_coord = nn.Linear(6 + 
#                                   hidden_dim_lane_lstm + 
#                                   hidden_dim_speed_lstm * (1 + int(bidirectional_speed_lstm)), 2)

        self.prediction_horizon_steps = prediction_horizon_steps
        self.input_hist_max_len = input_hist_max_len
        
        # normalization consts
        self.max_dist_diff_m = max_dist_diff_m
        self.x_coord_min = MAP_SEGMENT_2_X_MIN[map_segment]
        self.x_coord_max = MAP_SEGMENT_2_X_MAX[map_segment]
        self.y_coord_min = MAP_SEGMENT_2_Y_MIN[map_segment]
        self.y_coord_max = MAP_SEGMENT_2_Y_MAX[map_segment]
        self.fc_coord.bias.data = torch.Tensor([self.x_coord_min + (self.x_coord_max-self.x_coord_min)/2,
                                                   self.y_coord_min + (self.y_coord_max-self.y_coord_min)/2])
        self.speed_max = speed_max
        self.yaw_max = yaw_max
        self.max_agent_points_to_lane_end = max_agent_points_to_lane_end
        self.agent_speed_max = agent_speed_max
        self.agent_yaw_max = agent_yaw_max

        # params for beam search
        self.beam_width = beam_width
        
    def normalize_coordinates(self, coord_seq):
        # coord seq shape = [batch_size, max_seq_len, 2]
        max_min_range_x = (self.x_coord_max - self.x_coord_min)*torch.ones((coord_seq.shape[0], coord_seq.shape[1], 1))
        max_min_range_y = (self.y_coord_max - self.y_coord_min)*torch.ones((coord_seq.shape[0], coord_seq.shape[1], 1))
        max_min_range = torch.cat((max_min_range_x, max_min_range_y), dim=-1).to(self.device)
        
        min_x = self.x_coord_min*torch.ones((coord_seq.shape[0], coord_seq.shape[1], 1))
        min_y = self.y_coord_min*torch.ones((coord_seq.shape[0], coord_seq.shape[1], 1))
        min_vals = torch.cat((min_x, min_y), dim=-1).to(self.device)
        
        return (coord_seq - min_vals) / max_min_range
    
    def normalize_speed_yaw(self, speed_yaw_seq):
        # coord seq shape = [batch_size, max_seq_len, 2]
        max_speed_tensor = self.speed_max * torch.ones((speed_yaw_seq.shape[0], speed_yaw_seq.shape[1], 1))
        max_yaw_tensor = self.yaw_max * torch.ones((speed_yaw_seq.shape[0], speed_yaw_seq.shape[1], 1))
        max_vals_tensor = torch.cat((max_speed_tensor, max_yaw_tensor), dim=-1).to(self.device)
        return speed_yaw_seq / max_vals_tensor
    
#     def mask_unreliable_coord_estimates():
        
        

    def forward(self, lanes_seq, points_to_lane_end_seq, speed_seq, yaw_seq, coord_seq, seq_lengths, 
                beam=False, return_prediction_on_input=False, use_coord_selfpredictions=False):
        # inputs are normalized in the dataset getter, except for coord_seq
        
        ############## lane seq ################
        # lane_points_seq shape: [batch size, max_lane_seq_len]
        lane_embedded = self.lane_embedding(lanes_seq)
        # lane_embedded shape: [batch size, max_lane_seq_len, emb dim]

        # adding centroid_shift_seq, speed_seq, yaw_seq
        speed_yaw_input_batch = self.normalize_speed_yaw(torch.cat((torch.unsqueeze(speed_seq, 2),
                                                                    torch.unsqueeze(yaw_seq, 2)), dim=2))
        lanes_input_batch = torch.cat((lane_embedded,
                                       torch.unsqueeze(points_to_lane_end_seq, 2)/self.max_agent_points_to_lane_end,
                                       speed_yaw_input_batch), dim=2)
        # packed sequence
        packed_embedded_lanes_input_seq = nn.utils.rnn.pack_padded_sequence(lanes_input_batch, 
                                                                            seq_lengths, batch_first=True,
                                                                            enforce_sorted=False)

        lstm_output_lanes_, (hidden_lanes, cell_lanes) = self.lstm_lanes(packed_embedded_lanes_input_seq)
        # hidden shape = [num layers, batch size, lanes hid dim]
        # output shape = [seq_len, batch size, lanes hid dim]
        
        
        lstm_output_lanes, _ = nn.utils.rnn.pad_packed_sequence(lstm_output_lanes_, 
                                                         batch_first=True, 
                                                         total_length=lanes_input_batch.shape[1])
        # predictions of next lane and dist to lane end (in sampled points count) are trained with teacher on input seq
        input_pred_next_lane = self.softmax_act(self.fc_next_lane_logits(lstm_output_lanes))
        # input_pred_next_lane shape = [batch size, vocab size (all lanes + unknown&pad)]
        input_pred_next_dist_to_lane_end = self.fc_next_dist_to_lane_end(lstm_output_lanes)
        
        ############## speed seq ################
        # packed sequence
        packed_embedded_speed_input_seq = nn.utils.rnn.pack_padded_sequence(speed_yaw_input_batch, 
                                                                            seq_lengths, batch_first=True,
                                                                            enforce_sorted=False)
        lstm_output_speed_, (hidden_speed, cell_speed) = self.lstm_speed(packed_embedded_speed_input_seq)
        # predictions on input seq (with teacher)
        lstm_output_speed, _ = nn.utils.rnn.pad_packed_sequence(lstm_output_speed_, 
                                                             batch_first=True, 
                                                             total_length=lanes_input_batch.shape[1])
        input_pred_next_speed_yaw = self.fc_next_speed_yaw(lstm_output_speed)
        
        ############## coord seq ################
        def combine_coord_estimates(coord_current, pred_next_speed_yaw, 
                                    pred_next_lane, pred_next_dist_to_lane_end,
                                    speed_yaw_current, small_angle_threshold=0.04
                                   ):
            # estimating delta_x, delta_y based on assumption of constant x/y acceleration
            # with teacher for input
            
            delta_x = (speed_yaw_current[:, :, 0]*torch.cos(speed_yaw_current[:, :, 1]) + 
                       pred_next_speed_yaw[:, :, 0]*torch.cos(pred_next_speed_yaw[:, :, 1]))/20
            delta_y = (speed_yaw_current[:, :, 0]*torch.sin(speed_yaw_current[:, :, 1]) + 
                       pred_next_speed_yaw[:, :, 0]*torch.sin(pred_next_speed_yaw[:, :, 1]))/20
            coord_estimation_const_accel = coord_current + torch.cat((torch.unsqueeze(delta_x, 2),
                                                                      torch.unsqueeze(delta_y, 2)), dim=2)
            

            point_x_coord_weighted = torch.matmul(pred_next_lane, self.lane_point_x_coord)
            point_y_coord_weighted = torch.matmul(pred_next_lane, self.lane_point_y_coord)
            
            pred_next_dist_to_lane_end_ceil = torch.clamp(torch.ceil(pred_next_dist_to_lane_end).long(), 0, self.max_agent_points_to_lane_end - 1)
            pred_next_dist_to_lane_end_floor = torch.clamp(torch.floor(pred_next_dist_to_lane_end).long(), 0, self.max_agent_points_to_lane_end - 1)
            pred_next_dist_to_lane_end_diff_to_ceil = pred_next_dist_to_lane_end - pred_next_dist_to_lane_end_ceil

            
            x_coord_weighted_ceil = torch.gather(point_x_coord_weighted, 
                                                 dim=2, 
                                                 index=pred_next_dist_to_lane_end_ceil)
            x_coord_weighted_floor = torch.gather(point_x_coord_weighted, 
                                                 dim=2, 
                                                 index=pred_next_dist_to_lane_end_floor)
            y_coord_weighted_ceil = torch.gather(point_y_coord_weighted, 
                                                 dim=2, 
                                                 index=pred_next_dist_to_lane_end_ceil)
            y_coord_weighted_floor = torch.gather(point_y_coord_weighted, 
                                                 dim=2, 
                                                 index=pred_next_dist_to_lane_end_floor)
            x_coord_weighted = (x_coord_weighted_ceil * (1 - pred_next_dist_to_lane_end_diff_to_ceil) + 
                                x_coord_weighted_floor * pred_next_dist_to_lane_end_diff_to_ceil)
            y_coord_weighted = (y_coord_weighted_ceil * (1 - pred_next_dist_to_lane_end_diff_to_ceil) + 
                                y_coord_weighted_floor * pred_next_dist_to_lane_end_diff_to_ceil)
            
            coord_estimation_lane_pred = torch.cat((x_coord_weighted, 
                                                    y_coord_weighted), dim=-1)
            
#             if coord_current.shape[1] == 1:
#                 print('coord_estimation_lane_pred', coord_estimation_lane_pred[0], 
#                       'coord_current', coord_current[0], 'delta_x', delta_x[0], 'delta_y', delta_y[0], 
#                       'speed x', speed_yaw_current[0, :, 0], 'yaw cos', torch.cos(speed_yaw_current[0, :, 1]))
            

            # estimating coord based on a single circle segment assumption
            delta_alpha = pred_next_speed_yaw[:, :, 1] - speed_yaw_current[:, :, 1]
            r = (speed_yaw_current[:, :, 0] + pred_next_speed_yaw[:, :, 0]) / (20*delta_alpha)
            delta_x_turn = r * torch.cos(delta_alpha)
            delta_y_turn = r * torch.sin(delta_alpha)
            coord_estimation_const_turn = coord_current + torch.cat((torch.unsqueeze(delta_x_turn, 2),
                                                                     torch.unsqueeze(delta_y_turn, 2)), dim=2)
            
            min_x = self.x_coord_min*torch.ones((coord_estimation_const_turn.shape[0], coord_estimation_const_turn.shape[1], 1))
            min_y = self.y_coord_min*torch.ones((coord_estimation_const_turn.shape[0], coord_estimation_const_turn.shape[1], 1))
            min_vals = torch.cat((min_x, min_y), dim=-1).to(self.device)
            no_turn_datapoints_bool = (delta_alpha < small_angle_threshold).unsqueeze(dim=-1).repeat(1, 1, 2)
            coord_estimation_const_turn = torch.where(no_turn_datapoints_bool, 
                                                      min_vals,
                                                      coord_estimation_const_turn
                                                     )
            const_turn_estimations_present = (delta_alpha >= small_angle_threshold).unsqueeze(dim=-1).float()
            coord_estimation_input_batch = torch.cat((self.normalize_coordinates(coord_estimation_const_accel),
                                                      self.normalize_coordinates(coord_estimation_lane_pred),
                                                      self.normalize_coordinates(coord_estimation_const_turn),
                                                      const_turn_estimations_present
                                                     ), dim=-1)
            
            coord_estimation_from_speed = (coord_estimation_const_accel + torch.where(no_turn_datapoints_bool, 
                                                                                      torch.zeros_like(coord_estimation_const_turn),
                                                                                      coord_estimation_const_turn
                                                     ))/(torch.ones_like(coord_estimation_const_accel) + const_turn_estimations_present)
#             print('coord_estimation_input_batch', coord_estimation_input_batch[0])
#             print('coord_estimation_const_accel', torch.sum(torch.isnan(coord_estimation_const_accel)), torch.sum(torch.isinf(coord_estimation_const_accel)), 
#                   'coord_estimation_const_turn', torch.sum(torch.isnan(coord_estimation_const_turn)), torch.sum(torch.isinf(coord_estimation_const_turn)), 
#                   'coord_estimation_input_batch', torch.sum(torch.isnan(coord_estimation_input_batch)), torch.sum(torch.isinf(coord_estimation_input_batch)))
            
            # storing differences between lane-based estimation and vector-speed approaches
#             print(coord_estimation_const_accel)
#             print('-'*5)
#             print(coord_estimation_lane_pred)
#             print('='*20)
#             if len(coord_current.shape) == 3:
#                 print(torch.sqrt(torch.square(coord_current - coord_estimation_const_accel)).mean())
#                 print(torch.sum(torch.logical_not(no_turn_datapoints_bool)))
#                 print(torch.sqrt(torch.square(coord_current[torch.logical_not(no_turn_datapoints_bool)] - 
#                                               coord_estimation_const_turn[torch.logical_not(no_turn_datapoints_bool)])).mean())
#             lane_based_coord_diff_to_const_accel = torch.sqrt(torch.square(coord_estimation_lane_pred - coord_estimation_const_accel))
#             lane_based_coord_diff_to_const_turn = torch.sqrt(torch.square(coord_estimation_lane_pred - coord_estimation_const_turn))
#             const_accel_coord_diff_to_const_turn = torch.sqrt(torch.square(coord_estimation_const_accel - coord_estimation_const_turn))
#             coord_estimation_discrepancies = torch.cat((lane_based_coord_diff_to_const_accel,
#                                                       lane_based_coord_diff_to_const_turn,
#                                                       const_accel_coord_diff_to_const_turn), dim=-1)/self.max_dist_diff_m
            
            return coord_estimation_input_batch, coord_estimation_lane_pred, coord_estimation_from_speed #torch.cat((coord_estimation_input_batch, coord_estimation_discrepancies), dim=-1)
        
        coord_estimation_input_batch, input_coord_estimation_lane_pred, input_coord_estimation_from_speed = combine_coord_estimates(coord_seq, input_pred_next_speed_yaw, 
                                                                               input_pred_next_lane, 
                                                                               input_pred_next_dist_to_lane_end, 
                                                                               speed_yaw_input_batch)
        
        packed_embedded_coord_estimation_input_seq = nn.utils.rnn.pack_padded_sequence(coord_estimation_input_batch, 
                                                                                       seq_lengths, batch_first=True,
                                                                                       enforce_sorted=False)
        lstm_output_coord_, (hidden_coord, cell_coord) = self.lstm_coord(packed_embedded_coord_estimation_input_seq)
        def predict_coord(coord_estimation_input):
            permute_for_bn = len(coord_estimation_input.shape) == 3
            if permute_for_bn:
                x = self.bn_1(coord_estimation_input.permute(0, 2, 1)).permute(0, 2, 1)
            else:
                x = self.bn_1(coord_estimation_input)
            x = self.fc_coord_hidden_act(self.fc_coord_hidden(x))
            if permute_for_bn:
                x = self.bn_2(x.permute(0, 2, 1)).permute(0, 2, 1)
            else:
                x = self.bn_2(x)
            return self.fc_coord(x)
        
        if return_prediction_on_input:
            # predictions with teacher on input seq
            lstm_output_coord, _ = nn.utils.rnn.pad_packed_sequence(lstm_output_coord_, 
                                                                 batch_first=True, 
                                                                 total_length=lanes_input_batch.shape[1])
#             input_pred_next_coord = self.fc_coord(torch.cat((lstm_output_coord, lstm_output_lanes), dim=-1))
#             print('final input', torch.cat((coord_estimation_input_batch, 
#                                                              lstm_output_lanes,
#                                                              lstm_output_speed
#                                                             ), dim=-1)[0])
#             input_pred_next_coord = self.fc_coord(self.fc_coord_hidden(torch.cat((coord_estimation_input_batch, 
#                                                              lstm_output_lanes,
#                                                              lstm_output_speed
#                                                             ), dim=-1)))

            
            input_pred_next_coord = predict_coord(lstm_output_coord)
                
#             input_pred_next_coord = predict_coord(torch.cat((lstm_output_coord, 
#                                                              lstm_output_lanes,
#                                                              lstm_output_speed
#                                                             ), dim=-1)) # coord_estimation_input_batch)
        
        ############## storing last vals for consequent pred ##############
        speed_yaw_current = torch.gather(speed_yaw_input_batch,
                                      dim=1,
                                      index=seq_lengths.unsqueeze(dim=-1).unsqueeze(dim=-1).repeat(1, 1, 2) - 1)
        coord_current = torch.gather(coord_seq,
                                  dim=1,
                                  index=seq_lengths.unsqueeze(dim=-1).unsqueeze(dim=-1).repeat(1, 1, 2) - 1)
        
        coord_estimation_from_speed_current = torch.gather(input_coord_estimation_from_speed,
                                  dim=1,
                                  index=seq_lengths.unsqueeze(dim=-1).unsqueeze(dim=-1).repeat(1, 1, 2) - 1)
#         print('coord_estimation_from_speed_current', coord_estimation_from_speed_current, 'coord_current', coord_current)
        
#         print('nans coord_current before loop', torch.sum(torch.isnan(coord_current)))
        
        coord_estimation_lane_pred = torch.gather(input_coord_estimation_lane_pred,
                                  dim=1,
                                  index=seq_lengths.unsqueeze(dim=-1).unsqueeze(dim=-1).repeat(1, 1, 2) - 1)
        coord_estimation_from_speed = torch.gather(input_coord_estimation_from_speed,
                                  dim=1,
                                  index=seq_lengths.unsqueeze(dim=-1).unsqueeze(dim=-1).repeat(1, 1, 2) - 1)
        coord_estimation_input = torch.gather(coord_estimation_input_batch,
                                  dim=1,
                                  index=seq_lengths.unsqueeze(dim=-1).unsqueeze(dim=-1).repeat(1, 1, 
                                                                                               coord_estimation_input_batch.shape[-1]) - 1)
        
        ############## future predictions ################
        batch_size = lanes_seq.size(0)
        pred_lane_prob = torch.zeros((batch_size, self.prediction_horizon_steps, self.vocab_size + 2)).to(
            device)
        pred_points_to_lane_end_seq = torch.zeros((batch_size, self.prediction_horizon_steps)).to(
            device)
        pred_speed = torch.zeros((batch_size, self.prediction_horizon_steps)).to(device)
        pred_yaw = torch.zeros((batch_size, self.prediction_horizon_steps)).to(device)
        pred_coordinates = torch.zeros((batch_size, self.prediction_horizon_steps, 2)).to(
            device)
        pred_coordinates_from_lanes = torch.zeros((batch_size, self.prediction_horizon_steps, 2)).to(
            device)
        pred_coordinates_from_speed = torch.zeros((batch_size, self.prediction_horizon_steps, 2)).to(
            device)

        for prediction_step in range(self.prediction_horizon_steps):            
            ############## lane seq ################
            hidden_lanes_last = hidden_lanes[-1, :, :]
            pred_next_lane = self.softmax_act(self.fc_next_lane_logits(hidden_lanes_last))
            pred_next_dist_to_lane_end = self.fc_next_dist_to_lane_end(hidden_lanes_last)
            
            ############## speed seq ################
            if self.bidirectional_speed_lstm:
                hidden_speed_last = torch.cat((hidden_speed[-2,:,:], hidden_speed[-1,:,:]), dim = 1) 
            else:
                hidden_speed_last = hidden_speed[-1,:,:]
            pred_next_speed_yaw = self.fc_next_speed_yaw(hidden_speed_last)
            
            ############## coord seq ################
            hidden_coord_last = hidden_coord[-1, :, :]
#             pred_next_coord = self.fc_coord(torch.cat((hidden_coord_last, hidden_lanes_last), dim=-1))
#             print(coord_estimation.shape, hidden_lanes_last.shape, hidden_speed_last.shape)
#             pred_next_coord = predict_coord(torch.cat((hidden_coord_last, 
#                                                        hidden_lanes_last,
#                                                        hidden_speed_last
#                                                             ), dim=-1)) #coord_estimation_input)
            pred_next_coord = predict_coord(hidden_coord_last)
            
            

#             pred_next_coord = self.fc_coord(self.fc_coord_hidden(torch.cat((coord_estimation, hidden_lanes_last, hidden_speed_last), dim=-1)))

            
            ###################################################
            ############## storing predictions ################
            pred_lane_prob[:, prediction_step, :] = pred_next_lane
            pred_points_to_lane_end_seq[:, prediction_step:prediction_step + 1] = pred_next_dist_to_lane_end
            pred_speed[:, prediction_step] = pred_next_speed_yaw[:, 0]
            pred_yaw[:, prediction_step] = pred_next_speed_yaw[:, 1]
            pred_coordinates[:, prediction_step, :] = torch.squeeze(pred_next_coord) 
            pred_coordinates_from_lanes[:, prediction_step, :] = torch.squeeze(coord_estimation_lane_pred)
            pred_coordinates_from_speed[:, prediction_step, :] = torch.squeeze(coord_estimation_from_speed)
            
            ##########################################
            ########## next lstm step ################
            
            # lane lstm
            pred_next_lane = torch.unsqueeze(pred_next_lane, dim=1)
            lanes_pred = torch.argmax(pred_next_lane, dim=-1)
            embedded_lanes_pred = self.lane_embedding(lanes_pred)
            pred_next_speed_yaw = torch.unsqueeze(pred_next_speed_yaw, dim=1)
            pred_next_speed_yaw_scaled = self.normalize_speed_yaw(pred_next_speed_yaw)
            pred_next_dist_to_lane_end = torch.unsqueeze(pred_next_dist_to_lane_end, dim=1)
            lanes_lstm_input = torch.cat((embedded_lanes_pred,
                                          pred_next_dist_to_lane_end/self.max_agent_points_to_lane_end,
                                          pred_next_speed_yaw_scaled), dim=2)
            _, (hidden_lanes, cell_lanes) = self.lstm_lanes(lanes_lstm_input, 
                                                            (hidden_lanes, cell_lanes))
            
            # speed lstm
            _, (hidden_speed, cell_speed) = self.lstm_speed(pred_next_speed_yaw_scaled, 
                                                            (hidden_speed, cell_speed))
            
            # coord lstm
            (coord_estimation_input, 
             coord_estimation_lane_pred, 
             coord_estimation_from_speed) = combine_coord_estimates(coord_current if use_coord_selfpredictions else coord_estimation_from_speed_current, 
                                                                    pred_next_speed_yaw,
                                                       pred_next_lane, pred_next_dist_to_lane_end,
                                                       speed_yaw_current)
            _, (hidden_coord, cell_coord) = self.lstm_coord(coord_estimation_input, 
                                                            (hidden_coord, cell_coord))
            
            # storing last vals for the next lstm step
            coord_current = torch.unsqueeze(pred_next_coord, dim=1)
            coord_estimation_from_speed_current = coord_estimation_from_speed
            speed_yaw_current = pred_next_speed_yaw

        if return_prediction_on_input:
            return (input_pred_next_lane, input_pred_next_dist_to_lane_end, input_pred_next_speed_yaw, 
                    input_pred_next_coord, input_coord_estimation_lane_pred, input_coord_estimation_from_speed,
                    pred_lane_prob, pred_points_to_lane_end_seq, pred_speed, pred_yaw, 
                    pred_coordinates, pred_coordinates_from_lanes, pred_coordinates_from_speed)
        return pred_lane_prob, pred_points_to_lane_end_seq, pred_speed, pred_yaw, pred_coordinates, pred_coordinates_from_lanes, pred_coordinates_from_speed

In [14]:
lane_seq_model = LaneSeqModel(map_segment_2_train_vocab[MAP_SEGMENT_I])

In [15]:
print(lane_seq_model)

LaneSeqModel(
  (lane_embedding): Embedding(259, 16)
  (lstm_lanes): LSTM(19, 64, batch_first=True, dropout=0.2)
  (fc_next_lane_logits): Linear(in_features=64, out_features=259, bias=True)
  (fc_next_dist_to_lane_end): Linear(in_features=64, out_features=1, bias=True)
  (softmax_act): Softmax(dim=-1)
  (lstm_speed): LSTM(2, 64, batch_first=True, dropout=0.2)
  (fc_next_speed_yaw): Linear(in_features=64, out_features=2, bias=True)
  (lstm_coord): LSTM(7, 64, batch_first=True, dropout=0.2)
  (bn_1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc_coord_hidden): Linear(in_features=64, out_features=32, bias=True)
  (fc_coord_hidden_act): ReLU()
  (bn_2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc_coord): Linear(in_features=32, out_features=2, bias=True)
)


In [16]:
class LaneSeqDataset(Dataset):
    def __init__(self, 
                 agent_lane_df: pd.DataFrame,
                 valid_indices: np.array,
                 train_vocab_size: int,
                 history_len: int = HIST_LEN_FRAMES,
                 future_len: int = FUTURE_LEN_FRAMES
                ):
        self.history_len = history_len
        self.future_len = future_len
        self.valid_indices = valid_indices
        self.PAD_TOKEN_IDX = train_vocab_size + 1
        self.valid_hist_len = agent_lane_df['valid_hist_len'].values
        self.valid_future_len = agent_lane_df['valid_future_len'].values
        self.lane_vocab_idx = agent_lane_df['lane_vocab_idx'].values
        self.agent_speed = agent_lane_df['agent_speed'].values
        self.agent_yaw = agent_lane_df['agent_yaw'].values
        self.agent_coord = agent_lane_df['agent_coord'].values
        self.points_to_lane_end = agent_lane_df['points_to_lane_end'].values
        
    def __len__(self):
        return len(self.valid_indices)
    
    def __getitem__(self, index: int):
        row_i = self.valid_indices[index]                        
        valid_hist_len = self.valid_hist_len[row_i] 
        
        lanes_seq_ = self.lane_vocab_idx[row_i - valid_hist_len + 1:row_i + 1]
        speed_seq_ = self.agent_speed[row_i - valid_hist_len + 1:row_i + 1]
        yaw_seq_ = self.agent_yaw[row_i - valid_hist_len + 1:row_i + 1]
        points_to_lane_end_ = self.points_to_lane_end[row_i - valid_hist_len + 1:row_i + 1]
        agent_coord_ = self.agent_coord[row_i - valid_hist_len + 1:row_i + 1]
                
        # padding
        padding_len = self.history_len - valid_hist_len
        lanes_seq = np.concatenate((lanes_seq_, self.PAD_TOKEN_IDX*np.ones(padding_len))).astype(np.int) # shouldn't get to PAD_TOKEN_IDX, just in case
        speed_seq = np.concatenate((speed_seq_, np.zeros(padding_len))).astype(np.float32)
        yaw_seq = np.concatenate((yaw_seq_, np.zeros(padding_len))).astype(np.float32)
        points_to_lane_end = np.concatenate((points_to_lane_end_, np.zeros(padding_len))).astype(np.float32)
        agent_coord = np.concatenate((np.vstack(agent_coord_), np.zeros((padding_len, 2)))).astype(np.float32)
        hist_availabilities = np.concatenate((np.ones(valid_hist_len), np.zeros(padding_len))).astype(np.float32)

        
        # targets  
        gt_coord_ = self.agent_coord[row_i + 1:row_i + self.valid_future_len[row_i] + 1]
        valid_future_len = len(gt_coord_)
        future_padding_len = self.future_len - valid_future_len
        gt_coord = np.concatenate((np.vstack(gt_coord_), np.zeros((future_padding_len, 2)))).astype(np.float32)
        gt_coord_availabilities = np.concatenate((np.ones(valid_future_len), np.zeros(future_padding_len))).astype(np.float32)
        gt_lanes_seq_ = self.lane_vocab_idx[row_i + 1:row_i + valid_future_len + 1]
        gt_lanes_seq = np.concatenate((gt_lanes_seq_, self.PAD_TOKEN_IDX*np.ones(future_padding_len))).astype(np.int)
        gt_speed_ = self.agent_speed[row_i + 1:row_i + valid_future_len + 1]
        gt_speed = np.concatenate((gt_speed_, np.zeros(future_padding_len))).astype(np.float32)
        gt_yaw_ = self.agent_yaw[row_i + 1:row_i + valid_future_len + 1]
        gt_yaw = np.concatenate((gt_yaw_, np.zeros(future_padding_len))).astype(np.float32)  
        gt_points_to_lane_end_ = self.points_to_lane_end[row_i + 1:row_i + self.valid_future_len[row_i] + 1]
        gt_points_to_lane_end = np.concatenate((gt_points_to_lane_end_, np.zeros(future_padding_len))).astype(np.float32)  
        
        return (lanes_seq, points_to_lane_end, speed_seq, yaw_seq, agent_coord, valid_hist_len, hist_availabilities, 
                gt_lanes_seq, gt_points_to_lane_end, valid_future_len, gt_coord, gt_speed, gt_yaw, gt_coord_availabilities)

In [17]:
def get_valid_indices(agent_lane_df, min_hist_len=10): # min_hist_len in addition to the current timestamp

    has_enough_hist = agent_lane_df['the_same_agent_prev'].rolling(min_hist_len).sum() == min_hist_len
    
    valid_rows_bool = (agent_lane_df['the_same_agent_next'] & has_enough_hist)
    return np.arange(len(agent_lane_df))[valid_rows_bool]


def get_dataloader(agent_lane_df, map_segment_idx, shuffle=True, batch_size=1024, num_workers=7):
    agent_lane_df_map_segment = agent_lane_df[agent_lane_df['map_segment_group'] == map_segment_idx]
    valid_indices_map_segment = get_valid_indices(agent_lane_df_map_segment)
    required_columns = ['valid_hist_len', 'valid_future_len', 'lane_vocab_idx', 
                        'agent_speed', 'agent_yaw', 'agent_coord', 'points_to_lane_end']
    dataset = LaneSeqDataset(agent_lane_df_map_segment[required_columns],
                             valid_indices_map_segment,
                             len(map_segment_2_train_vocab[map_segment_idx]))
    dataloader = DataLoader(dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers)
    return dataloader
          
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [18]:
# map_segment_idx = 4
# agent_lane_df = agent_lane_df_val
# agent_lane_df_map_segment = agent_lane_df[agent_lane_df['map_segment_group'] == map_segment_idx]
# valid_indices_map_segment = get_valid_indices(agent_lane_df_map_segment)
# required_columns = ['valid_hist_len', 'valid_future_len', 'lane_vocab_idx', 
#                     'agent_speed', 'agent_yaw', 'agent_coord', 'points_to_lane_end']
# dataset = LaneSeqDataset(agent_lane_df_map_segment[required_columns],
#                          valid_indices_map_segment,
#                          len(map_segment_2_train_vocab[map_segment_idx]))

In [19]:
#https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/masked_cross_entropy.py
import torch
from torch.nn import functional
from torch.autograd import Variable

def sequence_mask(sequence_length, max_len=None):
    if max_len is None:
        max_len = sequence_length.data.max()
    batch_size = sequence_length.size(0)
    seq_range = torch.range(0, max_len - 1).long()
    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
    seq_range_expand = Variable(seq_range_expand)
    if sequence_length.is_cuda:
        seq_range_expand = seq_range_expand.cuda()
    seq_length_expand = (sequence_length.unsqueeze(1)
                         .expand_as(seq_range_expand))
    return seq_range_expand < seq_length_expand


def masked_cross_entropy(probs, target, length):
#     length = Variable(torch.LongTensor(length)).cuda()

    """
    Args:
        logits: A Variable containing a FloatTensor of size
            (batch, max_len, num_classes) which contains the
            unnormalized probability for each class.
        target: A Variable containing a LongTensor of size
            (batch, max_len) which contains the index of the true
            class for each corresponding step.
        length: A Variable containing a LongTensor of size (batch,)
            which contains the length of each data in a batch.
    Returns:
        loss: An average loss value masked by the length.
    """

    # logits_flat: (batch * max_len, num_classes)
#     logits_flat = logits.view(-1, logits.size(-1))
    # log_probs_flat: (batch * max_len, num_classes)
    log_probs_flat = torch.log(probs).view(-1, probs.size(-1))#functional.log_softmax(logits_flat)
    # target_flat: (batch * max_len, 1)
    target_flat = target.view(-1, 1)
    # losses_flat: (batch * max_len, 1)
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
    # losses: (batch, max_len)
    losses = losses_flat.view(*target.size())
    # mask: (batch, max_len)
    mask = sequence_mask(sequence_length=length, max_len=target.size(1))
    losses = losses * mask.float()
    loss = losses.sum() / length.float().sum()
    return loss

In [22]:
map_segment_idx = MAP_SEGMENT_I
dataloader_trn = get_dataloader(agent_lane_df_trn, map_segment_idx=map_segment_idx, 
                                batch_size=700, num_workers=16)
dataloader_val = get_dataloader(agent_lane_df_val, map_segment_idx=map_segment_idx, shuffle=False, batch_size=256)
lane_seq_model = LaneSeqModel(map_segment_2_train_vocab[map_segment_idx], 
                              embedding_dim=128,
                              hidden_dim_lane_lstm=64,
                             hidden_dim_speed_lstm=64,
                             hidden_dim_coord_lstm=64,
                              bidirectional_speed_lstm=False,
                             n_layers_lane_lstm=2,
                             n_layers_speed_lstm=2,
                             n_layers_coord_lstm=1,
                             map_segment=4,
                             dropout=0,
                              device=device).to(device)

#https://discuss.pytorch.org/t/how-the-pytorch-freeze-network-in-some-layers-only-the-rest-of-the-training/7088/15?u=raman_samusevich
def dfs_freeze(model, names_to_exclude=set()):
    for name, child in model.named_children():
        if name not in names_to_exclude:
            for param in child.parameters():
                param.requires_grad = False
            dfs_freeze(child)
            
def dfs_unfreeze(model, names_to_exclude=set()):
    for name, child in model.named_children():
        if name not in names_to_exclude:
            for param in child.parameters():
                param.requires_grad = True
            dfs_unfreeze(child)

In [None]:
# def train(lane_seq_model, dataloader_trn, dataloader_val, epochs_max=5):
epochs_max=5
progress_bar = tqdm(dataloader_trn)
iterator = iter(progress_bar)
epoch = 0
early_stopping = EarlyStopping(patience=7, verbose=True, path=f'map_segment_{map_segment_idx}_combined_loss_checkpoint.pt')

mse=nn.MSELoss(reduction="none")
clip_value=2

losses_train = []
losses_ce = []
losses_neg_loglik = []
losses_mse_points_to_lane_end = []
losses_mse_speed = []
losses_mse_yaw = []

############# helping funcs ###################
###############################################
def process_next_train_iters(losses_weights, optimizer, iters_count=None, epochs_max=None, use_coord_selfpredictions=False):
    global progress_bar, iterator, losses_train, losses_ce, losses_neg_loglik, losses_mse_points_to_lane_end, losses_mse_speed, losses_mse_yaw, epoch
    assert (iters_count is None) != (epochs_max is None)
    if epochs_max is not None:
        iter_incr = 0
        iters_count = 2
    else:
        iter_incr = 1
    i = 0
    while i < iters_count:
        i += iter_incr
        batch = next(iterator, None)
        if batch is None:
            print('Epoch end')
            epoch += 1
            print(f"Ep. {epoch}, Avg train loss: {np.mean(losses_train):.5f} (ce: {np.mean(losses_ce):.5f}, neg_loglik: {np.mean(losses_neg_loglik):.1f})\nloss_mse_points_to_lane_end: {np.mean(losses_mse_points_to_lane_end):.0f}, mse_speed: {np.mean(losses_mse_speed):.2f}, mse_yaw: {np.mean(losses_mse_yaw):.2f}")
            eval_on_val()
            if early_stopping.early_stop:
                break            
            if epochs_max is not None and epoch >= epochs_max:
                break
            progress_bar = tqdm(dataloader_trn)
            iterator = iter(progress_bar)
            batch = next(iterator, None)
            losses_train = []
            losses_ce = []
            losses_neg_loglik = []
            losses_mse_points_to_lane_end = []
            losses_mse_speed = []
            losses_mse_yaw = []

        # moving to GPU if available
        (lanes_seq, points_to_lane_end, speed_seq, yaw_seq, agent_coord, valid_hist_len, hist_availabilities,
         gt_lanes_seq, gt_points_to_lane_end, gt_valid_future_len, gt_coord, gt_speed, 
         gt_yaw, gt_coord_availabilities) = [x.to(device) for x in batch]
#         print('input nans', torch.sum(torch.isnan(agent_coord)))
        lane_seq_model.train()
        torch.set_grad_enabled(True)
        (input_pred_next_lane_prob, input_pred_next_dist_to_lane_end, 
         input_pred_next_speed_yaw, 
         input_pred_next_coord, input_coord_estimation_lane_pred, input_pred_coordinates_from_speed,
         pred_lane_prob, pred_points_to_lane_end_seq, 
         pred_speed, pred_yaw, 
         pred_coordinates, pred_coordinates_from_lanes, pred_coordinates_from_speed) = lane_seq_model(lanes_seq, points_to_lane_end, speed_seq,
                                                                  yaw_seq, agent_coord, valid_hist_len, return_prediction_on_input=True,
                                                                                                      use_coord_selfpredictions=use_coord_selfpredictions)

        loss_ce_with_teacher = masked_cross_entropy(input_pred_next_lane_prob, lanes_seq, valid_hist_len)
        loss_ce = masked_cross_entropy(pred_lane_prob, gt_lanes_seq, gt_valid_future_len)

        loss_neg_loglik_with_teacher = (0.5*mse(input_pred_next_coord, agent_coord).sum(dim=-1)*hist_availabilities).sum()/hist_availabilities.sum()
        loss_neg_loglik = (0.5*mse(pred_coordinates, gt_coord).sum(dim=-1)*gt_coord_availabilities).sum()/gt_coord_availabilities.sum()

        loss_neg_loglik_with_teacher_based_on_lanes = (0.5*mse(input_coord_estimation_lane_pred, agent_coord).sum(dim=-1)*hist_availabilities).sum()/hist_availabilities.sum()
        loss_neg_loglik_based_on_lanes = (0.5*mse(pred_coordinates_from_lanes, gt_coord).sum(dim=-1)*gt_coord_availabilities).sum()/gt_coord_availabilities.sum()

        loss_neg_loglik_with_teacher_from_speed = (0.5*mse(input_pred_coordinates_from_speed, agent_coord).sum(dim=-1)*hist_availabilities).sum()/hist_availabilities.sum()
        loss_neg_loglik_from_speed = (0.5*mse(pred_coordinates_from_speed, gt_coord).sum(dim=-1)*gt_coord_availabilities).sum()/gt_coord_availabilities.sum()

        loss_mse_speed_with_teacher = (mse(input_pred_next_speed_yaw[:, :, 0], speed_seq)*hist_availabilities).sum()/hist_availabilities.sum()
        loss_mse_speed = (mse(pred_speed, gt_speed)*gt_coord_availabilities).sum()/gt_coord_availabilities.sum()

        loss_mse_yaw_with_teacher = (mse(input_pred_next_speed_yaw[:, :, 1], yaw_seq)*hist_availabilities).sum()/hist_availabilities.sum()
        loss_mse_yaw = (mse(pred_yaw, gt_yaw)*gt_coord_availabilities).sum()/gt_coord_availabilities.sum()

        loss_mse_points_to_lane_end_teacher = (mse(input_pred_next_dist_to_lane_end.squeeze(), points_to_lane_end)*hist_availabilities).sum()/hist_availabilities.sum()
        loss_mse_points_to_lane_end = (mse(pred_points_to_lane_end_seq, gt_points_to_lane_end)*gt_coord_availabilities).sum()/gt_coord_availabilities.sum()

#         loss = (loss_ce_with_teacher + loss_ce + (loss_mse_speed_with_teacher + 
#                 loss_mse_speed)/5 + loss_mse_yaw_with_teacher +  loss_mse_yaw + 
#                 (loss_mse_points_to_lane_end_teacher +  loss_mse_points_to_lane_end)/20)
#             + 
#                     loss_neg_loglik_with_teacher_based_on_lanes, 
#                            loss_neg_loglik_based_on_lanes, 
#                            loss_neg_loglik_with_teacher_from_speed
        losses_list = [loss_ce_with_teacher, 
                       loss_ce,
                       loss_mse_speed_with_teacher, 
                       loss_mse_speed,
                       loss_mse_yaw_with_teacher, 
                       loss_mse_yaw,
                       loss_mse_points_to_lane_end_teacher, 
                       loss_mse_points_to_lane_end, 
                       loss_neg_loglik_with_teacher_based_on_lanes, 
                       loss_neg_loglik_based_on_lanes, 
                       loss_neg_loglik_with_teacher_from_speed,
                       loss_neg_loglik_from_speed, 
                       loss_neg_loglik_with_teacher,
                       loss_neg_loglik]
#             print('losses_list', losses_list)
        loss = sum([loss_ * weight_ 
                    for loss_, weight_ in zip(losses_list, losses_weights) 
                    if weight_ > 0
                   ])
    
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
#             print('clip_value', clip_value)
        nn.utils.clip_grad_norm_(lane_seq_model.parameters(), clip_value)
        optimizer.step()

        losses_train.append(loss.item())
        losses_ce.append(loss_ce.item())
        losses_neg_loglik.append(loss_neg_loglik.item())            
        losses_mse_points_to_lane_end.append(loss_mse_points_to_lane_end.item())
        losses_mse_speed.append(loss_mse_speed.item())
        losses_mse_yaw.append(loss_mse_yaw.item())
        progress_bar.set_description(f"E{epoch},{loss.item():.2f} ({loss_ce.item():.2f},{loss_neg_loglik.item():.1f}/{loss_neg_loglik_based_on_lanes.item():.0f}|{loss_neg_loglik_from_speed.item(): .0f}|{loss_neg_loglik_with_teacher_from_speed.item():.0f}){loss_mse_points_to_lane_end.item():.0f},{loss_mse_speed.item():.2f},{loss_mse_yaw.item():.2f}")


def eval_on_val():
    lane_seq_model.eval()

    losses_ce = []
    losses_neg_loglik = []
    losses_neg_loglik_based_on_lanes = []
    losses_neg_loglik_from_speed = []
    losses_mse_points_to_lane_end = []
    losses_mse_speed = []
    losses_mse_yaw = []
    for batch in tqdm(dataloader_val):

        # moving to GPU if available
        (lanes_seq, points_to_lane_end, speed_seq, yaw_seq, agent_coord, valid_hist_len, hist_availabilities,
         gt_lanes_seq, gt_points_to_lane_end, gt_valid_future_len, gt_coord, gt_speed, 
         gt_yaw, gt_coord_availabilities) = [x.to(device) for x in batch]

        (pred_lane_prob, pred_points_to_lane_end_seq, 
         pred_speed, pred_yaw, 
         pred_coordinates, pred_coordinates_from_lanes, pred_coordinates_from_speed) = lane_seq_model(lanes_seq, points_to_lane_end, speed_seq,
                                                                  yaw_seq, agent_coord, valid_hist_len, return_prediction_on_input=False)

        loss_ce = masked_cross_entropy(pred_lane_prob, gt_lanes_seq, gt_valid_future_len)
        loss_neg_loglik = (0.5*mse(pred_coordinates, gt_coord).sum(dim=-1)*gt_coord_availabilities).sum()/gt_coord_availabilities.sum()
        loss_neg_loglik_based_on_lanes = (0.5*mse(pred_coordinates_from_lanes, gt_coord).sum(dim=-1)*gt_coord_availabilities).sum()/gt_coord_availabilities.sum()
        loss_neg_loglik_from_speed = (0.5*mse(pred_coordinates_from_speed, gt_coord).sum(dim=-1)*gt_coord_availabilities).sum()/gt_coord_availabilities.sum()
        loss_mse_speed = (mse(pred_speed, gt_speed)*gt_coord_availabilities).sum()/gt_coord_availabilities.sum()
        loss_mse_yaw = (mse(pred_yaw, gt_yaw)*gt_coord_availabilities).sum()/gt_coord_availabilities.sum()
        loss_mse_points_to_lane_end = (mse(pred_points_to_lane_end_seq, gt_points_to_lane_end)*gt_coord_availabilities).sum()/gt_coord_availabilities.sum()

        losses_ce.append(loss_ce.item())
        losses_neg_loglik.append(loss_neg_loglik.item()) 
        losses_neg_loglik_based_on_lanes.append(loss_neg_loglik_based_on_lanes.item()) 
        losses_neg_loglik_from_speed.append(loss_neg_loglik_from_speed.item()) 
        losses_mse_points_to_lane_end.append(loss_mse_points_to_lane_end.item())
        losses_mse_speed.append(loss_mse_speed.item())
        losses_mse_yaw.append(loss_mse_yaw.item())
        progress_bar.set_description(f"({loss_ce.item():.2f},{loss_neg_loglik.item():.1f}/{loss_neg_loglik_based_on_lanes.item():.0f}|{loss_neg_loglik_from_speed.item(): .0f}){loss_mse_points_to_lane_end.item():.0f},{loss_mse_speed.item():.2f},{loss_mse_yaw.item():.2f}")
    print(f"Ep. {epoch}, Avg val losses. CE: {np.mean(losses_ce):.5f}, neg_loglik: {np.mean(losses_neg_loglik):.1f}), neg_loglik lanes: {np.mean(losses_neg_loglik_based_on_lanes):.1f}), neg_loglik speed: {np.mean(losses_neg_loglik_from_speed):.1f})\nloss_mse_points_to_lane_end: {np.mean(losses_mse_points_to_lane_end):.0f}, mse_speed: {np.mean(losses_mse_speed):.2f}, mse_yaw: {np.mean(losses_mse_yaw):.2f}")
    lr_scheduler.step(np.mean(losses_neg_loglik))
    early_stopping(np.mean(losses_neg_loglik), lane_seq_model) 

#############################################
############# learning schedule #############
blitzinit_loss_weights = torch.ones(14)
blitzinit_loss_weights[-6:] = 0
# speed
blitzinit_loss_weights[2:4] /= 5
# lane to end
blitzinit_loss_weights[6:8] /= 20
blitzinit_batches_count = 150
blitzinit_lr = 3e-2

init_no_coord_loss_weights = torch.ones(14)
init_no_coord_loss_weights[-6:] = 0
init_no_coord_batches_count = 100
init_no_coord_lr = 1e-2

aggressive_coord_init_loss_weights = torch.zeros(14)
aggressive_coord_init_loss_weights[-2:] = 1
aggressive_coord_init_batches_count = 150
aggressive_coord_init_lr = 3e-2

final_loss_weights = torch.ones(14)/30
final_loss_weights[:2] = 5
final_loss_weights[-6:-4] /= 100
final_loss_weights[-2:] = 1
lr_after_init = 3e-4


# blitzinit
optimizer = optim.Adam(lane_seq_model.parameters(), lr=blitzinit_lr)
process_next_train_iters(blitzinit_loss_weights, optimizer, iters_count=blitzinit_batches_count)
print('Blitzinit finished')

# further init without coord optimization
optimizer = optim.Adam(lane_seq_model.parameters(), lr=init_no_coord_lr)
process_next_train_iters(init_no_coord_loss_weights, optimizer, iters_count=init_no_coord_batches_count)
print('Init without coord-opt finished')

# aggressive coord-layers init
dfs_freeze(lane_seq_model, {'fc_coord', 'bn_1', 'fc_coord_hidden', 'bn_2', 'lstm_coord'})   
optimizer = optim.Adam(filter(lambda p: p.requires_grad, lane_seq_model.parameters()), lr=aggressive_coord_init_lr)
process_next_train_iters(aggressive_coord_init_loss_weights, optimizer, iters_count=aggressive_coord_init_batches_count)
print('Aggressive coord-related init finished')

dfs_unfreeze(lane_seq_model)

def clip_and_replace_explosures(grad):
    grad[torch.logical_or(torch.isnan(grad), torch.isinf(grad))] = torch.tensor(0.0).to(device)
#     grad = torch.clamp(grad, -0.25, 0.25)
    return grad

for param in lane_seq_model.parameters():
    if param.requires_grad:
        param.register_hook(clip_and_replace_explosures)
                
optimizer = optim.Adam(lane_seq_model.parameters(), lr=lr_after_init)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=3)
eval_on_val()
process_next_train_iters(final_loss_weights, optimizer, epochs_max=epochs_max, use_coord_selfpredictions=True)
    
# train(lane_seq_model, dataloader_trn, dataloader_val)

HBox(children=(FloatProgress(value=0.0, max=1506.0), HTML(value='')))

Blitzinit finished
Init without coord-opt finished
Aggressive coord-related init finished


HBox(children=(FloatProgress(value=0.0, max=14.0), HTML(value='')))


Ep. 0, Avg val losses. CE: 0.59440, neg_loglik: 30.4), neg_loglik lanes: 83.9), neg_loglik speed: 35.3)
loss_mse_points_to_lane_end: 10, mse_speed: 1.60, mse_yaw: 0.61
Validation loss decreased (inf --> 30.377873).  Saving model ...

















Epoch end
Ep. 1, Avg train loss: 816.74895 (ce: 0.90981, neg_loglik: 3667.6)
loss_mse_points_to_lane_end: 31, mse_speed: 2.57, mse_yaw: 0.53


HBox(children=(FloatProgress(value=0.0, max=14.0), HTML(value='')))


Ep. 1, Avg val losses. CE: 0.78720, neg_loglik: 20.1), neg_loglik lanes: 73.5), neg_loglik speed: 35.3)
loss_mse_points_to_lane_end: 32, mse_speed: 1.60, mse_yaw: 0.61
Validation loss decreased (30.377873 --> 20.123980).  Saving model ...


HBox(children=(FloatProgress(value=0.0, max=1506.0), HTML(value='')))