<a href="https://colab.research.google.com/github/ImaginationX4/HybridZero/blob/main/HBFS_Hybrid_Best_First_Search_with_AlphaZero.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install gymnasium

Collecting gymnasium
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Downloading gymnasium-1.0.0-py3-none-any.whl (958 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m958.1/958.1 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Installing collected packages: farama-notifications, gymnasium
Successfully installed farama-notifications-0.0.4 gymnasium-1.0.0


###1.Netwrok

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SharedNetwork(nn.Module):
    def __init__(self, input_size=16, hidden_size=128, num_actions=4):
        super(SharedNetwork, self).__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # 共享层
        self.shared_layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.LayerNorm(hidden_size)
        )

        # Policy head
        self.policy_head = nn.Sequential(
            nn.Linear(hidden_size, 64),
            nn.ReLU(),
            nn.LayerNorm(64),
            nn.Linear(64, num_actions)
        )

        # Value head
        self.value_head = nn.Sequential(
            nn.Linear(hidden_size, 64),
            nn.ReLU(),
            nn.LayerNorm(64),
            nn.Linear(64, 1)

        )

    def forward(self, state):
        """
        参数:
            state: 游戏状态的one-hot编码 (batch_size, 16)
        返回:
            policy_logits: 动作概率的对数 (batch_size, 4)
            value: 状态价值估计 (batch_size, 1)
        """
        if not isinstance(state, torch.Tensor):
            state = torch.FloatTensor(state).to(self.device)

        # one-hot code
        if len(state.shape) == 1:
            state = F.one_hot(torch.tensor(state.argmax()), num_classes=16).float()

        shared_features = self.shared_layers(state)

        policy_logits = self.policy_head(shared_features)
        value = self.value_head(shared_features)

        return policy_logits, value

    def get_value(self, state):
        """
        short cut for BFS
        """
        with torch.no_grad():
            _, value = self.forward(state)
            return value.item()

    def save(self, filepath):
        torch.save(self.state_dict(), filepath)
        print(f"Model saved to {filepath}")

    def load(self, filepath):
        self.load_state_dict(torch.load(filepath, map_location=self.device))
        print(f"Model loaded from {filepath}")

# 测试代码
if __name__ == "__main__":
    # 创建网络实例
    net = SharedNetwork()

    # 测试单个状态
    state = torch.zeros(16)
    state[0] = 1  # 假设在起始位置
    policy, value = net(state)

    print("Policy logits:", policy)
    print("Value:", value)

    # 测试批量状态
    batch_states = torch.zeros(4, 16)  # 4个状态的batch
    batch_states[0][0] = 1
    batch_states[1][1] = 1
    batch_states[2][2] = 1
    batch_states[3][3] = 1

    batch_policy, batch_value = net(batch_states)
    print("\nBatch policy:", batch_policy.shape)
    print("Batch value:", batch_value.shape)

Policy logits: tensor([1.4125, 0.2342, 0.5599, 0.1235], grad_fn=<ViewBackward0>)
Value: tensor([0.6734], grad_fn=<ViewBackward0>)

Batch policy: torch.Size([4, 4])
Batch value: torch.Size([4, 1])


  state = F.one_hot(torch.tensor(state.argmax()), num_classes=16).float()


###2.MCTS

In [None]:
from dataclasses import dataclass, field
from typing import Dict, Optional
import numpy as np

@dataclass
class Node:
  prior: float      # P(s,a)
  action_taken: Optional[int]
  state: Optional[int]
  visit_count: int = 0
  value_sum: float = 0
  parent: Optional['Node'] = None
  children: Dict[int, 'Node'] = field(default_factory=dict)
  done: bool = False

In [None]:
import math
import numpy as np
from dataclasses import dataclass, field
from typing import List, Tuple, Dict, Optional
import gymnasium as gym

@dataclass
class Node:
    prior: float  # P(s,a)
    action_taken: Optional[int]  # 到达此节点采取的动作
    visit_count: int = 0  # N(s,a)
    value_sum: float = 0  # Q=value_sum/visit_count
    parent: Optional['Node'] = None
    children: Dict[int, 'Node'] = field(default_factory=dict)
    done = False
    has_children = False

    def __post_init__(self):
        if self.children is None:
            self.children = {}

    @property
    def Q_value(self) -> float:
        if self.visit_count == 0:
            return 0.0
        return self.value_sum / self.visit_count

class MCTS:
    def __init__(self, model, num_simulations: int = 100, epoch_of_training=0):
        self.model = model
        self.num_simulations = num_simulations
        # UCB参数，随着训练进行逐渐减小
        self.c_puct = max(1.0 * (1 - epoch_of_training/100), 0.1)
        # Dirichlet噪声参数，随训练进行逐渐减小
        self.epsilon = max(0.25 * (1 - epoch_of_training/100), 0.01)
        self.env = gym.make('FrozenLake-v1', is_slippery=False)

    def predict_next_state(self, current_state, action):

      # 克隆环境，以避免影响原始环境
      env_copy = gym.make('FrozenLake-v1', render_mode=None)
      env_copy.reset(seed=42)

      # 设置环境状态
      env_copy.unwrapped.s = current_state

      # 执行动作并获取下一个状态
      next_state, reward, terminated, truncated, info = env_copy.step(action)

      return next_state
    def search(self, root_state) -> np.ndarray:
        """执行MCTS搜索"""
        root = Node(prior=1.0, action_taken=None)

        for _ in range(self.num_simulations):
            node = root
            state = root_state
            search_path = [node]
            done = False

            # 1. Selection
            while node.has_children and not done:
              action, node = self.select_child(node)
              # 执行动作
              self.env.reset()
              self.env.unwrapped.s = state
              state, reward, terminated, truncated, _ = self.env.step(action)
              done = terminated or truncated
              node.done = done
              search_path.append(node)



            # 2. Expansion and Evaluation
            value = 0
            if not done:
                policy, value = self.evaluate_state(state)
                self.expand_node(node, policy)

            # 3. Backup
            self.backup(search_path, value)

        # 返回根节点的动作概率分布
        return root

    def select_child(self, node: Node) -> Tuple[int, Node]:
        """使用UCB公式选择最佳子节点"""
        best_score = -float('inf')
        best_action = -1
        best_child = None

        sqrt_total_count = math.sqrt(node.visit_count)

        for action, child in node.children.items():
            # UCB公式
            Q = child.Q_value
            U = self.c_puct * child.prior * sqrt_total_count / (1 + child.visit_count)
            ucb_score = Q + U

            if ucb_score > best_score:
                best_score = ucb_score
                best_action = action
                best_child = child

        return best_action, best_child

    def evaluate_state(self, state) -> Tuple[np.ndarray, float]:
        """使用神经网络评估状态"""
        # 转换状态为one-hot编码
        state_onehot = np.zeros(16)
        state_onehot[state] = 1

        with torch.no_grad():
            policy, value = self.model(state_onehot)
            policy = F.softmax(policy, dim=0)
            #return policy.cpu().numpy(), value.item()
        #+++++++++++++++++#
        new_policy = []

        with torch.no_grad():
          for a in range(4):
            state= self.predict_next_state(state,a)
            state_onehot = np.zeros(16)
            state_onehot[state] = 1
            _, value_p = self.model(state_onehot)
            new_policy.append(value_p.item())

        return np.array(new_policy)/np.sum(new_policy), value.item()
    def expand_node(self, node: Node, policy: np.ndarray):
        """扩展节点，添加所有可能的子节点"""
        # 添加Dirichlet噪声
        noise = np.random.dirichlet([0.3] * len(policy))
        policy = (1 - self.epsilon) * policy + self.epsilon * noise


        # 创建子节点
        for action, prob in enumerate(policy):
            child = Node(
                prior=prob,
                action_taken=action,
                parent=node
            )
            node.children[action] = child

        node.has_children = True

    def backup(self, search_path: List[Node], value: float):
        """反向传播更新节点统计信息"""
        for node in reversed(search_path):
            node.visit_count += 1
            node.value_sum += value

    def get_action_probs(self, root: Node, temperature=0.1) -> np.ndarray:
        """获取动作概率分布"""
        counts = np.array([child.visit_count for child in root.children.values()])
        probs = counts / np.sum(counts)

        return probs

