In [1]:
import itertools

from collections import deque

import random
import gym
import torch
from torch.utils.data import DataLoader
import numpy as np

from AI_agents.Environments.gym_problem import GymProblem
from AI_agents.Search.best_first_search import a_star

from IL.dataset import ImitationLearningDataset
from IL.evaluation import evaluate_policy
from IL.ipython_vis import animate_policy
from IL.model import MLP
from IL.training import train_torch_classifier_sgd
import AI_agents.Search.utils as utils


# initialize env
env = gym.make("Taxi-v3").env
env.reset()

PASSENGER_IN_TAXI = 4  # passenger idx when in taxi
locs = env.unwrapped.locs  # environment locations

# random seed
seed = 42

In [2]:
class TaxiMonteCarloPolicy:
    def __init__(self):
        # a container for the plan actions.
        self.cur_plan = deque()
    
    def __call__(self, obs):
        # if out of actions (finished previous plan), or if observation is not in current plan,
        # create a new plan.
        taxi_prob = GymProblem(env, env.unwrapped.s)
        actions = list(taxi_prob.get_applicable_actions(utils.Node(utils.State(obs, False), None, None, 0)))
        chosen_action = random.choice(actions)
        return chosen_action
    
helicopter_policy = TaxiMonteCarloPolicy()

In [3]:
# This code will run forever until it is interrupted
#animate_policy(env, helicopter_policy)

In [4]:
# trajectory struct
class Trajectory:
    def __init__(self, observations=None, actions=None, rewards=None):
        self.observations = observations or []
        self.actions = actions or []
        self.rewards = rewards or []
    
    def add_step(self, observation, action, reward):
        self.observations.append(observation)
        self.actions.append(action)
        self.rewards.append(reward)
        
    def __str__(self):
        return 'trajectory: ' + str(list(zip(self.observations, self.actions)))
    
    def __repr__(self):
        return str(self)

In [5]:
def get_trajectory(policy, max_trajectory_length=float('inf')):
    # init trajectory object
    trajectory = Trajectory()
    
    # get first observation
    obs = env.reset()
    
    # init first reward
    reward = 0
    # iterate and step in environment.
    # limit num actions for incomplete policies
    for i in itertools.count(start=1):
        action = policy(obs)
        old_obs = obs
        obs, reward, done, info = env.step(action)
        trajectory.add_step(old_obs, action, reward)
        
        if done or i >= max_trajectory_length:
            break
    
    return trajectory

trajectory = get_trajectory(helicopter_policy)
trajectory

trajectory: [(323, 3), (323, 3), (323, 4), (323, 3), (323, 1), (223, 5), (223, 0), (323, 2), (343, 1), (243, 0), (343, 4), (343, 2), (343, 3), (323, 0), (423, 1), (323, 5), (323, 3), (323, 5), (323, 1), (223, 2), (243, 0), (343, 1), (243, 1), (143, 0), (243, 4), (243, 0), (343, 5), (343, 0), (443, 0), (443, 2), (443, 2), (443, 2), (443, 0), (443, 4), (443, 1), (343, 5), (343, 2), (343, 0), (443, 0), (443, 4), (443, 5), (443, 4), (443, 2), (443, 3), (423, 2), (443, 2), (443, 0), (443, 0), (443, 2), (443, 3), (423, 0), (423, 5), (423, 5), (423, 4), (423, 0), (423, 4), (423, 3), (423, 0), (423, 3), (423, 3), (423, 4), (423, 0), (423, 4), (423, 0), (423, 5), (423, 5), (423, 4), (423, 3), (423, 3), (423, 4), (423, 5), (423, 4), (423, 5), (423, 3), (423, 5), (423, 1), (323, 1), (223, 4), (223, 2), (243, 3), (223, 2), (243, 1), (143, 5), (143, 5), (143, 5), (143, 0), (243, 5), (243, 2), (263, 4), (263, 5), (263, 1), (163, 2), (183, 3), (163, 1), (63, 1), (63, 1), (63, 4), (63, 3), (43, 3), (4

In [6]:
def collect_data(policy, num_trajectories, max_trajectory_length=float('inf')):
    trajectories = []
    for _ in range(num_trajectories):
        trajectories.append(get_trajectory(policy, max_trajectory_length))

    return trajectories

# get the same trajectories every time!
env.seed(seed)

raw_data = collect_data(helicopter_policy, num_trajectories=1000)

In [7]:
from collections import defaultdict

def build_decision_dict(raw_data):
    state_action_scores = defaultdict(lambda: defaultdict(lambda: []))
    for trajectory in raw_data:
        reward_sum = 0
        for state, action, reward in reversed(list(zip(trajectory.observations, trajectory.actions, trajectory.rewards))):
            reward_sum += reward
            state_action_scores[state][action].append(reward_sum)
            
    for state, action_values in state_action_scores.items():
        for action, values_list in action_values.items():
            state_action_scores[state][action] = np.mean(values_list)
        state_action_scores[state] = max(state_action_scores[state], key=state_action_scores[state].get)
    return state_action_scores
    

In [8]:
class MCCPolicy:
    def __init__(self, state_action_map):
        self.state_action_map = state_action_map
    
    def __call__(self, obs):
        # preprocess observation
        return self.state_action_map[obs]

# create a policy driven by the MLP model that uses the same preprocessing function as in
# training
policy = MCCPolicy(build_decision_dict(raw_data))

In [9]:
total_reward, mean_reward = evaluate_policy(env, helicopter_policy, num_episodes=10000, seed=seed)
print('Monte Carlo Policy')
print('---------')
print(f'total reward over all episodes: {total_reward}')
print(f'mean reward per episode:        {mean_reward}')

  0%|          | 0/10000 [00:00<?, ?it/s]

Monte Carlo Policy
---------
total reward over all episodes: -983647
mean reward per episode:        -98.3647


In [10]:
total_reward, mean_reward = evaluate_policy(env, policy, num_episodes=10000, seed=seed)
print('Monte Carlo Control Policy')
print('-----------------')
print(f'total reward over all episodes: {total_reward}')
print(f'mean reward per episode:        {mean_reward}')

  0%|          | 0/10000 [00:00<?, ?it/s]

Monte Carlo Control Policy
-----------------
total reward over all episodes: -1445084
mean reward per episode:        -144.5084


In [None]:
# This code will run forever until it is interrupted
animate_policy(env, policy)

+---------+
|[34;1mR[0m: | : :G|
| : | : : |
| : : :[43m [0m: |
| | : | : |
|Y| : |[35mB[0m: |
+---------+
  (North)
