In [None]:
from __future__ import absolute_import, division, print_function

import base64
import imageio
import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image
import pyvirtualdisplay
import simpy

import tensorflow as tf

from tf_agents.agents.dqn import dqn_agent
from tf_agents.agents.categorical_dqn import categorical_dqn_agent
from tf_agents.networks import q_network
from tf_agents.networks import categorical_q_network

from tf_agents.policies import policy_saver
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common
from tf_agents.trajectories import time_step as ts
from tf_agents.specs import tensor_spec
#from env.RideSimulator.Grid import Grid
import tf_agents


import os,sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from RideSimulator.taxi_sim import run_simulation


In [None]:
#register custom env
import gym

gym.envs.register(
     id='taxi-v0',
     entry_point='env.taxi:TaxiEnv',
     max_episode_steps=1500,
     kwargs={'state_dict':None},
)

In [None]:
#hyper params

num_iterations = 30 # @param {type:"integer"}

initial_collect_steps = 1000  # @param {type:"integer"} 
collect_steps_per_iteration = 1  # @param {type:"integer"}
replay_buffer_max_length = 100000  # @param {type:"integer"}

batch_size = 64  # @param {type:"integer"}
learning_rate = 1e-3  # @param {type:"number"}
log_interval = 10  # @param {type:"integer"}

num_eval_episodes = 2  # @param {type:"integer"}
eval_interval = 5  # @param {type:"integer"}action

In [None]:
#load taxi env
env_name = "taxi-v0"
env = suite_gym.load(env_name)

tf_env = tf_py_environment.TFPyEnvironment(env)
reset = tf_env.reset()


In [None]:
#agent and policy
fc_layer_params = (100,)


q_net = q_network.QNetwork(
    tf_env.observation_spec(),
    tf_env.action_spec(),
    fc_layer_params=fc_layer_params)

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

train_step_counter = tf.Variable(0)

agent = dqn_agent.DqnAgent(
    tf_env.time_step_spec(),
    tf_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=train_step_counter)

agent.initialize()


#random policy
random_policy = random_tf_policy.RandomTFPolicy(tf_env.time_step_spec(),tf_env.action_spec())

#agent policy
eval_policy = agent.policy
collect_policy = agent.collect_policy

#replay buffer
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=tf_env.batch_size,
    max_length=replay_buffer_max_length)
    
saver = policy_saver.PolicySaver(eval_policy, batch_size=None)


In [None]:
#catagorical dqn agent
gamma = 0.99
num_atoms = 51  # @param {type:"integer"}
min_q_value = -20  # @param {type:"integer"}
max_q_value = 20  # @param {type:"integer"}
n_step_update = 2  # @param {type:"integer"}
categorical_q_net = categorical_q_network.CategoricalQNetwork(
    tf_env.observation_spec(),
    tf_env.action_spec(),
    num_atoms=num_atoms,
    fc_layer_params=fc_layer_params)

agent = categorical_dqn_agent.CategoricalDqnAgent(
    tf_env.time_step_spec(),
    tf_env.action_spec(),
    categorical_q_network=categorical_q_net,
    optimizer=optimizer,
    min_q_value=min_q_value,
    max_q_value=max_q_value,
    n_step_update=n_step_update,
    td_errors_loss_fn=common.element_wise_squared_loss,
    gamma=gamma,
    train_step_counter=train_step_counter)
agent.initialize()

#agent policy
eval_policy = agent.policy
collect_policy = agent.collect_policy

#replay buffer
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=tf_env.batch_size,
    max_length=replay_buffer_max_length)

In [None]:
#create dataset and iterator
# Dataset generates trajectories with shape [Bx2x...]
dataset = replay_buffer.as_dataset(
    num_parallel_calls=3, 
    sample_batch_size=batch_size, 
    num_steps=n_step_update+1).prefetch(3)

iterator = iter(dataset)
print(iterator)

In [None]:
"""
policy.action(reset)
#tf_env.time_step_spec()
print(reset)
#print(env.reset())
#print(ts.restart(tf.convert_to_tensor(np.array([0,0,0,0], dtype=np.int32), dtype=tf.float32)))
print(" ")
print(ts.TimeStep(tf.constant([0]), tf.constant([0.0]), tf.constant([1.0]),tf.convert_to_tensor(np.array([[0,0,0,0]], dtype=np.int32), dtype=tf.float32)))

#print(tensor_spec.to_array_spec(reset))
#encoder_func = tf_agents.utils.example_encoding.get_example_encoder(env.reset())
#encoder_func(env.reset())
"""

#run_simulation(policy)
#ts.termination(np.array([1,2,3,4], dtype=np.int32), reward=0.0)
#ts.transition(np.array([1,2,3,4], dtype=np.int32), reward=0.0, discount=1.0)

In [None]:
#create a static environment for evaluation purposes

#policy that always accepts
class AcceptPolicy:
  def __init__(self):
    print("init")

  def action(self, obs):
    return (tf.constant([1]))

acceptPol = AcceptPolicy()

eval_env = run_simulation(acceptPol)
#print(eval_env)

In [None]:
#evaluate a trained policy with respect to a pre-generated static environment
def evaluatePolicy(policy, eval_env):
    episode_reward = 0
    for state_list in eval_env:
        states = []
        driver_reward = 0
        
        for i in range(len(state_list)):
            state_tf = ts.TimeStep(tf.constant([1]), tf.constant(state_list[i]["reward"], dtype=tf.float32), tf.constant([1.0]), tf.convert_to_tensor(np.array([state_list[i]["observation"]], dtype=np.float32), dtype=tf.float32))
            action = policy.action(state_tf)
            #action = tf.random.uniform([1], 0, 2, dtype=tf.int32)
            if (action[0].numpy() == 1):
                reward = state_list[i]["reward"]
            else:
                reward = 0
            print (reward)
            driver_reward += reward
        episode_reward += driver_reward
        print("driver reward ", driver_reward)
    print("total reward ", episode_reward)

evaluatePolicy(acceptPol, eval_env)

