In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import gym
import envs
from tqdm import tqdm

In [2]:
class Brain(keras.Model):
    def __init__(self, action_dim=5,input_shape=(1, 8 * 8)):
        """Initialize the Agent's Brain model
        Args:
        action_dim (int): Number of actions
        """
        super(Brain, self).__init__()
        self.dense1 = layers.Dense(32, input_shape=input_shape, activation="relu")
        self.logits = layers.Dense(action_dim)
        
    def call(self, inputs):
        x = tf.convert_to_tensor(inputs)
        if len(x.shape) >= 2 and x.shape[0] != 1:
            x = tf.reshape(x, (1, -1))
        return self.logits(self.dense1(x))
    
    def process(self, observations):
        # Process batch observations using `call(inputs)`
        # behind-the-scenes
        action_logits = self.predict_on_batch(observations)
        return action_logits

In [3]:
class Agent(object):
    def __init__(self, action_dim=5, input_shape=(1, 8 * 8)):
        """Agent with a neural-network brain powered policy
        Args:
            brain (keras.Model): Neural Network based model
        """
        self.brain = Brain(action_dim, input_shape)
        self.policy = self.policy_mlp

    def policy_mlp(self, observations):
        observations = observations.reshape(1, -1)
        # action_logits = self.brain(observations)
        action_logits = self.brain.process(observations)
        action = tf.random.categorical(tf.math.log(action_logits), num_samples=1)
        return tf.squeeze(action, axis=1)

    def get_action(self, observations):
        return self.policy(observations)

    def learn(self, samples):
        raise NotImplementedError

In [4]:
#evaluate the agent in a given environment for one episode
def evaluate(agent, env, render=True):
    obs, episode_reward, done, step_num = env.reset(),0.0, False, 0
    while not done:
        action = agent.get_action(obs)
        obs, reward, done, info = env.step(action)
        episode_reward += reward
        step_num += 1
        if render:
            env.render()
    return step_num, episode_reward, done, info

In [8]:
#main function
if __name__ == "__main__":
    env = gym.make("Gridworld-v0")
    agent = Agent(env.action_space.n,env.observation_space.shape)
    for episode in tqdm(range(10)):
        steps, episode_reward, done, info = evaluate(agent, env)
        print(f"EpReward:{episode_reward:.2f} steps:{steps} done:{done} info:{info}")
    env.close()

 10%|████████▎                                                                          | 1/10 [00:00<00:07,  1.19it/s]

EpReward:-1.60 steps:100 done:True info:{'status': 'Max steps reached'}


 20%|████████████████▌                                                                  | 2/10 [00:01<00:06,  1.22it/s]

EpReward:-1.70 steps:100 done:True info:{'status': 'Max steps reached'}


 30%|████████████████████████▉                                                          | 3/10 [00:02<00:05,  1.23it/s]

EpReward:-1.00 steps:100 done:True info:{'status': 'Max steps reached'}


 40%|█████████████████████████████████▏                                                 | 4/10 [00:03<00:04,  1.24it/s]

EpReward:-1.60 steps:100 done:True info:{'status': 'Max steps reached'}


 50%|█████████████████████████████████████████▌                                         | 5/10 [00:04<00:04,  1.23it/s]

EpReward:-1.40 steps:100 done:True info:{'status': 'Max steps reached'}


 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:04<00:03,  1.22it/s]

EpReward:-2.00 steps:100 done:True info:{'status': 'Max steps reached'}


 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:05<00:02,  1.24it/s]

EpReward:-1.50 steps:100 done:True info:{'status': 'Max steps reached'}


 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:06<00:01,  1.23it/s]

EpReward:-1.80 steps:100 done:True info:{'status': 'Max steps reached'}


 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:07<00:00,  1.24it/s]

EpReward:-2.20 steps:100 done:True info:{'status': 'Max steps reached'}


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.23it/s]

EpReward:-1.30 steps:100 done:True info:{'status': 'Max steps reached'}