In [None]:
def self_play_mcts(replay_buffer, model, epoch_of_training=0):
   env = gym.make('FrozenLake-v1', is_slippery=False)
   model.eval()

   observation,_ = env.reset()
   done = False
   episode_reward = 0

   trajectory = []
   while not done:
       print('observation',observation)
       mcts = MCTS(model)
       root = mcts.search(observation)
       action_probs = mcts.get_action_probs(root, temperature=1)
       action = np.argmax(action_probs)
       print('action',action)
       #print('action_probs',action_probs)
       next_observation, reward, terminated, truncated, info = env.step(action)
       print('next_observation',next_observation)

       episode_reward += reward
       done = terminated or truncated

       trajectory.append({
           'state': observation,
           'action': action,
           'action_probs': action_probs
       })
       observation = next_observation

   # Store episode with decaying rewards
   for idx, t in enumerate(trajectory):
       decay_reward = episode_reward * (0.95 ** (len(trajectory) - idx - 1))
       replay_buffer.store(
           t['state'],
           t['action'],
           t['action_probs'],
           decay_reward
       )

   print('Episode reward:', episode_reward)
   return episode_reward

In [None]:
import gymnasium as gym
import torch
import numpy as np

# 初始化组件
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SharedNetwork().to(device)
replay_buffer = ReplayBuffer(batch_size=32, minimum_size=100)
reward = self_play_mcts(replay_buffer, model)



observation 0


  state = F.one_hot(torch.tensor(state.argmax()), num_classes=16).float()


action 2
next_observation 1
observation 1
action 1
next_observation 5
Episode reward: 0.0


####2.1 test test


In [None]:
import torch
import numpy as np
import gymnasium as gym

def test_mcts():
    # 1. 初始化环境和网络
    env = gym.make('FrozenLake-v1', is_slippery=False)
    shared_net = SharedNetwork()
    mcts = MCTS(shared_net, num_simulations=100)

    # 2. 运行一个完整的episode
    state, _ = env.reset()
    done = False
    total_reward = 0
    path = []

    print("Starting MCTS test...")
    print(f"Initial state: {state}")

    while not done:
        # 将状态转换为one-hot编码
        state_onehot = np.zeros(16)
        state_onehot[state] = 1

        # 执行MCTS搜索
        root = mcts.search(state)
        action_probs = mcts.get_action_probs(root, temperature=0.1)

        # 选择动作
        action = np.argmax(action_probs)

        # 打印当前状态和选择的动作
        action_names = ['LEFT', 'DOWN', 'RIGHT', 'UP']
        print(f"\nCurrent state: {state}")
        print(f"Chosen action: {action_names[action]}")
        print(f"Action probabilities: {action_probs}")

        # 查看根节点的统计信息
        print("\nRoot node statistics:")
        for a, child in root.children.items():
            print(f"Action {action_names[a]}: Visits = {child.visit_count}, "
                  f"Q-value = {child.Q_value:.3f}")

        # 执行动作
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        total_reward += reward

        # 记录路径
        path.append((state, action))
        state = next_state

        if done:
            print(f"\nEpisode finished!")
            print(f"Final state: {state}")
            print(f"Total reward: {total_reward}")
            if reward == 1:
                print("Successfully reached the goal!")
            else:
                print("Failed - either fell in a hole or exceeded steps")

    # 打印完整路径
    print("\nComplete path:")
    for step, (s, a) in enumerate(path):
        print(f"Step {step}: State {s} -> Action {action_names[a]}")

    return total_reward, path

if __name__ == "__main__":
    # 设置随机种子以保证可重复性
    np.random.seed(42)
    torch.manual_seed(42)

    # 运行测试
    reward, path = test_mcts()

    # 可以多次运行测试来查看稳定性
    print("\nRunning multiple episodes to check stability...")
    rewards = []
    for i in range(5):
        reward, _ = test_mcts()
        rewards.append(reward)

    print(f"\nResults over 5 episodes:")
    print(f"Average reward: {np.mean(rewards):.2f}")
    print(f"Rewards: {rewards}")

Starting MCTS test...
Initial state: 0

Current state: 0
Chosen action: DOWN
Action probabilities: [2.41369954e-03 9.95171090e-01 1.51123394e-06 2.41369954e-03]

Root node statistics:
Action LEFT: Visits = 23, Q-value = 0.264
Action DOWN: Visits = 42, Q-value = 0.157
Action RIGHT: Visits = 11, Q-value = 0.203
Action UP: Visits = 23, Q-value = 0.265

Current state: 4
Chosen action: UP
Action probabilities: [1.11014813e-02 1.08412903e-05 5.31066349e-09 9.88887672e-01]

Root node statistics:
Action LEFT: Visits = 30, Q-value = 0.220
Action DOWN: Visits = 15, Q-value = 0.084
Action RIGHT: Visits = 7, Q-value = 0.000
Action UP: Visits = 47, Q-value = 0.234

Current state: 0
Chosen action: LEFT
Action probabilities: [7.58328683e-01 2.41534761e-01 9.09395568e-05 4.56163571e-05]

Root node statistics:
Action LEFT: Visits = 37, Q-value = 0.257
Action DOWN: Visits = 33, Q-value = 0.184
Action RIGHT: Visits = 15, Q-value = 0.243
Action UP: Visits = 14, Q-value = 0.235


  state = F.one_hot(torch.tensor(state.argmax()), num_classes=16).float()