In [None]:
# compute average returnstep
def compute_avg_return(policy, num_episodes=10):
    total_reward = 0

    for i in range (num_episodes):
        #run one episode of simulation and record states
        state_lists = run_simulation(policy)
        episode_reward = 0
        for state_list in state_lists:
            states = []
            driver_reward = 0

            #convert states directly to tf timesteps
            for i in range(len(state_list)):
                state_tf = ts.TimeStep(tf.constant([1]), tf.constant(state_list[i]["reward"], dtype=tf.float32), tf.constant([1.0]), tf.convert_to_tensor(np.array([state_list[i]["observation"]], dtype=np.float32), dtype=tf.float32))
                driver_reward += state_tf.reward
            episode_reward += driver_reward
        
        #take average reward for all drivers in the episode
        episode_reward = episode_reward / len(state_lists)
        total_reward += episode_reward

    avg_return = total_reward / num_episodes
    print(avg_return)
    return avg_return.numpy()


In [None]:
#collect trajectories

def collect_data(num_iterations, policy, replay_buffer):
    for i in range (num_iterations):
        #run one episode of simulation and record states
        state_lists = run_simulation(policy)
        print("driver count : ", len(state_lists))
        for state_list in state_lists:
            states = []
            actions = []

            #convert states directly to tf timesteps
            for i in range(len(state_list)):
                #create time step
                if i == 0:
                    #state_tf = ts.restart(np.array(state_list[i]["observation"], dtype=np.float32))
                    state_tf = ts.TimeStep(tf.constant([0]), tf.constant([3.0]), tf.constant([1.0]), tf.convert_to_tensor(np.array([state_list[i]["observation"]], dtype=np.float32), dtype=tf.float32))
                    #print("first reward ", state_list[i]["reward"])
                    #print (state_tf)
                elif i < (len(state_list) - 1):
                    #reward is taken fro (i-1) because it should be the reward from the already completed action (prev. action)
                    state_tf = ts.TimeStep(tf.constant([1]), tf.constant(state_list[i-1]["reward"], dtype=tf.float32), tf.constant([1.0]), tf.convert_to_tensor(np.array([state_list[i]["observation"]], dtype=np.float32), dtype=tf.float32))
                    #state_tf = ts.termination(np.array(state_list[i]["observation"], dtype=np.float32), reward=state_list[i]["reward"])
                else:
                    state_tf = ts.TimeStep(tf.constant([2]), tf.constant(state_list[i-1]["reward"], dtype=tf.float32), tf.constant([0.0]), tf.convert_to_tensor(np.array([state_list[i]["observation"]], dtype=np.float32), dtype=tf.float32))

                #create action
                """if state_list[i]["action"] == 1:
                    action = tf.constant([1], dtype=tf.int32)
                else:
                    action = tf.constant([0], dtype=tf.int32)"""
                action = state_list[i]["action"]
                #print
                #print ("action", state_list[i]["action"])
                #print("obs", state_list[i]["observation"])
                states.append(state_tf)
                actions.append(action)

            for j in range(len(states)-1):
                present_state = states[j]
                next_state = states[j+1]
                action = actions[j]
                traj = trajectory.from_transition(present_state, action, next_state)
                #print(action)
                # Add trajectory to the replay buffer
                replay_buffer.add_batch(traj)
                #print(traj)
        """
        #re-register environemnt with new states
        env_name = 'taxi-v'+str(i)
        gym.envs.register(
             id=env_name,
             entry_point='env.taxi:TaxiEnv',
             max_episode_steps=1500,
             kwargs={'state_dict':state_list},
        )

        #reload new env
        env = suite_gym.load(env_name)
        tf_env = tf_py_environment.TFPyEnvironment(env)

        #reset tf env
        time_step = tf_env.reset()

        #loop through recorded steps
        for step in state_dict:
            present_state = tf_env.current_time_step()
            action = step.action
            new_state = tf_env.step(action)
            traj = trajectory.from_transition(time_step, action_step, next_time_step)
            replay_buffer.add_batch(traj)
        """
        #print(replay_buffer)
#collect_data(num_iterations, policy, replay_buffer)

In [None]:
#train agents

try:
    %%time
except:
    pass

# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)

# Reset the train step
agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_policy, num_eval_episodes)
print(' Average Return = {0}'.format( avg_return))
returns = [avg_return]
lost_iterations = 0
for _ in range(num_iterations):
    try:
        # Collect a few steps using collect_policy and save to the replay buffer.
        collect_data(collect_steps_per_iteration, collect_policy, replay_buffer)

        # Sample a batch of data from the buffer and update the agent's network.
        experience, unused_info = next(iterator)
        train_loss = agent.train(experience)

        step = agent.train_step_counter.numpy()

        if step % log_interval == 0:
            print('step = {0}: loss = {1}'.format(step, train_loss))

        if step % eval_interval == 0:
            avg_return = compute_avg_return(eval_policy, num_eval_episodes)
            print('step = {0}: Average Return = {1}'.format(step, avg_return))
            returns.append(avg_return)
            print("evaluation")
            saver.save('policy_%d' % step)
    
    except IndexError:
        lost_iterations += 1
        print("skipping iteration due to driver error")

In [None]:
#visualize progress
iterations = range(0, num_iterations +1, eval_interval)
plt.plot(iterations, returns)
plt.ylabel('Average Return')
plt.xlabel('Iterations')
#plt.ylim(top=50000)

In [None]:
#run_simulation(eval_policy)
evaluatePolicy(eval_policy, eval_env)

In [None]:
#evaluate a trained policy with respect to a pre-generated static environment
def evaluateSavedPolicy(policy, policy_state, eval_env):
    episode_reward = 0
    for state_list in eval_env:
        states = []
        driver_reward = 0
        
        for i in range(len(state_list)):
            state_tf = ts.TimeStep(tf.constant([1]), tf.constant(state_list[i]["reward"], dtype=tf.float32), tf.constant([1.0]), tf.convert_to_tensor(np.array([state_list[i]["observation"]], dtype=np.float32), dtype=tf.float32))
            action = policy.action(state_tf, policy_state)
            #action = tf.random.uniform([1], 0, 2, dtype=tf.int32)
            if (action[0].numpy() == 1):
                reward = state_list[i]["reward"]
            else:
                reward = 0
            print (reward)
            driver_reward += reward
        episode_reward += driver_reward
        print("driver reward ", driver_reward)
    print("total reward ", episode_reward)


