<a href="https://colab.research.google.com/github/ImaginationX4/HybridZero/blob/main/Enhanced_BFS_Frozen_lake.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install gymnasium

Collecting gymnasium
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Downloading gymnasium-1.0.0-py3-none-any.whl (958 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m958.1/958.1 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Installing collected packages: farama-notifications, gymnasium
Successfully installed farama-notifications-0.0.4 gymnasium-1.0.0


###1.Network

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Network(nn.Module):
    def __init__(self, input_size=64, hidden_size=128, num_actions=4):
        super(Network, self).__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # 共享层
        self.shared_layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.LayerNorm(hidden_size)
        )


        # Value head
        self.value_head = nn.Sequential(
            nn.Linear(hidden_size, 64),
            nn.ReLU(),
            nn.LayerNorm(64),
            nn.Linear(64, 1)

        )

    def forward(self, state):
        """
        参数:
            state: 游戏状态的one-hot编码 (batch_size, 16)
        返回:
            policy_logits: 动作概率的对数 (batch_size, 4)
            value: 状态价值估计 (batch_size, 1)
        """

        if isinstance(state, (int, np.integer)):
            state = F.one_hot(torch.tensor(state), num_classes=64).float()
        elif isinstance(state, np.ndarray):
            state = torch.FloatTensor(state)
        elif not isinstance(state, torch.Tensor):
            raise TypeError(f"Unsupported input type: {type(state)}")

        shared_features = self.shared_layers(state)

        value = self.value_head(shared_features)

        return value

    def get_value(self, state):
        """
        short cut for BFS
        """
        with torch.no_grad():
            value = self.forward(state)
            return value.item()

    def save(self, filepath):
        torch.save(self.state_dict(), filepath)
        print(f"Model saved to {filepath}")

    def load(self, filepath):
        self.load_state_dict(torch.load(filepath, map_location=self.device))
        print(f"Model loaded from {filepath}")


###2.Tree Node

In [None]:
import math
import numpy as np
from dataclasses import dataclass, field
from typing import List, Tuple, Dict, Optional
import gymnasium as gym

import math
import numpy as np
import torch
import torch.nn.functional as F
from dataclasses import dataclass
from typing import List, Optional, Dict
from queue import PriorityQueue

@dataclass
class Node:
  action_taken: int
  visit_count: int = 0
  state: int = 0
  parent: Optional['Node'] = None
  children: Dict[int, 'Node'] = None
  done = False
  has_children = False

  def __post_init__(self):
      if self.children is None:
          self.children = {}

  def __lt__(self, other):
    # 任意返回 False，因为我们只关心优先级的比较
    return False

###3.BFS

In [None]:
def bfs_search(start_state, depth_limit=11):
  """使用真实环境进行局部BFS搜索"""
  env = gym.make('FrozenLake-v1', is_slippery=False)
  path_node = Node(action_taken=None,state = start_state)
  #print('start_state',start_state)
  counter = 10
  queue = PriorityQueue()
  queue.put((0, path_node))

  def get_path_states(node):
    # 获取当前路径上的所有状态
    states = set()
    current = node
    while current:
        states.add(current.state)
        current = current.parent
        #print('记录的',current.state)
    return states
  def get_action_path(node):
    # 回溯完整路径
    path_action = []
    current = node
    while current.parent:
      path_action.append(current.action_taken)
      #print('current.state',current.state)
      current = current.parent
    return list(reversed(path_action))
  def get_path(node):
    # 回溯完整路径
    path = []
    current = node
    while current:
      path.append(current.state)
      current = current.parent
    return path
  while not queue.empty() and counter >=0:
    counter -= 1
    priority, current_node = queue.get()
    visited = get_path_states(current_node)
    # 如果达到深度限制
    if len(get_path(current_node)) > depth_limit:
        continue
    #reach to gaol
    if current_node.state == 15:
      return get_action_path(current_node), 1
    # 获取当前状态下的有效动作
    row = current_node.state // 4
    col = current_node.state % 4
    valid_actions = []
    if col > 0: valid_actions.append(0)    # 左
    if row < 3: valid_actions.append(1)    # 下
    if col < 3: valid_actions.append(2)    # 右
    if row > 0: valid_actions.append(3)    # 上

    for action in valid_actions:
      # 使用环境模拟动作
      env.reset()
      env.unwrapped.s = current_node.state
      #assert env.unwrapped.s == current_state, f"State mismatch: expected {current_state}, got {env.unwrapped.s}"
      next_state, reward, terminated, truncated, _ = env.step(action)
      done = terminated or truncated
      # 如果是合法的下一个状态
      if next_state not in visited:

          # 使用价值网络评估优先级
          _, value = self.evaluate_state(next_state)
          next_node = Node(action_taken=action,
                    state = next_state,
                    parent=current_node)
          current_node.children[action] = next_node
          #print('value',value)
          queue.put((-value, next_node))

  return None, 0

