In [1]:
import sys

sys.argv = ['']
sys.path.append("../..")
from src.grid_world import GridWorld
from examples.agent import Agent
import torch.nn as nn
from torch.optim import Optimizer
from torch import optim
import random
import numpy as np
import collections
from tqdm import tqdm
import torch
import torch.nn.functional as F

## Q Network

输入：状态（二维，归一化），动作（二维，dx，dy）

输出：Q值

In [2]:
class QNet(nn.Module):
    def __init__(self, states_dim, actions_num, hidden_dim=128):
        super(QNet, self).__init__()
        # states_dim: 状态维度,normalized后维度
        # actions_num: 动作维度,one-hot编码后维度
        self.fc1 = nn.Linear(states_dim + actions_num, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

## 经验回放缓冲区


In [3]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity)

    def add(self, state, action, reward, next_state, next_action, done):
        self.buffer.append((state, action, reward, next_state, next_action, done))

    def sample(self, batch_size):
        transitions = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, next_actions, dones = zip(*transitions)

        return (
            np.array(states),
            np.array(actions),
            np.array(rewards),
            np.array(next_states),
            np.array(next_actions),
            np.array(dones)
        )

    def size(self):
        return len(self.buffer)

    def clear(self):
        self.buffer.clear()

### 测试经验回放缓冲区

In [4]:
# Test cell
buffer = ReplayBuffer(capacity=100)
buffer.add((0,0), (0, 1), 10, (0,1), (0, 1), False)
buffer.add((0,1), (1, 1), 10, (1,1), (1, 1), False)
print(buffer.sample(1))
buffer.clear()
print(buffer.size())

(array([[0, 1]]), array([[1, 1]]), array([10]), array([[1, 1]]), array([[1, 1]]), array([False]))
0


## Sarsa Value Agent

In [5]:
class SarsaValue(Agent):
    def __init__(self,
                 env, epsilon=0.1, gamma=0.99, alpha=0.1,
                 num_episodes=10, episode_length=1024,
                 batch_size=256, state_dim=2, action_dim=2,
                 num_epochs=10
                 ):
        super().__init__(
            env=env,
            epsilon=epsilon,
            gamma=gamma,
            num_episodes=num_episodes,
            episode_length=episode_length
        )
        self.env = env
        self.epsilon = epsilon
        self.gamma = gamma
        self.num_episodes = num_episodes
        self.episode_length = episode_length
        self.alpha = alpha

        self.x_col = int(env.env_size[0])
        self.y_row = int(env.env_size[1])
        self.num_actions = env.num_actions
        self.action_space = env.action_space
        self.state_dim = state_dim
        self.action_dim = action_dim

        self.QNet = QNet(states_dim=state_dim, actions_num=self.num_actions, hidden_dim=128)
        self.optimizer = optim.Adam(self.QNet.parameters(), lr=self.alpha)
        self.loss_fn = nn.MSELoss()
        self.num_epochs = num_epochs

        self.V = np.zeros((self.x_col, self.y_row))
        self.Q = np.zeros((self.x_col, self.y_row, self.num_actions))
        self.policy = np.ones((self.x_col, self.y_row, self.num_actions)) / self.num_actions

        self.batch_size = batch_size
        self.buffer = ReplayBuffer(capacity=episode_length*num_episodes)

    def take_action(self, state):
        """
        Take an action according to current policy
        """
        x, y = state  # state = (x, y)
        probs = self.policy[x, y]
        action_idx = np.random.choice(np.arange(self.num_actions), p=probs)
        return self.action_space[action_idx]

    def generate_episode(self):
        self.env.reset()
        s = self.env.start_state
        a = self.take_action(s)
        for t in range(self.episode_length):
            s_next, reward, done, _ = self.env.step(a)
            a_next = self.take_action(s_next)
            self.buffer.add(s, a, reward, s_next, a_next, done)
            s = s_next
            a = a_next
            if done:
                break
        return self.buffer

    def action2onehot(self, actions):
        """
        将动作转换为one-hot编码形式
        :param actions: (dx, dy), ...
        :return one-hot
        """
        dct = {action: idx for idx, action in enumerate(self.action_space)}
        indices = [dct[tuple(action)] for action in actions]
        one_hot = np.eye(self.num_actions)[indices]
        return one_hot

    def state_action_to_tensor(self, states, actions):
        """
        将状态和动作转换为张量形式,并将states归一, 动作one-hot编码
        :param states: (x, y), ...
        :param actions: (dx, dy), ...
        :return 拼接后的tensors，可以直接输入QNet
        """
        states = np.array(states, dtype=np.float32)
        states[:, 0] = states[:, 0] / (self.x_col - 1)
        states[:, 1] = states[:, 1] / (self.y_row - 1)
        states_tensor = torch.tensor(states, dtype=torch.float)

        actions_onehot = self.action2onehot(actions)   # (B, num_actions)
        actions_tensor = torch.tensor(actions_onehot, dtype=torch.float)

        state_action_tensor = torch.cat([states_tensor, actions_tensor], dim=1)
        return state_action_tensor

    def update_action_value(self):
        """
        使用当前QNet对整个状态-动作空间进行估计，并赋值给self.Q
        """
        for x in range(self.x_col):
            for y in range(self.y_row):
                for a in self.action_space:
                    state_action = self.state_action_to_tensor([(x, y)], [a])
                    with torch.no_grad():
                        q_value = self.QNet(state_action).item()
                    self.Q[x, y, self.action2idx(a)] = q_value
        return self.Q

    def update_policy(self):
        """
        根据当前self.Q估计ε-greedy策略，并赋值给self.policy
        """
        for x in range(self.x_col):
            for y in range(self.y_row):
                best_a = np.argmax(self.Q[x, y])
                for a in range(self.num_actions):
                    if a == best_a:
                        self.policy[x, y, a] = 1 - self.epsilon + self.epsilon / self.num_actions
                    else:
                        self.policy[x, y, a] = self.epsilon / self.num_actions
        return self.policy

    def update_state_value(self):
        """
        根据当前self.Q和self.policy计算状态值V(s)，并赋值给self.V
        """
        for x in range(self.x_col):
            for y in range(self.y_row):
                self.V[x, y] = np.sum(self.policy[x, y] * self.Q[x, y])
        return self.V

    def update_QNet(self):
        loss_avg = 0
        for epoch in range(self.num_epochs):

            states, actions, rewards, next_states, next_actions, dones = self.buffer.sample(self.buffer.size())

            # 转换为张量
            states_actions = self.state_action_to_tensor(states, actions)
            rewards_tensor = torch.tensor(rewards, dtype=torch.float).view(-1, 1)
            next_states_actions = self.state_action_to_tensor(next_states, next_actions)
            dones_tensor = torch.tensor(dones, dtype=torch.float).view(-1, 1)

            # 计算当前 Q 值
            q_values = self.QNet(states_actions)

            # 计算目标 Q 值
            with torch.no_grad():
                q_next_values = self.QNet(next_states_actions)
                td_target = rewards_tensor + self.gamma * q_next_values * (1 - dones_tensor)

            # 计算损失
            loss = self.loss_fn(q_values, td_target)

            # 反向传播
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            loss_avg += loss.item()

        return loss_avg/self.num_epochs

    def run(self):
        loss = 0
        for episode in range(self.num_episodes):
            # 清空缓冲区
            self.buffer.clear()
            # 生成1条episode并存入缓冲区
            self.generate_episode()
            # 进行多次小批量更新
            loss += self.update_QNet()
            # 更新Q值估计
            self.update_action_value()
            # 更新策略
            self.update_policy()
            # 更新状态值
            self.update_state_value()
            if (episode + 1) % 10 == 0:
                print(f"Episode {episode + 1} finished, loss: {loss/10:.4f}")
                loss = 0




