### A2C：

![](assets/305.jpg)

### 例子

![](assets/256.jpg)

### 上述例子代码实现

例子：
使用gym仿真库，gym官网: https://www.gymlibrary.dev/environments/toy_text/frozen_lake/

In [2]:
!pip install gym==0.15.4
!pip install numpy
!pip install torch
!pip install tqdm

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple


In [None]:
import math
import random
from copy import deepcopy
from collections import namedtuple
import numpy as np
from tqdm import tqdm
import gym
from gym import spaces
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F


one_step_experience = namedtuple('one_step_experience', field_names=['current_observation', 'current_action', 'reward', 'next_observation'])


class CustomGridWorld(gym.Env):
    def __init__(self, grid_size=(5, 5), goal_position=(3, 2), forbidden_grids=None, action_space=5,
                 forbidden_grids_penalty=-2, tgt_grid_reward=10, step_penalty=-1):
        super(CustomGridWorld, self).__init__()
        self.grid_size = grid_size  # (rows, cols)
        self.goal_position = goal_position
        self.forbidden_grids_penalty = forbidden_grids_penalty
        self.tgt_grid_reward = tgt_grid_reward
        self.step_penalty = step_penalty
        self.action_space = spaces.Discrete(action_space)
        self.observation_space = spaces.Discrete(grid_size[0] * grid_size[1])
        self.state = (0, 0)
        self.done = False
        if forbidden_grids is None:
            forbidden_grids = [(1, 1), (1, 2), (2, 2), (3, 1), (3, 3), (4, 1)]
        self.forbidden_grids = set(forbidden_grids)

    def _get_state(self, observation):
        cols = self.grid_size[1]
        return (observation // cols, observation % cols)

    def reset(self, init_observation=0):
        self.state = self._get_state(init_observation)
        self.done = False
        return self._get_observation()

    def step(self, action):
        x, y = self.state
        if action == 0:  # Up
            new_x, new_y = max(0, x - 1), y
        elif action == 1:  # Right
            new_x, new_y = x, min(self.grid_size[1] - 1, y + 1)
        elif action == 2:  # Down
            new_x, new_y = min(self.grid_size[0] - 1, x + 1), y
        elif action == 3:  # Left
            new_x, new_y = x, max(0, y - 1)
        elif action == 4:  # Stay
            new_x, new_y = x, y
        else:
            new_x, new_y = x, y

        if new_x < 0 or new_x >= self.grid_size[0] or new_y < 0 or new_y >= self.grid_size[1]:
            reward = self.step_penalty * 2  # Extra penalty for hitting wall, but not as severe as forbidden area
            self.state = (x, y)
        else:
            self.state = (new_x, new_y)
            if self.state == self.goal_position:
                reward = self.tgt_grid_reward
                self.done = True
            elif self.state in self.forbidden_grids:
                reward = self.forbidden_grids_penalty
            else:
                reward = self.step_penalty

        return self._get_observation(), reward, self.done, {}

    def render(self, mode='human'):
        grid = np.full(self.grid_size, '.', dtype=object)
        grid[self.goal_position] = 'G'
        for f in self.forbidden_grids:
            grid[f] = 'H'
        for row in grid:
            print(' '.join(row))

    def _get_observation(self):
        return self.state[0] * self.grid_size[1] + self.state[1]

    def close(self):
        pass

    def vis_policy(self, q_table):
        rows, cols = self.grid_size
        action_maps = {0: '↑', 1: '→', 2: '↓', 3: '←', 4: '⊙'}
        policy = np.full(self.grid_size, '⊙', dtype=object)
        for row in range(rows):
            for col in range(cols):
                index = row * cols + col
                if (row, col) == self.goal_position:
                    policy[row, col] = 'G'
                else:
                    action = int(np.argmax(q_table[index]))
                    policy[row, col] = action_maps.get(action, '?')
        print(policy)


class ActorModel(nn.Module):
    def __init__(self, input_dim: int, embed_dim: int, hidden_layers: list[int], output_dim: int):
        super().__init__()
        self.embed = nn.Embedding(input_dim, embed_dim)
        layers = []
        in_dim = embed_dim
        for h in hidden_layers:
            layers += [nn.Linear(in_dim, h), nn.ReLU()]
            in_dim = h
        self.net = nn.Sequential(*layers)
        self.out = nn.Linear(in_dim, output_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        e = self.embed(x).squeeze(1) if x.dim() == 2 else self.embed(x)
        h = self.net(e)
        logits = self.out(h)
        probs = self.softmax(logits)
        return probs


class CriticModel(nn.Module):
    def __init__(self, input_dim: int, embed_dim: int, hidden_layers: list[int]):
        super().__init__()
        self.embed = nn.Embedding(input_dim, embed_dim)
        layers = []
        in_dim = embed_dim
        for h in hidden_layers:
            layers += [nn.Linear(in_dim, h), nn.ReLU()]
            in_dim = h
        self.net = nn.Sequential(*layers)
        self.out = nn.Linear(in_dim, 1)  # Output single value V(s)

    def forward(self, x):
        e = self.embed(x).squeeze(1) if x.dim() == 2 else self.embed(x)
        h = self.net(e)
        value = self.out(h).squeeze(-1)  # Return scalar value
        return value


class A2CSolver:
    def __init__(self, grid_size: tuple, goal_position: tuple, forbidden_grids: list[tuple], action_space: int,
                 hidden_layers: list[int], device: torch.device, actor_lr: float = 1e-3, critic_lr: float = 1e-3,
                 embed_dim: int = 32, forbidden_grids_penalty: int = -2, tgt_grid_reward: int = 10, step_penalty: int = -1):
        self.device = device
        self.grid_size = grid_size
        self.action_space = action_space
        self._init_env(grid_size, goal_position, forbidden_grids, action_space,
                      forbidden_grids_penalty, tgt_grid_reward, step_penalty)
        self._init_model(hidden_layers, embed_dim)
        self._init_trainer(actor_lr, critic_lr)

    def _init_env(self, grid_size: tuple, goal_position: tuple, forbidden_grids: list[tuple], action_space: int = 5,
                  forbidden_grids_penalty: int = -2, tgt_grid_reward: int = 10, step_penalty: int = -1):
        self.env = CustomGridWorld(grid_size=grid_size, goal_position=goal_position, forbidden_grids=forbidden_grids,
                                  action_space=action_space, forbidden_grids_penalty=forbidden_grids_penalty,
                                  tgt_grid_reward=tgt_grid_reward, step_penalty=step_penalty)
        self.n_observations = self.env.observation_space.n

    def _init_model(self, hidden_layers: list[int], embed_dim: int):
        self.actor = ActorModel(self.n_observations, embed_dim, hidden_layers, self.action_space).to(self.device)
        self.critic = CriticModel(self.n_observations, embed_dim, hidden_layers).to(self.device)

    def _init_trainer(self, actor_lr: float, critic_lr: float):
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr)

    def _generate_episode(self, n_steps: int, random_start: bool = False):
        start = 0 if not random_start else random.choice(list(range(self.n_observations)))
        current_observation = self.env.reset(start)
        episode_data = []

        for _ in range(n_steps):
            state = torch.tensor([current_observation], dtype=torch.long, device=self.device)
            action_probs = self.actor(state)
            dist = torch.distributions.Categorical(action_probs)
            action = dist.sample().squeeze(0)
            log_prob = dist.log_prob(action)

            next_observation, reward, done, _ = self.env.step(int(action.item()))
            episode_data.append((current_observation, action, log_prob, reward, next_observation))

            current_observation = next_observation
            if done:
                break

        return episode_data

    def solve(self, n_steps: int, n_episodes: int, gamma: float, random_start: bool = False, vis_policy: bool = True, log_iters: int = 100):
        self.actor.train()
        self.critic.train()
        pbar = tqdm(range(n_episodes))
        for n_episode in pbar:
            episode_data = self._generate_episode(n_steps, random_start)
            if len(episode_data) <= 1:
                continue

            # ====== Advantage Actor-Critic (A2C) ======
            critic_losses, actor_losses = [], []
            rewards = [r for *_, r, _ in episode_data]

            # Compute discounted returns
            returns = []
            G = 0
            for r in reversed(rewards):
                G = r + gamma * G
                returns.insert(0, G)
            returns = torch.tensor(returns, dtype=torch.float32, device=self.device)

            # Compute values and advantages
            values = []
            for t, (s, a, log_prob, r, s_next) in enumerate(episode_data):
                state_tensor = torch.tensor([s], dtype=torch.long, device=self.device)
                value = self.critic(state_tensor)
                values.append(value)

            values = torch.stack(values)
            advantages = returns - values.detach()

            # Update Critic and Actor
            for t, (s, a, log_prob, r, s_next) in enumerate(episode_data):
                # Critic loss: minimize MSE between predicted value and actual return
                critic_loss = (values[t] - returns[t]).pow(2)
                critic_losses.append(critic_loss)

                # Actor loss: policy gradient using advantage
                # Add entropy bonus for exploration
                entropy_bonus = 0.01 * torch.log(torch.tensor(self.action_space, dtype=torch.float32))
                actor_loss = -log_prob * advantages[t] + entropy_bonus
                actor_losses.append(actor_loss)

            # Batch update
            if critic_losses and actor_losses:
                critic_loss = torch.stack(critic_losses).mean()
                actor_loss = torch.stack(actor_losses).mean()

                self.critic_optimizer.zero_grad()
                critic_loss.backward()
                self.critic_optimizer.step()

                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                self.actor_optimizer.step()

            if (n_episode + 1) % log_iters == 0:
                avg_reward = np.mean(rewards) if rewards else 0
                pbar.set_description(
                    f'Episode {n_episode + 1}/{n_episodes}, critic_loss: {critic_loss.item():.4f}, actor_loss: {actor_loss.item():.4f}, avg_reward: {avg_reward:.3f}'
                )

        print("Training Done!")
        if vis_policy:
            print('Rendering final policy...')
            self.vis_policy()
        print('All done!')

    def create_fake_qtable(self):
        self.actor.eval()
        fake_q_table = np.zeros([self.n_observations, self.action_space], dtype=float)
        with torch.no_grad():
            states = torch.arange(self.n_observations, dtype=torch.long, device=self.device)
            action_probs = []
            batch = 256
            for i in range(0, len(states), batch):
                b = states[i:i + batch].unsqueeze(1)
                probs = self.actor(b)
                action_probs.append(probs.cpu().numpy())
            action_probs = np.vstack(action_probs)
            # Use action probabilities as fake Q-values for visualization
            fake_q_table[:] = action_probs
        return fake_q_table

    def vis_policy(self):
        self.env.render()
        fake_q_table = self.create_fake_qtable()
        self.env.vis_policy(fake_q_table)
        self.analyze_strategy()

    def analyze_strategy(self):
        """Analyze the learned strategy in detail"""
        print("\n" + "="*50)
        print("STRATEGY ANALYSIS")
        print("="*50)

        # Test policy from all starting positions
        self.actor.eval()
        success_count = 0
        forbidden_hits = 0
        total_steps = 0
        stuck_positions = 0

        for start in range(self.n_observations):
            obs = self.env.reset(start)
            path = [obs]  # observations only
            actions = []
            steps = 0
            hit_forbidden = False
            stuck_count = 0

            for step in range(200):  # Increased max steps for complex navigation
                state_tensor = torch.tensor([obs], dtype=torch.long, device=self.device)
                with torch.no_grad():
                    action_probs = self.actor(state_tensor)
                    action = action_probs.argmax(dim=1).item()

                old_obs = obs
                obs, reward, done, _ = self.env.step(action)
                path.append(obs)
                actions.append(action)
                steps += 1

                # Check if stuck in loop
                if obs == old_obs:
                    stuck_count += 1
                    if stuck_count > 10:  # Consider stuck after 10 non-moves
                        stuck_positions += 1
                        break
                else:
                    stuck_count = 0

                if reward == self.env.forbidden_grids_penalty:
                    hit_forbidden = True
                    forbidden_hits += 1

                if done:
                    if reward == self.env.tgt_grid_reward:
                        success_count += 1
                    break

            total_steps += steps

            # Print detailed analysis for first few starting positions
            if start < 5:
                start_pos = self.env._get_state(start)
                final_success = reward == self.env.tgt_grid_reward if 'reward' in locals() else False
                print(f"Start {start} ({start_pos}): Steps={steps}, Success={final_success}, Forbidden={hit_forbidden}, Stuck={steps >= 200}")
                print(f"  Path: {path[:15]}...")
                print(f"  Actions: {actions[:15]}...")

                # Show if path goes through forbidden areas
                path_coords = [self.env._get_state(obs) for obs in path]
                forbidden_in_path = [coord for coord in path_coords if coord in self.env.forbidden_grids]
                if forbidden_in_path:
                    print(f"  Forbidden areas visited: {forbidden_in_path}")

        success_rate = success_count / self.n_observations
        forbidden_rate = forbidden_hits / self.n_observations
        stuck_rate = stuck_positions / self.n_observations
        avg_steps = total_steps / self.n_observations if total_steps > 0 else 200

        print(f"\nOverall Performance:")
        print(f"  Success Rate: {success_rate:.3f} ({success_count}/{self.n_observations})")
        print(f"  Forbidden Hit Rate: {forbidden_rate:.3f} ({forbidden_hits}/{self.n_observations})")
        print(f"  Stuck Rate: {stuck_rate:.3f} ({stuck_positions}/{self.n_observations})")
        print(f"  Average Steps: {avg_steps:.1f}")
        print(f"  Goal Position: {self.env.goal_position}")
        print(f"  Forbidden Areas: {list(self.env.forbidden_grids)}")

        # Environment difficulty analysis
        print(f"\nEnvironment Difficulty:")
        total_forbidden = len(self.env.forbidden_grids)
        total_cells = self.env.grid_size[0] * self.env.grid_size[1]
        forbidden_ratio = total_forbidden / total_cells
        print(f"  Grid Size: {self.env.grid_size}")
        print(f"  Forbidden Ratio: {forbidden_ratio:.3f} ({total_forbidden}/{total_cells})")
        print(f"  Target Reward: {self.env.tgt_grid_reward}")
        print(f"  Forbidden Penalty: {self.env.forbidden_grids_penalty}")
        print("="*50)


if __name__ == '__main__':
    grid_size = (5, 5)
    goal_position = (4, 4)  # Move to corner away from forbidden areas
    forbidden_grids = [(1, 1), (1, 2), (2, 2), (3, 1), (3, 3), (4, 1)]  # Keep complex forbidden layout
    action_space = 5
    hidden_layers = [128, 128]  # Larger network to handle complex environment
    embed_dim = 32
    # Reward configuration - optimized for complex environment
    forbidden_grids_penalty = -5  # Moderate penalty
    tgt_grid_reward = 100  # Much higher reward to offset difficulty
    step_penalty = -0.2  # Small penalty to encourage exploration  
    device = torch.device('cpu')
    actor_lr = 2e-4  # Lower for more stable learning
    critic_lr = 1e-4  # Lower for better value estimation  
    solver = A2CSolver(
        grid_size, goal_position, forbidden_grids,
        action_space, hidden_layers, device, actor_lr, critic_lr,
        embed_dim, forbidden_grids_penalty, tgt_grid_reward, step_penalty
    )
    n_steps = 500
    n_episodes = 8000  # More episodes for complex environment
    log_iters = 100
    gamma = 0.99  # Higher discount factor for long-term planning
    random_start = True  # Enable random starts for better generalization
    vis_policy = True
    solver.solve(n_steps, n_episodes, gamma, random_start, vis_policy, log_iters)


![](assets/result2.jpg)