###4.Enhanced BFS

####4.1 version1.0

In [None]:
class EnhancedBFS:
  def __init__(self, network, num_simulations: int = 100):
    self.network = network
    self.num_simulations = num_simulations

  def bfs_search(self, start_state):
    """使用真实环境进行局部BFS搜索"""
    env = gym.make('FrozenLake-v1',map_name="8x8", is_slippery=False)#map_name="8x8",
    path_node = Node(action_taken=None,state = start_state)
    #print('start_state',start_state)
    counter = 100
    queue = PriorityQueue()
    queue.put((0, path_node))

    def get_path_states(node):
      # 获取当前路径上的所有状态
      states = set()
      current = node
      while current:
          states.add(current.state)
          current = current.parent
          #print('记录的',current.state)
      return states
    def get_action_path(node):
      # 回溯完整路径
      path_action = []
      current = node
      while current.parent:
        path_action.append(current.action_taken)
        #print('current.state',current.state)
        current = current.parent
      return list(reversed(path_action))
    def get_path(node):
      # 回溯完整路径
      path = []
      current = node
      while current:
        path.append(current.state)
        current = current.parent
      return path
    while not queue.empty() and counter >=0:#
      counter -= 1
      priority, current_node = queue.get()
      visited = get_path_states(current_node)
      '''# 如果达到深度限制
      if len(get_path(current_node)) > depth_limit:
          continue'''
      #reach to gaol
      if current_node.state == 63 and len(get_path(current_node))<=16:
        return get_action_path(current_node), 1
      # 获取当前状态下的有效动作
      row = current_node.state // 8#4
      col = current_node.state % 8#4
      valid_actions = []
      if col > 0: valid_actions.append(0)    # 左
      if row < 7: valid_actions.append(1)    # 下
      if col < 7: valid_actions.append(2)    # 右
      if row > 0: valid_actions.append(3)    # 上

      for action in valid_actions:
        # 使用环境模拟动作
        env.reset()
        env.unwrapped.s = current_node.state
        #assert env.unwrapped.s == current_state, f"State mismatch: expected {current_state}, got {env.unwrapped.s}"
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        # 如果是合法的下一个状态
        if next_state not in visited:

            # 使用价值网络评估优先级
            value = self.network.get_value(next_state)
            next_node = Node(
                      action_taken=action,
                      state = next_state,
                      parent=current_node)
            current_node.children[action] = next_node
            #print('value',value)
            queue.put((-value, next_node))

    return None, 0

  def predict_next_state(self, current_state, action):

    # 克隆环境，以避免影响原始环境
    env_copy = gym.make('FrozenLake-v1',map_name="8x8", render_mode=None)#map_name="8x8",
    env_copy.reset(seed=42)

    # 设置环境状态
    env_copy.unwrapped.s = current_state

    # 执行动作并获取下一个状态
    next_state, reward, terminated, truncated, info = env_copy.step(action)
    return next_state

  def search(self, root_state):

    root = Node(action_taken=None,state = root_state)

    for times in range(self.num_simulations):
      best_path, path_value = self.bfs_search(root.state)
      #cant find a path,use value net
      if best_path is None:
        best_actions = []
        def get_value(state):
          # 简单的价值计算：越接近目标价值越高
          goal_row, goal_col = 7, 7  # 目标在(3,3)
          current_row = state // 8
          current_col = state % 8
          manhattan_dist = abs(current_row - goal_row) + abs(current_col - goal_col)
          return 1.0 / (manhattan_dist + 1)  # 避免除以零
        for a in range(4):
          next_state = self.predict_next_state(root.state, a)
          next_state_value = get_value(next_state)
          best_actions.append(next_state_value)
        print('I find nothing!!!!!!!!!!!!!!!',[np.argmax(best_actions)])
        self.expand_path(root,[np.argmax(best_actions)])
      #find a path,just update
      else:
        print('I find something!!!!!!!!!!!!!!!',best_path)
        self.expand_path(root,best_path)


    best_visit_count = -1
    best_action = None

    for action, child_node in root.children.items():
        if child_node.visit_count > best_visit_count:
            best_visit_count = child_node.visit_count
            best_action = action


    return best_action

  def expand_path(self, start_node: Node, action_list: List[int]) -> Tuple[bool, Node]:
    """
    从起始节点按动作列表展开路径

    Args:
        start_node: 起始节点
        action_list: 动作列表

    Returns:
        success: 是否成功展开完整路径
        final_node: 最后到达的节点
    """
    env = gym.make('FrozenLake-v1',map_name="8x8", is_slippery=False)#map_name="8x8",
    env.reset()
    current_node = start_node
    env.unwrapped.s = current_node.state



    for action in action_list:
        # 执行动作
        next_state, reward, terminated, truncated, _ = env.step(action)

        # 检查是否已经有这个子节点
        if action in current_node.children:
            # 使用已存在的节点
            current_node = current_node.children[action]
        else:
            # 创建新节点
            new_node = Node(
                action_taken=action,
                state=next_state,
                parent=current_node,
            )
            # 标记父节点有子节点
            current_node.has_children = True
            current_node.children[action] = new_node
            current_node = new_node


        # 更新节点访问计数
        current_node.visit_count += 1

