In [None]:
import gym
import gym_mod_cartpole
from random import random, randint, uniform
from env.decom_lunar_lander import LunarLander as LunarLander_decom_reward
from copy import deepcopy
import os
import torch
import numpy as np
from tqdm import tqdm

from GVF_learner import GVF_learner
from memory.memory import ReplayBuffer_decom
from models.dqn_model import DQNModel

FloatTensor = torch.cuda.FloatTensor
LongTensor = torch.cuda.LongTensor

# Writer = SummaryWriter(log_dir="CartPole_summary")

In [None]:
# ENV_NAME = 'CartPoleMod-v0'
ENV_NAME = 'LunarLander_decom'
env = LunarLander_decom_reward()
# env._max_episode_steps = 500
ACTION_DICT = {
    "NOOP": 0,
    "LEFT":1,
    "MAIN":2,
    "RIGHT":3
}

In [None]:
# Set result saveing floder
result_floder = ENV_NAME
result_file = ENV_NAME + "/results.txt"
if not os.path.isdir(result_floder):
    os.mkdir(result_floder)

In [None]:
hyperparams_Lunarlander = {
    'epsilon_decay_steps' : 200000, 
    'final_epsilon' : 0.01,
    'batch_size' : 128, 
    'update_steps' : 3, 
    'memory_size' : 100000, 
    'beta' : 0.99, 
    'model_replace_freq' : 1,
    'learning_rate' : 0.0001,
    'decom_reward_len': 8,
    'soft_tau': 5e-4
}

In [None]:
class DQN_agent(object):
    def __init__(self, env, hyper_params, action_space = len(ACTION_DICT)):
        
        self.env = env
        self.max_episode_steps = env._max_episode_steps
        
        """
            beta: The discounted factor of Q-value function
            (epsilon): The explore or exploit policy epsilon. 
            initial_epsilon: When the 'steps' is 0, the epsilon is initial_epsilon, 1
            final_epsilon: After the number of 'steps' reach 'epsilon_decay_steps', 
                The epsilon set to the 'final_epsilon' determinately.
            epsilon_decay_steps: The epsilon will decrease linearly along with the steps from 0 to 'epsilon_decay_steps'.
        """
        self.beta = hyper_params['beta']
        self.initial_epsilon = 1
        self.final_epsilon = hyper_params['final_epsilon']
        self.epsilon_decay_steps = hyper_params['epsilon_decay_steps']
        self.soft_tau = hyper_params['soft_tau']

        """
            episode: Record training episode
            steps: Add 1 when predicting an action
            learning: The trigger of agent learning. It is on while training agent. It is off while testing agent.
            action_space: The action space of the current environment, e.g 2.
        """
        self.episode = 0
        self.steps = 0
        self.best_reward = -float("inf")
        self.action_space = action_space

        """
            input_len The input length of the neural network. It equals to the length of the state vector.
            output_len: The output length of the neural network. It is equal to the action space.
            eval_model: The model for predicting action for the agent.
            target_model: The model for calculating Q-value of next_state to update 'eval_model'.
        """
        state = env.reset()
        self.state_len = len(state)
        input_len = self.state_len + action_space
        output_len = 1
        self.decom_reward_len = hyper_params["decom_reward_len"]
        
        self.action_vector = self.get_action_vector()
        self.eval_model = DQNModel(input_len, output_len, learning_rate = hyper_params['learning_rate'])
        self.target_model = DQNModel(input_len, output_len)
        
#         memory: Store and sample experience replay.
        self.memory = ReplayBuffer_decom(hyper_params['memory_size'])
        
        """
            batch_size: Mini batch size for training model.
            update_steps: The frequence of traning model
            model_replace_freq: The frequence of replacing 'target_model' by 'eval_model'
        """
        
        self.batch_size = hyper_params['batch_size']
        self.update_steps = hyper_params['update_steps']
        self.model_replace_freq = hyper_params['model_replace_freq']
        
