  `cart pole` 这个环境很容易在一个地方找到最优解，但是移动到其他地方就不会了。
  
  所以采用把 episode 分成 10 个大块，再卡一下最高的时间，在其他地方也能搜到最优解。
  
  最开始的时候，我没卡时间，参数也没设置对，Return 图的方差特别大。

In [1]:
import gymnasium as gym
import torch
from ReplayBuffer import ReplayBuffer
from collections import defaultdict
from DQNAgent import DQNAgent
from tqdm import tqdm
import numpy as np
from lib.utils.draw import draw_line

In [2]:
# parameter
lr = 1e-3
hidden_dim = 128
gamma = 0.98
epsilon = 0.01
target_replace = 10
buffer_size = 10000
minimal_size = 500
batch_size = 64
SEED = 0
env = gym.make('CartPole-v1')



torch.manual_seed(SEED)
np.random.seed(SEED)

memory = ReplayBuffer(buffer_size)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

agent = DQNAgent(state_dim,hidden_dim, action_dim, lr, gamma, epsilon,target_replace, SEED)
stats = defaultdict(lambda: [])

In [3]:
env = gym.make('CartPole-v1')
num_episodes = 500
num_part = 10
seed_list = np.random.randint(num_episodes*10,size=num_episodes)
env_seed_list = [int(i) for i in seed_list]
for i in range(num_part):
    with tqdm( total= int(num_episodes/num_part), desc = 'Iteration %d' % i) as pbar:
        for i_episode in range( int(num_episodes/num_part) ):

            state = np.array(env.reset(seed=env_seed_list[i_episode])[0])
            done = False
            stats_rewards = 0
            stats_steps = 0
            eps = 0
            while not done:
                action = agent.get_action(state,i_episode,num_episodes/num_part)
                # env.step() 的返回为一个 ()
                # (array([-0.18629785, -1.7418021 ,  0.23948981,  2.7527318 ], dtype=float32), 1.0, True, False, {})
                result = env.step(action)
                next_state, reward, done  =  result[0], result[1], result[2]
                memory.add(state, action, reward, next_state, done)
                state = next_state
                # stats update
                stats_rewards += reward
                stats_steps += 1
                
                if len(memory) > minimal_size:
                    b_s,b_a,b_r,n_ns,b_d = memory.sample(batch_size)
                    transition_dict = {
                        'states': b_s,
                        'actions': b_a,
                        'rewards': b_r,
                        'next_states': n_ns,
                        'dones': b_d,
                    }
                    agent.update(transition_dict)
                # if stats_steps > 800:
                #     break
            stats['rewards'].append(stats_rewards)
            stats['steps'].append(stats_steps)
            # tqdm
            if (i_episode + 1) % 10 == 0:
                pbar.set_postfix({
                    'episode': '%d' % (num_episodes / 10 * i + i_episode + 1),
                    'return': '%.3f' % np.mean(stats['rewards'][-10:]),
                })
            pbar.update(1)



In [None]:
draw_line(stats)

In [None]:
env = gym.make('CartPole-v1',render_mode='human')


test_episode = 10
STATS = defaultdict(lambda: [])
for i in range(test_episode):
    stats_rewards = 0
    stats_steps = 0
    state = env.reset()[0]
    done = False
    while not done:
        action = agent.predict_action(state)
        result = env.step(action)
        next_state, reward, done  =  result[0], result[1], result[2]

        state = next_state
        #stats update
        stats_rewards += reward
        stats_steps += 1

    STATS['rewards'].append(stats_rewards)
    STATS['steps'].append(stats_steps)
    
draw_line(STATS)