####4.2 version2.0 (simplified)

In [None]:
from typing import List, Tuple, Optional
from queue import PriorityQueue
import gymnasium as gym
import numpy as np
from dataclasses import dataclass, field
from functools import lru_cache

@dataclass
class Node:
    state: int
    action_taken: Optional[int] = None
    parent: Optional['Node'] = None
    children: dict = field(default_factory=dict)
    visit_count: int = 0
    value: float = 0.0

class EnhancedBFS:
    def __init__(self, network, num_simulations: int = 100, env_size: int = 8):
        self.network = network
        self.num_simulations = num_simulations
        self.env_size = env_size
        self.env = self._create_env()
        self.goal_state = env_size * env_size - 1

    def _create_env(self) -> gym.Env:
        """创建环境的工厂方法"""
        return gym.make('FrozenLake-v1',
                       map_name=f"{self.env_size}x{self.env_size}",
                       is_slippery=False,
                       render_mode=None)

    @staticmethod
    def _get_valid_actions(state: int, env_size: int) -> List[int]:
        """获取有效动作列表"""
        row, col = state // env_size, state % env_size
        valid_actions = []
        if col > 0: valid_actions.append(0)    # 左
        if row < env_size - 1: valid_actions.append(1)    # 下
        if col < env_size - 1: valid_actions.append(2)    # 右
        if row > 0: valid_actions.append(3)    # 上
        return valid_actions

    @staticmethod
    def _get_manhattan_distance(state: int, goal_state: int, env_size: int) -> float:
        """计算曼哈顿距离"""
        current_row, current_col = state // env_size, state % env_size
        goal_row, goal_col = goal_state // env_size, goal_state % env_size
        return abs(current_row - goal_row) + abs(current_col - goal_col)

    @lru_cache(maxsize=1024)
    def _calculate_heuristic(self, state: int) -> float:
        """计算启发式值"""
        manhattan_dist = self._get_manhattan_distance(state, self.goal_state, self.env_size)
        return 1.0 / (manhattan_dist + 1)

    def bfs_search(self, start_state: int) -> Tuple[Optional[List[int]], int]:
        """使用启发式引导的BFS搜索"""
        visited = set()
        queue = PriorityQueue()
        start_node = Node(state=start_state)
        queue.put((-self._calculate_heuristic(start_state), id(start_node), start_node))

        while not queue.empty():
            _, _, current_node = queue.get()

            if current_node.state in visited:
                continue

            visited.add(current_node.state)

            if current_node.state == self.goal_state:
                return self._get_action_path(current_node), 1

            for action in self._get_valid_actions(current_node.state, self.env_size):
                self.env.reset()
                self.env.unwrapped.s = current_node.state
                next_state, _, terminated, truncated, _ = self.env.step(action)

                if next_state not in visited:
                    next_node = Node(
                        state=next_state,
                        action_taken=action,
                        parent=current_node
                    )
                    priority = -self._calculate_heuristic(next_state)
                    queue.put((priority, id(next_node), next_node))

        return None, 0

    def search(self, root_state: int) -> int:
        """主搜索方法"""
        root = Node(state=root_state)

        for _ in range(self.num_simulations):
            path, success = self.bfs_search(root.state)

            if path is None:
                best_action = self._get_best_heuristic_action(root.state)
                self._expand_path(root, [best_action])
            else:
                self._expand_path(root, path)

        return self._select_best_action(root)

    def _get_best_heuristic_action(self, state: int) -> int:
        """获取基于启发式的最佳动作"""
        valid_actions = self._get_valid_actions(state, self.env_size)
        action_values = []

        for action in valid_actions:
            next_state = self._predict_next_state(state, action)
            value = self._calculate_heuristic(next_state)
            action_values.append((value, action))

        return max(action_values)[1]

    def _predict_next_state(self, state: int, action: int) -> int:
        """预测下一个状态"""
        self.env.reset()
        self.env.unwrapped.s = state
        next_state, _, _, _, _ = self.env.step(action)
        return next_state

    def _expand_path(self, start_node: Node, action_list: List[int]) -> None:
        """展开路径"""
        current_node = start_node

        for action in action_list:
            if action not in current_node.children:
                next_state = self._predict_next_state(current_node.state, action)
                current_node.children[action] = Node(
                    state=next_state,
                    action_taken=action,
                    parent=current_node
                )

            current_node = current_node.children[action]
            current_node.visit_count += 1

    @staticmethod
    def _get_action_path(node: Node) -> List[int]:
        """获取动作路径"""
        path = []
        current = node
        while current.parent:
            path.append(current.action_taken)
            current = current.parent
        return list(reversed(path))

    @staticmethod
    def _select_best_action(root: Node) -> int:
        """选择最佳动作"""
        return max(root.children.items(),
                  key=lambda x: x[1].visit_count)[0]