#         if os.path.isdir("CartPole_summary/Lunarlander/DQN(unconstraint)/"):
#             shutil.rmtree("CartPole_summary/Lunarlander/DQN(unconstraint)/")
        
        
    # Linear decrease function for epsilon
    def linear_decrease(self, initial_value, final_value, curr_steps, final_decay_steps):
        decay_rate = curr_steps / final_decay_steps
        if decay_rate > 1:
            decay_rate = 1
        return initial_value - (initial_value - final_value) * decay_rate
    
    def get_action_vector(self):
        action_vector = np.zeros((self.action_space, self.action_space))
        for i in range(len(action_vector)):
            action_vector[i, i] = 1
        
        return FloatTensor(action_vector)
    
    def concat_state_action(self, states, actions = None, is_full_action = False):
        if is_full_action:
            com_state = FloatTensor(states).repeat((1, self.action_space)).view((-1, self.state_len))
            actions = self.action_vector.repeat((len(states), 1))
        else:
            com_state = states.clone()
            actions = actions.clone()
        state_action = torch.cat((com_state, actions), 1)
        return state_action
        
    def explore_or_exploit_policy(self, state):
        p = uniform(0, 1)
        # Get decreased epsilon
        epsilon = self.linear_decrease(self.initial_epsilon, 
                               self.final_epsilon,
                               self.steps,
                               self.epsilon_decay_steps)
        self.epsilon = epsilon
        
        if p < epsilon:
            return randint(0, self.action_space - 1)
        else:
            return self.greedy_policy(state)[0]
        
    def greedy_policy(self, state):
        state_ft = FloatTensor(state).view(-1, self.state_len)
        state_action = self.concat_state_action(state_ft, is_full_action = True)
        feature_vectors, q_values = self.eval_model.predict_batch(state_action)
        q_v, best_action = q_values.max(0)
        return best_action.item(), q_v, feature_vectors[best_action.item()]
    
    def update_batch(self):
#         print(self.update_steps)
        if len(self.memory) < self.batch_size or self.steps % self.update_steps != 0:
            return

        batch = self.memory.sample(self.batch_size)

        (states_actions, _, reward, next_states,
         is_terminal, _) = batch
        
#         states_actions = states_actions
        next_states = FloatTensor(next_states)
        terminal = FloatTensor([1 if t else 0 for t in is_terminal])
        reward = FloatTensor(reward)
        
        batch_index = torch.arange(self.batch_size,
                                   dtype=torch.long)
        
        # Current Q Values
        _, q_values = self.eval_model.predict_batch(states_actions)
        next_state_actions = self.concat_state_action(next_states, is_full_action = True)
        _, q_next = self.target_model.predict_batch(next_state_actions)
        q_next = q_next.view((-1, self.action_space))
        q_max, idx = q_next.detach().max(1)

        q_max = (1 - terminal) * q_max
        q_target = reward + self.beta * q_max
        q_target = q_target.unsqueeze(1)
        
        self.eval_model.fit(q_values, q_target)
        
    def learn_and_evaluate(self, training_episodes, test_interval):
        test_number = training_episodes // test_interval
        all_results = []
        
        for i in range(test_number):
            # learn
            self.learn(test_interval)
            # evaluate
            avg_reward = self.evaluate((i + 1) * test_interval)
            all_results.append(avg_reward)
            
        return all_results
    
    def get_features_decom(self, state, next_state, done):
        
        threshold_x = 1
        threshold_c_v = 1
        threshold_angle = 0.07
        threshold_p_v = 0.7
        
        features_decom = np.ones(self.decom_reward_len)
