In [None]:
# matplotlib plots within notebook
%matplotlib inline

import matplotlib.pyplot as plt

import numpy as np

from datetime import datetime
from tqdm import tqdm

import os, sys, shutil
import pickle


from l5kit.configs import load_config_data
from l5kit.data import ChunkedDataset, LocalDataManager
from l5kit.dataset import EgoDataset, AgentDataset
from l5kit.rasterization import build_rasterizer
from l5kit.visualization import draw_trajectory, TARGET_POINTS_COLOR, PREDICTED_POINTS_COLOR, REFERENCE_TRAJ_COLOR
from l5kit.geometry import transform_points
import l5kit
print('Using l5kit version: '+l5kit.__version__)



# Custom libs
sys.path.insert(0, './LyftAgent_lib')
from LyftAgent_lib import train_support as lyl_ts
from LyftAgent_lib import topologies as lyl_nn

# Print Code Version
import git
def print_git_info(path, nombre):
   repo = git.Repo(path)
   print('Using: %s \t branch %s \t commit hash %s'%(nombre, repo.active_branch.name, repo.head.object.hexsha))
   changed = [ item.a_path for item in repo.index.diff(None) ]
   if len(changed)>0:
       print('\t\t WARNING -- modified files:')
       print(changed)    
print_git_info('.', 'LyftAgent_lib')


import platform
print("python: "+platform.python_version())


import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.backend as K
print('Using TensorFlow version: '+tf.__version__)
print('Using Keras version: '+keras.__version__)

from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

# Training and Model Config

In [None]:
# set env variable for data
os.environ["L5KIT_DATA_FOLDER"] = ""
# get config
cfg = load_config_data("./AgentPrediction_config.yaml")
# Fill defaul classes
cfg = lyl_ts.fill_defaults(cfg)

In [None]:
# Epoch to restart training (0 means start over)
restart_epoch = 0

In [None]:
# Set loss functions and coupling list
# During training these are summed and wheighted like: loss_function[0]*loss_couplings[0] + ... + loss_function[n]*loss_couplings[n] 
loss_function = [lyl_ts.L_loss_single2mult, lyl_ts.L2_loss]
loss_couplings = [1.0, 1.0]

### Process config parameters

In [None]:
# Get parameters
model_map_input_shape      = (cfg["raster_params"]["raster_size"][0],
                              cfg["raster_params"]["raster_size"][1])

base_image_arch            = cfg["model_params"]["base_image_model"]
base_image_preprocess      = cfg["model_params"]["base_image_preprocess"]

num_future_frames          = cfg["model_params"]["future_num_frames"]

num_hist_frames            = cfg["model_params"]["history_num_frames"]


histEnc_recurrent_unit     = cfg["model_params"]["history_encoder_recurrent_unit"]
histEnc_recurrent_unit_num = cfg["model_params"]["history_encoder_recurrent_units_number"]
pathDec_recurrent_unit     = cfg["model_params"]["path_generation_decoder_recurrent_unit"]
pathDec_recurrent_unit_num = cfg["model_params"]["path_generation_decoder_recurrent_units_number"]



gen_batch_size             = cfg["train_data_loader"]["batch_size"]
gen_lr_list                = cfg["training_params"]["gen_lr_list"]
gen_lr_lims                = cfg["training_params"]["gen_lr_lims"]


number_of_scenes           = cfg["training_params"]["number_of_scenes"]
frames_per_scene           = cfg["training_params"]["frames_per_scene"]
randomize_frames           = cfg["training_params"]["randomize_frames"]
randomize_scenes           = cfg["training_params"]["randomize_scenes"]
epochs_train               = cfg["training_params"]["epochs_train"]
teacher_force_list         = cfg["training_params"]["teacher_force_list"]
teacher_force_lims         = cfg["training_params"]["teacher_force_lims"]


use_teacher_force          = cfg["training_params"]["use_teacher_force"]
init_decoder_on_history    = cfg["training_params"]["init_decoder_on_history"]
retrain_inputs_image_model = cfg["training_params"]["retrain_inputs_image_model"]
retrain_all_image_model    = cfg["training_params"]["retrain_all_image_model"]

