In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import gym
import random
from collections import deque
from tqdm import tqdm
import rl_utils

In [4]:
class QNet(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNet, self).__init__()
        self.fc1 = nn.Linear(state_dim, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, action_dim)
        
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return x

In [5]:
class DQN:
    def __init__(self, state_dim, action_dim, target_update, device, gamma=0.99, lr=1e-3):
        self.device = device
        self.action_dim = action_dim
        self.gamma = gamma
        
        self.q_net = QNet(state_dim, action_dim).to(device)
        self.target_q_net = QNet(state_dim, action_dim).to(device)
        self.target_q_net.load_state_dict(self.q_net.state_dict())
        
        self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)
        
        self.target_update = target_update #Target network update interval. If target_update = t,then when q_net is updated t times, target_q_net is updated once
        self.count = 0 #Counter.  record q_net update times

    def take_action(self, state, epsilon):
        # epsilon-greedy
        ############################
        # YOUR IMPLEMENTATION HERE #
        ############################

        if np.random.random() < epsilon:
            # 探索：随机选择动作
            action = np.random.randint(0, self.action_dim)
        else:
            # 利用：选择Q值最大的动作
            state = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device)
            with torch.no_grad():
                q_values = self.q_net(state)
                action = q_values.argmax().item()
        return action
         
    
    def update(self, replay_buffer, batch_size):
        states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)
        
        states = torch.tensor(states, dtype=torch.float32).to(self.device)
        actions = torch.tensor(actions, dtype=torch.int64).to(self.device)
        rewards = torch.tensor(rewards, dtype=torch.float32).to(self.device)
        next_states = torch.tensor(next_states, dtype=torch.float32).to(self.device)
        dones = torch.tensor(dones, dtype=torch.float32).to(self.device)
        
        ############################
        # YOUR IMPLEMENTATION HERE #
        ############################
        # 从当前策略网络中获取Q值
        q_values = self.q_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
        
        # 从目标网络中获取下一个状态的最大Q值
        with torch.no_grad():
            next_q_values = self.target_q_net(next_states).max(1)[0]
            # 计算目标Q值：r + γ * max_a' Q(s', a')，如果是终止状态则只有即时奖励
            target_q_values = rewards + self.gamma * next_q_values * (1 - dones)
        
        # 计算损失：均方误差
        loss = nn.MSELoss()(q_values, target_q_values)
        
        # 更新策略网络
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # 检查是否需要更新目标网络
        if self.count % self.target_update == 0:
            self.target_q_net.load_state_dict(self.q_net.state_dict())
            
            
        self.count += 1
        return loss.item()
    

In [6]:
# One-hot
def one_hot(state, state_dim):
    vec = np.zeros(state_dim, dtype=np.float32)
    vec[state] = 1.0
    return vec

In [None]:
env_name = "Taxi-v2"
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
env = gym.make(env_name)
env.seed(0)
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)


obs_space = env.observation_space.n
action_space = env.action_space.n

# parameters
total_episodes = 5000
episodes_per_iteration = 100
iterations = total_episodes // episodes_per_iteration

batch_size = 64
buffer_size = 20000
min_buffer_size = 5000
epsilon_start = 1.0
epsilon_end = 0.1
epsilon_decay = 0.999  
update_freq = 1
target_update = 10

replay_buffer = rl_utils.ReplayBuffer(buffer_size)
agent = DQN(obs_space, action_space, target_update, device, gamma=0.99, lr=0.0001)

return_list = []

epsilon = epsilon_start
total_steps = 0

# tqdm
for i_iter in range(iterations):
    with tqdm(range(episodes_per_iteration), desc=f"Iteration {i_iter}", ncols=100) as pbar:
        for i_episode_in_iter in pbar:
            i_episode = i_iter * episodes_per_iteration + i_episode_in_iter
            state = env.reset()
            state_vec = one_hot(state, obs_space)
            done = False
            episode_reward = 0
            episode_length = 0
            
            while not done:
                total_steps += 1
                episode_length += 1
                action = agent.take_action(state_vec, epsilon)
                next_state, reward, done, info = env.step(action)
                next_state_vec = one_hot(next_state, obs_space)
                
                replay_buffer.add(state_vec, action, reward, next_state_vec, done)
                state_vec = next_state_vec
                episode_reward += reward
                
                if replay_buffer.size() > min_buffer_size and total_steps % update_freq == 0:
                    loss = agent.update(replay_buffer, batch_size)
            
            epsilon = max(epsilon_end, epsilon * epsilon_decay)
            return_list.append(episode_reward)
    
    avg_return = np.mean(return_list[-episodes_per_iteration:])
    print(f"Episode: {(i_iter+1)*episodes_per_iteration}, Average Return: {avg_return:.2f}")


