In [None]:
from warnings import simplefilter 
simplefilter(action='ignore', category=FutureWarning)

from gym_reachability import gym_reachability  # Custom Gym env.
import gym
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import torch
from collections import namedtuple
import os

from KC_DQN.DDQN import DDQN
from KC_DQN.config import dqnConfig

import time
timestr = time.strftime("%Y-%m-%d-%H_%M_%S")

In [None]:
#== CONFIGURATION ==
toEnd = True
env_name = "lunar_lander_reachability-v0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
maxUpdates = 4000000
maxSteps = 100
updatePeriod = int(maxUpdates / 20)
updatePeriodHalf = int(updatePeriod/2)

CONFIG = dqnConfig(DEVICE=device, ENV_NAME=env_name, 
                   MAX_UPDATES=maxUpdates, MAX_EP_STEPS=maxSteps,
                   BATCH_SIZE=100, MEMORY_CAPACITY=10000,
                   GAMMA=.9, GAMMA_PERIOD=updatePeriod, GAMMA_END=0.999999,
                   EPS_PERIOD=updatePeriod, EPS_DECAY=0.6,
                   LR_C=1e-3, LR_C_PERIOD=updatePeriod, LR_C_DECAY=0.8,
                   MAX_MODEL=50)

modelsFolder = 'lunarLander/RA/{:s}'.format(timestr)
figureFolder = 'figure/lunarLander/RA/{:s}'.format(timestr)
os.makedirs(figureFolder, exist_ok=True)

#== REPORT ==
print(CONFIG.MAX_UPDATES, updatePeriod, CONFIG.MAX_EP_STEPS)
fig, ax = plt.subplots(1, 3, figsize=(12,2), sharex=True)
base = updatePeriodHalf
numUpdates = int(maxUpdates/base)

eps = np.zeros(numUpdates)
lr = np.zeros(numUpdates)
gamma = np.zeros(numUpdates)
eps_tmp = CONFIG.EPSILON
lr_tmp = CONFIG.LR_C
gamma_tmp = CONFIG.GAMMA
for i in range(numUpdates):
    if i != 0 and i % int(CONFIG.EPS_PERIOD/base) == 0:
        eps_tmp = max(eps_tmp*CONFIG.EPS_DECAY, 0.05)
    if i != 0 and i % int(CONFIG.LR_C_PERIOD/base) == 0:
        lr_tmp = max(lr_tmp*CONFIG.LR_C_DECAY, 1e-5)
    if i != 0 and i % int(CONFIG.GAMMA_PERIOD/base) == 0:
        gamma_tmp = min(1-(1-gamma_tmp)*CONFIG.GAMMA_DECAY, 0.99999999)
    eps[i] = eps_tmp
    lr[i] = lr_tmp
    gamma[i] = gamma_tmp
ax[0].plot(np.arange(numUpdates), eps)
ax[1].plot(np.arange(numUpdates), lr)
ax[2].plot(np.arange(numUpdates), gamma)
ax[0].set_title(r'$\epsilon$')
ax[1].set_title(r'$\alpha$')
ax[2].set_title(r'$\gamma$')
ax[1].set_xlabel('#Updates (x{:d})'.format(base))
plt.show()

print(gamma[-1], lr[-1])

In [None]:
env = gym.make(env_name, device=device, mode='RA', doneType='toEnd')
s_dim = env.observation_space.shape[0]
action_num = env.action_space.n

action_list = np.arange(action_num)
print(s_dim, action_num)

In [None]:
#== AGENT ==
vmin = -1
vmax = 1
reportPeriod = 100000
checkPeriod = 100000
print(reportPeriod)
agent=DDQN(s_dim, action_num, CONFIG, action_list, mode='RA', model='TanhTwo')

_, trainProgress = agent.learn(env, MAX_UPDATES=maxUpdates, MAX_EP_STEPS=CONFIG.MAX_EP_STEPS, addBias=False,
                                  warmupQ=False, toEnd=toEnd,
                                  vmin=vmin, vmax=vmax, showBool=False,
                                  reportPeriod=reportPeriod, checkPeriod=checkPeriod, outFolder=modelsFolder)