In [279]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import gym

#宣告繪製動畫的函數
from JSAnimation.IPython_display import display_animation
from matplotlib import animation
from IPython.display import display

def display_frames_as_gif(frames):
    plt.figure(figsize=(frames[0].shape[1]/72.0, frames[0].shape[0]/72.0),dpi=72)
    patch=plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])
    anim=animation.FuncAnimation(plt.gcf(),animate,frames=len(frames), interval=50)

    anim.save('movie_cartpole.mp4')
    display(display_animation(anim, default_mode='loop'))

#設定常數
Env='CartPole-v1'
num_discrete=6
Gamma=0.99
eta=0.5
max_step=1000
num_episode=10000

In [280]:
class Agent:
    def __init__(self, num_states,num_actions):
        self.brain=Brain(num_states,num_actions)
    def update_Q_function(self,observation,action,reward,observation_next):
        self.brain.update_Q_table(observation,action,reward,observation_next)
    def get_action(self,observation,step):
        action=self.brain.decide_action(observation,step)
        return action 

In [281]:
class Brain:
    def __init__(self,num_states,num_actions):
        self.num_actions=num_actions
        self.q_table=np.random.uniform(low=0, high=1, size=(num_discrete**num_states,num_actions))

    #計算離散化
    def discrete(self,min,max,num):
        return np.linspace(min,max,num+1)[1:-1]

    def discrete_state(self,observation):
        cart_pos, cart_v, pole_angle, pole_v=observation
        discrete=[
            np.digitize(cart_pos,bins=self.discrete(-2.4,2.4, num_discrete)),
            np.digitize(cart_v,bins=self.discrete(-3.0,3.0, num_discrete)),
            np.digitize(pole_angle,bins=self.discrete(-0.5,0.5, num_discrete)),
            np.digitize(pole_v,bins=self.discrete(-2.0,2.0, num_discrete)),
        ]
        return sum([x*(num_discrete**i) for i, x in enumerate(discrete)])
        
    def update_Q_table(self,observation,action,reward, observation_next):
        state=self.discrete_state(observation)
        state_next=self.discrete_state(observation_next)

        Max_Q_next=max(self.q_table[state_next][:])
        self.q_table[state,action]=self.q_table[state,action]+eta*(reward+Gamma*Max_Q_next-self.q_table[state,action])

    def decide_action(self,observation,episode):
        state=self.discrete_state(observation)
        epsilon=(0.5)*(1/(episode+1))

        if epsilon<= np.random.rand(): #用random.uniform(0,1)訓練一直不成功，改成rand()就ok了
            action=np.argmax(self.q_table[state][:])
        else:
            action=np.random.choice(self.num_actions)
        return action

In [282]:
class Environment:
    def __init__(self):
        self.env=gym.make(Env,render_mode="rgb_array")
        num_states=self.env.observation_space.shape[0]
        num_actions=self.env.action_space.n

        self.agent=Agent(num_states,num_actions)
    def run(self):
        complete_episodes=0
        is_final=False
        frames=[]

        for episode in range(num_episode):
            observation, info=self.env.reset()
            
            for step in range(max_step):
                if is_final is True:
                    frames.append(self.env.render())
                    
                action=self.agent.get_action(observation,episode)

                observation_next,_,done,_,_=self.env.step(action)

                if done:
                    if step<200:
                        reward=-1
                        complete_episodes=0
                    else:
                        reward=1
                        complete_episodes+=1
                else:
                    reward=0

                self.agent.update_Q_function(observation,action,reward,observation_next)

                observation=observation_next

                if done:
                    print('{0} Episode: Finished after {1} steps'.format(episode,step+1))
                    break
                
            if is_final is True and step>=300:
                display_frames_as_gif(frames)
                break
            if complete_episodes>=15:
                print('連續成功15次!')
                is_final=True

In [None]:
#Main
cartpole_env=Environment()
cartpole_env.run()