future_steps_train_list    = cfg["training_params"]["future_steps_train_list"]
future_steps_train_lims    = cfg["training_params"]["future_steps_train_lims"]
use_modulate_future_steps  = cfg["training_params"]["use_modulate_future_steps"]


model_version              = cfg["model_params"]["version"]
use_fading                 = cfg["model_params"]["use_fading"]

use_angle                  = cfg["model_params"]["use_angle"]

increment_net              = cfg["model_params"]["increment_net"]

mruv_guiding               = cfg["model_params"]["mruv_guiding"]
mruv_model_trainable       = cfg["model_params"]["mruv_model_trainable"]


# Append flags to output name
loss_names = ['Likelihood', 'MSE']
if mruv_model_trainable:
    loss_names.append('mruv_V_Loss')
    loss_names.append('mruv_A_Loss')
    loss_names.append('mruv_Conf_Loss')

# Set mru_model trainable flag
train_imgModel = retrain_inputs_image_model or retrain_all_image_model

# Checl model version and set forward pass 
isBaseModel = False
if model_version == 'Base':
    isBaseModel = True
    forward_pass_use = lyl_nn.modelBaseline_forward_pass
    save_path = './output_Baseline'
elif model_version == 'V1':
    forward_pass_use = lyl_nn.modelV1_forward_pass
    save_path = './output_V1_likelihood'
elif model_version == 'V2':
    forward_pass_use = lyl_nn.modelV2_forward_pass
    save_path = './output_V2_noAttn_big_multiLoss_imgRetrain'
    if mruv_guiding:
        save_path += '_mruvGuided'
    if mruv_model_trainable:
        save_path += '_mruvModel'
    cfg["model_params"]["history_encoder_recurrent_units_number"] = cfg["model_params"]["path_generation_decoder_recurrent_units_number"]



# Get image preprocessing function (depends on image encoding base architecture)
if base_image_preprocess == None:
    base_image_preprocess_fcn = lambda x: x
else:
    try:
        base_image_preprocess_fcn = getattr(keras.applications, base_image_preprocess.split('.')[0])
        base_image_preprocess_fcn = getattr(base_image_preprocess_fcn, base_image_preprocess.split('.')[1])
    except:
        raise Exception('Base image pre-processing not found. Requested function: %s'%base_image_preprocess) 


# Dataset Loader

### Train Set

In [None]:
dm = LocalDataManager()
dataset_path = dm.require(cfg["train_data_loader"]["key"])
zarr_dataset = ChunkedDataset(dataset_path)
zarr_dataset.open(cached=False)
print(zarr_dataset)
rast = build_rasterizer(cfg, dm)
train_dataset = AgentDataset(cfg, zarr_dataset, rast, min_frame_future = 10)
number_of_scenes = len(train_dataset.dataset.scenes)

In [None]:
tf_train_dataset = lyl_ts.get_tf_dataset(train_dataset, 
                                         num_hist_frames,
                                         model_map_input_shape,
                                         num_future_frames,
                                         randomize_frame=randomize_frames,
                                         randomize_scene=randomize_scenes, 
                                         num_scenes=number_of_scenes, 
                                         frames_per_scene = frames_per_scene)

# Map sample pre-processing function
tf_train_dataset = tf_train_dataset.map(lambda x: lyl_ts.tf_get_input_sample(x, 
                                                                             image_preprocess_fcn=base_image_preprocess_fcn, 
                                                                             use_fading = use_fading,
                                                                             use_angle = use_angle))
# Set batch size
tf_train_dataset = tf_train_dataset.batch(batch_size=gen_batch_size)   

### Validation Set

In [None]:
dm_val = LocalDataManager()
dataset_path_val = dm_val.require(cfg["val_data_loader"]["key"])
zarr_dataset_val = ChunkedDataset(dataset_path_val)
zarr_dataset_val.open(cached=False)
print(zarr_dataset_val)
rast_val = build_rasterizer(cfg, dm_val)
validation_dataset = AgentDataset(cfg, zarr_dataset_val, rast_val, min_frame_future = 10)
number_of_scenes_val = len(validation_dataset.dataset.scenes)