In [7]:
env = GridWorld()
env.reward_step = 0
env.reward_target = 100

agent = SarsaValue(env, epsilon=0.1, gamma=0.99, num_episodes=1000, episode_length=1000, batch_size=512, num_epochs=100, alpha=1e-3)
agent.run()
# agent.render_static()

print("Final Policy:")
print(agent.get_policy())

Episode 10 finished.
Episode 20 finished.
Episode 30 finished.
Episode 40 finished.
Episode 50 finished.
Episode 60 finished.
Episode 70 finished.
Episode 80 finished.
Episode 90 finished.
Episode 100 finished.
Final Policy:
[[[0.02 0.92 0.02 0.02 0.02]
  [0.2  0.2  0.2  0.2  0.2 ]
  [0.2  0.2  0.2  0.2  0.2 ]
  [0.2  0.2  0.2  0.2  0.2 ]
  [0.2  0.2  0.2  0.2  0.2 ]
  [0.2  0.2  0.2  0.2  0.2 ]
  [0.2  0.2  0.2  0.2  0.2 ]]

 [[0.02 0.92 0.02 0.02 0.02]
  [0.2  0.2  0.2  0.2  0.2 ]
  [0.2  0.2  0.2  0.2  0.2 ]
  [0.2  0.2  0.2  0.2  0.2 ]
  [0.2  0.2  0.2  0.2  0.2 ]
  [0.2  0.2  0.2  0.2  0.2 ]
  [0.02 0.92 0.02 0.02 0.02]]

 [[0.02 0.92 0.02 0.02 0.02]
  [0.2  0.2  0.2  0.2  0.2 ]
  [0.02 0.92 0.02 0.02 0.02]
  [0.02 0.02 0.92 0.02 0.02]
  [0.2  0.2  0.2  0.2  0.2 ]
  [0.2  0.2  0.2  0.2  0.2 ]
  [0.02 0.92 0.02 0.02 0.02]]

 [[0.02 0.92 0.02 0.02 0.02]
  [0.2  0.2  0.2  0.2  0.2 ]
  [0.02 0.92 0.02 0.02 0.02]
  [0.2  0.2  0.2  0.2  0.2 ]
  [0.2  0.2  0.2  0.2  0.2 ]
  [0.2  0.2  0.