# base text classification model: https://www.analyticsvidhya.com/blog/2020/01/first-text-classification-in-pytorch/
# tte model:
https://github.com/ragulpr/wtte-rnn/blob/master/examples/keras/standalone_simple_example.ipynb
https://stackoverflow.com/questions/50196212/what-is-the-state-of-the-art-way-of-doing-regression-with-probability-in-pytorch

In [1]:
import os
import pandas as pd
import numpy as np
from collections import defaultdict
from tqdm.auto import tqdm
import sys
sys.path.insert(0, '../scripts')
from map_traffic_lights_data import master_intersection_idx_2_tl_signal_indices
# early stopping source: https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py
from pytorchtools import EarlyStopping
from datetime import timedelta
from typing import Dict, List
import torch 
from torch.autograd import Variable
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch import nn, optim
from torch.distributions import Weibull
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True

HBox(children=(FloatProgress(value=0.0, description='Computing lane adjacency lists', max=528.0, style=Progres…




HBox(children=(FloatProgress(value=0.0, description='Computing lane adjacency lists', max=7977.0, style=Progre…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Lane blocked sets..', layout=Layout(wid…




In [2]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [3]:
TRAIN_INPUT_PATHS = ['../input/tl_events_df_val.hdf5']
VAL_INPUT_PATH = '../input/tl_events_df_sample.hdf5'
HIST_LEN_FRAMES = 100

joined_train_name = ''

In [4]:
tl_events_df_trn = pd.concat([pd.read_hdf(path, key='data') for path in TRAIN_INPUT_PATHS])
tl_events_df_val = pd.read_hdf(VAL_INPUT_PATH, key='data')

if 'continuous_time' not in tl_events_df_trn.columns:
    tl_events_df_trn['continuous_time'] = ((tl_events_df_trn['timestamp'].diff(1) < timedelta(seconds=0.31)) &  # observed 0.3 sec jumps (assuming it's between consec. scenes)
                                           (tl_events_df_trn['master_intersection_idx'].shift(1) == tl_events_df_trn['master_intersection_idx']))
if 'continuous_time' not in tl_events_df_val.columns:
    tl_events_df_val['continuous_time'] = ((tl_events_df_val['timestamp'].diff(1) < timedelta(seconds=0.31)) &
                                       (tl_events_df_val['master_intersection_idx'].shift(1) == tl_events_df_val['master_intersection_idx']))


def compute_last_valid_idx_for_seq(tl_events_df):
    tl_events_df['valid_hist_len'] = -1
    for row_i in tqdm(range(len(tl_events_df)), desc='Last valid...'):
        last_valid_idx = row_i
        while (row_i - last_valid_idx + 1 < HIST_LEN_FRAMES and 
               tl_events_df['continuous_time'].iloc[last_valid_idx]):
            last_valid_idx -= 1
        tl_events_df['valid_hist_len'].iloc[row_i] = row_i - last_valid_idx
        
if 'valid_hist_len' not in tl_events_df_trn.columns:
    compute_last_valid_idx_for_seq(tl_events_df_trn)
    if len(TRAIN_INPUT_PATHS) == 1:
        tl_events_df_trn.to_hdf(TRAIN_INPUT_PATHS[0], key='data')
    elif joined_train_name != '':
        tl_events_df_trn.to_hdf(os.path.join('../input', joined_train_name), key='data')
    else:
        print('Warn! Not storing the precomputed results!')
if 'valid_hist_len' not in tl_events_df_val.columns:
    compute_last_valid_idx_for_seq(tl_events_df_val)
    


HBox(children=(FloatProgress(value=0.0, description='Last valid...', max=1494704.0, style=ProgressStyle(descri…


Warn! Not storing the precomputed results!


HBox(children=(FloatProgress(value=0.0, description='Last valid...', max=8897.0, style=ProgressStyle(descripti…




In [5]:
intersection_2_train_vocab = defaultdict(dict)
intersection_2_term_freq = defaultdict(dict)

for intersection_idx in tqdm(tl_events_df_trn['master_intersection_idx'].unique()):
    intersection_related_inputs = tl_events_df_trn.loc[tl_events_df_trn['master_intersection_idx'] == intersection_idx,
                                                       'rnn_inputs_raw'].values
    for rnn_input_raw in intersection_related_inputs:
        for token, _ in rnn_input_raw:
            if token not in intersection_2_train_vocab[intersection_idx]:
                intersection_2_train_vocab[intersection_idx][token] = len(intersection_2_train_vocab[intersection_idx])
                intersection_2_term_freq[intersection_idx][token] = 1
            else:
                intersection_2_term_freq[intersection_idx][token] += 1

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




In [6]:
tl_events_df_trn.head()

Unnamed: 0,master_intersection_idx,timestamp,rnn_inputs_raw,tl_signal_classes,time_to_tl_change,continuous_time,last_valid_idx,valid_hist_len
542589,0,2019-11-21 11:28:36.903368-08:00,"[(NTTe, (0, 1, 0, 0))]","{8: 0, 9: 0}","{8: 5.01, 9: 5.01}",False,0,0
542590,0,2019-11-21 11:28:37.003465-08:00,"[(NTTe, (0, 1, 0, 0))]","{8: 0, 9: 0}","{8: 5.01, 9: 5.01}",True,0,1
542591,0,2019-11-21 11:28:37.103479-08:00,"[(NTTe, (0, 1, 0, 0))]","{8: 0, 9: 0}","{8: 5.01, 9: 5.01}",True,0,2
542592,0,2019-11-21 11:28:37.203406-08:00,"[(NTTe, (0, 1, 0, 0)), (/ggb, (0, 1, 0, 0))]","{8: 0, 9: 0}","{8: 5.01, 9: 5.01}",True,0,3
542593,0,2019-11-21 11:28:37.303268-08:00,"[(NTTe, (0, 1, 0, 0))]","{8: 0, 9: 0}","{8: 5.01, 9: 5.01}",True,0,4


In [7]:
tl_events_df_trn['tl_signal_classes'].map(len).value_counts()

7    282925
5    259391
1    241756
6    174061
2    145207
8    139277
4    123622
3     80357
0     24539
9     23569
Name: tl_signal_classes, dtype: int64

In [8]:
for intersection, vocab in intersection_2_train_vocab.items():
    print(f'intersection {intersection}, vocab size: {len(vocab)}')
    print(f'20 the least frequent terms: {sorted(intersection_2_term_freq[intersection].items(), key=lambda x: x[1])[:20]}')

intersection 0, vocab size: 61
20 the least frequent terms: [('H+dt_Qvqr_lane1', 1), ('H+dt_Qvqr_dPy4_lane1', 5), ('bAUT', 9), ('Pdk+_lane5', 10), ('bAUT_fxmM', 10), ('udk+_yPdC', 11), ('PrF8_lane37', 14), ('Pdk+_TPdC', 23), ('lane37_urF8', 42), ('H+dt_dPy4', 76), ('H+dt', 83), ('H+dt_lane1', 84), ('mFtj', 84), ('lane5', 101), ('0OdC_wck+', 124), ('Pdk+_lane5_udk+', 189), ('Qcp7_wqF8', 239), ('nwfo', 270), ('mVvp', 369), ('urF8', 515)]
intersection 1, vocab size: 89
20 the least frequent terms: [('9dFW_vb5l', 1), ('BE65_U3MV_gE65', 1), ('QLqW_pxWe_qpj8', 1), ('qpj8_vb5l', 2), ('U3MV_lane3', 2), ('l0b4', 2), ('NQvE_pxWe', 2), ('jHNg_v+ie', 2), ('ceFW_vLqW', 3), ('QLqW_lane25', 3), ('7je8_j19R_lane23', 3), ('CINg', 3), ('BE65_U3MV', 3), ('U3MV', 3), ('lane25_vb5l', 3), ('12MV_iD65', 4), ('9dFW_NQvE', 5), ('vb5l', 5), ('BE65_U3MV_lane3', 5), ('DQuy_cje8_j19R', 6)]
intersection 2, vocab size: 32
20 the least frequent terms: [('/Q/Y_cja6', 1), ('/Q/Y_xFWI', 1), ('XmBT', 4), ('/Q/Y_eR/Y_xFWI

In [9]:
master_intersection_idx_2_tl_signal_indices

defaultdict(list,
            {0: [0, 8, 9, 56, 57, 58, 59],
             3: [1, 14, 22, 28, 37, 38, 39],
             8: [2, 11, 32, 34, 60],
             1: [3, 12, 20, 24, 30, 33, 54],
             6: [4, 6, 41, 42, 43],
             4: [5, 13, 23, 31, 49, 50, 51, 61, 62],
             7: [7, 15, 44, 45, 46, 48, 52, 53],
             2: [10, 21, 36, 47, 55],
             9: [16, 18, 27, 40],
             5: [17, 19, 25, 26, 29, 35]})

In [10]:
class IntersectionModel(nn.Module):
    
    def __init__(self, vocab_size,
                 intersection_tl_signals,
                 embedding_dim=256, 
                 hidden_dim=64, 
                 n_layers=1, 
                 bidirectional=True, 
                 dropout=0,
                 device='cuda:0'
                ):
        
        #Constructor
        super().__init__()          
        
        #embedding layer
        self.embedding = nn.Embedding(vocab_size + 2, embedding_dim) # including PAD and UNKNOWN tokens
        
        #lstm layer
        self.bidirectional = bidirectional
        self.lstm = nn.LSTM(embedding_dim + 5, 
                           hidden_dim, 
                           num_layers=n_layers, 
                           bidirectional=bidirectional, 
                           dropout=dropout,
                           batch_first=True)
        
        # tl color classifier (0 -> red, 1 -> green)
        lstm_hidden_dim = hidden_dim * 2 if self.bidirectional else hidden_dim
        
        for tl_idx in intersection_tl_signals:
            setattr(self, f'fc_color_{tl_idx}', nn.Linear(lstm_hidden_dim, 1))
            getattr(self, f'fc_color_{tl_idx}').bias.data.fill_(0.0)
            # Weibull params
            setattr(self, f'fc_tte_k_{tl_idx}', nn.Linear(lstm_hidden_dim, 1))
            getattr(self, f'fc_tte_k_{tl_idx}').bias.data.fill_(3.0)
            setattr(self, f'fc_tte_lambda_{tl_idx}', nn.Linear(lstm_hidden_dim, 1))
            getattr(self, f'fc_tte_lambda_{tl_idx}').bias.data.fill_(2.0)
             
        # attempt to address instability, might be related to the instability reported https://github.com/ragulpr/wtte-rnn/blob/master/examples/keras/standalone_simple_example.ipynb
        def clip_and_replace_explosures(grad):
            grad[torch.logical_or(torch.isnan(grad), torch.isinf(grad))] = torch.tensor(0.0).to(device)
            grad = torch.clamp(grad, -0.25, 0.25)
            return grad
        
        for param in self.parameters():
            if param.requires_grad:
                param.register_hook(clip_and_replace_explosures)
                
        self.intersection_tl_signals = intersection_tl_signals
        
        # classifier activation function
        self.color_act = nn.Sigmoid()
        
        # Weibull params activation function
        self.param_act = nn.Softplus()
        
    def forward(self, tokens_seq, token_type_ohe, token_timesteps, seq_lengths):
        
        #tokens_seq = [batch size,sent_length]
        embedded = self.embedding(tokens_seq)
        #embedded = [batch size, sent_len, emb dim]       
        
        # adding token type ohe and timestep
        embedded = torch.cat((embedded, token_type_ohe, torch.unsqueeze(token_timesteps, 2)), dim=2)
      
        #packed sequence
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, seq_lengths, batch_first=True, enforce_sorted = False)
        
        packed_output, (hidden, cell) = self.lstm(packed_embedded)
        #hidden shape = [num layers * num directions, batch size, hid dim]
        
        #concat the final forward and backward hidden state
        if self.bidirectional:
            hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1) 
        else:
            hidden = hidden[-1,:,:]
            
        #hidden = [batch size, hid dim * num directions]

        #Final activation function
        tl_2_color_class = {tl_id: self.color_act(getattr(self, f'fc_color_{tl_id}')(hidden)) 
                            for tl_id in self.intersection_tl_signals}
        
        tl_2_tte_distr = dict()                    
        for tl_id in self.intersection_tl_signals:
            weibull_k = self.param_act(getattr(self, f'fc_tte_k_{tl_id}')(hidden))
            weibull_lambda = self.param_act(getattr(self, f'fc_tte_lambda_{tl_id}')(hidden))
            tl_2_tte_distr[tl_id] = Weibull(weibull_lambda, weibull_k)
        
        return tl_2_color_class, tl_2_tte_distr
    
    
# """
# class IntersectionModel(nn.Module):
    
#     def __init__(self, vocab_size,
#                  intersection_tl_signals,
#                  is_cuda_available,
#                  embedding_dim=256, 
#                  hidden_dim=64, 
#                  n_layers=1, 
#                  bidirectional=True, 
#                  dropout=0):
        
#         #Constructor
#         super().__init__()          
        
#         #embedding layer
#         self.embedding = nn.Embedding(vocab_size + 2, embedding_dim) # including PAD and UNKNOWN tokens
        
#         #lstm layer
#         self.bidirectional = bidirectional
        
        
#         # tl color classifier (0 -> red, 1 -> green)
#         lstm_hidden_dim = hidden_dim * 2 if self.bidirectional else hidden_dim
#         for tl_idx in intersection_tl_signals:
#             setattr(self, f'lstm_{tl_idx}', nn.LSTM(embedding_dim + 5, 
#                                                    hidden_dim, 
#                                                    num_layers=n_layers, 
#                                                    bidirectional=bidirectional, 
#                                                    dropout=dropout,
#                                                    batch_first=True))
#             setattr(self, f'fc_color_{tl_idx}', nn.Linear(lstm_hidden_dim, 1))
#             getattr(self, f'fc_color_{tl_idx}').bias.data.fill_(0.0)
#             # Weibull params
#             setattr(self, f'fc_tte_k_{tl_idx}', nn.Linear(lstm_hidden_dim, 1))
#             setattr(self, f'fc_tte_lambda_{tl_idx}', nn.Linear(lstm_hidden_dim, 1))
#         self.intersection_tl_signals = intersection_tl_signals
        
#         # classifier activation function
#         self.color_act = nn.Sigmoid()
        
#         # Weibull params activation function
#         self.param_act = nn.Softplus()
        
#     def forward(self, tokens_seq, token_type_ohe, token_timesteps, seq_lengths):
        
#         #tokens_seq = [batch size,sent_length]
#         embedded = self.embedding(tokens_seq)
#         #embedded = [batch size, sent_len, emb dim]       
        
#         # adding token type ohe and timestep
#         embedded = torch.cat((embedded, token_type_ohe, torch.unsqueeze(token_timesteps, 2)), dim=2)
      
#         #packed sequence
#         packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, seq_lengths, batch_first=True, enforce_sorted = False)
        
        
#         tl_2_color_class = dict()
#         tl_2_tte_distr = dict()                    
#         for tl_id in self.intersection_tl_signals:
#             _, (hidden, _) = getattr(self, f'lstm_{tl_id}')(packed_embedded)
#             #hidden = [batch size, num layers * num directions,hid dim]
#             #cell = [batch size, num layers * num directions,hid dim]

#             #concat the final forward and backward hidden state
#             if self.bidirectional:
#                 hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1) 

#             #hidden = [batch size, hid dim * num directions]

#             #Final activation function
#             tl_2_color_class[tl_id] = self.color_act(getattr(self, f'fc_color_{tl_id}')(hidden))
        
#             weibull_k = self.param_act(getattr(self, f'fc_tte_k_{tl_id}')(hidden))
#             weibull_lambda = self.param_act(getattr(self, f'fc_tte_lambda_{tl_id}')(hidden))
#             tl_2_tte_distr[tl_id] = Weibull(weibull_lambda, weibull_k)
        
#         return tl_2_color_class, tl_2_tte_distr
# """

In [11]:
class IntersectionDataset(Dataset):
    def __init__(self, 
                 tl_events_df: pd.DataFrame,
                 valid_indices: np.array,
                 train_vocab: Dict,
                 term_freq: Dict,
                 tl_signal_indices: List,
                 history_len_records: int = HIST_LEN_FRAMES,
                 min_freq=5,
                 min_idx=20
                ):
        self.tl_events_df = tl_events_df
        max_events_per_timestamp = self.tl_events_df['rnn_inputs_raw'].map(len).max()
        self.history_events_max = history_len_records * max_events_per_timestamp
        self.history_len_records = history_len_records
        self.valid_indices = valid_indices
        self.vocab_term_2_idx = train_vocab
        self.vocab_term_2_freq = term_freq
        self.min_freq = min_freq
        self.UNKNOWN_TOKEN_IDX = len(self.vocab_term_2_idx)
        self.PAD_TOKEN_IDX = len(self.vocab_term_2_idx) + 1
        self.tl_signal_indices = tl_signal_indices
        
    def __len__(self):
        return len(self.valid_indices)
    
    def __getitem__(self, index: int):
        row_i = self.valid_indices[index]                        
        valid_hist_len = self.tl_events_df['valid_hist_len'].iloc[row_i]
        raw_inputs_hist = self.tl_events_df['rnn_inputs_raw'].iloc[row_i - valid_hist_len + 1:row_i + 1]
        
        tokens_list, token_type_ohe_list, token_timesteps_list = [], [], []
        timestap = valid_hist_len
        for timestep_events in raw_inputs_hist:
            for token, token_type_ohe in timestep_events:
                # zero-max normalization
                token_timesteps_list.append(timestap/self.history_len_records)
                token_idx = self.vocab_term_2_idx[token] if token in self.vocab_term_2_idx and self.vocab_term_2_freq[token] >= self.min_freq else self.UNKNOWN_TOKEN_IDX
                tokens_list.append(token_idx)
                token_type_ohe_list.append(token_type_ohe)               
            timestap -= 1
        
        seq_len = len(tokens_list)
        tokens_np = np.array(tokens_list)
        token_type_ohe_np = np.array(token_type_ohe_list)
        token_timesteps_np = np.array(token_timesteps_list)
        
        # padding
        padding_len = self.history_events_max - seq_len
        tokens_np = np.concatenate((tokens_np, self.PAD_TOKEN_IDX*np.ones(padding_len))).astype(np.int) # shouldn't get to PAD_TOKEN_IDX, but just in case
        token_type_ohe_np = np.concatenate((token_type_ohe_np, np.zeros((padding_len, 4)))).astype(np.float32)
        token_timesteps_np = np.concatenate((token_timesteps_np, np.zeros(padding_len))).astype(np.float32)
        
        known_true_classes = self.tl_events_df['tl_signal_classes'].iloc[row_i]
        known_tte = self.tl_events_df['time_to_tl_change'].iloc[row_i]
        
        all_true_classes = {tl_id: np.float32(known_true_classes[tl_id]) if tl_id in known_true_classes else np.float32(0) for tl_id in self.tl_signal_indices}
        all_tte = {tl_id: np.float32(known_tte[tl_id]) if tl_id in known_tte else np.float32(99.0) for tl_id in self.tl_signal_indices}
        classes_availabilities = {tl_id: np.float32(tl_id in known_true_classes) for tl_id in self.tl_signal_indices}
        tte_availabilities = {tl_id: np.float32(tl_id in known_tte) for tl_id in self.tl_signal_indices}
        return tokens_np, token_type_ohe_np, token_timesteps_np, seq_len, all_true_classes, all_tte, classes_availabilities, tte_availabilities

In [12]:
def get_valid_indices(tl_events_df, history_len_records=HIST_LEN_FRAMES):
    # pd rolling accepts numbers only, we need to process series of size-2 tuples (is_non_empty, is_time_continuous)
    # let's encode is_non_empty, is_time_continuous as the 1st and the 2nd bit of int

    is_non_empty_bit = 0
    is_time_continuous_bit = 1

    def encode_len_continuity(records_len, is_continuous):
        res_int = 0
        if records_len >= 1:
            res_int += 1 << is_non_empty_bit
        if is_continuous:
            res_int += 1 << is_time_continuous_bit
        return res_int

    def decode_len_continuity(num):
        is_non_empty = bool(num & (1 << is_non_empty_bit))
        is_continuous = bool(num & (1 << is_time_continuous_bit))
        return is_non_empty, is_continuous

    is_nonempty___is_continuious_series = ((tl_events_df['rnn_inputs_raw'].map(lambda x: [len(x)]) + 
                                           tl_events_df['continuous_time'].map(lambda x: [x]))
                                           .map(lambda x: encode_len_continuity(*x))
                                           .astype(np.int))

    def is_nonempty_input_present(hist):
        # returns 1 when there's a non-empy input
        for i in range(len(hist) -1, -1, -1):
            is_non_empty, is_time_continuous = decode_len_continuity(int(hist.iloc[i]))
            if not is_time_continuous:
                return 0
            if is_non_empty:
                return 1
        return 0

    is_nonempty_input_present_series = is_nonempty___is_continuious_series.rolling(history_len_records - 1).agg(is_nonempty_input_present)
    
    valid_rows_bool = (tl_events_df['tl_signal_classes'].map(lambda x: len(x) > 0) &
                       (is_nonempty_input_present_series == 1))
    return np.arange(len(tl_events_df))[valid_rows_bool]


def get_dataloader(tl_events_df, intersection_idx, shuffle=True, batch_size=1024, num_workers=12):
    tl_events_df_intersection = tl_events_df[tl_events_df['master_intersection_idx'] == intersection_idx]
    valid_indices_intersection = get_valid_indices(tl_events_df_intersection)
    dataset = IntersectionDataset(tl_events_df_intersection,
                                  valid_indices_intersection,
                                  intersection_2_train_vocab[intersection_idx],
                                  intersection_2_term_freq[intersection_idx],
                                      master_intersection_idx_2_tl_signal_indices[intersection_idx]
                                 )
    dataloader = DataLoader(dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers)
    return dataloader

def train(dataloader_trn, dataloader_val, intersection_model, device, optimizer, lr_scheduler, early_stopping, 
          binary_crossentropy=nn.BCELoss(reduction="none"), epoch_max=15, clip_value=5):# ==== TRAIN LOOP
    for epoch in range(epoch_max):
        progress_bar = tqdm((dataloader_trn), desc=f'Epoch {epoch}')
        losses_train = []
        losses_lob_prob_train = []
        losses_bce_train = []

        for batch in progress_bar:
            tokens, token_type_ohe, token_timesteps, seq_len, all_true_classes, all_tte, classes_availabilities, tte_availabilities = batch
            # moving to GPU if available
            tokens, token_type_ohe, token_timesteps, seq_len = tokens.to(device), token_type_ohe.to(device), token_timesteps.to(device), seq_len.to(device)
            all_true_classes = {tl_i: vals.to(device) for tl_i, vals in all_true_classes.items()}
            all_tte = {tl_i: vals.to(device) for tl_i, vals in all_tte.items()}
            classes_availabilities = {tl_i: vals.to(device) for tl_i, vals in classes_availabilities.items()}
            tte_availabilities = {tl_i: vals.to(device) for tl_i, vals in tte_availabilities.items()}
            intersection_model.train()
            torch.set_grad_enabled(True)
            tl_2_color_class, tl_2_tte_distr = intersection_model(tokens, token_type_ohe, token_timesteps, seq_len)

            loss_bce = torch.tensor([0.0]).to(device)
            loss_bce_terms_count = torch.tensor([0.0]).to(device)
            for tl_id, pred_color_classes in tl_2_color_class.items():
                true_color_classes = all_true_classes[tl_id]
                bce_loss_tl = binary_crossentropy(torch.squeeze(pred_color_classes), true_color_classes)*classes_availabilities[tl_id]
                loss_bce += bce_loss_tl.sum()
                loss_bce_terms_count += classes_availabilities[tl_id].sum()
            if loss_bce_terms_count > 0:
                loss_bce /= loss_bce_terms_count

            loss_tte_log_prob = torch.tensor([0.0]).to(device)
            loss_tte_log_prob_terms_count = torch.tensor([0.0]).to(device)
            for tl_id, tte_distr in tl_2_tte_distr.items():
                true_ttes = torch.unsqueeze(all_tte[tl_id], -1)
                log_prob_all = torch.squeeze(tte_distr.log_prob(true_ttes))*tte_availabilities[tl_id]
                log_prob_all[torch.logical_or(torch.isnan(log_prob_all), torch.isinf(log_prob_all))] = torch.tensor(0.0).to(device)
                loss_tte_log_prob -= log_prob_all.sum()
                loss_tte_log_prob_terms_count += tte_availabilities[tl_id].sum()
            if loss_tte_log_prob_terms_count > 0:
                loss_tte_log_prob /= loss_tte_log_prob_terms_count 

            loss = loss_bce + loss_tte_log_prob

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(intersection_model.parameters(), clip_value)
            optimizer.step()

            losses_train.append(loss.item())
            losses_lob_prob_train.append(loss_tte_log_prob.item())
            losses_bce_train.append(loss_bce.item())
            progress_bar.set_description(f"Ep. {epoch}, loss: {loss.item():.2f} (bce: {loss_bce.item():.2f}, log prob: {loss_tte_log_prob.item():.3f})")
        print(f"Avg train loss: {np.mean(losses_train):.5f} (bce: {np.mean(losses_bce_train):.5f}, log prob: {np.mean(losses_lob_prob_train):.5f})")

        intersection_model.eval()
        losses_val_all = []
        losses_val_lob_prob_train = []
        losses_val_bce_train = []
        for batch in tqdm(dataloader_val, desc='Validation..'):
            tokens, token_type_ohe, token_timesteps, seq_len, all_true_classes, all_tte, classes_availabilities, tte_availabilities = batch
            tokens, token_type_ohe, token_timesteps, seq_len = tokens.to(device), token_type_ohe.to(device), token_timesteps.to(device), seq_len.to(device)
            all_true_classes = {tl_i: vals.to(device) for tl_i, vals in all_true_classes.items()}
            all_tte = {tl_i: vals.to(device) for tl_i, vals in all_tte.items()}
            classes_availabilities = {tl_i: vals.to(device) for tl_i, vals in classes_availabilities.items()}
            tte_availabilities = {tl_i: vals.to(device) for tl_i, vals in tte_availabilities.items()}
            intersection_model.train()
            torch.set_grad_enabled(True)
            tl_2_color_class, tl_2_tte_distr = intersection_model(tokens, token_type_ohe, token_timesteps, seq_len)

            loss_bce = torch.tensor([0.0]).to(device)
            loss_bce_terms_count = torch.tensor([0.0]).to(device)
            for tl_id, pred_color_classes in tl_2_color_class.items():
                true_color_classes = all_true_classes[tl_id]
                bce_loss_tl = binary_crossentropy(torch.squeeze(pred_color_classes), true_color_classes)*classes_availabilities[tl_id]
                loss_bce += bce_loss_tl.sum()
                loss_bce_terms_count += classes_availabilities[tl_id].sum()
            if loss_bce_terms_count:
                loss_bce /= loss_bce_terms_count

            loss_tte_log_prob = torch.tensor([0.0]).to(device)
            loss_tte_log_prob_terms_count = torch.tensor([0.0]).to(device)
            for tl_id, tte_distr in tl_2_tte_distr.items():
                true_ttes = torch.unsqueeze(all_tte[tl_id], -1)
                log_prob_all = torch.squeeze(tte_distr.log_prob(true_ttes))*tte_availabilities[tl_id]
                loss_tte_log_prob -= log_prob_all.sum()
                loss_tte_log_prob_terms_count += tte_availabilities[tl_id].sum()
            if loss_tte_log_prob_terms_count:
                loss_tte_log_prob /= loss_tte_log_prob_terms_count

            loss = loss_bce + loss_tte_log_prob/4 # to have similar scale of values

            losses_val_all.append(loss.item())
            losses_val_lob_prob_train.append(loss_tte_log_prob.item())
            losses_val_bce_train.append(loss_bce.item())

        loss_val_mean = np.mean(losses_val_all)
        print(f'Val loss: {loss_val_mean: .5f} (bce: {np.mean(losses_val_bce_train):.5f}, log prob: {np.mean(losses_val_lob_prob_train): .5f})')
        lr_scheduler.step(loss_val_mean)
        early_stopping(loss_val_mean, intersection_model)        
        if early_stopping.early_stop:
            break
            
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [13]:
def train_intersection_model(intersection_idx, device, lr=3e-4, embedding_dim=64, hidden_dim=64, n_layers=2, 
                             bidirectional=True, dropout=0.2, epoch_max=15):
    dataloader_trn = get_dataloader(tl_events_df_trn, intersection_idx)
    dataloader_val = get_dataloader(tl_events_df_val, intersection_idx, shuffle=False) 
    intersection_model = IntersectionModel(vocab_size=len(intersection_2_train_vocab[intersection_idx]),
                                           intersection_tl_signals=master_intersection_idx_2_tl_signal_indices[intersection_idx],
                                           embedding_dim=embedding_dim, 
                                             hidden_dim=hidden_dim, 
                                             n_layers=n_layers, 
                                             bidirectional=bidirectional, 
                                             dropout=dropout
                                          ).to(device)
    optimizer = optim.Adam(intersection_model.parameters(), lr=lr)
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=3)
    early_stopping = EarlyStopping(patience=7, verbose=True, path=f'intersection_{intersection_idx}_combined_loss_checkpoint.pt')
    train(dataloader_trn, dataloader_val, intersection_model, device, optimizer, lr_scheduler, early_stopping, epoch_max=epoch_max)

In [18]:
for intersection_idx in master_intersection_idx_2_tl_signal_indices.keys():
    print(f'Intersection {intersection_idx}..')
    train_intersection_model(intersection_idx, device, epoch_max=20)

Intersection 0..


HBox(children=(FloatProgress(value=0.0, description='Epoch 0', max=261.0, style=ProgressStyle(description_widt…


Avg train loss: 2.74029 (bce: 0.48843, log prob: 2.25185)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=3.0, style=ProgressStyle(description_w…


Val loss:  0.89715 (bce: 0.46262, log prob:  1.73809)
Validation loss decreased (inf --> 0.897147).  Saving model ...


HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=261.0, style=ProgressStyle(description_widt…


Avg train loss: 1.36033 (bce: 0.34041, log prob: 1.01993)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=3.0, style=ProgressStyle(description_w…


Val loss:  0.82824 (bce: 0.34346, log prob:  1.93911)
Validation loss decreased (0.897147 --> 0.828237).  Saving model ...


HBox(children=(FloatProgress(value=0.0, description='Epoch 2', max=261.0, style=ProgressStyle(description_widt…


Avg train loss: 1.04480 (bce: 0.24725, log prob: 0.79754)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=3.0, style=ProgressStyle(description_w…


Val loss:  0.84085 (bce: 0.29775, log prob:  2.17238)
EarlyStopping counter: 1 out of 10


HBox(children=(FloatProgress(value=0.0, description='Epoch 3', max=261.0, style=ProgressStyle(description_widt…


Avg train loss: 0.89798 (bce: 0.22005, log prob: 0.67794)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=3.0, style=ProgressStyle(description_w…


Val loss:  0.92642 (bce: 0.32419, log prob:  2.40893)
EarlyStopping counter: 2 out of 10


HBox(children=(FloatProgress(value=0.0, description='Epoch 4', max=261.0, style=ProgressStyle(description_widt…


Avg train loss: 0.86945 (bce: 0.22483, log prob: 0.64462)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=3.0, style=ProgressStyle(description_w…


Val loss:  0.91059 (bce: 0.28631, log prob:  2.49714)
EarlyStopping counter: 3 out of 10
Intersection 3..


HBox(children=(FloatProgress(value=0.0, description='Epoch 0', max=12.0, style=ProgressStyle(description_width…


Avg train loss: 3.53168 (bce: 0.68043, log prob: 2.85125)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=1.0, style=ProgressStyle(description_w…


Val loss:  0.00000 (bce: 0.00000, log prob:  0.00000)
Validation loss decreased (inf --> 0.000000).  Saving model ...


HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=12.0, style=ProgressStyle(description_width…


Avg train loss: 3.17498 (bce: 0.65722, log prob: 2.51776)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=1.0, style=ProgressStyle(description_w…


Val loss:  0.00000 (bce: 0.00000, log prob:  0.00000)
Validation loss decreased (0.000000 --> 0.000000).  Saving model ...


HBox(children=(FloatProgress(value=0.0, description='Epoch 2', max=12.0, style=ProgressStyle(description_width…


Avg train loss: 2.50000 (bce: 0.61556, log prob: 1.88444)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=1.0, style=ProgressStyle(description_w…


Val loss:  0.00000 (bce: 0.00000, log prob:  0.00000)
Validation loss decreased (0.000000 --> 0.000000).  Saving model ...


HBox(children=(FloatProgress(value=0.0, description='Epoch 3', max=12.0, style=ProgressStyle(description_width…


Avg train loss: 1.61506 (bce: 0.52069, log prob: 1.09437)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=1.0, style=ProgressStyle(description_w…


Val loss:  0.00000 (bce: 0.00000, log prob:  0.00000)
Validation loss decreased (0.000000 --> 0.000000).  Saving model ...


HBox(children=(FloatProgress(value=0.0, description='Epoch 4', max=12.0, style=ProgressStyle(description_width…


Avg train loss: 1.11540 (bce: 0.37782, log prob: 0.73758)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=1.0, style=ProgressStyle(description_w…


Val loss:  0.00000 (bce: 0.00000, log prob:  0.00000)
Epoch     5: reducing learning rate of group 0 to 3.0000e-05.
Validation loss decreased (0.000000 --> 0.000000).  Saving model ...
Intersection 8..


HBox(children=(FloatProgress(value=0.0, description='Epoch 0', max=16.0, style=ProgressStyle(description_width…


Avg train loss: 3.90095 (bce: 0.68485, log prob: 3.21610)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=1.0, style=ProgressStyle(description_w…


Val loss:  1.11473 (bce: 0.67982, log prob:  1.73964)
Validation loss decreased (inf --> 1.114726).  Saving model ...


HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=16.0, style=ProgressStyle(description_width…


Avg train loss: 3.25422 (bce: 0.65953, log prob: 2.59469)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=1.0, style=ProgressStyle(description_w…


Val loss:  0.93637 (bce: 0.63357, log prob:  1.21120)
Validation loss decreased (1.114726 --> 0.936369).  Saving model ...


HBox(children=(FloatProgress(value=0.0, description='Epoch 2', max=16.0, style=ProgressStyle(description_width…


Avg train loss: 2.26123 (bce: 0.59321, log prob: 1.66803)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=1.0, style=ProgressStyle(description_w…


Val loss:  0.74631 (bce: 0.48673, log prob:  1.03832)
Validation loss decreased (0.936369 --> 0.746307).  Saving model ...


HBox(children=(FloatProgress(value=0.0, description='Epoch 3', max=16.0, style=ProgressStyle(description_width…


Avg train loss: 1.69081 (bce: 0.47623, log prob: 1.21458)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=1.0, style=ProgressStyle(description_w…


Val loss:  0.62125 (bce: 0.36596, log prob:  1.02116)
Validation loss decreased (0.746307 --> 0.621255).  Saving model ...


HBox(children=(FloatProgress(value=0.0, description='Epoch 4', max=16.0, style=ProgressStyle(description_width…


Avg train loss: 1.41673 (bce: 0.37845, log prob: 1.03829)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=1.0, style=ProgressStyle(description_w…


Val loss:  0.53405 (bce: 0.29442, log prob:  0.95856)
Validation loss decreased (0.621255 --> 0.534055).  Saving model ...
Intersection 1..


HBox(children=(FloatProgress(value=0.0, description='Epoch 0', max=207.0, style=ProgressStyle(description_widt…


Avg train loss: 3.06190 (bce: 0.52343, log prob: 2.53847)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=2.0, style=ProgressStyle(description_w…


Val loss:  0.99316 (bce: 0.51536, log prob:  1.91117)
Validation loss decreased (inf --> 0.993155).  Saving model ...


HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=207.0, style=ProgressStyle(description_widt…


Avg train loss: 1.53320 (bce: 0.38975, log prob: 1.14345)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=2.0, style=ProgressStyle(description_w…


Val loss:  0.80085 (bce: 0.33777, log prob:  1.85234)
Validation loss decreased (0.993155 --> 0.800853).  Saving model ...


HBox(children=(FloatProgress(value=0.0, description='Epoch 2', max=207.0, style=ProgressStyle(description_widt…


Avg train loss: 1.31885 (bce: 0.33546, log prob: 0.98338)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=2.0, style=ProgressStyle(description_w…


Val loss:  0.80099 (bce: 0.31699, log prob:  1.93599)
EarlyStopping counter: 1 out of 10


HBox(children=(FloatProgress(value=0.0, description='Epoch 3', max=207.0, style=ProgressStyle(description_widt…


Avg train loss: 1.20800 (bce: 0.30479, log prob: 0.90321)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=2.0, style=ProgressStyle(description_w…


Val loss:  0.87455 (bce: 0.36534, log prob:  2.03685)
EarlyStopping counter: 2 out of 10


HBox(children=(FloatProgress(value=0.0, description='Epoch 4', max=207.0, style=ProgressStyle(description_widt…


Avg train loss: 1.10324 (bce: 0.25076, log prob: 0.85248)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=2.0, style=ProgressStyle(description_w…


Val loss:  0.84897 (bce: 0.31177, log prob:  2.14881)
EarlyStopping counter: 3 out of 10
Intersection 6..


HBox(children=(FloatProgress(value=0.0, description='Epoch 0', max=181.0, style=ProgressStyle(description_widt…


Avg train loss: 3.10046 (bce: 0.46640, log prob: 2.63406)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=1.0, style=ProgressStyle(description_w…


Val loss:  0.72876 (bce: 0.32484, log prob:  1.61565)
Validation loss decreased (inf --> 0.728756).  Saving model ...


HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=181.0, style=ProgressStyle(description_widt…


Avg train loss: 1.58367 (bce: 0.37222, log prob: 1.21145)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=1.0, style=ProgressStyle(description_w…


Val loss:  0.66225 (bce: 0.24889, log prob:  1.65341)
Validation loss decreased (0.728756 --> 0.662247).  Saving model ...


HBox(children=(FloatProgress(value=0.0, description='Epoch 2', max=181.0, style=ProgressStyle(description_widt…


Avg train loss: 1.24447 (bce: 0.28873, log prob: 0.95574)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=1.0, style=ProgressStyle(description_w…


Val loss:  0.61879 (bce: 0.16157, log prob:  1.82890)
Validation loss decreased (0.662247 --> 0.618795).  Saving model ...


HBox(children=(FloatProgress(value=0.0, description='Epoch 3', max=181.0, style=ProgressStyle(description_widt…


Avg train loss: 1.02238 (bce: 0.20413, log prob: 0.81825)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=1.0, style=ProgressStyle(description_w…


Val loss:  0.61216 (bce: 0.12626, log prob:  1.94362)
Validation loss decreased (0.618795 --> 0.612162).  Saving model ...


HBox(children=(FloatProgress(value=0.0, description='Epoch 4', max=181.0, style=ProgressStyle(description_widt…


Avg train loss: 0.88012 (bce: 0.14853, log prob: 0.73159)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=1.0, style=ProgressStyle(description_w…


Val loss:  0.66308 (bce: 0.13421, log prob:  2.11549)
EarlyStopping counter: 1 out of 10
Intersection 4..


HBox(children=(FloatProgress(value=0.0, description='Epoch 0', max=260.0, style=ProgressStyle(description_widt…


Avg train loss: 2.97351 (bce: 0.50227, log prob: 2.47123)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=1.0, style=ProgressStyle(description_w…


Val loss:  0.91270 (bce: 0.56241, log prob:  1.40115)
Validation loss decreased (inf --> 0.912703).  Saving model ...


HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=260.0, style=ProgressStyle(description_widt…


Avg train loss: 1.41727 (bce: 0.34041, log prob: 1.07686)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=1.0, style=ProgressStyle(description_w…


Val loss:  0.81437 (bce: 0.48060, log prob:  1.33508)
Validation loss decreased (0.912703 --> 0.814367).  Saving model ...


HBox(children=(FloatProgress(value=0.0, description='Epoch 2', max=260.0, style=ProgressStyle(description_widt…


Avg train loss: 1.16395 (bce: 0.24808, log prob: 0.91587)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=1.0, style=ProgressStyle(description_w…


Val loss:  0.74185 (bce: 0.40660, log prob:  1.34099)
Validation loss decreased (0.814367 --> 0.741851).  Saving model ...


HBox(children=(FloatProgress(value=0.0, description='Epoch 3', max=260.0, style=ProgressStyle(description_widt…


Avg train loss: 1.01827 (bce: 0.19276, log prob: 0.82551)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=1.0, style=ProgressStyle(description_w…


Val loss:  0.68746 (bce: 0.34525, log prob:  1.36887)
Validation loss decreased (0.741851 --> 0.687465).  Saving model ...


HBox(children=(FloatProgress(value=0.0, description='Epoch 4', max=260.0, style=ProgressStyle(description_widt…


Avg train loss: 0.92859 (bce: 0.16519, log prob: 0.76340)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=1.0, style=ProgressStyle(description_w…


Val loss:  0.68245 (bce: 0.33160, log prob:  1.40344)
Validation loss decreased (0.687465 --> 0.682455).  Saving model ...
Intersection 7..


HBox(children=(FloatProgress(value=0.0, description='Epoch 0', max=265.0, style=ProgressStyle(description_widt…


Avg train loss: 2.84229 (bce: 0.50428, log prob: 2.33802)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=2.0, style=ProgressStyle(description_w…


Val loss:  0.81505 (bce: 0.42739, log prob:  1.55064)
Validation loss decreased (inf --> 0.815050).  Saving model ...


HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=265.0, style=ProgressStyle(description_widt…


Avg train loss: 1.37905 (bce: 0.35385, log prob: 1.02519)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=2.0, style=ProgressStyle(description_w…


Val loss:  0.69011 (bce: 0.31786, log prob:  1.48900)
Validation loss decreased (0.815050 --> 0.690115).  Saving model ...


HBox(children=(FloatProgress(value=0.0, description='Epoch 2', max=265.0, style=ProgressStyle(description_widt…


Avg train loss: 1.09731 (bce: 0.25223, log prob: 0.84508)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=2.0, style=ProgressStyle(description_w…


Val loss:  0.59611 (bce: 0.22417, log prob:  1.48779)
Validation loss decreased (0.690115 --> 0.596114).  Saving model ...


HBox(children=(FloatProgress(value=0.0, description='Epoch 3', max=265.0, style=ProgressStyle(description_widt…


Avg train loss: 0.93829 (bce: 0.21404, log prob: 0.72425)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=2.0, style=ProgressStyle(description_w…


Val loss:  0.52416 (bce: 0.15613, log prob:  1.47210)
Validation loss decreased (0.596114 --> 0.524157).  Saving model ...


HBox(children=(FloatProgress(value=0.0, description='Epoch 4', max=265.0, style=ProgressStyle(description_widt…


Avg train loss: 0.85253 (bce: 0.20457, log prob: 0.64796)


HBox(children=(FloatProgress(value=0.0, description='Validation..', max=2.0, style=ProgressStyle(description_w…


Val loss:  0.52388 (bce: 0.14974, log prob:  1.49658)
Validation loss decreased (0.524157 --> 0.523881).  Saving model ...
Intersection 2..


HBox(children=(FloatProgress(value=0.0, description='Epoch 0', max=113.0, style=ProgressStyle(description_widt…




KeyboardInterrupt: 