In [None]:
tf_validation_dataset = lyl_ts.get_tf_dataset(validation_dataset, 
                                              num_hist_frames,
                                              model_map_input_shape,
                                              num_future_frames,
                                              randomize_frame=True,
                                              randomize_scene=True, 
                                              num_scenes=number_of_scenes_val, 
                                              frames_per_scene = frames_per_scene)

# Map sample pre-processing function
tf_validation_dataset = tf_validation_dataset.map(lambda x: lyl_ts.tf_get_input_sample(x,                                                                                                                                  image_preprocess_fcn=base_image_preprocess_fcn,
                                                                                      use_fading = use_fading,
                                                                                      use_angle = use_angle))
# Set batch size
tf_validation_dataset = tf_validation_dataset.batch(batch_size=gen_batch_size)

# Build Models

In [None]:
if restart_epoch>0:
    # Pick up from previous step    
    model_list = lyl_ts.load_models(save_path, 'epoch_%d'%(restart_epoch), 
                         load_img_model = (retrain_inputs_image_model or retrain_all_image_model), 
                         isBaseModel = isBaseModel,
                         mruv_guiding = mruv_guiding,
                         mruv_model_trainable = mruv_model_trainable)

    ImageEncModel = model_list[0]
    HistEncModel = model_list[1]
    PathDecModel = model_list[2]
    mruv_model = model_list[3]

else:
    
    # ---------------------- IMAGE ENCODER --------------------------
    # Load pretrained image processing model
    try:
        base_model_builder = getattr(keras.applications, base_image_arch)
    except:
        raise Exception('Base image processing model not found. Reuquested model: %s'%base_image_arch)


    base_img_model = base_model_builder(include_top=False, weights='imagenet')

    ImageEncModel = lyl_nn.imageEncodingModel(base_img_model, cfg)

    # ---------------------- PATH HISTORY ENCODER ------------------
    if not isBaseModel:
        HistEncModel = lyl_nn.pathEncodingModel(cfg)
    else:
        HistEncModel = None
    
    if mruv_guiding:
        if mruv_model_trainable:
            mruv_model  = lyl_nn.mruvModel(cfg, ImageEncModel)
        else:
            mruv_model  = lyl_ts.get_velocity_and_acceleration
    else:
        mruv_model = None

    # ---------------------- PATH DECODER --------------------------
    if model_version == 'Base':
        PathDecModel = lyl_nn.pathDecoderModel_Baseline(cfg, ImageEncModel)
    elif model_version == 'V1':
        PathDecModel = lyl_nn.pathDecoderModel_V1(cfg, ImageEncModel, HistEncModel)
    elif model_version == 'V2':
        PathDecModel = lyl_nn.pathDecoderModel_V2(cfg, ImageEncModel, HistEncModel)

### Show model information

In [None]:
ImageEncModel.summary()
tf.keras.utils.plot_model(ImageEncModel, show_shapes=True, show_layer_names=True, 
                              to_file=os.path.join(save_path,'ImageEncModel.png'))

In [None]:
if not isBaseModel:
    HistEncModel.summary()
    tf.keras.utils.plot_model(HistEncModel, show_shapes=True, show_layer_names=True, 
                              to_file=os.path.join(save_path,'HistEncModel.png'))

In [None]:
if mruv_guiding and mruv_model_trainable:
    mruv_model.summary()
    tf.keras.utils.plot_model(mruv_model, show_shapes=True, show_layer_names=True, 
                              to_file=os.path.join(save_path,'mruv_model.png'))

In [None]:
PathDecModel.summary()
tf.keras.utils.plot_model(PathDecModel, show_shapes=True, show_layer_names=True, 
                              to_file=os.path.join(save_path,'PathDecModel.png'))


# Training

### Setup TensorBoard

In [None]:
# Set up TensorBoard
logdir = os.path.join(save_path,'TensorBoard_outs')

