### REINFORCE：

![](assets/291.jpg)

### 例子

![](assets/256.jpg)

### 上述例子代码实现

例子：
使用gym仿真库，gym官网: https://www.gymlibrary.dev/environments/toy_text/frozen_lake/

In [2]:
!pip install gym==0.15.4
!pip install numpy
!pip install torch
!pip install tqdm

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple


In [None]:
import math
import random
from copy import deepcopy
from collections import namedtuple
import numpy as np
from tqdm import tqdm
import gym
from gym import spaces
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

# Namedtuple for experience
one_step_experience = namedtuple('one_step_experience', field_names=['current_observation', 'current_action', 'reward', 'next_observation'])

class CustomGridWorld(gym.Env):
    def __init__(self, grid_size=(5, 5), goal_position=(3, 2), forbidden_grids=None, action_space=5, forbidden_grids_penalty=-2, tgt_grid_reward=10):
        super(CustomGridWorld, self).__init__()
        self.grid_size = grid_size
        self.goal_position = goal_position
        self.forbidden_grids_penalty = forbidden_grids_penalty
        self.tgt_grid_reward = tgt_grid_reward
        self.action_space = spaces.Discrete(action_space)
        self.observation_space = spaces.Discrete(grid_size[0] * grid_size[1])
        self.state = (0, 0)
        self.done = False
        if forbidden_grids is None:
            forbidden_grids = [(1, 1), (1, 2), (2, 2), (3, 1), (3, 3), (4, 1)]  
        self.forbidden_grids = set(forbidden_grids)

    def _get_state(self, observation):
        return (observation // self.grid_size[0], observation % self.grid_size[0])

    def reset(self, init_observation=0):
        self.state = self._get_state(init_observation)
        self.done = False
        return self._get_observation()

    def step(self, action):
        x, y = self.state
        if action == 0:  # Up
            new_x = max(0, x - 1)
            new_y = y
        elif action == 1:  # Right
            new_x = x
            new_y = min(self.grid_size[1] - 1, y + 1)
        elif action == 2:  # Down
            new_x = min(self.grid_size[0] - 1, x + 1)
            new_y = y
        elif action == 3:  # Left
            new_x = x
            new_y = max(0, y - 1)
        elif action == 4:  # Unchanged (stay in place)
            new_x = x
            new_y = y

        if new_x < 0 or new_x >= self.grid_size[0] or new_y < 0 or new_y >= self.grid_size[1]:
            reward = self.forbidden_grids_penalty
            self.state = (x, y)
        else:
            self.state = (new_x, new_y)
            if self.state == self.goal_position:
                reward = self.tgt_grid_reward
                self.done = True
            elif self.state in self.forbidden_grids:
                reward = self.forbidden_grids_penalty
            else:
                reward = 0

        return self._get_observation(), reward, self.done, {}

    def render(self, mode='human'):
        grid = np.full(self.grid_size, 'F', dtype=object)
        grid[self.goal_position] = 'G'
        for f in self.forbidden_grids:
            grid[f] = 'H'
        grid[self.state] = 'A'
        for row in grid:
            print(' '.join(row))

    def _get_observation(self):
        return self.state[0] * self.grid_size[1] + self.state[1]

    def close(self):
        pass
    
    def vis_policy(self, q_table):
        self.reset()
        self.render()
        action_maps = {0: '↑', 1: '→', 2: '↓', 3: '←', 4: '⊙'}
        policy = np.full(self.grid_size, '⊙', dtype=object)
        for row in range(self.grid_size[0]):
            for col in range(self.grid_size[1]):
                index = row * self.grid_size[0] + col
                action = q_table[index].argmax()
                policy[row, col] = action_maps[action]
                if row == self.goal_position[0] and col == self.goal_position[1]:
                    policy[row, col] = 'G'
        print(policy)


class REINFORCEModel(nn.Module):
    def __init__(self, input_dim: int, hidden_layers: list[int], output_dim: int):
        super().__init__()
        layers = nn.ModuleList()
        
        in_layer = nn.Embedding(input_dim, hidden_layers[0])
        layers.append(in_layer)

        for index in range(len(hidden_layers) - 1):
            linear = nn.Linear(hidden_layers[index], hidden_layers[index+1], bias=True)
            activation = nn.ReLU()
            layers.extend([linear, activation])
            
        out_layer = nn.Linear(hidden_layers[-1], output_dim)
        softmax = nn.Softmax(dim=-1)
        layers.extend([out_layer, softmax])
        
        self.reinforce = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.reinforce(x)


class REINFORCESolver:
    def __init__(self, grid_size: tuple, goal_position: tuple, forbidden_grids: list[tuple], action_space: int,
                 hidden_layers: list[int], device: torch.device, lr: float=1e-3):
        self.device = device
        self.grid_size = grid_size
        self.action_space = action_space
        self._init_env(grid_size, goal_position, forbidden_grids)
        self._init_model(hidden_layers)
        self._init_trainer(lr)
    
    def _init_env(self, grid_size: tuple, goal_position: tuple, forbidden_grids: list[tuple], action_space: int=5):
        self.env = CustomGridWorld(grid_size=grid_size, goal_position=goal_position, forbidden_grids=forbidden_grids, action_space=action_space)
        self.n_observations = self.env.observation_space.n
    
    def _init_model(self, hidden_layers: list[int]):
        self.model = REINFORCEModel(self.n_observations, hidden_layers, self.action_space).to(self.device)
    
    def _init_trainer(self, lr):
        self.optimizer = optim.Adam(self.model.parameters(), lr)
        
    def _generate_episode(self, n_steps: int, random_start: bool=False):
        start = 0 if not random_start else random.choice(list(range(self.n_observations)))
        current_observation = self.env.reset(start)
        episode_logprobs = []
        episode_rewards = []

        for _ in range(n_steps):
            states = torch.tensor([current_observation], dtype=torch.long, device=self.device)
            action_probs = self.model(states)  # shape: [1, action_space]
            dist = torch.distributions.Categorical(action_probs)
            current_action = dist.sample()  # 采样动作
            log_prob = dist.log_prob(current_action)  # 对应的 log_prob

            next_observation, reward, done, _ = self.env.step(current_action.item())

            episode_logprobs.append(log_prob)
            episode_rewards.append(reward)
            current_observation = next_observation
            if done:
                break

        return episode_logprobs, episode_rewards

    
    def solve(self, n_steps: int, n_episodes: int, gamma: float, random_start: bool=True, vis_policy: bool=True):
        self.model.train()
        pbar = tqdm(range(n_episodes))
        for n_episode in pbar:
            episode_logprobs, episode_rewards = self._generate_episode(n_steps, random_start=random_start)
            episode_loss = []
            discounted_rewards = np.zeros_like(episode_rewards, dtype=np.float32)
            running_add = 0
            for t in reversed(range(len(episode_rewards))):
                running_add = episode_rewards[t] + gamma * running_add
                discounted_rewards[t] = running_add
            
            # Normalize the rewards
            discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + 1e-8)

            for t_outer in range(len(episode_rewards)):
                episode_loss.append(-episode_logprobs[t_outer] * discounted_rewards[t_outer])

            episode_loss = torch.stack(episode_loss).mean()
            episode_rewards_mean = torch.tensor(episode_rewards, dtype=torch.float32).mean()
            self.optimizer.zero_grad()
            episode_loss.backward()
            self.optimizer.step()

            if n_episode % log_iters == 0:
                desc = f'Episode: {n_episode} / {n_episodes}, avg loss: {episode_loss:.3f}, avg rewards: {episode_rewards_mean:.3f}'
                pbar.set_description(desc)

        print("Training Done!")

        if vis_policy:
            print('Rendering final policy...')
            self.vis_policy()
        print('All done!')
            
    def create_fake_qtable(self):
        self.model.eval()
        fake_q_table = torch.zeros([self.n_observations, self.action_space], device=self.device)
        states = torch.arange(self.n_observations, dtype=torch.int32, device=self.device)
        with torch.no_grad():
            q_values = self.model(states).detach().cpu().numpy()
            optimal_actions = q_values.argmax(1)
            fake_q_table[range(self.n_observations), optimal_actions] = 1
        return fake_q_table
    
    def vis_policy(self):
        fake_q_table = self.create_fake_qtable()
        self.env.vis_policy(fake_q_table.cpu().numpy())


if __name__ == '__main__':
    grid_size = (5, 5)
    goal_position = (3, 2)
    forbidden_grids = [(1, 1), (1, 2), (2, 2), (3, 1), (3, 3), (4, 1)]
    action_space = 5
    hidden_layers = [100, 100]
    output_dim = action_space
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    lr = 1e-5
    solver = REINFORCESolver(
        grid_size, goal_position, forbidden_grids, 
        action_space, hidden_layers, device, lr 
    )
    n_steps = 1000
    n_episodes = 10000
    log_iters = 10
    gamma = 0.99
    random_start = False
    vis_policy = True
    solver.solve(n_steps, n_episodes, gamma, random_start, vis_policy)

![](assets/result1.jpg)

由于运行步数过多，直接展示结果，但是这里，不知道为什么，有些位置学习的策略并不好。。。。