[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m

Current state: 4
Chosen action: UP
Action probabilities: [4.84405121e-06 1.33625446e-10 4.96030844e-13 9.99995156e-01]

Root node statistics:
Action LEFT: Visits = 20, Q-value = 0.149
Action DOWN: Visits = 7, Q-value = 0.065
Action RIGHT: Visits = 4, Q-value = 0.000
Action UP: Visits = 68, Q-value = 0.214

Current state: 0
Chosen action: DOWN
Action probabilities: [3.24747216e-06 9.99286984e-01 6.25660478e-04 8.41079339e-05]

Root node statistics:
Action LEFT: Visits = 13, Q-value = 0.176
Action DOWN: Visits = 46, Q-value = 0.164
Action RIGHT: Visits = 22, Q-value = 0.229
Action UP: Visits = 18, Q-value = 0.226

Current state: 4
Chosen action: UP
Action probabilities: [8.12189999e-04 5.43817925e-07 4.13637504e-12 9.99187266e-01]

Root node statistics:
Action LEFT: Visits = 27, Q-value = 0.126
Action DOWN: Visits = 13, Q-value = 0.106
Action RIGHT: Visits = 4, Q-value = 0.000
Action UP: Visits = 55, Q-value = 0.231

Current state: 0
Chosen actio

###3.BFS

In [None]:
from queue import PriorityQueue
import numpy as np
import torch
import gymnasium as gym

class PathNode:
    def __init__(self, state, action=None, parent=None, is_terminated=False, is_hole=False):
        self.state = state        # 当前状态
        self.action = action      # 到达此状态的动作
        self.parent = parent      # 前一个节点
        self.children = []        # 后续可能的路径
        self.is_terminated = is_terminated
        self.is_hole = is_hole    # 是否是陷阱

    def get_path(self):
        """回溯完整路径"""
        path = []
        current = self
        while current:
            path.append((current.state, current.action))
            current = current.parent
        return list(reversed(path))

    def add_child(self, child_node):
        self.children.append(child_node)

    def __lt__(self, other):
        return False

class PathFinder:
    def __init__(self, shared_network, grid_size=4):
        self.grid_size = grid_size
        self.goal = grid_size * grid_size - 1  # 最后一个格子
        self.env = gym.make('FrozenLake-v1', is_slippery=False)
        self.shared_network = shared_network

    def get_valid_actions(self, state):
        """获取有效动作"""
        row = state // self.grid_size    # 行号
        col = state % self.grid_size     # 列号
        actions = []

        # 检查四个方向的有效性
        if col > 0: actions.append(0)    # 左
        if row < 3: actions.append(1)    # 下
        if col < 3: actions.append(2)    # 右
        if row > 0: actions.append(3)    # 上

        return actions


    def get_state_value(self, state):
        """使用神经网络获取状态价值"""
        state_onehot = torch.zeros(16)
        state_onehot[state] = 1
        return self.shared_network.get_value(state_onehot)

    def best_first_search(self, start_state, max_path_length=20):
        """使用启发式搜索寻找最佳路径"""
        frontier = PriorityQueue()

        # 优先级是神经网络预测的价值和曼哈顿距离的组合
        value = self.get_state_value(start_state)
        initial_priority = -value

        root = PathNode(start_state)
        frontier.put((initial_priority, start_state, root))

        # 记录不同类型的路径
        all_paths = {
            'success': [],  # 成功到达目标的路径
            'hole': [],    # 掉进洞里的路径
            'timeout': []  # 超过最大长度的路径
        }

        def get_path_states(node):
            """获取路径上的所有状态"""
            states = set()
            current = node
            while current:
                states.add(current.state)
                current = current.parent
            return states

        def get_path_length(node):
            """获取路径长度"""
            length = 0
            current = node
            while current.parent:
                length += 1
                current = current.parent
            return length

        while not frontier.empty():
            priority, _, current_node = frontier.get()
            current_path_states = get_path_states(current_node)
            current_length = get_path_length(current_node)

            # 检查是否超过最大路径长度
            if current_length >= max_path_length:
                all_paths['timeout'].append(current_node.get_path())
                continue

            # 检查当前状态
            if current_node.is_hole:
                all_paths['hole'].append(current_node.get_path())
                continue

            if current_node.state == self.goal:
                all_paths['success'].append(current_node.get_path())


            # 展开当前节点
            for action in self.get_valid_actions(current_node.state):
                self.env.reset()
                self.env.unwrapped.s = current_node.state
                next_state, reward, done, truncated, _ = self.env.step(action)

                is_hole = done and reward == 0

                if next_state not in current_path_states:
                    new_node = PathNode(
                        state=next_state,
                        action=action,
                        parent=current_node,
                        is_hole=is_hole,
                    )
                    current_node.add_child(new_node)

                    # 使用神经网络预测的价值作为启发
                    value = self.get_state_value(next_state)

                    new_priority = -value

                    frontier.put((new_priority, next_state, new_node))

        return all_paths, root

    def print_path(self, path):
        """可视化路径"""
        if not path:
            print("No path found!")
            return

        actions_map = {0: "←", 1: "↓", 2: "→", 3: "↑"}
        for state, action in path:
            if action is not None:
                print(f"State: {state}, Action: {actions_map[action]}")
            else:
                print(f"Start State: {state}")

####3.1 test

In [None]:
import torch
import gymnasium as gym
import numpy as np
import time


def test_pathfinder():
    # 1. 初始化环境和网络
    env = gym.make('FrozenLake-v1', is_slippery=False)
    shared_network = SharedNetwork(input_size=16, hidden_size=128, num_actions=4)
    pathfinder = PathFinder(shared_network)

    # 2. 从起始状态开始搜索
    start_state = 0  # FrozenLake的起始位置
    print(f"\nStarting search from state {start_state}")

    # 3. 记录开始时间
    start_time = time.time()

    # 4. 执行搜索
    all_paths, root = pathfinder.best_first_search(start_state)

    # 5. 计算搜索时间
    end_time = time.time()
    search_time = end_time - start_time

    # 6. 打印搜索结果
    print(f"\nSearch completed in {search_time:.4f} seconds")

    # 打印成功的路径
    for i in all_paths['success']:
        print("\nSuccessful path found!")
        print("Path details:")
        pathfinder.print_path(i)
        path_length = len(i) - 1  # 减去起始状态
        print(f"Path length: {path_length}")
    else:
        print("\nNo successful path found")

    # 打印统计信息
    print("\nSearch statistics:")
    print(f"Number of success paths: {len(all_paths['success'])}")
    print(f"Number of hole paths: {len(all_paths['hole'])}")
    print(f"Number of timeout paths: {len(all_paths['timeout'])}")

    # 7. 验证路径的可行性
    if all_paths['success']:
        print("\nVerifying path...")
        env.reset()
        path = all_paths['success'][0]
        total_reward = 0

        for state, action in path[1:]:  # 跳过起始状态
            if action is not None:
                next_state, reward, done, truncated, _ = env.step(action)
                total_reward += reward
                if done:
                    if reward == 1:
                        print("Successfully reached the goal!")
                    else:
                        print("Failed - fell into a hole!")
                    break

        print(f"Total reward: {total_reward}")

if __name__ == "__main__":
    test_pathfinder()

    # 额外测试：测试不同起始状态
    env = gym.make('FrozenLake-v1', is_slippery=False)
    shared_network = SharedNetwork(input_size=16, hidden_size=128, num_actions=4)
    pathfinder = PathFinder(shared_network)


    state =0

    all_paths, _ = pathfinder.best_first_search(state)



Starting search from state 0

Search completed in 0.0328 seconds

Successful path found!
Path details:
Start State: 0
State: 1, Action: →
State: 2, Action: →
State: 6, Action: ↓
State: 10, Action: ↓
State: 14, Action: ↓
State: 15, Action: →
Path length: 6

Successful path found!
Path details:
Start State: 0
State: 1, Action: →
State: 2, Action: →
State: 6, Action: ↓
State: 10, Action: ↓
State: 9, Action: ←
State: 13, Action: ↓
State: 14, Action: →
State: 15, Action: →
Path length: 8

Successful path found!
Path details:
Start State: 0
State: 4, Action: ↓
State: 8, Action: ↓
State: 9, Action: →
State: 13, Action: ↓
State: 14, Action: →
State: 15, Action: →
Path length: 6

Successful path found!
Path details:
Start State: 0
State: 4, Action: ↓
State: 8, Action: ↓
State: 9, Action: →
State: 10, Action: →
State: 14, Action: ↓
State: 15, Action: →
Path length: 6

No successful path found

Search statistics:
Number of success paths: 4
Number of hole paths: 28
Number of timeout paths: 0

Ver

  state = F.one_hot(torch.tensor(state.argmax()), num_classes=16).float()


###4.HYBRID


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Network(nn.Module):
    def __init__(self, input_size=16, hidden_size=128, num_actions=4):
        super(Network, self).__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # 共享层
        self.shared_layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.LayerNorm(hidden_size)
        )

        # Policy head
        self.policy_head = nn.Sequential(
            nn.Linear(hidden_size, 64),
            nn.ReLU(),
            nn.LayerNorm(64),
            nn.Linear(64, num_actions)
        )

        # Value head
        self.value_head = nn.Sequential(
            nn.Linear(hidden_size, 64),
            nn.ReLU(),
            nn.LayerNorm(64),
            nn.Linear(64, 1)

        )

    def forward(self, state):
        """
        参数:
            state: 游戏状态的one-hot编码 (batch_size, 16)
        返回:
            policy_logits: 动作概率的对数 (batch_size, 4)
            value: 状态价值估计 (batch_size, 1)
        """
        shared_features = self.shared_layers(state)

        policy_logits = self.policy_head(shared_features)
        value = self.value_head(shared_features)

        return policy_logits, value

    def get_value(self, state):
        """
        short cut for BFS
        """
        with torch.no_grad():
            _, value = self.forward(state)
            return value.item()

    def save(self, filepath):
        torch.save(self.state_dict(), filepath)
        print(f"Model saved to {filepath}")

    def load(self, filepath):
        self.load_state_dict(torch.load(filepath, map_location=self.device))
        print(f"Model loaded from {filepath}")