if os.path.exists(logdir) and restart_epoch==0:
    shutil.rmtree(logdir)
tb_writer = tf.summary.create_file_writer(os.path.join(logdir,'train'), name = 'train')

tb_writer_val = tf.summary.create_file_writer(os.path.join(logdir,'validation'), name = 'validation')

# Its really ugly... but you can get it if you want it...
get_graph = False


### Setup optimizer

In [None]:
optimizer_gen = keras.optimizers.Adam(lr=1.0, beta_1=0.9, beta_2=0.99, epsilon=1.0e-8)
train_genL2_loss = keras.metrics.Mean(name='train_gen_L2_loss')

# Load optimizer state
if restart_epoch>0:
    if os.path.exists(os.path.join(os.path.join(save_path,'OptimizerStates'), 'epoch_%d.npy'%(restart_epoch))):
        # Get list of trainable variables
        train_vars = ImageEncModel.trainable_variables
        if not isBaseModel:
            train_vars += HistEncModel.trainable_variables
        train_vars += PathDecModel.trainable_variables
        if mruv_model_trainable:
            train_vars += mruv_model.trainable_variables
                    
        lyl_ts.load_optimizer_state(os.path.join(save_path,'OptimizerStates'), 'epoch_%d'%(restart_epoch), optimizer_gen, train_vars)
    else:
        print('Warning - Optimizer state not loaded, using blank optimizer.')

    



### Train!

In [None]:

# ----------------------------- Set initial parameters -------------------------------
teacher_force_base = 1.0
teacher_force = teacher_force_base   
if restart_epoch > 0:
    epoch_ini = restart_epoch+1
else:
    epoch_ini = 1
futureSteps_infer = num_future_frames
h_idx_all = list()
# ------------------------------------------------------------------------------------


