Q-learning

In [1]:
import numpy as np
from myenv import MyEnv
import math

class CartPoleSolver():
    
    def __init__(self, gamma=1.0, epsilon=1.0, alpha=0.01, episodes=10, batch_size=1000, interval_num=6):
        self.env = MyEnv()
        self.gamma = gamma # 折扣因子
        self.epsilon = epsilon # 贪婪策略参数
        self.alpha = alpha # 学习率
        self.episodes = episodes # 决策序列长度
        self.batch_size = batch_size # 训练次数
        self.interval_num = interval_num # 连续变量转离散变量分为几段

        self.pa_bin = np.linspace(-math.pi, math.pi, interval_num+1)[1: -1]
        self.pv_bin = np.linspace(-math.pi*15, math.pi*15, interval_num+1)[1: -1]

        self.q_table = np.random.uniform(low=0, high=1, size=(interval_num**4, 3))
        
    def get_state_index(self, observation):
        pole_angle, pole_v = observation
        
        state_index = 0
        state_index += np.digitize(pole_angle, bins = self.pa_bin) * self.interval_num
        state_index += np.digitize(pole_v, bins = self.pv_bin)
        
        return state_index
    
    def update_Q_table(self, observation, action, reward, next_observation):        
        state_index = self.get_state_index(observation)
        next_state_index = self.get_state_index(next_observation)
        
        maxQ = max(self.q_table[next_state_index][:])
        self.q_table[state_index, action] = self.q_table[state_index, action] + self.alpha * (reward + self.gamma*maxQ - self.q_table[state_index, action])
        
    def decide_action(self, observation, episode = 0):
        
        state = self.get_state_index(observation)
        # epsilon = 0.5 * (1 / (episode + 1))
        if self.epsilon <= np.random.uniform(0, 1):
            action = np.argmax(self.q_table[state][:])
        else:
            action = np.random.choice(3)
            
        return action
    def run(self):
        observation = self.env.reset()
        for t in range(1000):
            self.env.render()
            print(observation)
            action = self.decide_action(observation)
            next_observation, reward, _, _ = self.env.step(action)
            self.update_Q_table(observation, action, reward, next_observation)
            observation = next_observation
    
    def solve(self):
        for _ in range(self.episodes):
            self.run()

pygame 2.1.2 (SDL 2.0.16, Python 3.10.4)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
a = CartPoleSolver()
a.run()

[3.1415927 0.       ]
[ 3.1415927  -0.44309726]
[ 3.139377   -0.88265187]
[ 3.134964  -1.3173779]
[ 3.128377   -0.85981566]
[ 3.124078 -1.288199]
[ 3.117637  -1.7106072]
[ 3.109084 -2.125818]
[ 3.0984547 -2.5326376]
[ 3.0857916 -2.0437088]
[ 3.0755732 -1.9942838]
[ 3.0656016 -2.3823004]
[ 3.0536902 -1.87512  ]
[ 3.0443146 -2.2511468]
[ 3.033059  -2.1755319]
[ 3.0221813 -2.0938795]
[ 3.0117118 -2.4495664]
[ 2.999464 -2.35315 ]
[ 2.987698  -2.2503057]
[ 2.9764466 -2.1413803]
[ 2.9657397 -2.4698331]
[ 2.9533906 -1.903206 ]
[ 2.9438746 -1.3339044]
[ 2.937205  -1.2067114]
[ 2.9311714 -1.5197558]
[ 2.9235728 -1.826795 ]
[ 2.9144387 -2.1269743]
[ 2.9038038 -1.9763719]
[ 2.893922  -2.2639315]
[ 2.8826025 -2.1004047]
[ 2.8721004 -1.4885871]
[ 2.8646574 -0.8756475]
[ 2.860279  -0.2633569]
[ 2.8589623 -0.5396595]
[2.856264   0.07319189]
[2.85663    0.23958269]
[ 2.857828   -0.03866245]
[ 2.8576345  -0.31536505]
[2.856058   0.29644927]
[2.8575401  0.01807523]
[ 2.8576305  -0.25891703]
[ 2.8563359 

KeyboardInterrupt: 