####4.3 version3.0(Find action in BFS tree)

In [None]:
from typing import List, Tuple, Optional
from queue import PriorityQueue
import gymnasium as gym
import numpy as np
from dataclasses import dataclass, field
from functools import lru_cache

@dataclass
class Node:
  state: int
  action_taken: Optional[int] = None
  parent: Optional['Node'] = None
  children: dict = field(default_factory=dict)
  visit_count: int = 0
  value: float = 0.0

class EnhancedBFS:
  def __init__(self, network, num_simulations: int = 100, env_size: int = 8):
      self.network = network
      self.num_simulations = num_simulations
      self.env_size = env_size
      self.env = self._create_env()
      self.goal_state = env_size * env_size - 1

  def _create_env(self) -> gym.Env:
      """创建环境的工厂方法"""
      return gym.make('FrozenLake-v1',
                      map_name=f"{self.env_size}x{self.env_size}",
                      is_slippery=False,
                      render_mode=None)

  @staticmethod
  def _get_valid_actions(state: int, env_size: int) -> List[int]:
      """获取有效动作列表"""
      row, col = state // env_size, state % env_size
      valid_actions = []
      if col > 0: valid_actions.append(0)    # 左
      if row < env_size - 1: valid_actions.append(1)    # 下
      if col < env_size - 1: valid_actions.append(2)    # 右
      if row > 0: valid_actions.append(3)    # 上
      return valid_actions

  @staticmethod
  def _get_manhattan_distance(state: int, goal_state: int, env_size: int) -> float:
      """计算曼哈顿距离"""
      current_row, current_col = state // env_size, state % env_size
      goal_row, goal_col = goal_state // env_size, goal_state % env_size
      return abs(current_row - goal_row) + abs(current_col - goal_col)

  @lru_cache(maxsize=1024)
  def _calculate_heuristic(self, state: int) -> float:
      """计算启发式值"""
      manhattan_dist = self._get_manhattan_distance(state, self.goal_state, self.env_size)
      return 1.0 / (manhattan_dist + 1)

  @staticmethod
  def _get_action_path(node: Node) -> List[int]:

      """获取动作路径"""
      path = []
      current = node
      while current.parent:
          path.append(current.action_taken)
          current = current.parent
      return list(reversed(path))

  def bfs_search(self, start_state: int) -> Tuple[Optional[List[int]], int, Node]:
      """返回路径、是否成功、以及根节点"""
      visited = set()
      queue = PriorityQueue()
      root_node = Node(state=start_state)
      queue.put((-self._calculate_heuristic(start_state), id(root_node), root_node))
      found_goal = False
      goal_node = None
      counter = 10

      while not queue.empty() and counter > 0:
        counter -= 1
        _, _, current_node = queue.get()

        if current_node.state in visited:
            continue

        visited.add(current_node.state)
        current_node.visit_count += 1  # 记录访问次数

        if current_node.state == self.goal_state:
            found_goal = True
            goal_node = current_node
            break

        for action in self._get_valid_actions(current_node.state, self.env_size):
            self.env.reset()
            self.env.unwrapped.s = current_node.state
            next_state, _, _, _, _ = self.env.step(action)

            if next_state not in visited:
              next_node = Node(
                  state=next_state,
                  action_taken=action,
                  parent=current_node
              )
              current_node.children[action] = next_node  # 建立树结构
              priority = -self._calculate_heuristic(next_state)
              queue.put((priority, id(next_node), next_node))

      if found_goal:
          return self._get_action_path(goal_node), 1, root_node
      return None, 0, root_node


  def get_best_action_from_tree(self, root_node: Node) -> int:
      """基于搜索树选择最佳动作"""
      best_action = None
      best_value = float('-inf')

      for action, child in root_node.children.items():
          # 计算每个动作的价值
          value = self._evaluate_subtree(child)
          if value > best_value:
              best_value = value
              best_action = action

      return best_action #if best_action is not None else self._get_best_heuristic_action(root_node.state)

  def _evaluate_subtree(self, node: Node) -> float:
      """评估子树的价值"""
      # 如果找到目标
      if node.state == self.goal_state:
          return float('inf')

      # 综合考虑多个因素
      visit_value = node.visit_count  # 访问次数说明这个方向被多次探索
      heuristic_value = self._calculate_heuristic(node.state)  # 启发式值
      children_value = max([self._evaluate_subtree(child) for child in node.children.values()]) if node.children else 0

      return visit_value + heuristic_value + 0.5 * children_value  # 可以调整这些因素的权重

  def search(self, root_state: int) -> int:
    """主搜索方法"""
    _, _, root_node = self.bfs_search(root_state)
    return self.get_best_action_from_tree(root_node)

