In [1]:
import gym
import agent as my_agent
import torch
import variable as v
from tqdm import tqdm
import utils
import matplotlib.pyplot as plt
from copy import deepcopy

In [2]:
env = gym.make('LunarLander-v2')
train_session = None

In [3]:
num_action = env.action_space.n
dim_state = env.observation_space.shape[0]
seed = 78

## Base Agent

In [4]:
nn_archi = [{'type': 'linear', 'in': dim_state, 'out': 64, 'activation': 'relu'}, 
            #{'type': 'linear', 'in': 512, 'out': 256, 'activation': 'relu'},
            {'type': 'linear', 'in': 64, 'out': num_action, 'activation': 'None'}]
base_agent_init = {
    'seed': seed,
    'policy_type': 'softmax',
    'temperature': 0.75,
    'exploration_rate': {
        'er': .001,
        'max_er': 1,
        'min_er': 0.0,
        'decay_er': 0.01,
        'constant_er': True
    },
    'num_action': num_action,
    'max_position_init': .04,
    'max_position_reward_bonus': 0.0,
    'neural_network_handler': {
        'seed': seed,
        'discount_factor': .99,
        'nn_archi': nn_archi,
        "eval_train_delay": 300
    },
    'replay_buffer': {'buffer_size': 25000, 
                      'mini_batch_size': 64, 
                      'seed': seed},
    'optim': {'lr': 5e-4},
    'early_stop': {
        'skip_training_threshold': 180.0, 
        'stop_training_threshold': 200.0, 
        'episode_window_skip': 10, 
        'episode_window_stop': 100
    }
}

### To compare Agent(s)

In [5]:
if not train_session:
    print('Train Session reset')
    train_session = utils.TrainSession({}, env, seed)

Train Session reset


In [6]:

base_agent_init["early_stop"].values()

dict_values([180.0, 200.0, 10, 100])

In [7]:
tuned_parameters = {('early_stop', 'skip_training_threshold'): [180.0, float("inf")],
                    ('early_stop', 'stop_training_threshold'): [200.0, float("inf")]}


last_added_agent_names = train_session.parameter_grid_append(my_agent.DQN, base_agent_init, tuned_parameters)
s = '\n- '.join(last_added_agent_names)
print(f"Agents added: \n- {s}")

Agents added: 
- early_stop_skip_training_threshold:180.0;early_stop_stop_training_threshold:200.0;
- early_stop_skip_training_threshold:180.0;early_stop_stop_training_threshold:inf;
- early_stop_skip_training_threshold:inf;early_stop_stop_training_threshold:200.0;
- early_stop_skip_training_threshold:inf;early_stop_stop_training_threshold:inf;


In [8]:
all_agent_names = train_session.agents.keys()
s = '\n- '.join(all_agent_names)
print(f"All Agents: \n- {s}")

All Agents: 
- early_stop_skip_training_threshold:180.0;early_stop_stop_training_threshold:200.0;
- early_stop_skip_training_threshold:180.0;early_stop_stop_training_threshold:400.0;
- early_stop_skip_training_threshold:400.0;early_stop_stop_training_threshold:200.0;
- early_stop_skip_training_threshold:400.0;early_stop_stop_training_threshold:400.0;


In [8]:
n_episode = 700
t_max_per_episode = 300
graphical = False
selected_agent_names = ['early_stop_skip_training_threshold:180.0;early_stop_stop_training_threshold:200.0;', 
                        'early_stop_skip_training_threshold:inf;early_stop_stop_training_threshold:inf;']

In [None]:
train_session.train(n_episode=n_episode, graphical=graphical, agent_subset=selected_agent_names)

100%|██████████| 700/700 [03:45<00:00,  3.10it/s]
 79%|███████▊  | 551/700 [05:14<03:13,  1.30s/it]

In [None]:
train_session.plot_results(window=50, agent_subset=selected_agent_names)

In [None]:
fig = plt.figure(figsize=(20, 10))
plt.plot(x, x)