# Run for the given number of epochs
for epoch in range(epoch_ini, epochs_train+1):

    # Using the "train" TensorBoard file
    with tb_writer.as_default():

        # Step index for this epoch
        idx_gen_train = 0
        # Number of steps per epoch to run
        steps_per_epoch = int(np.floor(number_of_scenes*frames_per_scene/gen_batch_size))
        # Global step index (grows indefinitelly with the epochs)
        this_step_num = idx_gen_train+(steps_per_epoch*(epoch-1))

        # ----------------------------- Set future steos to train on -------------------------
        if use_modulate_future_steps:
            # Modulate number of future steps to train on
            futureSteps_infer = lyl_ts.get_future_steps_train(future_steps_train_list, 
                                                            future_steps_train_lims, 
                                                            epoch, 
                                                            futureSteps_infer)
        with tf.name_scope("Train_params"):
            tf.summary.scalar('obj_steps', futureSteps_infer, step=this_step_num)

        # ------------------------------------------------------------------------------------
    
        # ----------------------------- Set teacher force value ------------------------------
        if not isBaseModel:
            if use_teacher_force:
                # Modulate teacher force
                teacher_force = lyl_ts.get_teacher_force_weight(teacher_force_list, 
                                                        teacher_force_lims, 
                                                        epoch, 
                                                        teacher_force, 
                                                        linearize=True)
                if teacher_force<=0.0:
                    use_teacher_force = False
                    print('Teacher force deactivated.')

            with tf.name_scope("Train_params"):
                tf.summary.scalar('teacher_force', teacher_force, step=this_step_num)
        # ------------------------------------------------------------------------------------


        # ----------------------------- Set learning rate value ------------------------------
        # Update learning rate
        this_lr = lyl_ts.update_lr(gen_lr_list, gen_lr_lims, epoch, optimizer_gen)
        with tf.name_scope("Train_params"):
            tf.summary.scalar('learning_rate', this_lr, step=this_step_num)
        
        # ------------------------------------------------------------------------------------


        # Flush stdout to avoid tqdm overlap
        sys.stdout.flush()
                                
        train_genL2_loss_aux = 0
        

        # List of hard examples
        # This is a list ot those examples where the error was 5x larger than mean
        h_example_idxs = list()
        h_example_loss = list()
        h_example_mean = list()
        example_log = list()
        
        # Set tqdm progress bar
        train_dataset_prog_bar = tqdm(tf_train_dataset, total=steps_per_epoch)
        # Iterate over the training dataset once
        # (Each complete iteration will shield a different, random, frame set)
        for (thisSampleMapComp, thisSampeHistPath, thisSampeTargetPath, 
            thisHistAvail, thisTargetAvail, 
            thisTimeStamp, thisTrackID, thisRasterFromAgent, thisWorldFromAgent, thisCentroid, 
            thisSampleIdx) in train_dataset_prog_bar:

            # Update step num
            this_step_num = idx_gen_train+(steps_per_epoch*(epoch-1))

            # If the batch size of this mini-batch is not the same as the 
            # expected by the net, the training loop is finished (we loose only a few random frames)
            if thisSampleMapComp.shape[0] < gen_batch_size:
                break

            # If the graph is being recorded, start the tracing
            if get_graph:
                tf.summary.trace_on(graph=True, profiler=True)
                       

            # Run the correct train step depending on the model
            if isBaseModel:
                step_losses, gradients_out  = lyl_ts.generator_train_step_Base(thisSampleMapComp,
                                                                        thisSampeTargetPath, 
                                                                        thisTargetAvail,
                                                                        ImageEncModel,
                                                                        PathDecModel, 
                                                                        optimizer_gen, 
                                                                        train_genL2_loss,
                                                                        forward_pass_use,
                                                                        loss_use = loss_function,
                                                                        gradient_clip_value = 10.0)
            else:
                PathDecModel.reset_states()
                HistEncModel.reset_states()
                
                step_losses, gradients_out = lyl_ts.generator_train_step(thisSampleMapComp,
                                                                        thisSampeHistPath, 
                                                                        thisSampeTargetPath, 
                                                                        thisHistAvail, 
                                                                        thisTargetAvail,
                                                                        ImageEncModel,
                                                                        HistEncModel, 
                                                                        PathDecModel, 
                                                                        optimizer_gen, 
                                                                        train_genL2_loss,
                                                                        tf.constant(tf.zeros(PathDecModel.inputs[-1].shape)),
                                                                        forward_pass_use,
                                                                        loss_use = loss_function,
                                                                        loss_couplings = loss_couplings,
                                                                        stepsInfer = futureSteps_infer,
                                                                        use_teacher_force=use_teacher_force,
                                                                        teacher_force_weight = tf.constant(teacher_force, 
                                                                                                                dtype=tf.float32),
                                                                        gradient_clip_value = 10.0,
                                                                        stop_gradient_on_prediction = False,
                                                                        increment_net = increment_net,
                                                                        mruv_guiding = mruv_guiding,
                                                                        mruv_model = mruv_model, 
                                                                        mruv_model_trainable = mruv_model_trainable)

            # If the graph was recorded, save it and stop the tracing
            if get_graph:
                tf.summary.trace_export(
                    name="train_step_trace",
                    step=0,
                    profiler_outdir=logdir)
                tf.summary.trace_off()
            get_graph = False


            # Analyze train step output
            thisL2 = train_genL2_loss.result()
            if np.isnan(thisL2):
                raise Exception('Bad potato... step %d'%idx_gen_train)
                
            # -------------------------- Update progress bar --------------------------------------
            # Get step compound loss
            train_genL2_loss_aux += thisL2
            train_genL2_loss.reset_states()
            print_gen_L2 = train_genL2_loss_aux/(idx_gen_train+1)

            # Update progress bar
            msg_string = '(Epoch %d/%d) Gen. Loss: %.2f (last %.2f) '%(epoch,
                                                                            epochs_train,
                                                                            print_gen_L2, 
                                                                            thisL2)
            if use_teacher_force and not isBaseModel:
                msg_string += '(t.f. = %.2f)'%teacher_force
            train_dataset_prog_bar.set_description(msg_string)

            # ------------------------------------------------------------------------------------




            # ------------------------------- Log the hard examples data -------------------------
            comp_loss = 0
            for step_loss, k in zip(step_losses, loss_couplings):
                comp_loss += k*step_loss.numpy() 
            for thisLoss, thisIDX in zip(comp_loss, thisSampleIdx):
                example_log.append(thisIDX)
                if (thisLoss > 10.0*print_gen_L2):
                    h_example_idxs.append(thisIDX)
                    h_example_loss.append(thisLoss)
                    h_example_mean.append(print_gen_L2)
                                
            # ------------------------------------------------------------------------------------





            # ------------------------------ Save info to tensorboard ----------------------------
            with tf.name_scope("Loss_metrics"):
                for this_name, this_loss in zip(loss_names, step_losses):
                    tf.summary.scalar(this_name, np.mean(this_loss.numpy()), step=this_step_num)

            grad_out_count = 0
            if train_imgModel:
                with tf.name_scope("Gradient_ImgEncModel"):
                    for grad, var in zip(gradients_out[grad_out_count], ImageEncModel.trainable_variables):
                        tf.summary.scalar(var.name+'_norm', np.linalg.norm(var.numpy()), step=this_step_num)
                grad_out_count += 1
            if not isBaseModel:
                with tf.name_scope("Gradient_HistEncModel"):
                    for grad, var in zip(gradients_out[grad_out_count], HistEncModel.trainable_variables):
                        tf.summary.scalar(var.name+'_norm', np.linalg.norm(var.numpy()), step=this_step_num)
                    grad_out_count += 1
            with tf.name_scope("Gradient_PathDecModel"):
                for grad, var in zip(gradients_out[grad_out_count], PathDecModel.trainable_variables):
                    tf.summary.scalar(var.name+'_norm', np.linalg.norm(var.numpy()), step=this_step_num)
                grad_out_count += 1
            if mruv_model_trainable:
                with tf.name_scope("Gradient_mruvModel"):
                    for grad, var in zip(gradients_out[grad_out_count], mruv_model.trainable_variables):
                        tf.summary.scalar(var.name+'_norm', np.linalg.norm(var.numpy()), step=this_step_num)
                grad_out_count += 1

            if this_step_num % 100 == 0:
                if train_imgModel:
                    with tf.name_scope("ImageEncodingModel"):
                        for weights, layer in zip(ImageEncModel.get_weights(), ImageEncModel.trainable_variables):
                            tf.summary.histogram(layer.name, weights, step=this_step_num)
                if not isBaseModel:
                    with tf.name_scope("HistoryEncodingModel"):
                        for weights, layer in zip(HistEncModel.get_weights(), HistEncModel.trainable_variables):
                            tf.summary.histogram(layer.name, weights, step=this_step_num)
                with tf.name_scope("PathDecoderModel"):
                    for weights, layer in zip(PathDecModel.get_weights(), PathDecModel.trainable_variables):
                        tf.summary.histogram(layer.name, weights, step=this_step_num)
                if mruv_model_trainable:
                    with tf.name_scope("Gradient_mruvModel"):
                        for weights, layer in zip(mruv_model.get_weights(), mruv_model.trainable_variables):
                            tf.summary.histogram(layer.name, weights, step=this_step_num)
                tb_writer.flush()


            # ------------------------------------------------------------------------------------

            # Update this epoch train step
            idx_gen_train += 1

            
        # --------------------------- Save Models -------------------------------------------
        if retrain_inputs_image_model or retrain_all_image_model:
            lyl_ts.save_model(ImageEncModel, os.path.join(save_path,'ImageEncModel'), 'epoch_%d'%epoch, use_keras=True)
        elif epoch == 1:
            lyl_ts.save_model(ImageEncModel, os.path.join(save_path,'ImageEncModel'), 'all_epochs', use_keras=True)
        if not isBaseModel:
            lyl_ts.save_model(HistEncModel, os.path.join(save_path,'HistEncModel'), 'epoch_%d'%epoch, use_keras=True)
        lyl_ts.save_model(PathDecModel, os.path.join(save_path,'PathDecModel'), 'epoch_%d'%epoch, use_keras=True)
        if mruv_model_trainable:
            lyl_ts.save_model(mruv_model, os.path.join(save_path,'mruvModel'), 'epoch_%d'%epoch, use_keras=True)
        lyl_ts.save_optimizer_state(optimizer_gen, os.path.join(save_path,'OptimizerStates'), 'epoch_%d'%epoch)

        # ------------------------------------------------------------------------------------

        # -------------------------- Save list of hard samples -------------------------------
        h_example_idxs = np.array(h_example_idxs)
        h_example_loss = np.array(h_example_loss)
        h_example_mean = np.array(h_example_mean)
        file_out = open(os.path.join(save_path,'epoch_%d_h_idx.txt'%epoch),"w") 
        for idx, loss, media in zip(h_example_idxs, h_example_loss, h_example_mean):
            file_out.write('%d : \t %g (%g)\n'%(idx, loss, media))
        file_out.close()
        file_out = open(os.path.join(save_path,'epoch_%d_sample_log_idx.txt'%epoch),"w") 
        for idx in example_log:
            file_out.write('%d, '%idx)
        file_out.close()
        # ------------------------------------------------------------------------------------



    # Perform validation using the "validation" TensorBoard file
    with tb_writer_val.as_default():

        # Validate
        if (epoch%10 == 0):
            # Once every ten, a full validation step
            # (it is not really a complete run, because the dataset is built with a given "frames_per_scene")
            out_metrics = lyl_ts.validate_model(tf_validation_dataset, 
                                ImageEncModel, HistEncModel, PathDecModel, forward_pass_use, 
                                increment_net = increment_net,
                                mruv_guiding = mruv_guiding, 
                                mruv_model = mruv_model,
                                mruv_model_trainable = mruv_model_trainable,
                                all_metrics = True, base_model = isBaseModel)
        else:
            out_metrics = lyl_ts.validate_model(tf_validation_dataset, 
                                ImageEncModel, HistEncModel, PathDecModel, forward_pass_use, 
                                increment_net = increment_net,
                                mruv_guiding = mruv_guiding, 
                                mruv_model = mruv_model,
                                mruv_model_trainable = mruv_model_trainable,
                                all_metrics = True,
                                steps_validate = 100, stepsInfer = futureSteps_infer, base_model = isBaseModel)
        
        
        with tf.name_scope("Loss_metrics"):
            tf.summary.scalar('MSE', out_metrics[0], step=this_step_num)
            tf.summary.scalar('Likelihood', out_metrics[1], step=this_step_num)
            tf.summary.scalar('TD(0)', out_metrics[2][0], step=this_step_num)
            tf.summary.scalar('TD(10)', out_metrics[2][10], step=this_step_num)
            tf.summary.scalar('TD(25)', out_metrics[2][25], step=this_step_num)
            tf.summary.scalar('TD(T)', out_metrics[2][-1], step=this_step_num)
            tf.summary.scalar('TD(mean)', np.mean(out_metrics[2]), step=this_step_num)

        tb_writer_val.flush()

        

    print('')   
    h_idx_all.append(np.unique(h_example_idxs))
    print(np.unique(h_example_idxs).shape[0])

               

