In [None]:
from ast import Mod
import math
from typing import Optional, Union
import numpy as np
import pygame
from pygame import gfxdraw
import gym
from gym import spaces, logger
from gym.utils import seeding

class MyEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
    """
    新的环境状态仅包含角度[-pi, pi]和角速度[-15pi, 15pi]
    动作包含-3v, 0, 3v
    """
    def __init__(self):
        self.g = 9.81
        self.m = 0.055
        self.l = 0.042
        self.J = 1.91e-4
        self.b = 3e-6
        self.K = 0.0536
        self.R = 9.5
        self.T_s = 0.005
        self.umap = {0:-3, 1:0, 2:3}
        self.R_rew = 1
        self.Q_rew1 = 5
        self.Q_rew2 = 0.1

        self.a_threshold = 15 * math.pi
        high = np.array(
            [
                math.pi,
                self.a_threshold * 2,
            ],
            dtype=np.float32,
        )
        self.action_space = spaces.Discrete(3)
        self.observation_space = spaces.Box(-high, high, dtype=np.float32)

        self.screen = None
        self.clock = None
        self.isopen = True
        self.state = None

    def step(self, action, method):
        theta, theta_v = self.state
        u = self.umap[action]
        
        temp = 1/self.J * (self.m*self.g*self.l*math.sin(theta) - self.b*theta_v - self.K**2/self.R*theta_v + self.K/self.R*u)
        theta_new = theta + self.T_s*theta_v
        theta_v_new = theta_v + self.T_s*temp
        if theta_new > math.pi:
            theta_new -= math.pi*2
        elif theta_new < -math.pi:
            theta_new += math.pi*2
        self.state = (theta_new, theta_v_new)
        done = False
        if method == 0:
            reward = -self.Q_rew1*theta_new**2 - self.Q_rew2*theta_v_new**2 - self.R_rew*u**2
        else:
            theta_unit = math.pi*2/method
            theta_v_unit = math.pi*30/method
            theta_r = (int((theta_new+math.pi)/theta_unit) + 1/2) * theta_unit - math.pi
            theta_v_r = (int((theta_v_new+math.pi*15)/theta_v_unit) + 1/2) * theta_v_unit - math.pi*15
            reward = -self.Q_rew1*theta_r**2 - self.Q_rew2*theta_v_r**2 - self.R_rew*u**2

        return np.array(self.state, dtype=np.float32), reward, done, {}

    def test(self):
        pass
    
    def reset(
        self,
        *,
        seed: Optional[int] = None,
        return_info: bool = False,
        options: Optional[dict] = None,
    ):
        super().reset(seed=seed)
        self.state = (math.pi, 0)
        # self.state = (0, 0)
        if not return_info:
            return np.array(self.state, dtype=np.float32)
        else:
            return np.array(self.state, dtype=np.float32), {}

    def render(self):
        screen_width = 600
        screen_height = 400

        world_width = 4.8
        scale = screen_width / world_width
        polewidth = 10.0
        polelen = scale
        cartwidth = 50.0
        cartheight = 30.0

        if self.state is None:
            return None

        x = self.state

        if self.screen is None:
            pygame.init()
            pygame.display.init()
            self.screen = pygame.display.set_mode((screen_width, screen_height))
        if self.clock is None:
            self.clock = pygame.time.Clock()

        self.surf = pygame.Surface((screen_width, screen_height))
        self.surf.fill((255, 255, 255))

        l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2
        axleoffset = cartheight / 4.0
        cartx = screen_width / 2.0
        carty = 150
        cart_coords = [(l, b), (l, t), (r, t), (r, b)]
        cart_coords = [(c[0] + cartx, c[1] + carty) for c in cart_coords]
        gfxdraw.aapolygon(self.surf, cart_coords, (0, 0, 0))
        gfxdraw.filled_polygon(self.surf, cart_coords, (0, 0, 0))

        l, r, t, b = (
            -polewidth / 2,
            polewidth / 2,
            polelen - polewidth / 2,
            -polewidth / 2,
        )

        pole_coords = []
        for coord in [(l, b), (l, t), (r, t), (r, b)]:
            coord = pygame.math.Vector2(coord).rotate_rad(-x[0])
            coord = (coord[0] + cartx, coord[1] + carty + axleoffset)
            pole_coords.append(coord)
        gfxdraw.aapolygon(self.surf, pole_coords, (202, 152, 101))
        gfxdraw.filled_polygon(self.surf, pole_coords, (202, 152, 101))

        gfxdraw.aacircle(
            self.surf,
            int(cartx),
            int(carty + axleoffset),
            int(polewidth / 2),
            (129, 132, 203),
        )
        gfxdraw.filled_circle(
            self.surf,
            int(cartx),
            int(carty + axleoffset),
            int(polewidth / 2),
            (129, 132, 203),
        )

        gfxdraw.hline(self.surf, 0, screen_width, carty, (0, 0, 0))
        self.surf = pygame.transform.flip(self.surf, False, True)
        self.screen.blit(self.surf, (0, 0))
        
        pygame.event.pump()
        self.clock.tick(50)
        pygame.display.flip()

    def close(self):
        if self.screen is not None:
            pygame.display.quit()
            pygame.quit()
            self.isopen = False