#         cart_position, cart_velocity, pole_angle, pole_velocity = state
        next_cart_position, next_cart_velocity, next_pole_angle, next_pole_velocity = next_state
            
        if threshold_x < next_cart_position:
            features_decom[0] = -1
        if -threshold_x > next_cart_position:
            features_decom[1] = -1        
    
        if threshold_c_v < next_cart_velocity:
            features_decom[2] = -1
        if -threshold_c_v > next_cart_velocity:
            features_decom[3] = -1   
            
        if threshold_angle < next_pole_angle:
            features_decom[4] = -1
        if -threshold_angle > next_pole_angle:
            features_decom[5] = -1   
        
        if threshold_p_v < next_pole_velocity:
            features_decom[6] = -1
        if -threshold_p_v > next_pole_velocity:
            features_decom[7] = -1   
        return features_decom
    
    def learn(self, test_interval):
        
        for episode in tqdm(range(test_interval), desc="Training"):
            state = self.env.reset()
            done = False
            steps = 0
            
            while steps < self.max_episode_steps and not done:
                steps += 1
                self.steps += 1
                
                action = self.explore_or_exploit_policy(state)
                next_state, reward, done, _, features_decom = self.env.step(action)
                
#                 features_decom = self.get_features_decom(state, next_state, steps < self.max_episode_steps and done)
                action_vector = np.zeros(self.action_space)
                action_vector[action] = 1
                
                self.memory.add(np.concatenate((state.copy(), action_vector.copy()), axis=0), -1, reward, next_state, steps < self.max_episode_steps and done, features_decom)
                self.update_batch()
                
                if self.steps % self.model_replace_freq == 0:
                    if self.model_replace_freq == 1:
                        self.target_model.replace_soft(self.eval_model, tau = self.soft_tau)
                    else:
                        self.target_model.replace(self.eval_model)
                state = next_state

    def evaluate(self, episode_num, trials = 10):
        total_reward = 0
        for _ in tqdm(range(trials), desc="Evaluating"):
            state = self.env.reset()
            done = False
            steps = 0
            
            while steps < self.max_episode_steps and not done:
                steps += 1
                action = self.greedy_policy(state)[0]
                state, reward, done, _, _ = self.env.step(action)
                total_reward += reward
            
        avg_reward = total_reward / trials
        print(avg_reward)
        if avg_reward >= self.best_reward:
            self.best_reward = avg_reward
            self.save_model()
            print("save")
#         Writer.add_scalars(main_tag='CartPole/DQN',
#                                 tag_scalar_dict = {'DQN(unconstraint)':avg_reward}, 
# #                                 scalar_value=,
#                                 global_step=episode_num)
    
        return avg_reward
    
    def save_model(self):
        self.eval_model.save(result_floder + '/best_model.pt')
        self.memory.save(result_floder)
        
    def load_model(self):
        self.eval_model.load(result_floder + '/best_model.pt')
        self.memory.load(result_floder)

## Train Cart Pole DQN agent
Generating policy and feature dataset

In [None]:
training_episodes, test_interval = 10000, 100
agent = DQN_agent(env, hyperparams_Lunarlander)
result = agent.learn_and_evaluate(training_episodes, test_interval)

## GVF learner
Train GVF model base on the dataset and policy above

In [None]:
CP_GVF_PARAMETERS = {
    "batch size" : 64, # update batch size
    "learning rate" : 0.0001,
    "feature num" : 8, # numbers/length of feature
    "state length" : 8,
    "discount factor" : [0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99], # for each features respectively
    "action space": 4,
    'model_replace_freq' : 1,
    'soft_tau': 0.5
}
dqn = DQN_agent(env, hyperparams_Lunarlander)
dqn.load_model()
def policy(state_actions):
    _, q_next = dqn.eval_model.predict_batch(state_actions)
    q_next = q_next.view((-1, dqn.action_space))
    q_max, idx = q_next.detach().max(1)
    return idx

def get_dataset():
    dataset = dqn.memory._storage
    print(dataset[:10])
    return np.array(dataset).tolist()


In [None]:
dataset = get_dataset()
gvf_learner = GVF_learner(CP_GVF_PARAMETERS, dataset, policy)
gvf_learner.learn_and_eval(1000, 10)

In [None]:
# 
# test_data = torch.tensor([[0.015588092617690563, -0.0004975938936695457, 4.084892424316422e-08, 6.387576578781307e-10, 0.0024009563494473696, -9.596109151743804e-08, 1, 1, 0, 0, 1, 0]]).cuda()
# print(test_data.size())
# r = gvf_learner.eval_model(test_data)
# print(r)