In [1]:
import gymnasium as gym
import random
import numpy as np
import collections
from tqdm import tqdm
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import rl_utils

In [2]:
class ReplayBuffer:
    def __init__(self,capacity):
        self.buffer=collections.deque(maxlen=capacity)
    def add(self,state,action,reward,next_state,done):
        self.buffer.append((state,action,reward,next_state,done))
    def sample(self,batchsize):
        transitions=random.sample(self.buffer,batch_size)
        state,action,reward,next_state,done=zip(*transitions)
        return np.array(state),action,reward,np.array(next_state),done
    def size(self):
        return len(self.buffer)

In [3]:
class Qnet(torch.nn.Module):
    def __init__(self,state_dim,hidden_dim,action_dim):
        super(Qnet, self).__init__()
        self.fc1=torch.nn.Linear(state_dim,hidden_dim)
        self.relu=torch.nn.ReLU()
        self.fc2=torch.nn.Linear(hidden_dim,action_dim)
    def forward(self,x):
        x=self.fc2(self.relu(self.fc1(x)))
        return x

In [4]:
net=Qnet(1,1,1)
print(net)

Qnet(
  (fc1): Linear(in_features=1, out_features=1, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=1, out_features=1, bias=True)
)


In [5]:
class DQN:
    def __init__(self,state_dim,hidden_dim,action_dim,learning_rate,gamma,epsilon,target_update,device):
        self.state_dim=state_dim
        self.hidden_dim=hidden_dim
        self.action_dim=action_dim
        self.learning_rate=learning_rate
        self.gamma=gamma
        self.epsilon=epsilon
        self.target_update=target_update
        self.device=device
        self.q_net=Qnet(self.state_dim,self.hidden_dim,self.action_dim).to(self.device)
        self.target_q_net=Qnet(self.state_dim,self.hidden_dim,self.action_dim).to(self.device)
        self.optimizer=torch.optim.Adam(params=self.q_net.parameters(),lr=self.learning_rate)
        self.count=0
    def take_action(self,state):
        if np.random.random()<self.epsilon:
            action=np.random.randint(self.action_dim)
        else:
            state=torch.tensor([state],dtype=torch.float).to(self.device)
            action=self.q_net(state).argmax().item()    #.item()将单个tensor转为标量
        return action
    def update(self,transition_dict):
        states=torch.tensor(transition_dict['states'],dtype=torch.float,device=self.device)
        actions=torch.tensor(transition_dict['actions'],device=self.device).view(-1,1)
        next_states=torch.tensor(transition_dict['next_states'],dtype=torch.float,device=self.device)
        rewards=torch.tensor(transition_dict['rewards'],device=self.device).view(-1,1)
        dones=torch.tensor(transition_dict['dones'],dtype=torch.float,device=self.device).view(-1,1)
        q_values=self.q_net(states).gather(1,actions)
        max_next_values=self.target_q_net(next_states).max(1)[0].view(-1,1)
        q_targets=rewards+gamma*max_next_values*(1-dones)
        dqn_loss=torch.mean(F.mse_loss(q_values,q_targets))
        self.optimizer.zero_grad()
        dqn_loss.backward()
        self.optimizer.step()
        if self.count%self.target_update == 0:
            self.target_q_net.load_state_dict(self.q_net.state_dict())
        self.count+=1

In [6]:
env=gym.make('CartPole-v1')
state_dim=env.observation_space.shape[0]
hidden_dim=128
action_dim=env.action_space.n
lr=2e-3
gamma=0.98
epsilon=0.01
target_update=1
num_episodes=500
device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
random.seed(666)
np.random.seed(666)
torch.manual_seed(666)
buffer_size=10000
minimal_size=500
batch_size=64
agent=DQN(state_dim,hidden_dim,action_dim,lr,gamma,epsilon,target_update,device)
replay_buffer=ReplayBuffer(buffer_size)
return_list=[]