In [None]:
env = gym.make('FrozenLake-v1',map_name="8x8", is_slippery=False)#map_name="8x8",
value_net = Network()
bfs = EnhancedBFS(value_net, num_simulations=10)
action = bfs.search(45)
print(action)

1


####4.4 version4.0(NN with Imitation-Learning)

In [None]:
import torch
import torch.nn as nn

class ValueNetwork(nn.Module):
    def __init__(self, env_size):
        super().__init__()
        self.state_size = env_size * env_size  # one-hot编码状态

        self.network = nn.Sequential(
            nn.Linear(env_size**2, 256),
            nn.ReLU(),
            nn.LayerNorm(256),       # 添加归一化
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, state):
        x = self.state_to_tensor(state)
        return self.network(x)

    def state_to_tensor(self, state):
        # 将状态转换为one-hot向量
        x = torch.zeros(self.state_size)
        x[state] = 1.0
        return x
class ExpertDataCollector:
    def __init__(self, env_size):
        self.env_size = env_size
        self.expert_data = []
        # 创建环境以获取地图信息
        self.env = gym.make('FrozenLake-v1',
                           map_name=f"{env_size}x{env_size}",
                           is_slippery=False)

    def _is_valid_state(self, state: int) -> bool:
        """检查状态是否有效"""
        # 获取地图
        desc = self.env.unwrapped.desc.flatten()
        # 检查是否是洞
        return desc[state] != b'H'

    def collect_from_bfs(self):
      """用BFS计算每个状态到终点的最短步数"""
      goal = self.env_size**2 - 1
      for state in range(self.env_size**2):
          if not self._is_valid_state(state):
              continue
          # 运行BFS计算最短路径
          queue = deque([(state, 0)])
          visited = set()
          while queue:
              s, steps = queue.popleft()
              if s == goal:
                  self.expert_data.append((state, 1.0/(steps+1)))
                  break
              for action in range(4):
                  self.env.unwrapped.s = s
                  next_state, _, _, _, _ = self.env.step(action)
                  if next_state not in visited and self._is_valid_state(next_state):
                      visited.add(next_state)
                      queue.append((next_state, steps+1))



    def collect_from_manhattan(self):
        """使用曼哈顿距离启发式生成专家数据"""
        goal_state = self.env_size * self.env_size - 1
        for state in range(self.env_size * self.env_size):
            if self._is_valid_state(state):
                steps = self._manhattan_distance(state, goal_state)
                value = 1.0 / (steps + 1)  # 归一化的价值
                self.expert_data.append((state, value))

    def _manhattan_distance(self, state, goal_state):
        state_row, state_col = state // self.env_size, state % self.env_size
        goal_row, goal_col = goal_state // self.env_size, goal_state % self.env_size
        return abs(state_row - goal_row) + abs(state_col - goal_col)