In [None]:
import numpy as np
# from myenv import MyEnv
import math
from tqdm import tqdm
%matplotlib inline
import matplotlib.pyplot as plt

class CartPoleSolver():
    
    def __init__(self, gamma=0.98, epsilon=0.99, alpha=0.25, episodes=1000, batch_size=20000, interval_num=100, error=1e-2):
        self.env = MyEnv()
        self.gamma = gamma # 折扣因子
        self.epsilon = epsilon # 贪婪策略参数
        self.alpha = alpha # 学习率
        self.episodes = episodes # 决策序列长度
        self.batch_size = batch_size # 训练次数
        self.interval_num = interval_num # 连续变量转离散变量分为几段
        self.error = error
        self.error_table = []

        self.pa_bin = np.linspace(-math.pi, math.pi, interval_num+1)[1: -1]
        self.pv_bin = np.linspace(-math.pi*15, math.pi*15, interval_num+1)[1: -1]

        # self.q_table = np.random.uniform(low=0, high=1, size=(interval_num**2, 3))
        self.q_table = np.zeros((interval_num**2, 3), dtype= np.float64)
        self.trail = np.zeros((interval_num**2, 3), dtype= np.float64)
        
    def get_state_index(self, observation):
        pole_angle, pole_v = observation
        
        state_index = 0
        state_index += np.digitize(pole_angle, bins = self.pa_bin) * self.interval_num
        state_index += np.digitize(pole_v, bins = self.pv_bin)
        
        return state_index
    
    def update_Q_table(self, observation, action, reward, next_observation):        
        state_index = self.get_state_index(observation)
        next_state_index = self.get_state_index(next_observation)
        # if self.trail[state_index, action] == 1:
        #     return
        # self.trail[state_index, action] = 1
        max_next = max(self.q_table[next_state_index][:])
        q_target = reward + self.gamma * max_next
        self.q_table[state_index, action] = self.q_table[state_index, action] + self.alpha * (q_target - self.q_table[state_index, action])
        
    def decide_action(self, observation, epsilon):
        
        state = self.get_state_index(observation)
        
        if epsilon <= np.random.uniform(0, 1):
            action = np.argmax(self.q_table[state][:])
        else:
            action = np.random.choice(3)
            
        return action
    def run(self, epsilon=0, quiet=True):
        # self.trail[self.trail > 0] = 0
        observation = self.env.reset()
        if not quiet:
            for t in range(1000):
                self.env.render()
                action = self.decide_action(observation, epsilon)
                next_observation, reward, _, _ = self.env.step(action, 0)
                observation = next_observation
            return
        for t in range(self.batch_size):
                # print(observation)
            # action = self.decide_action(observation, self.epsilon)
            action = self.decide_action(observation, epsilon)
            next_observation, reward, _, _ = self.env.step(action, 0)
            self.update_Q_table(observation, action, reward, next_observation)
            observation = next_observation
    
    def compute_error(self, a, b):
        return np.linalg.norm(np.mat(a)-np.mat(b))
        # error = a-b
        # return (error*error).max()
    
    def solve(self):
        epsilon = self.epsilon
        self.index = 0
        for i in tqdm(range(self.episodes)):
        # for i in range(self.episodes):
            self.index += 1
            q_table = self.q_table.copy()
            # epsilon = self.epsilon/self.index
            epsilon = epsilon * self.epsilon
            self.run(epsilon)
            error = self.compute_error(q_table, self.q_table)
            self.error_table.append(error)
            if error < self.error:
                print('经过%d次迭代后收敛'%self.index)
                break
            # else:
                # print('第%d次迭代，误差为%f，更新了%d个Q值'%(self.index, error,np.sum(self.trail)))
                # print('第%d次迭代，误差为%f'%(self.index, error))
        if self.index == self.episodes:
            print('到达最大迭代次数，此时变化值为%f'%error)
        
    def get_Q_table(self):
        for i in range(3):
            print('action: ', i)
            for j in range(self.interval_num**2):
                a = int(j/self.interval_num)
                b = j%self.interval_num
                print('angel: ', a, ', angel_v: ', b)
                print(self.q_table[j, i])

    def plot_error(self):
        x = range(self.index)
        y = self.error_table
        plt.title("Diff vs. Iteration Plot")
        plt.plot(x, y, label="Train_Loss_list")
        plt.xlabel("iteration")
        plt.ylabel("diff")
        plt.show()
    
    def plot_Q_table(self):
        def f(x, y, a):
            index = self.get_state_index(np.array((x, y), dtype=np.float32))
            return self.q_table[index, a]
        
        fig = plt.figure(figsize=(18, 6), facecolor='w')
        # x = np.arange(-math.pi, math.pi, self.interval_num)
        # y = np.arange(-15*math.pi, 15*math.pi, self.interval_num)
        
        x = np.linspace(-math.pi, math.pi, self.interval_num+1)[: -1]
        y = np.linspace(-15*math.pi, 15*math.pi, self.interval_num+1)[: -1]
        
        X, Y = np.meshgrid(x, y)
        
        # Z1 = self.q_table[:, 0].reshape(self.interval_num, -1)
        # Z2 = self.q_table[:, 1].reshape(self.interval_num, -1)
        # Z3 = self.q_table[:, 2].reshape(self.interval_num, -1)
        
        Z1 = f(X, Y, 0)
        Z2 = f(X, Y, 1)
        Z3 = f(X, Y, 2)
        # Z1 = np.arange(self.interval_num*self.interval_num).reshape(self.interval_num, -1)
        # Z2 = np.arange(self.interval_num*self.interval_num, -1).reshape(self.interval_num, -1)
        
        ax = fig.add_subplot(131, projection='3d')
        plt.title("Q table of action -15")
        ax.set_xlabel('theta[rad]')
        ax.set_ylabel('theta_v[rad/s]')
        ax.set_zlabel('Q_value')
        ax.plot_surface(X, Y, Z1, cmap='rainbow')
        # ax.contour(X, Y, Z1, zdim='z', offset=0, cmap='rainbow')
        
        ax = fig.add_subplot(132, projection='3d')
        plt.title("Q table of action 0")
        ax.set_xlabel('theta[rad]')
        ax.set_ylabel('theta_v[rad/s]')
        ax.set_zlabel('Q_value')
        ax.plot_surface(X, Y, Z2, cmap='rainbow')
        # ax.contour(X, Y, Z2, zdim='z', offset=0, cmap='rainbow')
        
        ax = fig.add_subplot(133, projection='3d')
        plt.title("Q table of action 15")
        ax.set_xlabel('theta[rad]')
        ax.set_ylabel('theta_v[rad/s]')
        ax.set_zlabel('Q_value')
        ax.plot_surface(X, Y, Z3, cmap='rainbow')
        # ax.contour(X, Y, Z3, zdim='z', offset=0, cmap='rainbow')
        
        plt.show()
            
    def plot_action_of_state(self):
        # plt.title("action of state")
        q0 = self.q_table[:, 0]
        q1 = self.q_table[:, 1]
        q2 = self.q_table[:, 2]
        action = self.q_table[:, 0]
        
        for i in range(self.interval_num*self.interval_num):
            action[i] = np.argmax([q0[i], q1[i], q2[i]])
        action = action.reshape(self.interval_num, -1)
        print()
        plt.matshow(action, cmap='rainbow')
        plt.colorbar()
        plt.show()
    