# Single sample tests

### Process a validation sample

In [None]:
# Take a sample batch
for (valSampleMapComp, valSampeHistPath, valSampeTargetPath, 
         valHistAvail, valTargetAvail, 
     valTimeStamp, valTrackID, valRasterFromAgent, valWorldFromAgent, valCentroid, valSampleIdx) in tf_validation_dataset:
    break

# Process it
PathDecModel.reset_states()
HistEncModel.reset_states()
valPredPath = forward_pass_use(valSampleMapComp,
                                        valSampeHistPath,
                                        valHistAvail,
                                        50, 
                                        ImageEncModel, HistEncModel, PathDecModel,
                                        use_teacher_force=False, target_path=valSampeTargetPath, increment_net = increment_net,
                                                    mruv_guiding = mruv_guiding,
                                                    mruv_model = mruv_model,
                                                    mruv_model_trainable = mruv_model_trainable)

valPredPath = valPredPath.numpy()
valLoss = lyl_ts.L_loss_single2mult(valPredPath, valSampeTargetPath, valTargetAvail)
valLoss = valLoss.numpy()

print('Mean likelihood ot this batch: %0.2f'%np.mean(valLoss))


### Plot input, predicted and expected path

In [None]:
idx_sample = np.random.randint(gen_batch_size)
num_targets = np.sum(valTargetAvail.numpy(), axis=1)[idx_sample].astype(np.int32)