class ImitationBFS:
    def __init__(self, env_size: int = 8):
        self.env_size = env_size
        self.env = gym.make('FrozenLake-v1',
                           map_name=f"{env_size}x{env_size}",
                           is_slippery=False)  # 确定性环境
        self.goal_state = env_size * env_size - 1

        # 初始化网络和传统启发式
        self.value_net = ValueNetwork(env_size)
        self.expert_collector = ExpertDataCollector(env_size)

        # 训练网络
        self._train_from_expert()

    def _get_valid_actions(self, state: int) -> List[int]:
        """获取当前状态下合法动作"""
        valid_actions = []
        row = state // self.env_size
        col = state % self.env_size
        if col > 0: valid_actions.append(0)  # 左
        if row < self.env_size-1: valid_actions.append(1)  # 下
        if col < self.env_size-1: valid_actions.append(2)  # 右
        if row > 0: valid_actions.append(3)  # 上
        return valid_actions

    def _train_from_expert(self):
        # 收集专家数据
        self.expert_collector.collect_from_manhattan()  # 启用数据收集

        # 训练网络
        optimizer = torch.optim.Adam(self.value_net.parameters())
        for epoch in range(40):
            total_loss = 0
            for state, target_value in self.expert_collector.expert_data:
                optimizer.zero_grad()
                predicted_value = self.value_net(state)
                loss = nn.MSELoss()(predicted_value, torch.tensor([target_value]))
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

    def _calculate_heuristic(self, state: int) -> float:
        """结合网络输出和曼哈顿距离"""
        with torch.no_grad():
            success_prob = self.value_net(state).item()

        # 计算曼哈顿距离
        current_row = state // self.env_size
        current_col = state % self.env_size
        goal_row = self.env_size - 1
        goal_col = self.env_size - 1
        manhattan_dist = abs(current_row - goal_row) + abs(current_col - goal_col)

        # 加权组合（可调整系数）
        return success_prob * 0.7 #+ (1/(manhattan_dist+0.1)) * 0.3

    @staticmethod
    def _get_action_path(node: Node) -> List[int]:
        path = []
        current = node
        while current.parent:
            path.append(current.action_taken)
            current = current.parent
        return list(reversed(path))

    def bfs_search(self, start_state: int) -> Tuple[Optional[List[int]], int, Node]:
        visited = set()
        queue = PriorityQueue()
        root_node = Node(state=start_state)
        queue.put((-self._calculate_heuristic(start_state), id(root_node), root_node))
        found_goal = False
        goal_node = None
        counter = 50

        while not queue.empty() and counter > 0:
            counter -= 1
            _, _, current_node = queue.get()

            if current_node.state in visited:
                continue

            visited.add(current_node.state)
            current_node.visit_count += 1

            if current_node.state == self.goal_state:
                found_goal = True
                goal_node = current_node
                break

            for action in self._get_valid_actions(current_node.state):
                # 正确设置环境状态
                self.env.reset()
                self.env.unwrapped.s = current_node.state
                next_state, _, _, _, _ = self.env.step(action)

                if next_state not in visited:
                    next_node = Node(
                        state=next_state,
                        action_taken=action,
                        parent=current_node
                    )
                    current_node.children[action] = next_node
                    priority = -self._calculate_heuristic(next_state)
                    queue.put((priority, id(next_node), next_node))

        if found_goal:
            return self._get_action_path(goal_node), 1, root_node
        return None, 0, root_node

    def get_best_action_from_tree(self, root_node: Node) -> int:
        def _evaluate_subtree(node: Node) -> float:
            if node.state == self.goal_state:
                return float('inf')

            visit_value = node.visit_count
            heuristic_value = self._calculate_heuristic(node.state)
            children_value = max([_evaluate_subtree(child) for child in node.children.values()]) if node.children else 0
            return visit_value * 0.6 + heuristic_value * 0.3 + children_value * 0.1

        best_action = None
        best_value = float('-inf')

        for action, child in root_node.children.items():
            value = _evaluate_subtree(child)
            if value > best_value:
                best_value = value
                best_action = action

        # 后备启发式策略
        if best_action is None:
            valid_actions = self._get_valid_actions(root_node.state)
            return self._get_best_heuristic_action(root_node.state, valid_actions)
        return best_action

    def _get_best_heuristic_action(self, state: int, valid_actions: List[int]) -> int:
        """纯启发式策略"""
        best_action = None
        best_value = -float('inf')
        for action in valid_actions:
            self.env.unwrapped.s = state
            next_state, _, _, _, _ = self.env.step(action)
            value = self._calculate_heuristic(next_state)
            if value > best_value:
                best_value = value
                best_action = action
        return best_action if best_action is not None else valid_actions[0]

    def search(self, root_state: int) -> int:
        _, _, root_node = self.bfs_search(root_state)
        return self.get_best_action_from_tree(root_node)

In [None]:
env = gym.make('FrozenLake-v1',map_name="8x8", is_slippery=False)#map_name="8x8",
value_net = ValueNetwork(8)
bfs = ImitationBFS()
action = bfs.search(0)
print(action)

####4.5 version5.0(NN with double network learning)

In [None]:
from typing import List, Tuple, Optional
from queue import PriorityQueue
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
from dataclasses import dataclass, field
from functools import lru_cache

# 神经网络模型定义
class HeuristicNetwork(nn.Module):
    def __init__(self, env_size):
        super().__init__()
        self.embedding = nn.Embedding(env_size**2, 32)
        self.fc = nn.Sequential(
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()  # 输出0-1之间的启发值
        )

    def forward(self, state):
        x = self.embedding(state)
        return self.fc(x)