In [None]:
#load saved policy
saved_policy = tf.compat.v2.saved_model.load('pol/policy_10')
policy_state = saved_policy.get_initial_state(batch_size=3)
"""time_step = ...
while True:
  policy_step = saved_policy.action(time_step, policy_state)
  policy_state = policy_step.state
  time_step = f(policy_step.action)
"""
observations = [8, 10, 0, 35]
#observation_ts = ts.transition(np.array(observations, dtype=np.float32), reward=0.0, discount=1.0)
observation_ts = ts.TimeStep(tf.constant([1]), tf.constant([0.0]), tf.constant([1.0]),
                                tf.convert_to_tensor(np.array([observations], dtype=np.float32), dtype=tf.float32))
action = saved_policy.action(observation_ts, policy_state)
print(action)

In [None]:
avg_return = compute_avg_return(saved_policy, num_eval_episodes)

In [None]:
#evaluateSavedPolicy(saved_policy, policy_state, eval_env)
evaluatePolicy(eval_policy, eval_env)

In [16]:
#test against data from pickme dataset
import pandas as pd
week_6 = pd.read_csv("Eval_data.csv")
tot = 0
tot_accept = 1
num = 5000
for i in range(num):
    data_point = week_6.iloc[i][['distance_to_pickup','trip_distance','day_time','accepted_trip_count','action']].tolist()
    #observation_ts = ts.transition(np.array(data_point[:-1], dtype=np.float32), reward=0.0, discount=1.0)
    #print(np.array(data_point[:-1],dtype=np.float32))
    observation_ts = ts.TimeStep(tf.constant([1]), tf.constant([0.0]), tf.constant([1.0]), tf.convert_to_tensor(np.array([data_point[:-1]], dtype=np.float32), dtype=tf.float32))
    policy_step = eval_policy.action(observation_ts)
    policy_state = policy_step.state
    #print(policy_step.action.numpy()[0])
    if policy_step.action.numpy()[0] == 1:
        tot_accept += 1
    if policy_step.action.numpy()[0] == data_point[-1]:
        tot += 1

print(f'Accuracy: {tot/num * 100}%')
print(f'accept freq: {tot_accept/num * 100}%')

