In [317]:
import numpy as np
import gym
import matplotlib.pyplot as plt
%matplotlib inline

In [130]:

class tree():
    def __init__(
        self,
        max_depth: float,
        i_depth: float,
        minimum_sample_leaf: float,
        y_val: np.array,
        x_val: np.array,
        is_terminal: bool,
    ):
        self.max_depth = max_depth
        self.minimum_sample_leaf = minimum_sample_leaf
        self.i_depth = i_depth
        self.is_terminal = is_terminal

        self.y_val = y_val
        self.x_val = x_val
        self.best_feature = None
        self.best_feature_value = None

        self.l_tree = None
        self.r_tree = None

    def obj_fun(self, l_values, r_values):
        l_mean = np.mean(l_values)
        r_mean = np.mean(r_values)

        l_len = len(l_values)
        r_len = len(r_values)

        l_mse =np.sum(((l_values - l_mean)**2)**0.5)
        r_mse =np.sum(((r_values - r_mean)**2)**0.5)

        mse_weighted =  (l_mse * l_len + r_mse * r_len)/(l_len + r_len)

        return mse_weighted
        


    def fit(self, ):
        best_score = None
        for i in range(len(self.x_val[0,:])):

            for j in list(set(self.x_val[:,i])):

                left_ind = self.x_val[:,i] < j

                y_left = self.y_val[left_ind]
                y_right = self.y_val[~left_ind]

                if self.best_feature is None:
                    self.best_feature = i
                    self.best_feature_value = j
                    best_score = self.obj_fun(y_left, y_right)
                    
                else:
                    new_score = self.obj_fun(y_left, y_right)
                    if new_score < best_score:
                        self.best_feature = i
                        self.best_feature_value = j
                        best_score = new_score
                        


        if self.max_depth >= self.i_depth:
            # 찾은 최적값으로 좌우 할당
            left_ind = self.x_val[:,self.best_feature] < self.best_feature_value
            
            y_left = self.y_val[left_ind]
            y_right = self.y_val[~left_ind]

            x_left = self.x_val[left_ind]
            x_right = self.x_val[~left_ind]

            if len(y_left) > self.minimum_sample_leaf:
                self.l_tree = tree(max_depth = self.max_depth,i_depth = self.i_depth + 1, minimum_sample_leaf = self.minimum_sample_leaf, x_val = x_left, y_val = y_left, is_terminal=False)
                self.l_tree.fit()
            else:
                self.l_tree = tree(max_depth = self.max_depth,i_depth = self.i_depth + 1, minimum_sample_leaf = self.minimum_sample_leaf, x_val = x_left, y_val = y_left, is_terminal=True)

            if len(y_right) > self.minimum_sample_leaf:
                self.r_tree = tree(max_depth = self.max_depth,i_depth = self.i_depth + 1, minimum_sample_leaf = self.minimum_sample_leaf, x_val = x_right, y_val = y_right, is_terminal=False)
                self.r_tree.fit()
            else:
                self.r_tree = tree(max_depth = self.max_depth,i_depth = self.i_depth + 1, minimum_sample_leaf = self.minimum_sample_leaf, x_val = x_right, y_val = y_right, is_terminal=True)
        
        else:
            self.is_terminal = True


    def i_pred(self,x_data):
        if len(x_data) != len(self.x_val[0,:]):
            raise Exception(f'입력된 자료의 차원이 {len(x_data)} 입니다. 학습된 자료의 차원 {len(self.x_val[0,:])}과 일치시켜야 합니다.') 
        
        if self.is_terminal:
            pred = np.mean(self.y_val)
            return pred
        
        if x_data[self.best_feature] < self.best_feature_value:
            if self.l_tree.is_terminal:
                pred = np.mean(self.l_tree.y_val)
            else:
                pred = self.l_tree.i_pred(x_data)
        else:
            if self.r_tree.is_terminal:
                pred = np.mean(self.r_tree.y_val)
            else:
                pred = self.r_tree.i_pred(x_data)

        return pred

    def prediction(self, x_arr):
        
        results = []
        for i in range(len(x_arr)):
            i_val = x_arr[i,:]
            
            result = self.i_pred(i_val)
            results.append(result)

        return results
    
    def get_tree_structure(self):
        def get_info_dic(i_tree):
            result = {
            'best_feature': i_tree.best_feature,
            'best_feature_value': i_tree.best_feature_value,
            'terminal': i_tree.is_terminal,
            'depth': i_tree.i_depth,
            }
            
            if i_tree.l_tree is not None:
                result['l_tree'] = get_info_dic(i_tree.l_tree)
            
            if i_tree.r_tree is not None:
                result['r_tree'] = get_info_dic(i_tree.r_tree)

            return result
        
        info = get_info_dic(self)
        return info