print('Sample num.: %d (%d target points)'%(valSampleIdx[idx_sample], num_targets))
print('Timestamp: %d \t TrackID %d'%(valTimeStamp[idx_sample], valTrackID[idx_sample]))

In [None]:
HISTORY_TRAJ_COLOR = (0, 255, 0)

plt.figure(dpi=300)
plt.subplot(1,3,1)
img_aux = (np.copy(valSampleMapComp[idx_sample,:,:,:3].numpy()))
img_aux = (((img_aux-np.min(img_aux))/(np.max(img_aux)-np.min(img_aux)))*255.0).astype(np.int32)
draw_trajectory(img_aux,
                transform_points(valSampeTargetPath[idx_sample,:num_targets,:2].numpy(), 
                                 valRasterFromAgent[idx_sample,:num_targets,:].numpy()),
                TARGET_POINTS_COLOR)

draw_trajectory(img_aux,
                transform_points(valPredPath[idx_sample,:num_targets,:2], 
                                 valRasterFromAgent[idx_sample,:num_targets,:].numpy()),
                PREDICTED_POINTS_COLOR)

draw_trajectory(img_aux,
                transform_points(valSampeHistPath[idx_sample,:,:2].numpy(), 
                                 valRasterFromAgent[idx_sample,:,:].numpy()),
                HISTORY_TRAJ_COLOR)