# 2. 初始化主要组件


In [None]:
import math
import numpy as np
import torch
import torch.nn.functional as F
from dataclasses import dataclass
from typing import List, Optional, Dict
from queue import PriorityQueue

@dataclass
class Node:
    prior: float
    action_taken: int
    visit_count: int = 0
    value_sum: float = 0
    state: int = 0
    parent: Optional['Node'] = None
    children: Dict[int, 'Node'] = None
    done = False
    has_children = False

    def __post_init__(self):
        if self.children is None:
            self.children = {}

    @property
    def Q_value(self) -> float:
        if self.visit_count == 0:
            return 0.0
        return self.value_sum / self.visit_count
    def __lt__(self, other):
      # 任意返回 False，因为我们只关心优先级的比较
      return False

class HybridMCTS:
    def __init__(self, model, num_simulations: int = 100, local_search_depth: int = 2):
        self.model = model
        self.num_simulations = num_simulations
        self.c_puct = 2.0
        self.local_search_depth = local_search_depth
        self.epsilon = 0.4


    def bfs_search(self, start_state, depth_limit=3):
        """使用真实环境进行局部BFS搜索"""
        env = gym.make('FrozenLake-v1', is_slippery=False)
        path_node = Node(prior=1.0, action_taken=None,state = start_state)
        #print('start_state',start_state)

        queue = PriorityQueue()
        queue.put((0, path_node))

        def get_path_states(node):
          # 获取当前路径上的所有状态
          states = set()
          current = node
          while current:
              states.add(current.state)
              current = current.parent
              #print('记录的',current.state)
          return states
        def get_action_path(node):
          # 回溯完整路径
          path_action = []
          current = node
          while current.parent:
            path_action.append(current.action_taken)
            #print('current.state',current.state)
            current = current.parent
          return list(reversed(path_action))
        def get_path(node):
          # 回溯完整路径
          path = []
          current = node
          while current:
            path.append(current.state)
            current = current.parent
          return path
        while not queue.empty():

          priority, current_node = queue.get()
          visited = get_path_states(current_node)
          # 如果达到深度限制
          if len(get_path(current_node)) > depth_limit:
              continue
          #reach to gaol
          if current_node.state == 15:
            return get_action_path(current_node), 1
          # 获取当前状态下的有效动作
          row = current_node.state // 4
          col = current_node.state % 4
          valid_actions = []
          if col > 0: valid_actions.append(0)    # 左
          if row < 3: valid_actions.append(1)    # 下
          if col < 3: valid_actions.append(2)    # 右
          if row > 0: valid_actions.append(3)    # 上

          for action in valid_actions:
              # 使用环境模拟动作
              env.reset()
              env.unwrapped.s = current_node.state
              #assert env.unwrapped.s == current_state, f"State mismatch: expected {current_state}, got {env.unwrapped.s}"
              next_state, reward, terminated, truncated, _ = env.step(action)
              done = terminated or truncated
              # 如果是合法的下一个状态
              if next_state not in visited:

                  # 使用价值网络评估优先级
                  _, value = self.evaluate_state(next_state)
                  next_node = Node(prior=1.0,
                            action_taken=action,
                            state = next_state,
                            parent=current_node)
                  current_node.children[action] = next_node
                  #print('value',value)
                  queue.put((-value, next_node))

        return None, 0

    def search(self, root_state):
      root = Node(prior=1.0, action_taken=None)

      for times in range(self.num_simulations):
        node = root
        env = gym.make('FrozenLake-v1', is_slippery=False)
        env.reset()
        state = root_state
        env.unwrapped.s = state  # 设置环境状态

        search_path = [node]
        value = 0
        switched_to_bfs = False

        #BFS
        if self.should_use_bfs(state):
          best_path, path_value = self.bfs_search(state, self.local_search_depth)
          if best_path:
            first_action = best_path[0]
            policy = np.ones(4) * 0.01
            policy[first_action] = 0.7
            self.expand_node(node, policy)
            search_path.append(node.children[first_action])
            value = path_value * 0.98 ** (len(best_path)-1)
            switched_to_bfs = True
        #print('Alphazero')
        if not switched_to_bfs:
          while node.has_children and not node.done:
            action, node = self.select_child(node, state)
            observation, reward, terminated, truncated, info = env.step(action)
            node.done = terminated or truncated
            state = observation
            search_path.append(node)

            # 在selection过程中检查是否可以切换到BFS
            if not node.done and self.should_use_bfs(state):
              best_path, path_value = self.bfs_search(state, self.local_search_depth)
              if best_path:
                # 找到BFS路径，记录value并结束selection
                value = path_value * 0.98 ** (len(best_path))
                switched_to_bfs = True
                break

          # 如果没有切换到BFS且没有结束，进行常规expansion
          if not switched_to_bfs and not node.done:
              policy, value = self.evaluate_state(state)
              self.expand_node(node, policy)
        '''# 1. Selection
        while node.has_children and not node.done:
          action, node = self.select_child(node,state)
          #print('action',action)
          observation, reward, terminated, truncated, info = env.step(action)
          #print('observation',observation)
          node.done = terminated or truncated
          state = observation
          search_path.append(node)


        # 2. Expansion and Evaluation

        if not node.done:
          policy, value = self.evaluate_state(state)
          self.expand_node(node, policy)'''

        # 3. Backup
        self.backup(search_path, value)

      return root

    def should_use_bfs(self, state):
        """判断是否应该使用BFS"""
        # 可以基于以下因素:
        # 1. 到目标的估计距离
        # 2. 状态的复杂度
        # 3. 当前的计算资源
        distance_to_goal = self.estimate_distance(state)
        #print('distance_to_goal',distance_to_goal)
        return distance_to_goal <= self.local_search_depth

    def estimate_distance(self, state):
      """估计当前状态到目标的曼哈顿距离"""
      # FrozenLake 环境中，目标在右下角 (3,3)，状态编号为 15
      current_row = state // 4  # 当前行
      current_col = state % 4   # 当前列
      goal_row = 3  # 目标行
      goal_col = 3  # 目标列

      manhattan_dist = abs(current_row - goal_row) + abs(current_col - goal_col)
      return manhattan_dist

    def select_child(self, node, state):
        def is_valid_move(state=0):
          row = state // 4
          col = state % 4
          valid_actions = []
          if col > 0: valid_actions.append(0)    # 左
          if row < 3: valid_actions.append(1)    # 下
          if col < 3: valid_actions.append(2)    # 右
          if row > 0: valid_actions.append(3)
          return valid_actions
        best_score = -float('inf')
        best_action = -1
        best_child = None

        sqrt_total_count = math.sqrt(node.visit_count)

        for action, child in node.children.items():
          if action not in is_valid_move(state):
            continue
          ucb_score = child.Q_value + self.c_puct * child.prior * sqrt_total_count / (1 + child.visit_count)
          if ucb_score > best_score:
            best_score = ucb_score
            best_action = action
            best_child = child


        return best_action, best_child

    def expand_node(self, node: Node, policy: np.ndarray):

      noise = np.random.dirichlet([0.3] * len(policy))
      #policy = 0.75 * policy + 0.25 * noise


      policy = (1 - self.epsilon) * policy + self.epsilon * noise

      for action, prob in enumerate(policy):

        child = Node(
            prior=prob,
            action_taken=action,
            parent=node
        )

        node.children[action] = child

        node.has_children = True

    def evaluate_state(self, state):
        """使用神经网络评估状态"""
        with torch.no_grad():
          if isinstance(state, (int, np.integer)):
            # 如果是整数状态，直接转one-hot
            state_tensor = F.one_hot(torch.tensor(state), num_classes=16).float()
          else:
              # 如果是其他格式，先转换为tensor
              state_tensor = torch.FloatTensor(state)
          policy, value = self.model(state_tensor)
          return F.softmax(policy, dim=-1), value.item()



    def backup1(self, search_path: List[Node], value: float):
      """带长度惩罚的回传更新"""
      # 基础参数
      gamma = 0.95  # 折扣因子
      base_length = 8  # 基准长度
      length_penalty = 0.1  # 长度惩罚系数
      path_length = len(search_path)
      # 计算长度惩罚
      if path_length > base_length:
          # 超过基准长度的部分施加惩罚
          length_multiplier = 1.0 - length_penalty * ((path_length - base_length) / base_length)
          # 确保惩罚不会太过严重
          length_multiplier = max(0.2, length_multiplier)
          current_value = value * length_multiplier
      else:
          current_value = value

      # 从后往前更新
      for idx, node in enumerate(reversed(search_path)):
          # 基础更新
          node.visit_count += 1

          # 使用折扣更新价值
          discounted_value = current_value * (gamma ** idx)
          node.value_sum += discounted_value
          #print('node.value_sum',node.value_sum)
    def backup(self, search_path: List[Node], value: float):
        """回传更新"""
        gamma = 0.98
        current_value = value
        for node in reversed(search_path):
            node.visit_count += 1
            node.value_sum += current_value
            current_value *= gamma


    def get_action_probs(self, root: Node, temperature: float = 1.0):
        """获取动作概率"""
        counts = np.array([child.visit_count for child in root.children.values()])
        if temperature == 0:
            probs = np.zeros_like(counts)
            probs[np.argmax(counts)] = 1
            return probs
        else:
            counts = counts ** (1.0 / temperature)
            probs = counts / np.sum(counts)
            return probs

