<a href="https://colab.research.google.com/github/ImaginationX4/HybridZero-/blob/main/alphazero_style_BFS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install gymnasium

Collecting gymnasium
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Downloading gymnasium-1.0.0-py3-none-any.whl (958 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m958.1/958.1 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Installing collected packages: farama-notifications, gymnasium
Successfully installed farama-notifications-0.0.4 gymnasium-1.0.0


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Network(nn.Module):
  def __init__(self, input_size=16, output_size=1):
      super(Network, self).__init__()
      self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

      self.net = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(256, output_size)
        )


  def forward(self, x):
    if not isinstance(x, torch.Tensor):
        x = torch.FloatTensor(x)

    x = x.to(self.device)

    value = self.net(x)
    return value

  def save(self, filepath):
    torch.save(self.state_dict(), filepath)
    print(f"Model saved to {filepath}")

  def load(self, filepath):
    self.load_state_dict(torch.load(filepath, map_location=self.device))
    print(f"Model loaded from {filepath}")

In [None]:
import gymnasium as gym
import numpy as np
class PathNode:
  def __init__(self, state, action=None, parent=None, is_terminated=False, is_hole=False):
      self.state = state        # 当前状态
      self.action = action      # 到达此状态的动作
      self.parent = parent      # 前一个节点
      self.children = []        # 后续可能的路径
      self.is_terminated = is_terminated  # 标记是否是终止状态（不管成功还是失败）
      self.is_hole = is_hole


  def get_path(self):
    # 回溯完整路径
    path = []
    current = self
    while current:
      path.append((current.state, current.action))
      current = current.parent
    return list(reversed(path))
  def add_child(self, child_node):
    self.children.append(child_node)

  def __lt__(self, other):
      return False


In [None]:
from queue import PriorityQueue
import numpy as np

def action_to_chinese(actions):
  translation = {
      0: "左",
      1: "下",
      2: "右",
      3: "上"
  }
  return [translation[a] for a in actions]
class PathFinder:
  def __init__(self, grid_size=4):
    self.grid_size = grid_size
    self.goal = 15
    self.env = gym.make('FrozenLake-v1', is_slippery=False)
    self.net = Network()

  def get_valid_actions(self, state):
    row = state // 4    # 行号
    col = state % 4     # 列号
    actions = []

    # 检查四个方向的有效性
    if col > 0: actions.append(0)    # 可以向左
    if row < 3: actions.append(1)    # 可以向下
    if col < 3: actions.append(2)    # 可以向右
    if row > 0: actions.append(3)    # 可以向上

    return actions

  def manhattan_distance(self, state, goal_state):

    current_row = state // 4
    current_col = state % 4
    goal_row = goal_state // 4
    goal_col = goal_state % 4
    return abs(current_row - goal_row) + abs(current_col - goal_col)
  def net_value(self, state):
    state_one_hot = torch.zeros(self.grid_size * self.grid_size)
    state_one_hot[state] = 1
    value =  self.net(state_one_hot)
    return value.item()

  def best_first_search(self, start_state ,max_path_length=10):
    frontier = PriorityQueue()
    # 优先级是到目标的曼哈顿距离
    initial_priority = self.manhattan_distance(start_state, self.goal)
    root = PathNode(start_state)
    frontier.put((initial_priority, start_state, root))
    all_paths = {
        'success': [],  # 成功到达目标的路径
        'hole': [],    # 掉进洞里的路径
        'timeout': []  # 超过最大长度的路径
    }

    def get_path_states(node):
      # 获取当前路径上的所有状态
      states = set()
      current = node
      while current:
          states.add(current.state)
          current = current.parent
      return states

    def get_path_length(node):
        length = 0
        current = node
        while current.parent:
            length += 1
            current = current.parent
        return length

    while not frontier.empty():
      priority,_, current_node = frontier.get()
      current_path_states = get_path_states(current_node)
      #print('current_path_states',current_path_states)
      current_length = get_path_length(current_node)
      # 检查是否超过最大路径长度
      if current_length >= max_path_length:
        all_paths['timeout'].append(current_node.get_path())
        continue

      # 检查当前状态
      if current_node.is_hole:
        all_paths['hole'].append(current_node.get_path())
        continue

      if current_node.state == self.goal:
        all_paths['success'].append(current_node.get_path())
        break

      # 展开当前节点
      #print('get_valid_actions(current_node.state)',action_to_chinese(self.get_valid_actions(current_node.state)))
      for action in self.get_valid_actions(current_node.state):
        self.env.reset()
        self.env.unwrapped.s = current_node.state

        next_state, reward, done, truncated, _ = self.env.step(action)
        #avoid back and forth actions
        is_hole = done and reward == 0

        if next_state not in current_path_states:
          #print('new node',next_state)
          new_node = PathNode(
                    state=next_state,
                    action=action,
                    parent=current_node,
                    is_hole=is_hole,
                )
          current_node.add_child(new_node)
          #new_priority = self.manhattan_distance(next_state, self.goal)
          new_priority = self.net_value(next_state)
          frontier.put((new_priority, next_state, new_node))

    return all_paths, root

In [None]:
import time

start_time = time.time()
a=PathFinder()
print(a.get_valid_actions(0))
env=gym.make('FrozenLake-v1', is_slippery=False)
state,_=env.reset()
all_paths,root=a.best_first_search(state)
end_time = time.time()
execution_time = end_time - start_time
print(f"执行时间: {execution_time:.4f} 秒")


[1, 2]
执行时间: 0.0226 秒


In [None]:
import time

start_time = time.time()
a=PathFinder()
print(a.get_valid_actions(0))
env=gym.make('FrozenLake-v1', is_slippery=False)
state,_=env.reset()
all_paths,root=a.best_first_search(state)
end_time = time.time()
execution_time = end_time - start_time
print(f"执行时间: {execution_time:.4f} 秒")

[1, 2]
执行时间: 0.0133 秒


In [None]:
len(all_paths['timeout'])
all_paths['success']
action ,state = all_paths['success'][0][1]
print(action,state)

1 2