In [277]:
class q_boosting_model():
    def __init__(
        self,
        max_tree_n: float,
        
        minimum_sample_leaf: float,
        max_depth: float,
        action_list: list,
        learning_rate: float,

    ):
        self.action_list = action_list
        self.max_tree_n = max_tree_n
        self.learning_rate = learning_rate
        self.now_tree_n = 0

        self.minimum_sample_leaf = minimum_sample_leaf
        self.max_depth = max_depth

        self.children = []
    
    def boost_prediction(self, input_data):
        if len(self.children) == 0:
            raise Exception('do model fit')
        else:
            results = [list(np.array(q_tree.prediction(input_data)) * self.learning_rate) for q_tree in self.children]
            final_result = np.sum(results,axis = 0)
            return final_result
        
    def boost_fit(self, x_data, y_data):
        if self.now_tree_n < self.max_tree_n:

            if len(self.children) == 0:
                i_tree = tree(max_depth = self.max_depth,i_depth = 0, minimum_sample_leaf = self.minimum_sample_leaf, x_val = x_data, y_val = y_data, is_terminal=False)
                i_tree.fit()
                self.children.append(i_tree)
                self.now_tree_n += 1

            else:
                results = self.boost_prediction(x_data)
                residuals = y_data - results

                i_tree = tree(max_depth = self.max_depth,i_depth = 0, minimum_sample_leaf = self.minimum_sample_leaf, x_val = x_data, y_val = residuals, is_terminal=False)
                i_tree.fit()
                self.children.append(i_tree)
                self.now_tree_n += 1

        else:
            raise Exception('max_tree_n 초과')

        # tree(max_depth = 5,i_depth = 0, minimum_sample_leaf = 2, x_val = x, y_val = y, is_terminal=False)

    def get_action_boost(self, x_data,):

        best_action = None
        best_q_val = None

        if len(self.children) != 0:
            for i in self.action_list:
                input = np.concatenate([x_data, np.array([[i]])], axis = 1)
                q_val = self.boost_prediction(input)

                if best_q_val is None:
                    best_q_val = q_val
                    best_action = i
                else:
                    if q_val > best_q_val:
                        best_q_val = q_val
                        best_action = i
        else:
            best_action = self.get_action_random()
        
        return best_action
    
    def get_action_random(self,):

        action = np.random.choice(self.action_list)
        return action

            

In [278]:
# ### 사용 예시 ###
# qbm = q_boosting_model(max_tree_n = 1000, minimum_sample_leaf = 2, max_depth = 10, action_list= [0,1], learning_rate = 0.01)

# x = np.array(
#     [[1,2,3],[4,5,6],[7,8,9],[10,11,12],[13,14,15],[16,17,18],[19,20,21]]
# )

# y = np.array(
#     [1,2,3,4,5,6,7]
# )

# for i in range(5):
#     qbm.boost_fit(x,y)
# print([child.prediction(x) for child in qbm.children])
# print(qbm.boost_prediction(x))
# print(qbm.get_action_boost(np.array([[1,2]]), is_random=False ,epsilon= 1))

In [295]:
qbm = q_boosting_model(max_tree_n = 1000, minimum_sample_leaf = 5, max_depth = 10, action_list= [0,1], learning_rate = 0.01)
env = gym.make('CartPole-v1')

s, info = env.reset()
# env.render()

s_list = []
ns_list = []
a_list = []
r_list = []
state_val = []

cum_r = 0

traj_len = 0
n_episode = 100
eps = 1
eps_grad = 0
eps_spike = 0
eps_gradual = eps

while True:
    if np.random.random() > eps:
        a = qbm.get_action_boost(s.reshape(-1,4),)
    else:
        a = qbm.get_action_random()

    ns, r, done, info, prob = env.step(a)
    cum_r += r

    s_list.append(s.tolist())
    ns_list.append(ns.tolist())
    a_list.append([a])
    r_list.append([r])

    traj_len += 1
    
    if done:
        eps *= (1 - eps_grad)
        state_v_i = []
        for i in range(traj_len):
            state_v_i.append(i)
        state_v_i.reverse()
        state_val.extend(state_v_i)

        s, info = env.reset()
        traj_len = 0
        n_episode -= 1

        if cum_r > 180:
            print("학습종료")
            break
        else:
        # boost fitting
            input = np.concatenate([s_list[-traj_len:-1],a_list[-traj_len:-1]], axis = 1)
            labels = np.array(state_val[-traj_len:-1])
            qbm.boost_fit(input,labels)
        print(cum_r, eps)
        cum_r = 0

        if n_episode % 5 == 0: 
            eps_gradual *= (1 - 0.03)
            eps = eps_spike
           
        else:
            eps = eps_gradual
            

        if n_episode == 0: break

    else:
        s = ns

32.0 1
24.0 1
28.0 1
22.0 1
36.0 1
20.0 0
12.0 0.97
28.0 0.97
29.0 0.97
56.0 0.97
28.0 0
13.0 0.9409
32.0 0.9409
18.0 0.9409
29.0 0.9409
45.0 0
13.0 0.912673
13.0 0.912673
13.0 0.912673
16.0 0.912673
35.0 0
44.0 0.8852928099999999
41.0 0.8852928099999999
19.0 0.8852928099999999
44.0 0.8852928099999999
학습종료


In [319]:
### 성능 테스트
s, info = env.reset()
# render = lambda : plt.imshow(env.render(mode='rgb_array'))
# render()

test_epi_n = 20

cum_r_list = []
cum_r = 0

while True:
    a = qbm.get_action_boost(s.reshape(-1,4),)
    ns, r, done, info, prob = env.step(a)
    cum_r += r
    
    if done:
        test_epi_n -= 1
        s, info = env.reset()
        print(cum_r)
        cum_r_list.append(cum_r)
        cum_r = 0

        if test_epi_n == 0: break
    
    else:
        # print(s)
        s = ns


946.0
861.0
1125.0
1599.0
1658.0
1433.0
478.0
173.0
5309.0
830.0
379.0
373.0
1275.0
75.0
41.0
476.0
51.0
108.0
1181.0
1223.0