torch.save(agent.q_net.state_dict(), "dqn_taxi.pth")

  result = entry_point.load(False)
Iteration 0: 100%|████████████████████████████████████████████████| 100/100 [01:34<00:00,  1.06it/s]


Episode: 100, Average Return: -757.83


Iteration 1: 100%|████████████████████████████████████████████████| 100/100 [01:45<00:00,  1.06s/it]


Episode: 200, Average Return: -558.22


Iteration 2: 100%|████████████████████████████████████████████████| 100/100 [00:48<00:00,  2.06it/s]


Episode: 300, Average Return: -218.72


Iteration 3: 100%|████████████████████████████████████████████████| 100/100 [00:31<00:00,  3.18it/s]


Episode: 400, Average Return: -113.95


Iteration 4: 100%|████████████████████████████████████████████████| 100/100 [00:27<00:00,  3.64it/s]


Episode: 500, Average Return: -92.80


Iteration 5: 100%|████████████████████████████████████████████████| 100/100 [00:23<00:00,  4.31it/s]


Episode: 600, Average Return: -65.02


Iteration 6: 100%|████████████████████████████████████████████████| 100/100 [00:18<00:00,  5.49it/s]


Episode: 700, Average Return: -42.08


Iteration 7: 100%|████████████████████████████████████████████████| 100/100 [00:18<00:00,  5.32it/s]


Episode: 800, Average Return: -40.24


Iteration 8: 100%|████████████████████████████████████████████████| 100/100 [00:15<00:00,  6.35it/s]


Episode: 900, Average Return: -28.12


Iteration 9: 100%|████████████████████████████████████████████████| 100/100 [00:14<00:00,  6.78it/s]


Episode: 1000, Average Return: -22.64


Iteration 10: 100%|███████████████████████████████████████████████| 100/100 [00:15<00:00,  6.65it/s]


Episode: 1100, Average Return: -19.51


Iteration 11: 100%|███████████████████████████████████████████████| 100/100 [00:13<00:00,  7.57it/s]


Episode: 1200, Average Return: -13.82


Iteration 12: 100%|███████████████████████████████████████████████| 100/100 [00:12<00:00,  7.83it/s]


Episode: 1300, Average Return: -10.74


Iteration 13: 100%|███████████████████████████████████████████████| 100/100 [00:13<00:00,  7.60it/s]


Episode: 1400, Average Return: -8.69


Iteration 14: 100%|███████████████████████████████████████████████| 100/100 [00:11<00:00,  8.38it/s]


Episode: 1500, Average Return: -6.02


Iteration 15: 100%|███████████████████████████████████████████████| 100/100 [00:10<00:00,  9.10it/s]


Episode: 1600, Average Return: -4.01


Iteration 16: 100%|███████████████████████████████████████████████| 100/100 [00:11<00:00,  8.73it/s]


Episode: 1700, Average Return: -3.96


Iteration 17: 100%|███████████████████████████████████████████████| 100/100 [00:11<00:00,  8.81it/s]


Episode: 1800, Average Return: -0.21


Iteration 18: 100%|███████████████████████████████████████████████| 100/100 [00:11<00:00,  8.89it/s]


Episode: 1900, Average Return: -1.90


Iteration 19: 100%|███████████████████████████████████████████████| 100/100 [00:10<00:00,  9.16it/s]


Episode: 2000, Average Return: 0.82


Iteration 20: 100%|███████████████████████████████████████████████| 100/100 [00:10<00:00,  9.28it/s]


Episode: 2100, Average Return: 1.44


Iteration 21: 100%|███████████████████████████████████████████████| 100/100 [00:10<00:00,  9.53it/s]