a = CartPoleSolver()
a.solve()
a.run(quiet=False)
a.plot_error()
a.plot_Q_table()
a.plot_action_of_state()

In [None]:
import numpy as np
# from myenv import MyEnv
import math
from tqdm import tqdm
%matplotlib inline
import matplotlib.pyplot as plt

class CartPoleSolver():
    
    def __init__(self, gamma=0.98, epsilon=0.99, alpha=0.25, episodes=1000, batch_size=20000, interval_num=100, error=1e-2):
        self.env = MyEnv()
        self.gamma = gamma # 折扣因子
        self.epsilon = epsilon # 贪婪策略参数
        self.alpha = alpha # 学习率
        self.episodes = episodes # 决策序列长度
        self.batch_size = batch_size # 训练次数
        self.interval_num = interval_num # 连续变量转离散变量分为几段
        self.error = error
        self.error_table = []

        self.pa_bin = np.linspace(-math.pi, math.pi, interval_num+1)[1: -1]
        self.pv_bin = np.linspace(-math.pi*15, math.pi*15, interval_num+1)[1: -1]

        # self.q_table = np.random.uniform(low=0, high=1, size=(interval_num**2, 3))
        self.q_table = np.zeros((interval_num**2, 3), dtype= np.float64)
        self.trail = np.zeros((interval_num**2, 3), dtype= np.float64)
        
    def get_state_index(self, observation):
        pole_angle, pole_v = observation
        
        state_index = 0
        state_index += np.digitize(pole_angle, bins = self.pa_bin) * self.interval_num
        state_index += np.digitize(pole_v, bins = self.pv_bin)
        
        return state_index
    
    def update_Q_table(self, observation, action, reward, next_observation, epsilon):        
        state_index = self.get_state_index(observation)
        next_state_index = self.get_state_index(next_observation)
        # if self.trail[state_index, action] == 1:
        #     return
        # self.trail[state_index, action] = 1
        # max_next = max(self.q_table[next_state_index][:])
        # q_target = reward + self.gamma * max_next
        a_next = self.decide_action(next_observation, epsilon)
        q_target = reward + self.gamma * self.q_table[next_state_index, a_next]
        self.q_table[state_index, action] = self.q_table[state_index, action] + self.alpha * (q_target - self.q_table[state_index, action])
        
    def decide_action(self, observation, epsilon):
        
        state = self.get_state_index(observation)
        
        if epsilon <= np.random.uniform(0, 1):
            action = np.argmax(self.q_table[state][:])
        else:
            action = np.random.choice(3)
            
        return action
    def run(self, epsilon=0, quiet=True):
        # self.trail[self.trail > 0] = 0
        observation = self.env.reset()
        if not quiet:
            for t in range(1000):
                self.env.render()
                action = self.decide_action(observation, epsilon)
                next_observation, reward, _, _ = self.env.step(action, 0)
                observation = next_observation
            return
        for t in range(self.batch_size):
                # print(observation)
            # action = self.decide_action(observation, self.epsilon)
            action = self.decide_action(observation, epsilon)
            next_observation, reward, _, _ = self.env.step(action, 0)
            self.update_Q_table(observation, action, reward, next_observation, epsilon)
            observation = next_observation
    
    def compute_error(self, a, b):
        return np.linalg.norm(np.mat(a)-np.mat(b))
        # error = a-b
        # return (error*error).max()
    
    def solve(self):
        epsilon = self.epsilon
        self.index = 0
        for i in tqdm(range(self.episodes)):
        # for i in range(self.episodes):
            self.index += 1
            q_table = self.q_table.copy()
            # epsilon = self.epsilon/self.index
            epsilon = epsilon * self.epsilon
            self.run(epsilon)
            error = self.compute_error(q_table, self.q_table)
            self.error_table.append(error)
            if error < self.error:
                print('经过%d次迭代后收敛'%self.index)
                break
            # else:
                # print('第%d次迭代，误差为%f，更新了%d个Q值'%(self.index, error,np.sum(self.trail)))
                # print('第%d次迭代，误差为%f'%(self.index, error))
        if self.index == self.episodes:
            print('到达最大迭代次数，此时变化值为%f'%error)
        
    def get_Q_table(self):
        for i in range(3):
            print('action: ', i)
            for j in range(self.interval_num**2):
                a = int(j/self.interval_num)
                b = j%self.interval_num
                print('angel: ', a, ', angel_v: ', b)
                print(self.q_table[j, i])

    def plot_error(self):
        x = range(self.index)
        y = self.error_table
        plt.title("Diff vs. Iteration Plot")
        plt.plot(x, y, label="Train_Loss_list")
        plt.xlabel("iteration")
        plt.ylabel("diff")
        plt.show()
    
    def plot_Q_table(self):
        def f(x, y, a):
            index = self.get_state_index(np.array((x, y), dtype=np.float32))
            return self.q_table[index, a]
        
        fig = plt.figure(figsize=(18, 6), facecolor='w')
        # x = np.arange(-math.pi, math.pi, self.interval_num)
        # y = np.arange(-15*math.pi, 15*math.pi, self.interval_num)
        
        x = np.linspace(-math.pi, math.pi, self.interval_num+1)[: -1]
        y = np.linspace(-15*math.pi, 15*math.pi, self.interval_num+1)[: -1]
        
        X, Y = np.meshgrid(x, y)
        
        # Z1 = self.q_table[:, 0].reshape(self.interval_num, -1)
        # Z2 = self.q_table[:, 1].reshape(self.interval_num, -1)
        # Z3 = self.q_table[:, 2].reshape(self.interval_num, -1)
        
        Z1 = f(X, Y, 0)
        Z2 = f(X, Y, 1)
        Z3 = f(X, Y, 2)
        # Z1 = np.arange(self.interval_num*self.interval_num).reshape(self.interval_num, -1)
        # Z2 = np.arange(self.interval_num*self.interval_num, -1).reshape(self.interval_num, -1)
        
        ax = fig.add_subplot(131, projection='3d')
        plt.title("Q table of action -15")
        ax.set_xlabel('theta[rad]')
        ax.set_ylabel('theta_v[rad/s]')
        ax.set_zlabel('Q_value')
        ax.plot_surface(X, Y, Z1, cmap='rainbow')
        # ax.contour(X, Y, Z1, zdim='z', offset=0, cmap='rainbow')
        
        ax = fig.add_subplot(132, projection='3d')
        plt.title("Q table of action 0")
        ax.set_xlabel('theta[rad]')
        ax.set_ylabel('theta_v[rad/s]')
        ax.set_zlabel('Q_value')
        ax.plot_surface(X, Y, Z2, cmap='rainbow')
        # ax.contour(X, Y, Z2, zdim='z', offset=0, cmap='rainbow')
        
        ax = fig.add_subplot(133, projection='3d')
        plt.title("Q table of action 15")
        ax.set_xlabel('theta[rad]')
        ax.set_ylabel('theta_v[rad/s]')
        ax.set_zlabel('Q_value')
        ax.plot_surface(X, Y, Z3, cmap='rainbow')
        # ax.contour(X, Y, Z3, zdim='z', offset=0, cmap='rainbow')
        
        plt.show()
            
    def plot_action_of_state(self):
        # plt.title("action of state")
        q0 = self.q_table[:, 0]
        q1 = self.q_table[:, 1]
        q2 = self.q_table[:, 2]
        action = self.q_table[:, 0]
        
        for i in range(self.interval_num*self.interval_num):
            action[i] = np.argmax([q0[i], q1[i], q2[i]])
        action = action.reshape(self.interval_num, -1)
        print()
        plt.matshow(action, cmap='rainbow')
        plt.colorbar()
        plt.show()
    
a = CartPoleSolver()
a.solve()
a.run(quiet=False)
a.plot_error()
a.plot_Q_table()
a.plot_action_of_state()