### Q-learning公式（off-policy版本）：

![](assets/200.jpg)

### 例子

![](assets/202.jpg)

### 上述例子代码实现

例子：
使用gym仿真库，gym官网: https://www.gymlibrary.dev/environments/toy_text/frozen_lake/

In [4]:
!pip install gym==0.15.4

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple


In [1]:
from collections import namedtuple
import numpy as np
from tqdm import tqdm
import gym
from gym import spaces
from gym.envs.registration import register

register(
    id='CustomGridWorld-v0',
    entry_point='__main__:CustomGridWorld',
)

class CustomGridWorld(gym.Env):
    def __init__(self, grid_size=(5, 5), goal_position=(3, 2), forbidden_grids=None, action_space=5):
        super(CustomGridWorld, self).__init__()
        # Grid size (rows, columns)
        self.grid_size = grid_size
        self.goal_position = goal_position
        # Define action space: up, right, down, left, unchanged (5 actions)
        self.action_space = spaces.Discrete(action_space)
        # Observation space: grid positions, represented as a flat space
        self.observation_space = spaces.Discrete(grid_size[0] * grid_size[1])
        # Initialize agent's starting position (top-left corner)
        self.state = (0, 0)
        self.done = False
        # Set the forbidden grids (if not specified, use a default list)
        if forbidden_grids is None:
            forbidden_grids = [(1, 1), (1, 2), (2, 2), (3, 1), (3, 3), (4, 1)]  # Given forbidden grids
        self.forbidden_grids = set(forbidden_grids)

    def reset(self):
        """Resets the environment to the initial state"""
        self.state = (0, 0)  # Reset the agent to the top-left corner
        self.done = False
        return self._get_observation()

    def step(self, action):
        """Executes one step in the environment"""
        if self.done:
            return self._get_observation(), 1, True, {}

        x, y = self.state
        # Define movement based on action
        if action == 0:  # Up
            new_x = max(0, x - 1)
            new_y = y
        elif action == 1:  # Right
            new_x = x
            new_y = min(self.grid_size[1] - 1, y + 1)
        elif action == 2:  # Down
            new_x = min(self.grid_size[0] - 1, x + 1)
            new_y = y
        elif action == 3:  # Left
            new_x = x
            new_y = max(0, y - 1)
        elif action == 4:  # Unchanged (stay in the same position)
            new_x = x
            new_y = y

        # Check if the new position is out of bounds
        if new_x < 0 or new_x >= self.grid_size[0] or new_y < 0 or new_y >= self.grid_size[1]:
            reward = -1  # Penalty for trying to go out of bounds
            self.state = (x, y)  # Keep the agent at the same position
        else:
            self.state = (new_x, new_y)
            # Check if the agent reached the goal
            if self.state == self.goal_position:
                reward = 1
                self.done = True
            # Check if the agent stepped into a forbidden grid
            elif self.state in self.forbidden_grids:
                reward = -1  # Penalty for entering a forbidden grid
            else:
                reward = 0  # No penalty for regular move

        return self._get_observation(), reward, self.done, {}

    def render(self, mode='human'):
        """Renders the environment (prints the grid)"""
        grid = np.full(self.grid_size, 'F', dtype=object)  # Default is frozen
        # Set goal, forbidden grids, and agent position
        grid[self.goal_position] = 'G'
        for f in self.forbidden_grids:
            grid[f] = 'H'  # H for hole (forbidden grid)
        # Print the grid with agent position
        grid[self.state] = 'A'
        for row in grid:
            print(' '.join(row))
        print()

    def _get_observation(self):
        """Returns the current state as a flat index"""
        return self.state[0] * self.grid_size[1] + self.state[1]

    def close(self):
        """Close the environment"""
        pass
    
    def vis_policy(self, q_table):
        self.render()
        action_maps = {0: '↑', 1: '→', 2: '↓', 3: '←', 4: '⊙'}
        policy = np.full(self.grid_size, '⊙', dtype=object)
        for row in range(self.grid_size[0]):
            for col in range(self.grid_size[1]):
                index = row * self.grid_size[0] + col
                action = q_table[index].argmax()
                policy[row, col] = action_maps[action]
        print(policy)

one_step_experience = namedtuple('one_step_experience', field_names=['current_observation', 'next_observation', 'action', 'reward'])

def solver():
    grid_size = (5, 5)
    goal_position = (3, 2)
    forbidden_grids = [(1, 1), (1, 2), (2, 2), (3, 1), (3, 3), (4, 1)]
    action_space = 5
    env = CustomGridWorld(grid_size=grid_size, goal_position=goal_position, forbidden_grids=forbidden_grids, action_space=action_space)
    env.render()
    
    n_episodes = 1000
    n_steps = 1000000
    policies_b = []
    for _ in tqdm(range(n_episodes), desc='generate episode'):
        policy_b = []
        current_observation = env.reset()
        for _ in range(n_steps):
            action = env.action_space.sample()
            next_observation, reward, _, _ = env.step(action)
            ose = one_step_experience(current_observation=current_observation, next_observation=next_observation, action=action, reward=reward)
            current_observation = next_observation
            policy_b.append(ose)
        policies_b.append(policy_b)
    
    q_table = np.zeros([grid_size[0] * grid_size[1], action_space], dtype=np.float32)
    alpha = 0.1
    gamma = 0.9
    for i_episode in tqdm(range(n_episodes), desc='update policy'):
        policy_b = policies_b[i_episode]
        for i_step in range(n_steps):
            ose = policy_b[i_step]
            current_observation = ose.current_observation
            next_observation = ose.next_observation
            action = ose.action
            reward = ose.reward
            q_table[current_observation, action] = q_table[current_observation, action] - alpha * (q_table[current_observation, action] - (reward + gamma * q_table[next_observation, :].max()))
        
    print(q_table)
    print(q_table.argmax(1))
    env.vis_policy(q_table)


if __name__ == "__main__":
    solver()

A F F F F
F H H F F
F F H F F
F H G H F
F H F F F



generate episode:   4%|██▌                                                         | 43/1000 [03:27<1:17:01,  4.83s/it]


KeyboardInterrupt: 

由于步数太多，直接展示运行结果

![](assets/result.jpg)