In [13]:
import tensorflow as tf
import numpy as np
import random

In [14]:
#Params
kld_reg= 1
adl_reg=1

fdim=16
zdim=16
sigma=1.3
past_length=8
future_length=12
data_scale=170
enc_past_size=(past_length*2,512,256,fdim)
enc_dest_size=(2,8,16,fdim)
enc_latent_size=(2*fdim,8,50,2*zdim)
dec_size=(fdim + zdim,1024,512,1024,2)
predictor_size=(2*fdim,1024,512,256,2*(future_length-1))
learning_rate=0.0003

In [15]:
def loadData(file_path: str):
  npz = np.load(file_path, allow_pickle=True)
  return npz['observations'], npz['obs_speed'], npz['targets'], npz[
      'target_speed'], npz['mean'], npz['std']

In [16]:
class Dense(tf.Module):
  def __init__(self, input_dim, output_size, name=None):
    super(Dense, self).__init__(name=name)
    self.w = tf.Variable(tf.random.uniform([input_dim, output_size],-(1.0/input_dim)**0.5,(1.0/input_dim)**0.5 ),name='w',dtype=tf.float32)
    self.b = tf.Variable(tf.random.uniform([output_size],-(1.0/input_dim)**0.5,(1.0/input_dim)**0.5 ), name='b',dtype=tf.float32)
  def __call__(self, x):
    x = tf.constant(x,dtype=tf.float32)
    y = tf.matmul(x, self.w) + self.b
    return tf.nn.relu(y)

class FullyConnectedNeuralNet(tf.Module):
  def __init__(self,sizes, name=None):
    super(FullyConnectedNeuralNet, self).__init__(name=name)
    self.layers = []
    with self.name_scope:
      for i in range(len(sizes)-1):
        self.layers.append(Dense(input_dim=sizes[i], output_size=sizes[i+1]))
  @tf.Module.with_name_scope
  def __call__(self, x):
    for layer in self.layers:
      x = layer(x)
    return x

In [17]:
class MainModel(tf.Module):
  def __init__(self,name=None):
    super(MainModel, self).__init__(name=name)

    self.zdim = zdim
    self.sigma = sigma

    self.pastEncoder = FullyConnectedNeuralNet(enc_past_size)

    self.destEncoder = FullyConnectedNeuralNet(enc_dest_size)

    self.latentDistributionEncoder = FullyConnectedNeuralNet(enc_latent_size)

    self.latentDistributionDecoder = FullyConnectedNeuralNet(dec_size)

    self.predictorNetwork = FullyConnectedNeuralNet(predictor_size)

  def forward(self, x, dest = []):

    if len(dest):
        self.training=True
    else:
        self.training=False
        
    # encode
    traj_past_ftr = self.pastEncoder(x)
    #print(f"ftraj max {ftraj.numpy().max()}")
    if not self.training:
        z = tf.random.normal((x.shape[0], self.zdim),0,self.sigma)

    else:
        dest_ftr = self.destEncoder(dest)
        #print(f"dest_features Max {dest_features.numpy().max()}")

        concat_ftr = tf.concat((traj_past_ftr, dest_ftr), axis = 1)
        latent =  self.latentDistributionEncoder(concat_ftr)
        mu = latent[:, 0:self.zdim] # 2-d array
        logvar = latent[:, self.zdim:] # 2-d array

        var = tf.math.exp(logvar*0.5)
        #print(f"var {var}")
        eps = tf.random.normal(var.shape)
        z = eps*var + mu
        #print(f"z -> {z}")


    latentDistributionDecoder_input = tf.concat((traj_past_ftr, z), axis = 1)
    generated_dest = self.latentDistributionDecoder(latentDistributionDecoder_input)
    
    if self.training:
        generated_dest_ftr = self.destEncoder(generated_dest)
        prediction_ftr = tf.concat((traj_past_ftr, generated_dest_ftr), axis = 1)
        pred_future = self.predictorNetwork(prediction_ftr)
        
        return (generated_dest, mu, logvar, pred_future)
    else:
        return generated_dest

  def predict(self, past, generated_dest):
        
    traj_past_ftr = self.pastEncoder(past)
    generated_dest_ftr = self.destEncoder(generated_dest)
    prediction_ftr = tf.concat((traj_past_ftr, generated_dest_ftr), axis = 1)
    future_traj = self.predictorNetwork(prediction_ftr)
    return future_traj

In [18]:
def calculate_loss(dest, dest_rec, mean, log_var, future, future_rec):
    
    rcl = tf.math.reduce_mean(tf.keras.metrics.mean_squared_error(dest, dest_rec))
    adl = tf.math.reduce_mean(tf.keras.metrics.mean_squared_error(future, future_rec))

    kld = -0.5 * tf.math.reduce_sum(1 + log_var - mean**2 - tf.math.exp(log_var))

    return rcl, kld, adl

In [19]:
def next_batch(X,batchSize):
    start = random.randint(0, len(X)-batchSize)
    return X[start:start+batchSize]

In [20]:
def train(trajx,model,optimizer):
    train_loss = 0
    total_rcl, total_kld, total_adl = 0, 0, 0
    
    traj = trajx - trajx[:, :1, :]
    traj *= data_scale

    x = traj[:, :past_length, :]
    y = traj[:, past_length:, :]

    x = x.reshape(-1, x.shape[1]*x.shape[2]) # (x,y,x,y ... )
    dest = y[:, -1, :]
    future = y[:, :-1, :].reshape(y.shape[0],-1)
                
    #x.astype(np.float64)
    #print(f"dest-> {trajx.shape}")

    with tf.GradientTape() as tape:
        x=tf.constant(x,dtype=tf.float32)
        tape.watch(x)
        dest_rec, mu, var, future_rec = model.forward(x, dest=dest)
        #print(f"dest_recon {dest_recon}")
        #print(f"mu {mu}")
        #print(f"var {var}")
        #print(f"interpolated_future {interpolated_future}")
                    
        rcl, kld, adl = calculate_loss(dest, dest_rec, mu, var, future, future_rec)
                    
        loss = rcl + kld * kld_reg + adl * adl_reg
        grad_sub = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grad_sub, model.trainable_variables))
                    
        #print(f"total Loss {loss}")
        #print(f"rcl Loss {rcl}")
        #print(f"kld Loss {kld}")
        #print(f"adl Loss {adl}")
    train_loss+=loss
    total_rcl+=rcl
    total_kld+=kld
    total_adl+=adl
    return train_loss, total_rcl, total_kld, total_adl
                