@dataclass
class Node:
    state: int
    action_taken: Optional[int] = None
    parent: Optional['Node'] = None
    children: dict = field(default_factory=dict)
    visit_count: int = 0
    value: float = 0.0

class NeuralEnhancedBFS:
    def __init__(self, env_size: int = 8, num_simulations: int = 100,
                 buffer_size: int = 10000, batch_size: int = 32):
        self.env_size = env_size
        self.env = self._create_env()
        self.goal_state = env_size**2 - 1

        # 神经网络相关配置
        self.model = HeuristicNetwork(env_size)
        self.target_model = HeuristicNetwork(env_size)
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
        self.loss_fn = nn.MSELoss()

        # 经验回放缓存
        self.replay_buffer = deque(maxlen=buffer_size)
        self.batch_size = batch_size

        # 目标网络同步间隔
        self.target_update_interval = 15
        self.train_step_counter = 0

    def _create_env(self) -> gym.Env:
        return gym.make('FrozenLake-v1',
                       map_name=f"{self.env_size}x{self.env_size}",
                       is_slippery=False,
                       render_mode=None)

    # ... (保留原有的_get_valid_actions和_get_action_path方法)
    def _get_valid_actions(self,state: int, env_size: int) -> List[int]:
      """获取有效动作列表"""
      row, col = state // env_size, state % env_size
      valid_actions = []
      if col > 0: valid_actions.append(0)    # 左
      if row < env_size - 1: valid_actions.append(1)    # 下
      if col < env_size - 1: valid_actions.append(2)    # 右
      if row > 0: valid_actions.append(3)    # 上
      return valid_actions

    def _get_action_path(self,node: Node) -> List[int]:
        """获取动作路径"""
        path = []
        current = node
        while current.parent:
            path.append(current.action_taken)
            current = current.parent
        return list(reversed(path))
    def _calculate_heuristic(self, state: int) -> float:
        """使用神经网络预测启发值"""
        with torch.no_grad():
            state_tensor = torch.LongTensor([state])
            return self.model(state_tensor).item()

    def _update_network(self, states, targets):
        """训练网络"""
        states = torch.LongTensor(states)
        targets = torch.FloatTensor(targets)

        predictions = self.model(states).squeeze()
        loss = self.loss_fn(predictions, targets)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # 定期更新目标网络
        self.train_step_counter += 1
        if self.train_step_counter % self.target_update_interval == 0:
            self.target_model.load_state_dict(self.model.state_dict())

    def _remember(self, state, target):
        """存储经验"""
        self.replay_buffer.append((state, target))

    def _replay(self):
        """经验回放"""
        if len(self.replay_buffer) < self.batch_size:
            return

        batch = np.random.choice(len(self.replay_buffer), self.batch_size, replace=False)
        states, targets = zip(*[self.replay_buffer[i] for i in batch])

        self._update_network(states, targets)

    def _get_bootstrap_target(self, state):
        """使用目标网络生成训练目标"""
        with torch.no_grad():
            state_tensor = torch.LongTensor([state])
            return self.target_model(state_tensor).item()

    def bfs_search(self, start_state: int) -> Tuple[Optional[List[int]], int, Node]:
        visited = set()
        queue = PriorityQueue()
        root_node = Node(state=start_state)
        queue.put((-self._calculate_heuristic(start_state), id(root_node), root_node))
        found_goal = False
        goal_node = None

        while not queue.empty():
            _, _, current_node = queue.get()

            if current_node.state in visited:
                continue
            visited.add(current_node.state)
            current_node.visit_count += 1

            # 收集训练数据
            if current_node.parent is not None:
                target = self._get_bootstrap_target(current_node.state)
                self._remember(current_node.parent.state, target)

            if current_node.state == self.goal_state:
                found_goal = True
                goal_node = current_node
                # 传播成功信号
                self._remember(current_node.state, 1.0)
                break

            for action in self._get_valid_actions(current_node.state, self.env_size):
                self.env.reset()
                self.env.unwrapped.s = current_node.state
                next_state, _, terminated, _, _ = self.env.step(action)

                if terminated and next_state != self.goal_state:
                    self._remember(next_state, 0.0)  # 记录失败状态
                    continue

                if next_state not in visited:
                    next_node = Node(
                        state=next_state,
                        action_taken=action,
                        parent=current_node,
                        value=self._calculate_heuristic(next_state)
                    )
                    current_node.children[action] = next_node
                    priority = -next_node.value
                    queue.put((priority, id(next_node), next_node))

            # 进行经验回放
            self._replay()

        if found_goal:
          current = goal_node
          while current.parent:
              current.value += 1.0  # 或其他奖励值
              current = current.parent

          return self._get_action_path(goal_node), 1, root_node
        return None, 0, root_node

    # ... (保留其他辅助方法)
    def get_best_action_from_tree(self, root_node: Node) -> int:
      """基于搜索树选择最佳动作"""
      best_action = None
      best_value = float('-inf')

      for action, child in root_node.children.items():
          # 计算每个动作的价值
          value = self._evaluate_subtree(child)
          if value > best_value:
              best_value = value
              best_action = action

      return best_action #if best_action is not None else self._get_best_heuristic_action(root_node.state)

    def _evaluate_subtree(self, node: Node) -> float:
        """评估子树的价值"""
        # 如果找到目标
        if node.state == self.goal_state:
            return float('inf')

        # 综合考虑多个因素
        visit_value = node.visit_count  # 访问次数说明这个方向被多次探索
        heuristic_value = node.value#self._calculate_heuristic(node.state)  # 启发式值
        children_value = max([self._evaluate_subtree(child) for child in node.children.values()]) if node.children else 0
        return visit_value + heuristic_value + 0.5 * children_value  # 可以调整这些因素的权重


    def search(self, root_state: int) -> int:
        _, _, root_node = self.bfs_search(root_state)
        best_action = self.get_best_action_from_tree(root_node)

        '''# 使用最终结果更新网络
        if best_action is not None:
            next_state = self._simulate_step(root_state, best_action)
            if next_state == self.goal_state:
                self._remember(root_state, 1.0)
            else:
                self._remember(root_state, self._get_bootstrap_target(next_state))'''
        return best_action

    def _simulate_step(self, state, action):
        self.env.reset()
        self.env.unwrapped.s = state
        next_state, _, _, _, _ = self.env.step(action)
        return next_state