Episode: 2200, Average Return: 2.73


Iteration 22: 100%|███████████████████████████████████████████████| 100/100 [00:10<00:00,  9.72it/s]


Episode: 2300, Average Return: 4.05


Iteration 23: 100%|███████████████████████████████████████████████| 100/100 [00:10<00:00,  9.74it/s]


Episode: 2400, Average Return: 3.05


Iteration 24: 100%|███████████████████████████████████████████████| 100/100 [00:09<00:00, 10.11it/s]


Episode: 2500, Average Return: 3.41


Iteration 25: 100%|███████████████████████████████████████████████| 100/100 [00:09<00:00, 10.07it/s]


Episode: 2600, Average Return: 3.60


Iteration 26: 100%|███████████████████████████████████████████████| 100/100 [00:09<00:00, 10.11it/s]


Episode: 2700, Average Return: 4.13


Iteration 27: 100%|███████████████████████████████████████████████| 100/100 [00:10<00:00,  9.67it/s]


Episode: 2800, Average Return: 3.21


Iteration 28: 100%|███████████████████████████████████████████████| 100/100 [00:10<00:00,  9.53it/s]


Episode: 2900, Average Return: 2.52


Iteration 29: 100%|███████████████████████████████████████████████| 100/100 [00:09<00:00, 10.00it/s]


Episode: 3000, Average Return: 3.37


Iteration 30: 100%|███████████████████████████████████████████████| 100/100 [00:10<00:00,  9.62it/s]


Episode: 3100, Average Return: 2.83


Iteration 31: 100%|███████████████████████████████████████████████| 100/100 [00:10<00:00,  9.52it/s]


Episode: 3200, Average Return: 2.67


Iteration 32: 100%|███████████████████████████████████████████████| 100/100 [00:10<00:00,  9.35it/s]


Episode: 3300, Average Return: 2.08


Iteration 33: 100%|███████████████████████████████████████████████| 100/100 [00:10<00:00,  9.68it/s]


Episode: 3400, Average Return: 3.46


Iteration 34: 100%|███████████████████████████████████████████████| 100/100 [00:10<00:00,  9.54it/s]


Episode: 3500, Average Return: 2.94


Iteration 35: 100%|███████████████████████████████████████████████| 100/100 [00:09<00:00, 10.00it/s]


Episode: 3600, Average Return: 3.72


Iteration 36: 100%|███████████████████████████████████████████████| 100/100 [00:11<00:00,  9.06it/s]


Episode: 3700, Average Return: 3.41


Iteration 37: 100%|███████████████████████████████████████████████| 100/100 [00:10<00:00,  9.60it/s]


Episode: 3800, Average Return: 4.04


Iteration 38: 100%|███████████████████████████████████████████████| 100/100 [00:10<00:00,  9.24it/s]


Episode: 3900, Average Return: 3.61


Iteration 39: 100%|███████████████████████████████████████████████| 100/100 [00:10<00:00,  9.30it/s]


Episode: 4000, Average Return: 4.51


Iteration 40: 100%|███████████████████████████████████████████████| 100/100 [00:11<00:00,  8.85it/s]


Episode: 4100, Average Return: 2.77


Iteration 41: 100%|███████████████████████████████████████████████| 100/100 [00:11<00:00,  8.97it/s]


Episode: 4200, Average Return: 2.89


Iteration 42: 100%|███████████████████████████████████████████████| 100/100 [00:11<00:00,  8.96it/s]


Episode: 4300, Average Return: 3.43


Iteration 43: 100%|███████████████████████████████████████████████| 100/100 [00:11<00:00,  8.87it/s]


Episode: 4400, Average Return: 3.58


Iteration 44: 100%|███████████████████████████████████████████████| 100/100 [00:10<00:00,  9.31it/s]


Episode: 4500, Average Return: 2.80


Iteration 45:  27%|████████████▉                                   | 27/100 [00:02<00:06, 10.65it/s]

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(5, 5))

# rewards
mv_return = rl_utils.moving_average(return_list, 9)
axes.plot(mv_return)
axes.set_title("Episode Rewards")
axes.set_xlabel("Episode")
axes.set_ylabel("Reward")
axes.grid(True)

plt.tight_layout()
plt.show()