plt.imshow(img_aux, origin='botom')
plt.axis('off')
plt.subplot(1,3,2)
plt.imshow(valSampleMapComp[idx_sample,:,:,3].numpy().astype(np.int32), origin='botom')
plt.axis('off')
plt.subplot(1,3,3)
plt.imshow(valSampleMapComp[idx_sample,:,:,4].numpy().astype(np.int32), origin='botom')
plt.axis('off')
plt.tight_layout()
plt.show()


### Error plots

In [None]:

plt.figure(dpi = 150)
plt.subplot(1,2,1)
plt.plot(valPredPath[idx_sample,:num_targets,0])
plt.plot(valSampeTargetPath.numpy()[idx_sample,:num_targets,0])
plt.ylabel('X position')
plt.xlabel('steps')
plt.legend(['predicted','ground truth'])

plt.subplot(1,2,2)
plt.plot(valPredPath[idx_sample,:num_targets,1])
plt.plot(valSampeTargetPath.numpy()[idx_sample,:num_targets,1])
plt.ylabel('Y position')
plt.xlabel('steps')
plt.legend(['predicted','ground truth'])
plt.tight_layout()
plt.show()

plt.figure(dpi = 150)
plt.subplot(1,2,1)
plt.plot(valSampeTargetPath.numpy()[idx_sample,:num_targets,0], valPredPath[idx_sample,:num_targets,0], '*-')
plt.plot(valSampeTargetPath[idx_sample,:num_targets,0], valSampeTargetPath[idx_sample,:num_targets,0], '-.')
plt.ylabel('X position predicted')
plt.xlabel('X position ground truth')
plt.subplot(1,2,2)
plt.plot(valSampeTargetPath.numpy()[idx_sample,:num_targets,1], valPredPath[idx_sample,:num_targets,1], '*-')
plt.plot(valSampeTargetPath[idx_sample,:num_targets,1], valSampeTargetPath[idx_sample,:num_targets,1], '-.')
plt.ylabel('Y position predicted')
plt.xlabel('Y position ground truth')
plt.tight_layout()
plt.show()


In [None]:
from importlib import reload
reload(lyl_ts)
reload(lyl_nn)