[ 4.368 13.28  16.     0.   ]
1
[ 8.2  26.44 14.    4.  ]
1
[ 0.788 18.44  15.     5.   ]
1
[ 2.08 42.56 17.    7.  ]
1
[ 8.292 21.28  18.     9.   ]
1
[ 0.728  4.04  18.    10.   ]
0
[ 5.644 12.56  19.    13.   ]
0
[ 4.18 21.28  8.   20.  ]
1
[4.000e-03 1.532e+01 1.100e+01 2.100e+01]
1
[ 2.048 55.4   12.    22.   ]
1
[  3.648 174.08   14.     23.   ]
1
[ 6.072 18.48  14.    23.   ]
1
[ 3.344 21.32  17.    25.   ]
1
[8.00e-03 8.68e+00 1.80e+01 2.50e+01]
0
[ 1.38 17.96 18.   25.  ]
1
[ 0.196 64.36  11.    32.   ]
1
[ 1.392 41.72  15.    34.   ]
1
[ 5.28 94.6  15.   35.  ]
1
[ 6.332 25.24  15.    35.   ]
1
[ 1.424 68.6   15.    35.   ]
1
[ 2.044 81.68  16.    35.   ]
1
[13.588 46.28  17.    36.   ]
1
[ 6.124 54.68  18.    37.   ]
1
[ 3.48 12.48  8.   42.  ]
0
[ 0.   69.08  9.   45.  ]
1
[ 5.464 34.32  11.    49.   ]
1
[ 2.72 10.92 15.   53.  ]
0
[ 7.752 29.08  16.    55.   ]
1
[ 3.452 21.76  17.    57.   ]
0
[ 0.   61.88 18.   58.  ]
1
[ 6.26 11.64  9.   63.  ]
0
[ 0.5  54.76 10.   64.  

0
[ 4.084 69.88  15.    88.   ]
1
[ 3.344 19.28  16.    89.   ]
0
[ 3.288 23.12  17.    92.   ]
0
[ 2.228 17.92  17.    93.   ]
0
[ 2.328 45.28  18.    94.   ]
1
[4.00e-02 2.66e+01 1.80e+01 9.50e+01]
0
[ 1.268 25.04  18.    96.   ]
0
[ 6.956  8.44  18.    96.   ]
0
[ 1.584 10.    18.    96.   ]
0
[6.000e-02 2.124e+01 1.800e+01 9.600e+01]
0
[  5.18  23.6   12.   106.  ]
0
[  7.084  62.68   13.    108.   ]
1
[  2.792  25.64   14.    109.   ]
0
[  4.34  84.    16.   110.  ]
1
[  4.736  22.68   20.    115.   ]
0
[ 4.18 24.2   7.    0.  ]
1
[0.444 3.04  7.    1.   ]
0
[ 2.268 13.32  10.     6.   ]
1
[1.600e-02 5.888e+01 7.000e+00 1.600e+01]
1
[ 1.592 44.84   8.    17.   ]
1
[ 8.408 92.88   9.    20.   ]
1
[  3.016 121.8    11.     25.   ]
1
[ 1.036 42.04  12.    28.   ]
1
[ 2.524 24.32  13.    29.   ]
1
[ 3.844 27.56   7.    30.   ]
1
[ 2.044 31.44   8.    32.   ]
1
[ 1.608 47.68  10.    37.   ]
1
[ 2.288 25.76  15.    42.   ]
1
[15.064 16.56   7.    44.   ]
0
[ 5.94 11.32  8.   47.  ]
0
[ 

1
[ 3.076 56.16  15.    28.   ]
1
[ 2.544 85.44  15.    28.   ]
1
[ 2.408 38.24  17.    29.   ]
1
[ 1.344 15.92  17.    29.   ]
0
[ 3.512 14.12  17.    29.   ]
0
[ 3.552 60.24  19.    30.   ]
1
[ 0.576 34.88  19.    30.   ]
1
[ 3.724  6.32  19.    30.   ]
0
[ 4.544 41.68  19.    30.   ]
1
[ 0.972 26.52  19.    30.   ]
1
[ 4.584  8.72  19.    30.   ]
0
[ 3.616 23.2   19.    30.   ]
1
[ 0.744 34.36  19.    30.   ]
1
[11.108 17.4   19.    30.   ]
0
[ 7.948 14.16  19.    30.   ]
0
[ 3.756  9.44  19.    30.   ]
0
[13.76 14.96 20.   32.  ]
0
[ 7.704 32.4   21.    33.   ]
1
[  5.532 261.12   17.     39.   ]
1
[ 5.624 13.08  17.    39.   ]
0
[ 4.712 13.24  18.    40.   ]
0
[ 4.18 27.76 18.   40.  ]
1
[ 2.956 19.28  18.    40.   ]
0
[ 3.232 64.    18.    40.   ]
1
[ 2.764 76.16  21.    42.   ]
1
[ 0.392 16.12  22.    44.   ]
0
[ 2.992 20.92  22.    44.   ]
0
[ 6.268 17.48  22.    44.   ]
0
[ 0.136 18.84  22.    44.   ]
0
[ 8.108 13.16  23.    44.   ]
0
[ 0.368 11.92  23.    44.   ]
0
[ 2.56  4.

[ 0.364 20.36  19.    12.   ]
1
[ 5.696 10.28  19.    13.   ]
0
[ 2.676 18.68  20.    14.   ]
1
[ 5.34 53.36 22.   16.  ]
1
[ 2.812  4.24  18.    18.   ]
0
[ 0.764 13.68  18.    19.   ]
0
[ 4.276 46.24  20.    20.   ]
1
[ 0.836  6.44  21.    23.   ]
0
[21.26 19.08 21.   24.  ]
0
[ 0.088 50.28  22.    27.   ]
1
[ 3.968 21.8    8.    30.   ]
1
[ 3.864 33.    12.    33.   ]
1
[ 2.74 82.16 13.   34.  ]
1
[ 0.176  3.64  14.    35.   ]
0
[ 0.928  7.68  15.    35.   ]
0
[ 1.388 19.56  19.    42.   ]
0
[ 3.352 11.32  19.    43.   ]
0
[16.768 20.68  21.    46.   ]
0
[ 0.528 36.48  22.    47.   ]
1
[ 0.728 45.36   7.    49.   ]
1
[ 8.936 47.76  18.    58.   ]
1
[ 2.772 15.24  19.    59.   ]
0
[ 1.48 11.28 19.   60.  ]
0
[2.800e-02 7.524e+01 1.900e+01 6.100e+01]
1
[ 5.412  5.36   8.    63.   ]
0
[ 3.008 18.64  10.    66.   ]
0
[15.52 13.44 16.   70.  ]
0
[ 1.184 19.12  18.    72.   ]
0
[ 6.   53.12 19.   73.  ]
1
[ 3.324 22.88   9.    75.   ]
0
[ 5.268 28.88  11.    77.   ]
0
[ 2.672 32.76  14.  

0
[  0.284  61.08   18.    105.   ]
1
[  3.796   6.8    19.    106.   ]
0
[  4.48  20.28  20.   107.  ]
0
[  1.896   8.32   20.    108.   ]
0
[ 7.056  5.8   13.     5.   ]
0
[ 2.08 30.44 15.    6.  ]
1
[ 5.892 52.52  16.     7.   ]
1
[ 4.   70.56 16.    7.  ]
1
[ 6.508 13.12  15.    13.   ]
0
[ 0.992 54.16  17.    17.   ]
1
[ 1.412 38.16  18.    18.   ]
1
[ 6.168 18.36  18.    19.   ]
0
[17.576 24.64  22.    24.   ]
1
[ 3.708 28.4   11.    30.   ]
1
[ 0.272 20.08  16.    36.   ]
1
[ 1.6  11.28 18.   38.  ]
0
[ 1.492 53.68  18.    39.   ]
1
[ 3.812 23.96  18.    39.   ]
1
[ 2.304 23.36  19.    40.   ]
1
[ 2.244 13.24  21.    42.   ]
0
[ 6.66 33.68  7.   44.  ]
1
[ 4.928 26.72   8.    48.   ]
1
[ 3.192 27.12  10.    51.   ]
1
[ 1.164 16.52  13.    54.   ]
0
[ 4.464  6.4   13.    55.   ]
0
[ 3.688  8.8   13.    56.   ]
0
[ 5.476  7.72  13.    56.   ]
0
[ 0.608 33.2   13.    57.   ]
1
[ 6.972 16.4   15.    58.   ]
0
[  8.78 114.12  16.    60.  ]
1
[ 1.32 60.48  9.   67.  ]
1
[ 4.556 16.76 

1
[ 7.692 32.64   8.     4.   ]
1
[ 15.716 103.32   16.     14.   ]
1
[ 8.192  8.2   16.    14.   ]
0
[ 1.56 25.92 18.   18.  ]
1
[ 1.648 14.88  18.    18.   ]
0
[ 2.836 50.16  18.    18.   ]
1
[  0.772 149.6    16.      7.   ]
1
[ 0.6   4.28  9.   10.  ]
0
[ 3.152 49.04  18.    12.   ]
1
[ 8.028 70.8   18.    12.   ]
1
[ 5.4  8.4 14.  13. ]
0
[0.488 4.8   8.    0.   ]
1
[ 8.648 24.72   8.     0.   ]
1
[11.784  7.84  10.     5.   ]
0
[ 1.064 55.44  11.     8.   ]
1
[ 8.844 20.92  14.    11.   ]
1
[ 1.204 13.36  15.    12.   ]
0
[ 3.416 41.52  16.    13.   ]
1
[ 3.72 60.12 16.   14.  ]
1
[ 7.68 27.36 20.   19.  ]
1
[ 5.088  0.28  21.    19.   ]
0
[ 4.432 37.92  23.    22.   ]
1
[ 0.54  6.2   2.   26.  ]
0
[ 3.352 42.04   6.    26.   ]
1
[1.6000e-02 2.0368e+02 6.0000e+00 2.6000e+01]
1
[ 4.808 18.6    6.    26.   ]
1
[ 4.392 67.44   8.    28.   ]
1
[ 1.396  8.28   8.    29.   ]
0
[ 1.184 97.32   8.    29.   ]
1
[ 6.212 48.6    8.    29.   ]
1
[ 1.22 23.12  9.   30.  ]
1
[ 1.664 28.76  12.

1
[ 3.664 37.96  12.     6.   ]
1
[ 5.268 78.72  12.     6.   ]
1
[ 3.26 31.28  7.    6.  ]
1
[8.064 7.92  8.    7.   ]
0
[ 3.396 36.     8.     7.   ]
1
[ 4.456 13.84   8.     7.   ]
1
[ 2.928 11.28   8.     7.   ]
1
[ 6.416 26.56   9.     7.   ]
1
[ 1.704 13.    10.     8.   ]
1
[ 0.96 21.28 12.    9.  ]
1
[ 0.28  4.72 12.    9.  ]
0
[ 1.728 10.28  15.     9.   ]
0
[ 0.032 13.16  15.     9.   ]
1
[ 3.112 15.8   15.     9.   ]
1
[ 3.532  2.16  15.     9.   ]
0
[20.292 13.    15.     9.   ]
0
[ 1.72 12.6  15.    9.  ]
1
[ 0.764 61.    15.     9.   ]
1
[ 4.968 31.68  17.     9.   ]
1
[ 5.124 61.08  18.    10.   ]
1
[ 9.968 41.68  19.    11.   ]
1
[8.000e-03 3.076e+02 7.000e+00 1.100e+01]
1
[ 0.032 13.64   7.    11.   ]
1
[ 7.596 13.12   7.    11.   ]
0
[ 1.68 46.48  7.   11.  ]
1
[ 3.104  3.72  15.    15.   ]
0
[ 0.14 84.44 17.   16.  ]
1
[ 2.732 50.24  17.    17.   ]
1
[ 0.94  2.52 18.   18.  ]
0
[ 1.192 29.68   9.    19.   ]
1
[  4.336 135.44    9.     19.   ]
1
[ 1.076 14.2   10.    

1
[ 1.628 93.16  14.    15.   ]
1
[  7.216 172.84   14.     15.   ]
1
[ 1.748 15.76  14.    15.   ]
1
[ 1.796 22.24  15.    15.   ]
1
[8.000e-03 1.152e+01 1.600e+01 1.700e+01]
0
[ 3.288 38.6   11.    21.   ]
1
[ 2.976 10.76  10.     1.   ]
1
[ 3.2  16.64 11.    3.  ]
1
[ 6.864 35.48  21.     4.   ]
1
[ 3.64  8.36 19.    5.  ]
0
[ 1.752 34.08  14.     6.   ]
1
[ 1.896 10.08  19.     1.   ]
1
[ 4.74 17.92 18.    3.  ]
1
[ 4.452 47.    19.     4.   ]
1
[ 1.608 27.68  19.     5.   ]
1
[ 0.884  6.28  19.    12.   ]
0
[ 2.664 77.12  20.    15.   ]
1
[ 3.284  5.64  21.    17.   ]
0
[ 4.848 16.2   22.    19.   ]
0
[ 5.524 18.76  21.    25.   ]
0
[36.732 63.04  21.    27.   ]
1
[13.568 72.36  22.    28.   ]
1
[ 5.148 21.32   0.    30.   ]
0
[ 7.356 48.    22.    36.   ]
1
[ 5.556 21.12  23.    38.   ]
0
[ 2.   28.92 19.   39.  ]
1
[ 5.156 27.32  21.    40.   ]
1
[ 2.048 11.92  21.    41.   ]
0
[ 7.152  4.44  22.    43.   ]
0
[ 4.124  4.08  23.    45.   ]
0
[ 3.104 10.48  16.     2.   ]
1
[ 19.7

0
[ 4.748  4.28  18.     4.   ]
0
[ 3.304 34.72  20.     4.   ]
1
[ 2.808 10.08  20.     4.   ]
1
[ 4.816 82.64  21.     6.   ]
1
[ 6.896  3.76  22.     9.   ]
0
[ 2.152 27.    15.    10.   ]
1
[ 3.892 45.36  16.    10.   ]
1
[ 1.096 27.64  18.    11.   ]
1
[10.188 12.72  18.    11.   ]
0
[21.932 30.8   18.    11.   ]
1
[32.772  7.36  18.    11.   ]
0
[ 1.28  8.92 18.   11.  ]
0
[ 3.82  7.92 18.   11.  ]
0
[ 2.484  3.4   18.    11.   ]
0
[ 4.572 47.12  20.    13.   ]
1
[ 7.292  8.04  21.    13.   ]
0
[ 2.596 30.88  21.    13.   ]
1
[ 2.644 25.36  23.    15.   ]
1
[ 1.78 30.84 16.   16.  ]
1
[13.   67.16 16.   16.  ]
1
[ 5.032 38.12  16.    17.   ]
1
[ 4.532 71.72  17.    18.   ]
1
[ 2.636 42.88  20.    18.   ]
1
[ 7.268  8.44  20.    18.   ]
0
[ 0.804 50.32  21.    19.   ]
1
[ 2.988 24.    22.    20.   ]
1
[ 3.4  14.04 22.   20.  ]
0
[ 1.74 29.88 17.   24.  ]
1
[ 6.76  8.44 18.   25.  ]
0
[10.232 42.36  18.    25.   ]
1
[ 3.624 81.24  18.    25.   ]
1
[ 3.504 21.6   19.    26.   ]
1
[ 

0
[ 3.944 12.04  15.    50.   ]
0
[ 6.12  6.44 16.   52.  ]
0
[ 7.032 16.32  16.    53.   ]
0
[ 1.192 27.12  17.    55.   ]
1
[ 0.072 27.44  18.    59.   ]
1
[ 8.1  3.2 18.  60. ]
0
[ 3.376 27.88  19.    61.   ]
1
[ 6.392 12.32  19.    61.   ]
0
[ 1.544  8.32  19.    62.   ]
0
[ 8.94 11.44 19.   62.  ]
0
[ 1.024 19.88  20.    64.   ]
0
[ 8.6  53.68 20.   64.  ]
1
[ 1.76 10.68 10.   64.  ]
0
[ 1.752 28.04  11.    64.   ]
1
[ 7.556 72.88   0.     0.   ]
1
[ 3.2  41.32  0.    1.  ]
1
[ 1.988 13.24   2.     4.   ]
1
[ 1.38 39.08 12.    4.  ]
1
[ 1.724 24.88  13.     5.   ]
1
[ 0.888 62.72  14.     6.   ]
1
[ 6.264  9.92  14.     6.   ]
0
[ 8.248 14.12  14.     6.   ]
1
[ 4.168 64.    15.     7.   ]
1
[ 4.988 35.84  17.     8.   ]
1
[ 3.684 12.68  17.     9.   ]
1
[ 2.46 29.32 11.   17.  ]
1
[ 7.916 28.16  12.    19.   ]
1
[ 0.752 15.    12.    20.   ]
1
[ 3.016 58.84  13.    21.   ]
1
[ 5.568 24.48  14.    23.   ]
1
[ 2.56 17.44 15.   26.  ]
1
[ 1.212 38.16  16.    27.   ]
1
[ 1.04 19.44 1

0
[1.60e-02 4.52e+00 7.00e+00 4.80e+01]
0
[ 5.844 71.28  10.    51.   ]
1
[ 5.428 17.2   12.    55.   ]
0
[ 1.184 45.    14.    57.   ]
1
[ 2.972 26.44  17.    58.   ]
1
[ 1.844 11.52  17.    58.   ]
0
[ 0.864 16.64  17.    59.   ]
0
[ 2.244 24.76  15.    67.   ]
0
[ 5.428 13.28  18.    75.   ]
0
[ 2.772 16.2   19.    76.   ]
0
[ 2.292 28.2   20.    77.   ]
1
[ 0.872  8.76  21.    80.   ]
0
[ 8.684 16.12  23.    83.   ]
0
[ 4.156 10.4   12.     4.   ]
1
[ 5.768 18.28  16.     5.   ]
1
[ 3.996 40.84  17.     6.   ]
1
[ 1.236 38.     6.    11.   ]
1
[ 0.164 29.76   6.    11.   ]
1
[ 8.008  9.28   7.    12.   ]
0
[ 2.448 44.64   7.    12.   ]
1
[ 3.788 22.44   7.    12.   ]
1
[ 4.468 80.88   9.    13.   ]
1
[ 9.592 12.72  13.    19.   ]
0
[ 0.052 32.08  17.    21.   ]
1
[ 0.996 33.44  18.    22.   ]
1
[ 2.28 40.24 19.   24.  ]
1
[ 1.604 22.92  20.    24.   ]
1
[ 1.256 14.68  20.    24.   ]
0
[ 3.564 10.96  20.    25.   ]
0
[ 1.244 16.36  20.    25.   ]
0
[ 8.068 29.32  20.    25.   ]
1
[ 

1
[10.136 27.16  16.     7.   ]
1
[14.652 21.44  16.     7.   ]
1
[ 2.632 41.28  16.    16.   ]
1
[ 0.684 19.92  16.    18.   ]
1
[ 1.84  1.32 16.   18.  ]
0
[ 0.828 12.64  17.    18.   ]
0
[ 2.   31.88 18.   20.  ]
1
[ 3.08 52.92 18.   20.  ]
1
[ 1.648  6.    18.    20.   ]
0
[ 5.872 41.32  18.    20.   ]
1
[ 5.204 45.76   8.    22.   ]
1
[ 3.216 40.88  10.    24.   ]
1
[ 6.444 12.32  20.    25.   ]
0
[ 2.696  3.08  20.    25.   ]
0
[ 0. 10. 20. 26.]
0
[1.6000e-02 7.4868e+02 2.0000e+01 2.6000e+01]
1
[ 2.712 38.44  21.    27.   ]
1
[ 1.596 25.08  21.    28.   ]
1
[19.324 28.6   21.    30.   ]
1
[ 0.556 33.36   6.    33.   ]
1
[ 1.96 17.12  6.   33.  ]
0
[ 4.016 28.12   7.    34.   ]
1
[ 2.808  5.52  20.    36.   ]
0
[ 3.016 23.44  19.    39.   ]
1
[ 6.628 22.32  20.    39.   ]
0
[ 0.82 39.72 20.   39.  ]
1
[ 2.048 20.28  20.    40.   ]
0
[ 2.144 27.96   8.    42.   ]
1
[ 9.648 33.96  12.    43.   ]
1
[ 6.32 18.92 19.   47.  ]
0
[ 2.844 11.48  19.    47.   ]
0
[ 3.62 25.4  20.   48.  ]


0
[ 4.656 69.72  13.    31.   ]
1
[ 3.452 57.12  14.    32.   ]
1
[17.02 27.72 15.   34.  ]
0
[14.46 69.4  16.   35.  ]
1
[ 5.6  40.36 16.   35.  ]
1
[10.112 18.32  16.    35.   ]
0
[ 9.776 23.84  16.    35.   ]
1
[ 7.316 65.08  16.    35.   ]
1
[ 2.212 15.6   16.    35.   ]
0
[18.544 42.24  16.    35.   ]
1
[ 4.932 18.08  18.    37.   ]
0
[ 0.124 16.68   5.     2.   ]
1
[3.864 4.84  6.    3.   ]
0
[ 0.324 12.92   7.     7.   ]
1
[ 3.632 29.8    9.    14.   ]
1
[ 2.004 31.32  10.    17.   ]
1
[ 0.672 17.04  12.    18.   ]
1
[ 0.032 19.16  13.    20.   ]
1
[ 2.3  10.36 14.   22.  ]
0
[12.356 36.44   8.    27.   ]
1
[ 3.784  6.    13.    32.   ]
0
[ 3.076 11.44  11.    38.   ]
0
[ 0.98 16.2  12.   43.  ]
0
[ 3.42 56.52 15.   45.  ]
1
[ 0.812 26.8   10.    50.   ]
1
[ 4.504 34.48  12.    54.   ]
1
[ 4.664 20.04  12.    55.   ]
0
[ 0.308 16.36  15.    58.   ]
0
[5.972 9.88  7.    0.   ]
1
[2.828 8.56  7.    0.   ]
1
[ 0.028 12.84   7.     0.   ]
1
[ 4.476 16.16   8.     0.   ]
1
[ 7.652 48

1
[ 0.556 18.28  16.    38.   ]
0
[ 3.664  5.4   17.    39.   ]
0
[13.232  5.36  20.    43.   ]
0
[ 0.   25.32 20.   43.  ]
1
[ 3.028  6.32  21.    45.   ]
0
[ 3.648 15.08  10.    49.   ]
0
[ 0.64 12.12 14.   57.  ]
0
[ 0.98 12.68 14.   58.  ]
0
[18.744 28.36  15.    59.   ]
0
[10.436 15.64  23.    74.   ]
0
[ 0.836 25.6    8.    76.   ]
0
[ 2.856 46.36  10.    79.   ]
1
[ 5.552 37.44  17.    88.   ]
1
[11.684 74.52  17.    89.   ]
1
[ 8.116 37.12  17.    89.   ]
1
[ 4.652  4.44  18.    90.   ]
0
[ 0.988  4.96  21.    92.   ]
0
[ 2.52  6.72 21.   93.  ]
0
[23.732 52.4   21.    95.   ]
1
[ 6.664 10.96  10.    98.   ]
0
[ 33.852  50.84   12.    103.   ]
0
[  7.332  41.04   13.    106.   ]
0
[  8.36  21.68  14.   109.  ]
0
[  0.   20.4  14.  111. ]
0
[  0.376  38.08   15.    113.   ]
0
[  8.396  16.44   17.    117.   ]
0
[  1.6   34.48  18.   121.  ]
0
[  4.54  27.88  20.   128.  ]
0
[  3.768   3.12   20.    129.   ]
0
[  1.48  16.92  20.   129.  ]
0
[  2.316   4.36   20.    130.   ]
0
[ 

0
[  1.036  34.52   14.    139.   ]
0
[ 10.584  18.16   17.    141.   ]
0
[  3.732  71.2    21.    143.   ]
1
[ 0.52 90.76 14.    8.  ]
1
[ 0.66 12.64 18.   15.  ]
0
[ 3.116 25.16  19.    16.   ]
1
[ 8.124 74.6   19.    17.   ]
1
[ 2.564  1.56  11.    21.   ]
0
[ 5.928 69.32  13.    23.   ]
1
[ 2.848 76.96  15.    24.   ]
1
[ 1.02 17.84 12.    0.  ]
1
[ 0.728 33.68  14.    11.   ]
1
[ 5.4  52.52 15.   12.  ]
1
[ 2.92  4.52 23.   21.  ]
0
[ 0.616 35.72  23.    22.   ]
1
[ 6.136 10.2   23.    46.   ]
0
[ 2.324 18.8    8.    49.   ]
0
[ 0.804 20.76   9.    50.   ]
0
[ 4.248 52.4    9.    51.   ]
1
[ 4.64 12.96 10.   52.  ]
0
[  4.9 174.8  11.   55. ]
1
[ 2.32 15.88 12.   56.  ]
0
[ 1.764 40.8   10.    59.   ]
1
[ 3.332 29.44  14.    67.   ]
1
[ 2.296 23.96  16.    72.   ]
0
[ 5.952 22.84   9.    74.   ]
0
[ 6.248  4.96   9.    76.   ]
0
[ 4.6  10.04 11.   81.  ]
0
[13.38 14.24 13.   86.  ]
0
[ 4.168 40.36   5.     1.   ]
1
[ 6.896 28.56   5.     2.   ]
1
[ 3.284 16.8    6.     5.   ]
1
[ 

0
[  5.732  15.32   18.    109.   ]
0
[  4.388  44.64   20.    111.   ]
1
[  0.18  18.2   20.   112.  ]
0
[  2.088  26.48   20.    113.   ]
0
[ 9.556 29.24  16.     2.   ]
1
[  2.972 157.8    16.      3.   ]
1
[ 3.572 28.08  17.     4.   ]
1
[20.228 37.96  10.     6.   ]
1
[ 0.936 36.2   18.     7.   ]
1
[ 0.968 14.88   7.     8.   ]
1
[ 3.172 14.     7.    10.   ]
1
[ 2.436 31.08   8.    12.   ]
1
[ 0.348 22.44   8.    12.   ]
1
[15.22 51.08 15.   15.  ]
1
[ 5.676 15.48  15.    16.   ]
0
[ 2.66 10.4  13.    3.  ]
1
[ 0.024  8.24  19.     7.   ]
0
[4.0e-03 2.8e+01 1.1e+01 1.0e+01]
1
[ 8.524 22.84  15.    11.   ]
1
[ 3.72 96.48 11.   15.  ]
1
[ 0.724 43.44  11.    15.   ]
1
[ 7.404 21.76  11.    15.   ]
1
[14.984 30.68  14.    19.   ]
1
[ 0.556  4.12  17.    25.   ]
0
[15.312 61.32  18.    27.   ]
1
[ 4.244 29.44  19.    28.   ]
1
[ 1.944 14.6   19.    28.   ]
0
[ 2.64 26.36 20.   29.  ]
1
[ 2.024 22.52  14.     0.   ]
1
[ 0.2 68.4 15.   2. ]
1
[ 0.74 21.   17.    4.  ]
1
[ 4.448  8.4  

[ 3.8   8.92  7.   38.  ]
0
[ 4.352 25.96   9.    40.   ]
1
[2.000e-02 3.108e+01 1.600e+01 4.600e+01]
1
[ 8.6 66.8 16.  47. ]
1
[ 8.492 32.    18.    57.   ]
1
[ 3.196  9.64  19.    59.   ]
0
[ 0.688 60.96  19.    60.   ]
1
[10.168 40.88  14.     1.   ]
1
[ 0.  38.4  5.   1. ]
1
[7.592 7.88  6.    2.   ]
1
[ 3.004 15.04   6.     2.   ]
1
[ 2.516 16.4   12.     2.   ]
1
[  4.376 122.04   13.      4.   ]
1
[ 0.888  9.2   14.     5.   ]
1
[10.532 43.08  22.     8.   ]
1
[ 2.808 11.4   23.     8.   ]
1
[ 3.952 32.72   0.     9.   ]
1
[8.02 8.56 0.   9.  ]
0
[14.216  7.56  11.    19.   ]
0
[ 3.16  8.4  14.   22.  ]
0
[ 1.1  17.24 17.   27.  ]
1
[ 0.884 39.04  17.    36.   ]
1
[ 4.508 36.72  18.    37.   ]
1
[ 0.052 28.24  18.    38.   ]
1
[ 1.8  17.72 22.   40.  ]
0
[ 1.392 24.12  10.    44.   ]
1
[ 2.616 32.76  14.    51.   ]
1
[ 1.516  3.64  19.    54.   ]
0
[ 3.088 19.36  23.    56.   ]
0
[ 1.848 24.8    1.    59.   ]
0
[ 0.492  9.48  22.    63.   ]
0
[ 5.368 29.28  22.    63.   ]
1
[ 3.

1
[11.832 11.44  19.    23.   ]
0
[ 1.528 31.6   19.    24.   ]
1
[ 8.06  6.76  9.   27.  ]
0
[ 4.42 89.   10.   28.  ]
1
[ 2.404 24.6   17.    36.   ]
1
[ 3.688 13.52  19.    42.   ]
0
[13.16 75.2  20.   43.  ]
1
[ 0.988 16.64  20.    43.   ]
0
[ 4.38  3.72 20.   45.  ]
0
[11.924  4.88  21.    48.   ]
0
[ 0.   11.52 21.   49.  ]
0
[12.876 10.88  21.    50.   ]
0
[13.392  6.2   21.    52.   ]
0
[ 3.928  8.36  22.    53.   ]
0
[5.20e-02 5.96e+00 2.30e+01 5.60e+01]
0
[ 0.132 49.4   23.    56.   ]
1
[ 0.248  1.68  23.    57.   ]
0
[ 2.188 49.88  12.     1.   ]
1
[ 0.252 21.92  20.     2.   ]
1
[ 2.86 19.88 21.    3.  ]
1
[ 0.124 24.48  21.     4.   ]
1
[ 8.344 51.84  22.     5.   ]
1
[ 2.564 18.    22.     6.   ]
1
[ 2.164  6.12  23.     9.   ]
0
[ 3.868 15.52  11.     9.   ]
1
[ 4.452 29.36  19.    19.   ]
1
[ 0.46  3.6  19.   20.  ]
0
[ 0.   20.08 20.   22.  ]
1
[ 2.264  9.88  21.    24.   ]
0
[ 3.988 45.6   22.    25.   ]
1
[10.96 66.36 14.   32.  ]
1
[ 4.788 13.08  15.    34.   ]
0
[1

In [None]:
"""

























reward results - 
random policy - around 9.5k
learned policy - 14k
always accept policy - 19.4k
"""

##############################################################################################














In [None]:
# startup simulation

def simpy_episode(rewards, steps, time_step, tf_env, policy):

    TIME_MULTIPLIER = 50
    DRIVER_COUNT = 1
    TRIP_COUNT = 8000
    RUN_TIME = 10000
    INTERVAL = 20
    # GRID_WIDTH = 3809
    # GRID_HEIGHT = 2622
    GRID_WIDTH = 60
    GRID_HEIGHT = 40
    HEX_AREA = 2.6

    Env = simpy.Environment()
    map_grid = Grid(env=Env, width=GRID_WIDTH, height=GRID_HEIGHT, interval=INTERVAL, num_drivers=DRIVER_COUNT,
                    hex_area=HEX_AREA)

    taxi_spots = map_grid.taxi_spots
    driver_list = create_drivers(Env, DRIVER_COUNT, map_grid)
    driver_pools = map_grid.driver_pools

    run_simulation(TRIP_COUNT, RUN_TIME, DRIVER_COUNT, TIME_MULTIPLIER, map_grid, taxi_spots, driver_list, driver_pools, Env, rewards, steps, time_step, tf_env, policy)
    t_count = 0
    for dr in driver_list:
        d_t_count = dr.total_trip_count
        t_count += d_t_count
        print(f"{dr.id} completed {d_t_count}")

    print(f"Total trip count: {t_count}")

In [None]:
var = tf.random.uniform([1], 0, 2, dtype=tf.int32)
var[0] = 2
print (var)

In [None]:
#simple episode run - atttempt 1

time_step = tf_env.reset()
rewards = []
steps = []
num_episodes = 5

for _ in range(num_episodes):
    simpy_episode(rewards, step, time_step, tf_env, policy)

    action = tf.random.uniform([1], 0, 2, dtype=tf.int32)
    time_step = tf_env.step(action)
    episode_steps += 1
    episode_reward += time_step.reward.numpy()
  rewards.append(episode_reward)
  steps.append(episode_steps)
  time_step = tf_env.reset()

num_steps = np.sum(steps)
avg_length = np.mean(steps)
avg_reward = np.mean(rewards)

In [None]:
#simple episode run - atttempt 2

#time_step = tf_env.reset()
rewards = []
steps = []
num_episodes = 5

for _ in range(num_episodes):
    time_step = tf_env.reset()
    
    
    
    
    
    
    simpy_episode(rewards, step, time_step, tf_env, policy)

    action = tf.random.uniform([1], 0, 2, dtype=tf.int32)
    time_step = tf_env.step(action)
    episode_steps += 1
    episode_reward += time_step.reward.numpy()
  rewards.append(episode_reward)
  steps.append(episode_steps)
  time_step = tf_env.reset()

num_steps = np.sum(steps)
avg_length = np.mean(steps)
avg_reward = np.mean(rewards)

In [None]:
#simple episode run template
"""
time_step = tf_env.reset()
rewards = []
steps = []
num_episodes = 5

for _ in range(num_episodes):
  episode_reward = 0
  episode_steps = 0
  while not time_step.is_last():
    action = tf.random.uniform([1], 0, 2, dtype=tf.int32)
    time_step = tf_env.step(action)
    episode_steps += 1
    episode_reward += time_step.reward.numpy()
  rewards.append(episode_reward)
  steps.append(episode_steps)
  time_step = tf_env.reset()

num_steps = np.sum(steps)
avg_length = np.mean(steps)
avg_reward = np.mean(rewards)

print('num_episodes:', num_episodes, 'num_steps:', num_steps)
print('avg_length', avg_length, 'avg_reward:', avg_reward)
"""