In [3]:
from typing import *

import random

from IPython.display import display

# 定义环境
class GridWorld:
    def __init__(
            self,
            rows : int,
            cols : int,
            start : Tuple[int, int],
            goal : Tuple[int, int],
            obstacles : List[Tuple[int, int]]
    ):
        self.rows = rows
        self.cols = cols
        self.start = start
        self.goal = goal
        self.obstacles = obstacles
        self.state = start

    def reset(self) -> Tuple[int, int]:
        self.state = self.start
        return self.state

    def move(self, action : int) -> Tuple:
        next_state = None
        if action == 0:  # 上
            next_state = (self.state[0] - 1, self.state[1])
        elif action == 1:  # 下
            next_state = (self.state[0] + 1, self.state[1])
        elif action == 2:  # 左
            next_state = (self.state[0], self.state[1] - 1)
        elif action == 3:  # 右
            next_state = (self.state[0], self.state[1] + 1)

        # 考虑碰到边缘和障碍物的情况
        if (0 <= next_state[0] < self.rows and
                0 <= next_state[1] < self.cols and
                next_state not in self.obstacles):
            self.state = next_state

        # 到达终点
        if self.state == self.goal:
            return self.state, 1, True  # 到达目标，奖励1
        return self.state, -0.1, False  # 每步惩罚

# QLearning算法
class QLearningAgent:
    def __init__(
            self,
            actions : List[int],
            goal : List[int],
            alpha : Annotated[float, "学习率"]=0.1,
            gamma : Annotated[float, '折扣因子']=0.9,
            epsilon : Annotated[float, '探索率']=0.1
    ):
        self.q_table = {}  # key=(state, action) value=float
        self.actions = actions
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.goal = goal

    def get_q_value(self, state : Tuple[int, int], action : int) -> float:
        return self.q_table.get((state, action), 0.0)

    # 采用e-g策略，增加灵活性
    def choose_action(self, state : Tuple[int, int]) -> int:
        if random.uniform(0, 1) < self.epsilon:
            return random.choice(self.actions)  # 探索
        else:
            q_values = [self.get_q_value(state, a) for a in self.actions]
            max_q = max(q_values)
            return self.actions[q_values.index(max_q)]  # 利用

    def update_q_value(
            self,
            state: Tuple[int, int],
            action: int,
            reward: float,
            next_state: Tuple[int, int]
    ) -> None:
        if next_state == self.goal:
            # 如果到达终点，将Q值设为当前奖励
            self.q_table[(state, action)] = (self.get_q_value(state, action) +
                                             self.alpha * (reward - self.get_q_value(state, action)))
        else:
            best_next_q = max([self.get_q_value(next_state, a) for a in self.actions])
            self.q_table[(state, action)] = (self.get_q_value(state, action) +
                                             self.alpha * (reward + self.gamma * best_next_q - self.get_q_value(state, action)))

# 主程序
def main():
    rows = 5
    cols = 6  # 支持任意列数
    start = (0, 0)
    goal = (4, 5)
    obstacles = [(1, 1), (2, 1), (3, 4), (4, 3), (0, 3)]

    env = GridWorld(rows, cols, start, goal, obstacles)
    agent = QLearningAgent(actions=[0, 1, 2, 3], goal=[4, 5])  # 上、下、左、右

    episodes = 1
    for episode in range(1, 1 + episodes):
        state = env.reset()
        print(f"第{episode}轮开始！")
        done = False
        while not done:
            action = agent.choose_action(state)
            next_state, reward, done = env.move(action)
            agent.update_q_value(state, action, reward, next_state)
            state = next_state

    # 打印学习到的Q值
    return agent.q_table
    # for state_action, value in agent.q_table.items():
    #     print(f"State: {state_action[0]}, Action: {state_action[1]}, Q-value: {value:.2f}")

    
display(main())

第1轮开始！


{((0, 0), 0): -0.0199,
 ((0, 0), 1): -0.028000000000000004,
 ((1, 0), 0): -0.02962,
 ((0, 0), 2): -0.0199,
 ((0, 0), 3): -0.028810000000000002,
 ((0, 1), 0): -0.0199,
 ((0, 1), 1): -0.0199,
 ((0, 1), 2): -0.021520000000000004,
 ((1, 0), 1): -0.028000000000000004,
 ((2, 0), 0): -0.029701000000000005,
 ((1, 0), 2): -0.02962,
 ((1, 0), 3): -0.0199,
 ((0, 1), 3): -0.028810000000000002,
 ((0, 2), 0): -0.0199,
 ((0, 2), 1): -0.028000000000000004,
 ((1, 2), 0): -0.0199,
 ((0, 2), 2): -0.020791000000000004,
 ((0, 2), 3): -0.0199,
 ((1, 2), 1): -0.028000000000000004,
 ((2, 2), 0): -0.0199,
 ((1, 2), 2): -0.0199,
 ((1, 2), 3): -0.019000000000000003,
 ((1, 3), 0): -0.0199,
 ((1, 3), 1): -0.019000000000000003,
 ((2, 3), 0): -0.0199,
 ((1, 3), 2): -0.021520000000000004,
 ((2, 0), 1): -0.010000000000000002,
 ((3, 0), 0): -0.010000000000000002,
 ((2, 0), 2): -0.010000000000000002,
 ((2, 0), 3): -0.010000000000000002,
 ((2, 2), 1): -0.019000000000000003,
 ((3, 2), 3): -0.010000000000000002,
 ((3, 3), 