In [7]:
for i in range(10):
    with tqdm(total=int(num_episodes/10),desc='第%d轮'%i) as pbar:
        for i_episode in range(int(num_episodes/10)):
            episode_return=0
            state=env.reset()[0]
            done=False
            while not done:
                action=agent.take_action(state)
                next_state,reward,done,_,_=env.step(action)
                replay_buffer.add(state,action,reward,next_state,done)
                state=next_state
                episode_return+=reward
                if replay_buffer.size()>minimal_size:
                    b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)
                    transition_dict={
                        "states":b_s,
                        "actions":b_a,
                        "rewards":b_r,
                        "next_states":b_ns,
                        "dones":b_d
                    }
                    agent.update(transition_dict)
            return_list.append(episode_return)
            if(i_episode+1)%10==0:
                pbar.set_postfix({
                    'episode':
                    '%d' % (num_episodes / 10 * i + i_episode + 1),
                    'return':
                    '%.3f' % np.mean(return_list[-10:])
                })
            pbar.update(1)
                    

  state=torch.tensor([state],dtype=torch.float).to(self.device)
第0轮: 100%|█████████████████████████████████████| 50/50 [00:01<00:00, 46.97it/s, episode=50, return=9.400]
第1轮: 100%|███████████████████████████████████| 50/50 [00:03<00:00, 13.90it/s, episode=100, return=18.700]
第2轮: 100%|███████████████████████████████████| 50/50 [00:03<00:00, 15.11it/s, episode=150, return=54.800]
第3轮: 100%|██████████████████████████████████| 50/50 [00:37<00:00,  1.33it/s, episode=200, return=331.300]
第4轮: 100%|██████████████████████████████████| 50/50 [00:40<00:00,  1.22it/s, episode=250, return=344.500]
第5轮: 100%|██████████████████████████████████| 50/50 [00:39<00:00,  1.26it/s, episode=300, return=305.500]
第6轮: 100%|██████████████████████████████████| 50/50 [00:35<00:00,  1.40it/s, episode=350, return=307.200]
第7轮: 100%|██████████████████████████████████| 50/50 [00:48<00:00,  1.03it/s, episode=400, return=300.100]
第8轮: 100%|██████████████████████████████████| 50/50 [00:42<00:00,  1.16it/s, episode=450

In [None]:
episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('DQN on {}'.format(env_name))
plt.show()

mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('DQN on {}'.format(env_name))
plt.show()

In [6]:
lr = 2e-3
num_episodes = 500
hidden_dim = 128
gamma = 0.98
epsilon = 0.01
target_update = 10
buffer_size = 10000
minimal_size = 500
batch_size = 64
device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "cpu")

env_name = 'CartPole-v1'
env = gym.make(env_name)
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
replay_buffer = ReplayBuffer(buffer_size)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
state_dim,action_dim
agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon,
            target_update, device)

return_list = []
for i in range(10):
    with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:
        for i_episode in range(int(num_episodes / 10)):
            episode_return = 0
            state = env.reset()[0]
            done = False
            while not done:
                
                action = agent.take_action(state)
                next_state, reward, done, _,_ = env.step(action)
                replay_buffer.add(state, action, reward, next_state, done)
                state = next_state
                
                episode_return += reward
                # 当buffer数据的数量超过一定值后,才进行Q网络训练
                if replay_buffer.size() > minimal_size:
                    b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)
                    transition_dict = {
                        'states': b_s,
                        'actions': b_a,
                        'next_states': b_ns,
                        'rewards': b_r,
                        'dones': b_d
                    }
                    agent.update(transition_dict)
            return_list.append(episode_return)
            if (i_episode + 1) % 10 == 0:
                pbar.set_postfix({
                    'episode':
                    '%d' % (num_episodes / 10 * i + i_episode + 1),
                    'return':
                    '%.3f' % np.mean(return_list[-10:])
                })
            pbar.update(1)

  state=torch.tensor([state],dtype=torch.float).to(self.device)
Iteration 0: 100%|███████████████████████████████| 50/50 [00:01<00:00, 48.87it/s, episode=50, return=9.400]
Iteration 1: 100%|██████████████████████████████| 50/50 [00:01<00:00, 49.69it/s, episode=100, return=9.900]
Iteration 2: 100%|█████████████████████████████| 50/50 [00:02<00:00, 22.26it/s, episode=150, return=38.200]
Iteration 3: 100%|████████████████████████████| 50/50 [00:12<00:00,  4.14it/s, episode=200, return=148.500]
Iteration 4:   4%|██▎                                                       | 2/50 [00:00<00:15,  3.03it/s]


KeyboardInterrupt: 

In [8]:
n=torch.rand(1,2)
n

tensor([[0.9274, 0.3905]])

In [11]:
n.max(1)[1].item()

0