In [21]:
def test(trajx, model, best_of_n = 1):
    
    traj = trajx - trajx[:, :1, :]
    traj *= data_scale

    x = traj[:, :past_length, :]
    y = traj[:, past_length:, :]

    x = x.reshape(-1, x.shape[1]*x.shape[2])

    dest = y[:, -1, :]
    
    destination_errors = []
    dectination_recs = []
    
    for _ in range(best_of_n):
        x=tf.constant(x,dtype=tf.float32)
        dest_rec = model.forward(x)
        dectination_recs.append(np.array(dest_rec))

        error = np.linalg.norm(dest_rec - dest, axis = 1)
        destination_errors.append(error)

    destination_errors = np.array(destination_errors)
    dectination_recs = np.array(dectination_recs)
    # average error
    avg_dest_error = np.mean(destination_errors)

    indices = np.argmin(destination_errors, axis = 0)

    best_dest = dectination_recs[indices,np.arange(x.shape[0]),  :]

    # taking the minimum error out of all guess
    dest_error = np.mean(np.min(destination_errors, axis = 0))

    future_dest = model.predict(x, best_dest)
    # final overall prediction
    predicted_future = np.concatenate((future_dest, best_dest), axis = 1)
    predicted_future = np.reshape(predicted_future, (-1, future_length, 2))
    # ADE error
    overall_error = np.mean(np.linalg.norm(y - predicted_future, axis = 2))

    overall_error /= data_scale
    dest_error /= data_scale
    avg_dest_error /= data_scale
    #print('Test time error in destination best: {:0.3f} and mean: {:0.3f}'.format(dest_error, avg_dest_error))
    #print('Test time error overall (ADE) best: {:0.3f}'.format(overall_error))

    return overall_error, dest_error, avg_dest_error

            

In [22]:
def run_train():
    observations, _, targets, _, _, _ = loadData('./data/eth/eth_train.npz')
    train_dataset = np.concatenate([observations, targets], axis=1)
    observations, _, targets, _, _, _ = loadData('./data/eth/eth_test.npz')
    test_dataset = np.concatenate([observations, targets], axis=1)
    
    epochs = 8
    batchSize=100
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    model=MainModel()
    N=20
    best_test_loss = 50 # start saving after this threshold
    best_endpoint_loss = 50
    for epo in range(epochs):
            #print(f"Epoch : {epo+1}")
            for it in range(int(len(train_dataset)/batchSize)):
                trajx_train = next_batch(train_dataset,batchSize)
                trajx_test = next_batch(test_dataset,len(test_dataset))
                
                train_loss, rcl, kld, adl = train(trajx_train, model,optimizer)
                test_loss, final_point_loss_best, final_point_loss_avg = test(trajx_test, model, best_of_n = N)
                
                if best_test_loss > test_loss:
                    print("Epoch: ", epo+1)
                    print('################## BEST PERFORMANCE {:0.2f} ########'.format(test_loss))
                    best_test_loss = test_loss
                """"
                if best_test_loss < 10.25:
                    save_path = './content/trained.pt'
                """

                if final_point_loss_best < best_endpoint_loss:
                    best_endpoint_loss = final_point_loss_best

                print("Train Loss", train_loss)
                print("RCL", rcl)
                print("KLD", kld)
                print("ADL", adl)
                print("Test ADE", test_loss)
                print("Test Average FDE (Across  all samples)", final_point_loss_avg)
                print("Test Min FDE", final_point_loss_best)
                print("Test Best ADE Loss So Far (N = {})".format(N), best_test_loss)
                print("Test Best Min FDE (N = {})".format(N), best_endpoint_loss)

In [23]:
run_train()

dest-> (30307, 20, 2)
Epoch:  1
################## BEST PERFORMANCE 4.98 ########
Train Loss tf.Tensor(127799520.0, shape=(), dtype=float32)
RCL tf.Tensor(422538.7, shape=(), dtype=float32)
KLD tf.Tensor(127160770.0, shape=(), dtype=float32)
ADL tf.Tensor(216217.94, shape=(), dtype=float32)
Test ADE 4.983596108225233
Test Average FDE (Across  all samples) 6.510464298023897
Test Min FDE 6.510459271599265
Test Best ADE Loss So Far (N = 20) 4.983596108225233
Test Best Min FDE (N = 20) 6.510459271599265
dest-> (30307, 20, 2)
Epoch:  1
################## BEST PERFORMANCE 4.98 ########
Train Loss tf.Tensor(16692468.0, shape=(), dtype=float32)
RCL tf.Tensor(422502.12, shape=(), dtype=float32)
KLD tf.Tensor(16053777.0, shape=(), dtype=float32)
ADL tf.Tensor(216189.02, shape=(), dtype=float32)
Test ADE 4.982743053827522
Test Average FDE (Across  all samples) 6.510466452205883
Test Min FDE 6.510454963235294
Test Best ADE Loss So Far (N = 20) 4.982743053827522
Test Best Min FDE (N = 20) 6.5104549

KeyboardInterrupt: 