In [None]:
def bfs_search(start_state, depth_limit=3):
        """使用真实环境进行局部BFS搜索"""
        env = gym.make('FrozenLake-v1', is_slippery=False)
        path_node = Node(prior=1.0, action_taken=None,state = start_state)
        #print('start_state',start_state)

        queue = PriorityQueue()
        queue.put((0, path_node))

        def get_path_states(node):
          # 获取当前路径上的所有状态
          states = set()
          current = node
          while current:
              states.add(current.state)
              current = current.parent
              #print('记录的',current.state)
          return states
        def get_action_path(node):
          # 回溯完整路径
          path_action = []
          current = node
          while current.parent:
            path_action.append(current.action_taken)
            #print('current.state',current.state)
            current = current.parent
          return list(reversed(path_action))
        def get_path(node):
          # 回溯完整路径
          path = []
          current = node
          while current:
            path.append(current.state)
            current = current.parent
          return path
        while not queue.empty():

          priority, current_node = queue.get()
          visited = get_path_states(current_node)

          #reach to gaol
          if current_node.state == 15:
            return get_action_path(current_node), 1
          if len(get_path(current_node)) >= depth_limit:
              continue
          # 获取当前状态下的有效动作
          row = current_node.state // 4
          col = current_node.state % 4
          valid_actions = []
          if col > 0: valid_actions.append(0)    # 左
          if row < 3: valid_actions.append(1)    # 下
          if col < 3: valid_actions.append(2)    # 右
          if row > 0: valid_actions.append(3)    # 上

          for action in valid_actions:
              # 使用环境模拟动作

              env.reset()
              env.unwrapped.s = current_node.state
              print('current_node.state',current_node.state)
              print('action',action)
              #assert env.unwrapped.s == current_state, f"State mismatch: expected {current_state}, got {env.unwrapped.s}"
              next_state, reward, terminated, truncated, _ = env.step(action)
              done = terminated or truncated
              print('next_state',next_state)
              # 如果是合法的下一个状态
              if next_state not in visited:
                  print('next,next_state',next_state)
                  net=Network()
                  with torch.no_grad():
                    if isinstance(next_state, (int, np.integer)):
                      # 如果是整数状态，直接转one-hot
                      state_tensor = F.one_hot(torch.tensor(next_state), num_classes=16).float()
                    else:
                        # 如果是其他格式，先转换为tensor
                        state_tensor = torch.FloatTensor(next_state)
                  value=net.get_value(state_tensor)
                  print('value',value)
                  # 使用价值网络评估优先级

                  next_node = Node(prior=1.0,
                            action_taken=action,
                            state = next_state,
                            parent=current_node)
                  current_node.children[action] = next_node
                  #print('value',value)
                  queue.put((-value, next_node))

        return None, 0
bfs_search(10, depth_limit=4)

current_node.state 10
action 0
next_state 9
next,next_state 9
value 0.4209800064563751
current_node.state 10
action 1
next_state 14
next,next_state 14
value -0.3975231647491455
current_node.state 10
action 2
next_state 11
next,next_state 11
value -0.4413963556289673
current_node.state 10
action 3
next_state 6
next,next_state 6
value 0.038615211844444275
current_node.state 9
action 0
next_state 8
next,next_state 8
value -0.4255306124687195
current_node.state 9
action 1
next_state 13
next,next_state 13
value 0.39316731691360474
current_node.state 9
action 2
next_state 10
current_node.state 9
action 3
next_state 5
next,next_state 5
value -0.27908843755722046
current_node.state 13
action 0
next_state 12
next,next_state 12
value 0.614520788192749
current_node.state 13
action 2
next_state 14
next,next_state 14
value 0.6405401825904846
current_node.state 13
action 3
next_state 9
current_node.state 6
action 0
next_state 5
next,next_state 5
value -0.08646199852228165
current_node.state 6
action