####4.5 TEST

In [None]:
import gymnasium as gym
import torch
import numpy as np
from collections import defaultdict
from queue import PriorityQueue
from typing import List, Tuple, Dict, Optional, Set
from dataclasses import dataclass

def test_enhanced_bfs():
    # 1. 创建简单的价值网络

    # 2. 初始化环境和算法
    env = gym.make('FrozenLake-v1',map_name="8x8", is_slippery=False)#map_name="8x8",
    #value_net = ValueNetwork(8)

    bfs = NeuralEnhancedBFS()#EnhancedBFS(value_net, num_simulations=10)#Neural


    # 3. 运行多个回合
    num_episodes = 1
    total_reward = 0

    print("\n开始测试Enhanced BFS...")

    for episode in range(num_episodes):
        state, _ = env.reset()
        done = False
        episode_reward = 0
        steps = 0

        print(f"\n回合 {episode + 1}:")
        print(f"起始状态: {state}")

        while not done and steps < 100:
            # 使用算法选择动作
            action = bfs.search(state)
            print(f"Steps {steps}: 在状态 {state} 选择动作 {action}")

            # 执行动作
            state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            episode_reward += reward
            steps += 1

            print(f"-> 新状态: {state}, 奖励: {reward}")

            if done:
                if reward > 0:
                    print("成功到达目标！")
                else:
                    print("失败（掉入陷阱或超时）")

        total_reward += episode_reward
        print(f"回合 {episode + 1} 结束 - 总步数: {steps}, 总奖励: {episode_reward}")

    print(f"\n测试完成 - 平均奖励: {total_reward/num_episodes}")

if __name__ == "__main__":
    test_enhanced_bfs()


开始测试Enhanced BFS...

回合 1:
起始状态: 0
(visit_value heuristic_value children_value
0 0.4753377139568329 0
(visit_value heuristic_value children_value
0 0.4753377139568329 0
(visit_value heuristic_value children_value
0 0.4917571544647217 0
(visit_value heuristic_value children_value
0 0.4926496744155884 0
(visit_value heuristic_value children_value
0 0.4926496744155884 0
(visit_value heuristic_value children_value
0 0.498965859413147 0
(visit_value heuristic_value children_value
0 0.4926496744155884 0
(visit_value heuristic_value children_value
0 0.4917571544647217 0
(visit_value heuristic_value children_value
1 0.5086429119110107 0.4926496744155884
(visit_value heuristic_value children_value
1 0.5093880891799927 1.754967749118805
(visit_value heuristic_value children_value
0 0.4926496744155884 0
(visit_value heuristic_value children_value
1 0.5162313580513 2.386871963739395
(visit_value heuristic_value children_value
1 0.5082707405090332 2.7096673399209976
(visit_value heuristic_value ch