([1, 2], 1)

In [None]:
import gymnasium as gym
import torch
import torch.nn.functional as F
import numpy as np
def analyze_mcts_search(model, num_simulations=100):
    """分析MCTS搜索过程中的访问统计"""
    env = gym.make('FrozenLake-v1', is_slippery=False)
    mcts = HybridMCTS(model, num_simulations=num_simulations)
    action_names = ['LEFT', 'DOWN', 'RIGHT', 'UP']

    state, _ = env.reset()
    state = 9

    print(f"\nInitial state: {state}")

    # 执行MCTS搜索
    root = mcts.search(state)

    # 分析根节点的子节点
    print("\nRoot node analysis:")
    print(f"Total visits to root: {root.visit_count}")
    print(f"Root value: {root.Q_value:.4f}")
    print("\nChildren statistics:")
    print("Action | Visits | Q-Value | Prior | UCB Score")
    print("-" * 50)

    # 计算UCB分数用于比较
    sqrt_total_count = math.sqrt(root.visit_count)

    for action, child in root.children.items():
        ucb_score = child.Q_value + mcts.c_puct * child.prior * sqrt_total_count / (1 + child.visit_count)
        print(f"{action_names[action]:<6} | {child.visit_count:>6} | {child.Q_value:>7.4f} | {child.prior:>5.3f} | {ucb_score:>9.4f}")

    # 显示访问最多的前3条路径
    print("\nTop visited paths:")
    def get_most_visited_path(node, depth=0, max_depth=7):
        if depth >= max_depth or not node.children:
            if not node:
              print('no children in this node')
            return []

        most_visited_child = max(node.children.items(),
                               key=lambda x: x[1].visit_count)
        action, child = most_visited_child
        return [(action, child.visit_count, child.Q_value)] + get_most_visited_path(child, depth + 1)

    path = get_most_visited_path(root)
    for depth, (action, visits, value) in enumerate(path):
        print(f"Depth {depth}: Action={action_names[action]}, Visits={visits}, Value={value:.4f}")

    return root

# 使用示例
model = Network()
#model.load_state_dict(torch.load('best_model.pth'))  # 如果有保存的模型
model.eval()

# 分析MCTS搜索
root = analyze_mcts_search(model, num_simulations=100)
print(root.children)


Initial state: 9

Root node analysis:
Total visits to root: 100
Root value: 0.6705

Children statistics:
Action | Visits | Q-Value | Prior | UCB Score
--------------------------------------------------
LEFT   |      8 | -0.0284 | 0.379 |    0.8133
DOWN   |     18 |  0.7065 | 0.160 |    0.8745
RIGHT  |     68 |  0.8203 | 0.222 |    0.8848
UP     |      5 |  0.0000 | 0.239 |    0.7977

Top visited paths:
Depth 0: Action=RIGHT, Visits=68, Value=0.8203
Depth 1: Action=DOWN, Visits=57, Value=0.9800
{0: Node(prior=tensor(0.3787, dtype=torch.float64), action_taken=0, visit_count=8, value_sum=-0.22704900028705596, state=0, parent=Node(prior=1.0, action_taken=None, visit_count=100, value_sum=67.04786851206273, state=0, parent=None, children={...}), children={0: Node(prior=tensor(0.2679, dtype=torch.float64), action_taken=0, visit_count=0, value_sum=0, state=0, parent=..., children={}), 1: Node(prior=tensor(0.2946, dtype=torch.float64), action_taken=1, visit_count=3, value_sum=0, state=0, paren

###5.Replay_buffer

In [None]:
class GameHistory:
   def __init__(self):
       self.observations = []
       self.actions = []
       self.mcts_policies = []
       self.values = []

   def store(self, observation, action, action_probs, value):
       self.observations.append(observation)
       self.actions.append(action)
       self.mcts_policies.append(action_probs)
       self.values.append(value)

   def clear(self):
       self.observations.clear()
       self.actions.clear()
       self.mcts_policies.clear()
       self.values.clear()

class ReplayBuffer:
   def __init__(self, batch_size, minimum_size, capacity=500):
       self.batch_size = batch_size
       self.minimum_size = minimum_size
       self.capacity = capacity
       self.game_history = GameHistory()

   def store(self, observation, action, action_probs, value):
       if len(self.game_history.observations) >= self.capacity:
           self.game_history.observations.pop(0)
           self.game_history.actions.pop(0)
           self.game_history.mcts_policies.pop(0)
           self.game_history.values.pop(0)

       self.game_history.store(observation, action, action_probs, value)

   def sample_batch(self, batch_size=32):
       device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
       if batch_size is None:
           batch_size = self.batch_size

       indices = np.random.choice(len(self.game_history.observations), batch_size)

       observations = torch.tensor([self.game_history.observations[i] for i in indices], dtype=torch.float32)
       actions = torch.tensor([self.game_history.actions[i] for i in indices], dtype=torch.long)
       policies = torch.tensor([self.game_history.mcts_policies[i] for i in indices], dtype=torch.float32)
       values = torch.tensor([self.game_history.values[i] for i in indices], dtype=torch.float32)

       return observations.to(device), actions.to(device), policies.to(device), values.to(device)

   def __len__(self):
       return len(self.game_history.observations)

   def clear(self):
       self.game_history.clear()


### 6.Self_play

In [None]:
def self_play(replay_buffer, model, epoch_of_training=0):
   env = gym.make('FrozenLake-v1', is_slippery=False)
   model.eval()

   observation,_ = env.reset()
   done = False
   episode_reward = 0

   trajectory = []
   while not done:
       print('observation',observation)
       mcts = HybridMCTS(model)
       root = mcts.search(observation)
       action_probs = mcts.get_action_probs(root, temperature=1)
       action = np.argmax(action_probs)
       print('action',action)
       #print('action_probs',action_probs)
       next_observation, reward, terminated, truncated, info = env.step(action)
       print('next_observation',next_observation)

       episode_reward += reward
       done = terminated or truncated

       trajectory.append({
           'state': observation,
           'action': action,
           'action_probs': action_probs
       })
       observation = next_observation

   # Store episode with decaying rewards
   for idx, t in enumerate(trajectory):
       decay_reward = episode_reward * (0.95 ** (len(trajectory) - idx - 1))
       replay_buffer.store(
           t['state'],
           t['action'],
           t['action_probs'],
           decay_reward
       )

   print('Episode reward:', episode_reward)
   return episode_reward

In [None]:

import gymnasium as gym
import torch
import numpy as np

# 初始化组件
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Network().to(device)
replay_buffer = ReplayBuffer(batch_size=32, minimum_size=100)

# 运行一次self_play
reward = self_play(replay_buffer, model)

# 检查存储的数据

observations, actions, policies, values = replay_buffer.sample_batch(10)
print("Sample from buffer:")
print("States shape:", observations.shape)
print("Actions:", actions)
print("Policies shape:", policies.shape)
print("Values:", values)

observation 0
action 1
next_observation 4
observation 4
action 3
next_observation 0
observation 0
action 2
next_observation 1
observation 1
action 2
next_observation 2
observation 2
action 1
next_observation 6
observation 6
action 1
next_observation 10
observation 10
action 1
next_observation 14
observation 14
action 2
next_observation 15
Episode reward: 1.0
Sample from buffer:
States shape: torch.Size([10])
Actions: tensor([3, 1, 2, 2, 1, 2, 1, 2, 2, 1])
Policies shape: torch.Size([10, 4])
Values: tensor([0.7351, 0.6983, 0.8145, 1.0000, 0.6983, 0.8145, 0.9025, 0.8145, 0.8145,
        0.9500])


  policies = torch.tensor([self.game_history.mcts_policies[i] for i in indices], dtype=torch.float32)


###7.Training


In [None]:
def train(num_epochs,lr):
  # 创建环境

  model = Network()
  replay_buffer = ReplayBuffer(batch_size=32, minimum_size=200)
  optimizer = torch.optim.Adam(model.parameters(), lr)
  #  first self to fill some data into replay buffer

  get_one_time_success = 0
  while get_one_time_success == 0 :
    get_one_time_success=self_play(replay_buffer, model)

  for epoch in range(num_epochs):
    #self play
    print('epoch',epoch)
    self_play(replay_buffer, model)
    #some setting
    epoch_value_loss = 0
    epoch_policy_loss = 0
    num_batches = 4
    for _ in range(num_batches):
      # simple the data to get target value
      observations, actions, policies, values = replay_buffer.sample_batch(32)
      optimizer.zero_grad()
      # feed observations to model
      state_tensor =  torch.tensor(observations, dtype=torch.long)
      one_hot_state = F.one_hot(state_tensor, num_classes=16)
      one_hot_state = one_hot_state.float()
      model.train()
      policy_logits, pred_values = model(one_hot_state)
      # get the loss
      value_loss = F.mse_loss(pred_values.squeeze(), values)
      probs = F.softmax(policy_logits, dim=1)
      policy_loss = -torch.sum(policies * torch.log(probs))
      total_loss = value_loss + policy_loss
      # optimize
      total_loss.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
      optimizer.step()
      epoch_value_loss += value_loss.item()
      epoch_policy_loss += policy_loss.item()
    # 打印训练信息
    avg_value_loss = epoch_value_loss / num_batches
    avg_policy_loss = epoch_policy_loss / num_batches
    print(f"Value Loss: {avg_value_loss:.4f}, Policy Loss: {avg_policy_loss:.4f}")

  return model

In [None]:
model=train(10, lr=0.001)


observation 0
action 2
next_observation 1
observation 1
action 1
next_observation 5
Episode reward: 0.0
observation 0
action 1
next_observation 4
observation 4
action 1
next_observation 8
observation 8
action 2
next_observation 9
observation 9
action 1
next_observation 13
observation 13
action 2
next_observation 14
observation 14
action 2
next_observation 15
Episode reward: 1.0
epoch 0
observation 0
action 2
next_observation 1
observation 1
action 1
next_observation 5
Episode reward: 0.0
Value Loss: 0.3721, Policy Loss: 29.4700
epoch 1
observation 0


  state_tensor =  torch.tensor(observations, dtype=torch.long)


action 1
next_observation 4
observation 4
action 1
next_observation 8
observation 8
action 2
next_observation 9
observation 9
action 1
next_observation 13
observation 13
action 2
next_observation 14
observation 14
action 2
next_observation 15
Episode reward: 1.0
Value Loss: 0.0828, Policy Loss: 19.0769
epoch 2
observation 0
action 1
next_observation 4
observation 4
action 1
next_observation 8
observation 8
action 2
next_observation 9
observation 9
action 1
next_observation 13
observation 13
action 2
next_observation 14
observation 14
action 2
next_observation 15
Episode reward: 1.0
Value Loss: 0.0597, Policy Loss: 19.6469
epoch 3
observation 0
action 1
next_observation 4
observation 4
action 1
next_observation 8
observation 8
action 2
next_observation 9
observation 9
action 1
next_observation 13
observation 13
action 2
next_observation 14
observation 14
action 2
next_observation 15
Episode reward: 1.0
Value Loss: 0.0321, Policy Loss: 17.8768
epoch 4
observation 0
action 1
next_observat

In [None]:
# 初始化组件
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
replay_buffer = ReplayBuffer(batch_size=32, minimum_size=100)

# 运行一次self_play
reward = self_play(replay_buffer, model)

observation 0
action 2
next_observation 1
observation 1
action 2
next_observation 2
observation 2
action 1
next_observation 6
observation 6
action 1
next_observation 10
observation 10
action 1
next_observation 14
observation 14
action 2
next_observation 15
Episode reward: 1.0


In [None]:
def find_path_with_mcts(start_state=0, num_simulations=100, max_steps=50, model=Network()):
    """使用MCTS直接寻找从起点到终点的路径"""
    env = gym.make('FrozenLake-v1', is_slippery=False)

    mcts = HybridMCTS(model, num_simulations=num_simulations)
    action_names = ['LEFT', 'DOWN', 'RIGHT', 'UP']

    # 重置环境到起始状态
    env.reset()
    env.unwrapped.s = start_state
    state = start_state

    path = []
    total_reward = 0
    steps = 0

    print("\nFrozenLake地图:")
    desc = env.unwrapped.desc
    for row in desc:
        print("".join([c.decode('utf-8') for c in row]))

    print(f"\n开始寻路，起始状态: {start_state}")

    while steps < max_steps:
        # 执行MCTS搜索
        root = mcts.search(state)

        # 获取最佳动作
        action_probs = mcts.get_action_probs(root, temperature=1)
        action = np.argmax(action_probs)

        # 记录访问统计
        print(f"\n步骤 {steps + 1}:")
        print(f"当前状态: {state}")
        print(f"选择动作: {action_names[action]}")
        print("各动作访问次数:")
        for act, child in root.children.items():
            print(f"{action_names[act]}: {child.visit_count} 访问, Q值: {child.Q_value:.4f}")

        # 执行动作
        next_state, reward, terminated, truncated, _ = env.step(action)
        print(f"next_state: {next_state}")
        done = terminated or truncated

        # 更新路径信息
        path.append((state, action))
        total_reward += reward
        state = next_state
        steps += 1

        # 检查是否到达目标
        if done:
            path.append((state, None))  # 添加最终状态
            if reward > 0:
                print(f"\n成功找到路径! 总步数: {steps}")
                break
            else:
                print(f"\n失败! 掉入陷阱. 总步数: {steps}")
                path = None
                break

    # 可视化最终路径
    if path:
        print("\n找到的路径:")
        for i, (s, a) in enumerate(path):
            if a is not None:
                print(f"Step {i + 1}: State {s} -> Action {action_names[a]}")
            else:
                print(f"Final State: {s}")

        # 在网格上显示路径
        grid = [['.' for _ in range(4)] for _ in range(4)]
        for i, row in enumerate(env.unwrapped.desc):
            for j, cell in enumerate(row):
                grid[i][j] = cell.decode('utf-8')

        path_states = [p[0] for p in path]
        for state in path_states[:-1]:  # 除了最后一个状态
            row, col = state // 4, state % 4
            if grid[row][col] == 'F':  # 只在安全的格子上标记
                grid[row][col] = '*'

        print("\n路径可视化 (* 表示路径):")
        for row in grid:
            print(" ".join(row))

    return path, total_reward

# 使用示例
#model = Network()

path, reward = find_path_with_mcts(0, 100,10, model)
if path:
    print(f"\n总奖励: {reward}")


FrozenLake地图:
SFFF
FHFH
FFFH
HFFG

开始寻路，起始状态: 0

步骤 1:
当前状态: 0
选择动作: DOWN
各动作访问次数:
LEFT: 0 访问, Q值: 0.0000
DOWN: 80 访问, Q值: 0.7896
RIGHT: 19 访问, Q值: 0.5358
UP: 0 访问, Q值: 0.0000
next_state: 4

步骤 2:
当前状态: 4
选择动作: DOWN
各动作访问次数:
LEFT: 0 访问, Q值: 0.0000
DOWN: 92 访问, Q值: 0.8954
RIGHT: 1 访问, Q值: 0.0000
UP: 6 访问, Q值: 0.7989
next_state: 8

步骤 3:
当前状态: 8
选择动作: RIGHT
各动作访问次数:
LEFT: 0 访问, Q值: 0.0000
DOWN: 5 访问, Q值: 0.0000
RIGHT: 82 访问, Q值: 0.8970
UP: 12 访问, Q值: 0.8182
next_state: 9

步骤 4:
当前状态: 9
选择动作: DOWN
各动作访问次数:
LEFT: 17 访问, Q值: 0.7829
DOWN: 74 访问, Q值: 0.8930
RIGHT: 7 访问, Q值: 0.7725
UP: 1 访问, Q值: 0.0000
next_state: 13

步骤 5:
当前状态: 13
选择动作: RIGHT
各动作访问次数:
LEFT: 0 访问, Q值: 0.0000
DOWN: 0 访问, Q值: 0.0000
RIGHT: 93 访问, Q值: 0.9800
UP: 6 访问, Q值: 0.7928
next_state: 14

步骤 6:
当前状态: 14
选择动作: RIGHT
各动作访问次数:
LEFT: 0 访问, Q值: 0.0000
DOWN: 0 访问, Q值: 0.0000
RIGHT: 1 访问, Q值: 1.0000
UP: 0 访问, Q值: 0.0000
next_state: 15

成功找到路径! 总步数: 6

找到的路径:
Step 1: State 0 -> Action DOWN
Step 2: State 4 -> Action DOWN
Step 3: S

###8.some ideas

In [None]:
def train(num_epochs, lr, batch_size=32):
    model = Network()
    replay_buffer = ReplayBuffer(batch_size=batch_size, minimum_size=200)
    optimizer = torch.optim.Adam(model.parameters(), lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)

    best_episode_reward = 0

    # 初始填充经验池
    print("Filling replay buffer with initial experiences...")
    for i in range(4):
        episode_reward = self_play(replay_buffer, model)
        print(f"Initial episode {i+1}: reward = {episode_reward}")

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")

        # 1. Self-play阶段
        episode_reward = self_play(replay_buffer, model)
        print(f"Episode reward: {episode_reward}")

        # 2. 训练阶段
        epoch_value_loss = 0
        epoch_policy_loss = 0
        num_batches = 4  # 每个epoch训练多个batch

        model.train()
        for _ in range(num_batches):
            # 采样batch数据
            observations, actions, policies, values = replay_buffer.sample_batch(batch_size)

            # 转换状态表示
            state_tensor = torch.tensor(observations, dtype=torch.long)  # 确保是LongTensor
            one_hot_state = F.one_hot(state_tensor, num_classes=16).float()

            # 前向传播
            optimizer.zero_grad()
            policy_logits, pred_values = model(one_hot_state)

            # 计算损失
            value_loss = F.mse_loss(pred_values.squeeze(), values)
            probs = F.softmax(policy_logits, dim=1)
            policy_loss = -torch.sum(policies * torch.log_softmax(policy_logits, dim=1))

            # 添加L2正则化和熵正则化
            l2_reg = 0.001
            l2_loss = 0
            for param in model.parameters():
                l2_loss += torch.norm(param)

            policy_entropy = -(probs * torch.log(probs + 1e-10)).sum(1).mean()
            entropy_reg = 0.01

            total_loss = value_loss + policy_loss + l2_reg * l2_loss - entropy_reg * policy_entropy

            # 反向传播
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            # 记录损失
            epoch_value_loss += value_loss.item()
            epoch_policy_loss += policy_loss.item()

        # 学习率调整
        scheduler.step()

        # 打印训练信息
        avg_value_loss = epoch_value_loss / num_batches
        avg_policy_loss = epoch_policy_loss / num_batches
        print(f"Value Loss: {avg_value_loss:.4f}, Policy Loss: {avg_policy_loss:.4f}")

        # 保存最佳模型
        if episode_reward > best_episode_reward:
            best_episode_reward = episode_reward
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"New best model saved with reward: {best_episode_reward}")

        # 每隔几个epoch进行评估
        if (epoch + 1) % 5 == 0:
            model.eval()
            eval_reward = evaluate(model)
            print(f"Evaluation reward: {eval_reward}")

    return model

def evaluate(model, num_episodes=5):
    """评估模型性能"""
    model.eval()
    env = gym.make('FrozenLake-v1', is_slippery=False)
    total_rewards = 0

    for episode in range(num_episodes):
        state, _ = env.reset()
        done = False
        episode_reward = 0

        while not done:
            with torch.no_grad():
                state_tensor = F.one_hot(torch.tensor(state), num_classes=16).float()
                policy_logits, _ = model(state_tensor)
                action = torch.argmax(policy_logits).item()

            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            episode_reward += reward
            state = next_state

        total_rewards += episode_reward

    return total_rewards / num_episodes

In [None]:
def setup_hybrid_mcts():
    # 创建环境
    env = gym.make('FrozenLake-v1', is_slippery=False)

    # 创建神经网络模型
    model = Network()

    # 创建HybridMCTS实例
    mcts = HybridMCTS(model=model, num_simulations=100, local_search_depth=3)

    # 添加环境到MCTS实例
    mcts.env = env

    # 设置其他必要的参数
    mcts.epsilon = 0.25  # Dirichlet噪声参数

    return env, model, mcts

# 3. 实现搜索和动作选择的主循环
def play_episode(env, mcts, model, render=False):
    state, _ = env.reset()
    done = False
    total_reward = 0

    while not done:
        # 将当前状态转换为one-hot编码
        state_tensor = F.one_hot(torch.tensor(state), num_classes=16).float()

        # 执行MCTS搜索
        root = mcts.search(state)

        # 获取动作概率
        action_probs = mcts.get_action_probs(root, temperature=1.0)

        # 选择动作
        action = np.argmax(action_probs)

        # 执行动作
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated


        total_reward += reward
        state = next_state

    return total_reward

# 4. 测试代码
def test_hybrid_mcts():
    env, model, mcts = setup_hybrid_mcts()

    # 运行几个测试episode
    for episode in range(5):
        reward = play_episode(env, mcts, model, render=True)
        print(f"Episode {episode}: Total Reward = {reward}")

    env.close()

if __name__ == "__main__":
    test_hybrid_mcts()

Episode 0: Total Reward = 0.0
Episode 1: Total Reward = 0
Episode 2: Total Reward = 0.0
Episode 3: Total Reward = 0
Episode 4: Total Reward = 0.0
