<a href="https://colab.research.google.com/github/ImaginationX4/HybridZero/blob/main/THTS_BFS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install gymnasium



In [None]:
import gymnasium as gym
import numpy as np
import torch
import math
from queue import PriorityQueue
from collections import deque
from dataclasses import dataclass, field
from typing import Optional

# 定义价值网络（保持不变）
class ValueNetwork(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(4, 256),
            torch.nn.LayerNorm(256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 256),
            torch.nn.LayerNorm(256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 1)
        )
        torch.nn.init.uniform_(self.layers[-1].weight, -0.1, 0.1)

    def forward(self, x):
        return self.layers(x).squeeze(-1)

# 定义搜索树节点（增加唯一标识符）
@dataclass
class Node:
    state: np.ndarray
    action_taken: Optional[int] = None
    parent: Optional['Node'] = None
    children: dict = field(default_factory=dict)
    value: float = 0.0
    depth: int = 0
    done: bool = False
    reward: float = 0.0
    visit_count: int = 1
    _id: int = field(init=False)  # 添加唯一ID用于哈希

    def __post_init__(self):
        self._id = id(self)
        if self.parent is not None:
            self.depth = self.parent.depth + 1

class BFS:
    def __init__(self,
                 num_simulations: int = 10,
                 buffer_size: int = 10000,
                 batch_size: int = 64,
                 gamma: float = 0.99,
                 exploration_weight: float = 1.5):
        self.env = gym.make('CartPole-v1')
        self.num_simulations = num_simulations
        self.gamma = gamma
        self.exploration_weight = exploration_weight

        # 维护持久化搜索树
        self.root: Optional[Node] = None
        self.current_tree: Optional[Node] = None

        # 神经网络与训练相关（保持不变）
        self.model = ValueNetwork()
        self.target_model = ValueNetwork()
        self.target_model.load_state_dict(self.model.state_dict())
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
        self.loss_fn = torch.nn.MSELoss()
        self.replay_buffer = deque(maxlen=buffer_size)
        self.batch_size = batch_size
        self.train_step_counter = 0
        self.target_update_interval = 20

    def bfs_search(self, current_state: np.ndarray) -> int:
        """基于持久化搜索树的UCT算法"""
        # 初始化或更新搜索树
        if self.root is None or not self._is_state_equal(self.root.state, current_state):
            self.root = Node(current_state)
        self.current_tree = self.root

        # 创建临时环境用于模拟
        temp_env = gym.make('CartPole-v1', render_mode='no_render')

        # 执行多次蒙特卡洛模拟
        for _ in range(self.num_simulations):
            self._simulate(temp_env)

        # 选择访问次数最多的动作
        best_action = max(self.root.children.keys(), key=lambda a: self.root.children[a].visit_count)#max([0, 1], key=lambda a: self.root.children[a].visit_count)

        return best_action, self.root

    def _simulate(self, temp_env: gym.Env):
        """单次蒙特卡洛模拟（选择->扩展->评估->反向传播）"""
        node = self.current_tree
        path = []

        # 阶段1: 选择（Selection）
        while not node.done and len(node.children) == self.env.action_space.n:
            # 使用UCT选择子节点
            action = self._select_child(node)
            node = node.children[action]
            path.append(node)

        # 阶段2: 扩展（Expansion）
        if not node.done and len(node.children) < self.env.action_space.n:
            # 选择未扩展的动作
            available_actions = set(range(self.env.action_space.n)) - set(node.children.keys())
            action = np.random.choice(list(available_actions))

            # 执行动作得到新状态
            temp_env.reset()
            temp_env.unwrapped.state = node.state
            next_state, reward, done, _, _ = temp_env.step(action)
            # 预测节点价值
            with torch.no_grad():
                state_tensor = torch.FloatTensor(next_state)
                next_value = self.model(state_tensor).item()

            # 创建新节点
            child = Node(
                state=next_state,
                action_taken=action,
                parent=node,
                value=next_value,
                reward=reward,
                done=done
            )
            node.children[action] = child
            node = child
            #path.append(node)

        # 阶段3: 收集
        with torch.no_grad():
            state_tensor = torch.FloatTensor(node.state)
            targe_value = self.target_model(state_tensor).item()
        target = node.reward + (1 - node.done) * self.gamma * targe_value
        self._remember(node.parent.state, target)

        # 阶段4: 反向传播（Backpropagation）

        self._backpropagate(path, node.value)

        # 经验回放
        if len(self.replay_buffer) >= self.batch_size:
            self._replay()

    def _select_child(self, node: Node) -> int:
        """基于UCT公式选择子节点"""
        total_visits = node.visit_count
        best_score = -float('inf')
        best_action = -1

        for action in node.children:
            child = node.children[action]
            exploit = child.value
            explore = self.exploration_weight * math.sqrt(math.log(total_visits + 1e-7) / (child.visit_count + 1e-7))
            score = exploit + explore

            if score > best_score:
                best_score = score
                best_action = action

        return best_action

    def _backpropagate1(self, path: list[Node], value: float):
        """反向传播更新路径上的节点"""
        for node in reversed(path):
            node.visit_count += 1
            node.value += (value - node.value) / node.visit_count  # 增量平均
            value = node.reward + self.gamma * value * (1 - node.done)

            '''# 存储经验
            if node.parent is not None:
                target = node.reward + (1 - node.done) * self.gamma * node.value
                self._remember(node.parent.state, target)'''
    def _backpropagate(self, path: list[Node], value: float):
      """反向传播更新路径上的节点，使用加权平均更新节点值"""
      for node in reversed(path):
          node.visit_count += 1

          if node.children:
            total_visits = sum(child.visit_count for child in node.children.values())

            weighted_value = sum(child.value * child.visit_count for child in node.children.values()) / total_visits
            node.value = weighted_value


    def _is_state_equal(self, state1: np.ndarray, state2: np.ndarray) -> bool:
        """状态比较（CartPole状态为连续值，需设定阈值）"""
        return np.allclose(state1, state2, atol=1e-3)

    # 以下方法与原始代码保持一致
    def _remember(self, state: np.ndarray, target_value: float):
        self.replay_buffer.append((state, target_value))

    def _replay(self):
        batch = np.random.choice(len(self.replay_buffer), self.batch_size, replace=False)
        states, target_values = zip(*[self.replay_buffer[i] for i in batch])

        states = torch.FloatTensor(np.array(states))
        target_values = torch.FloatTensor(target_values)

        self.optimizer.zero_grad()
        pred_values = self.model(states)
        loss = self.loss_fn(pred_values, target_values)
        loss.backward()
        self.optimizer.step()

        self.train_step_counter += 1
        if self.train_step_counter % self.target_update_interval == 0:
            self.target_model.load_state_dict(self.model.state_dict())

# 训练函数（保持不变）
def train_agent():
    env = gym.make('CartPole-v1')
    agent = BFS(num_simulations=100)

    for episode in range(20):
        state = env.reset()[0]
        total_reward = 0
        done = False

        while not done:
            action, node = agent.bfs_search(state)
            next_state, reward, done, _,_ = env.step(action)
            total_reward += reward
            state = next_state
            print(f"Episode {episode+1}, Reward: {total_reward}")

        print(f"Episode {episode+1}, Reward: {total_reward}")

if __name__ == "__main__":
    train_agent()

  logger.warn(


Episode 1, Reward: 1.0
Episode 1, Reward: 2.0
Episode 1, Reward: 3.0
Episode 1, Reward: 4.0
Episode 1, Reward: 5.0
Episode 1, Reward: 6.0
Episode 1, Reward: 7.0
Episode 1, Reward: 8.0
Episode 1, Reward: 9.0
Episode 1, Reward: 10.0
Episode 1, Reward: 11.0
Episode 1, Reward: 12.0
Episode 1, Reward: 13.0
Episode 1, Reward: 14.0
Episode 1, Reward: 15.0
Episode 1, Reward: 16.0
Episode 1, Reward: 17.0
Episode 1, Reward: 18.0
Episode 1, Reward: 19.0
Episode 1, Reward: 20.0
Episode 1, Reward: 21.0
Episode 1, Reward: 22.0
Episode 1, Reward: 23.0
Episode 1, Reward: 24.0
Episode 1, Reward: 25.0
Episode 1, Reward: 26.0
Episode 1, Reward: 27.0
Episode 1, Reward: 28.0
Episode 1, Reward: 29.0
Episode 1, Reward: 30.0
Episode 1, Reward: 31.0
Episode 1, Reward: 32.0
Episode 1, Reward: 33.0
Episode 1, Reward: 34.0
Episode 1, Reward: 35.0
Episode 1, Reward: 36.0
Episode 1, Reward: 37.0
Episode 1, Reward: 38.0
Episode 1, Reward: 39.0
Episode 1, Reward: 40.0
Episode 1, Reward: 41.0
Episode 1, Reward: 42.0
E

KeyboardInterrupt: 

###2.THIS-BFS FROZEN LAKE


In [None]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from collections import deque
from dataclasses import dataclass, field
from typing import Optional
class ValueNetwork(nn.Module):
    def __init__(self, state_size=16, hidden_size=128):
        super().__init__()
        self.state_size = state_size
        self.hidden_size = hidden_size

        # 输入是状态索引（无需one-hot编码）
        self.layers = nn.Sequential(
            nn.Linear(1, hidden_size),  # 输入是状态索引（标量）
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1)  # 输出是状态值
        )

        # 初始化最后一层的权重
        nn.init.uniform_(self.layers[-1].weight, -0.1, 0.1)

    def forward(self, x):
        # 输入x是状态索引（标量），转换为浮点数张量
        x = x.float().unsqueeze(-1)  # 形状: (batch_size, 1)
        return self.layers(x).squeeze(-1)  # 输出形状: (batch_size,)
class ValueNetwork1(torch.nn.Module):
    def __init__(self, state_size=16):
        super().__init__() # 嵌入层处理离散状态
        self.state_size=state_size
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(state_size, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 1)
        )
        torch.nn.init.uniform_(self.layers[-1].weight, -0.1, 0.1)

    def forward(self, x):
        '''a = torch.zeros(self.state_size)
        a[x]=1'''
        x = x.long()
        x = F.one_hot(x, num_classes=self.state_size).float()   # 输入为状态索引
        return self.layers(x).squeeze(-1)
@dataclass
class Node:
    state: int  # FrozenLake 状态是离散的 (0-15)
    action_taken: Optional[int] = None
    parent: Optional['Node'] = None
    children: dict = field(default_factory=dict)
    cost: float = 0.0      # 改为成本（越小越好）
    done: bool = False
    visit_count: int = 1

    shortest_path_length: float = float('inf')
    best_action: Optional[int] = None
    is_safe_path: bool = False
    reached_goal: bool = False




    def __post_init__(self):
        self._id = id(self)

class UCTFrozenLake:

    def __init__(self,
                 num_simulations: int = 1,
                 buffer_size: int = 1000,
                 batch_size: int = 64,
                 gamma: float = 0.99,
                 exploration_weight: float = 1.5):
      self.num_simulations = num_simulations
      self.gamma = gamma
      self.exploration_weight = exploration_weight
      self.env = gym.make('FrozenLake-v1', is_slippery=False)

      # 神经网络
      self.model = ValueNetwork(state_size=16)
      self.target_model = ValueNetwork(state_size=16)
      self.target_model.load_state_dict(self.model.state_dict())
      self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
      self.loss_fn = torch.nn.MSELoss()
      self.replay_buffer = deque(maxlen=buffer_size)
      self.batch_size = batch_size
      self.train_step_counter = 0
      self.target_update_interval = 20

      # 搜索树
      self.root: Optional[Node] = None
    def search(self, current_state: int):
      self.observe_state(current_state)
      self.root = Node(state=current_state)
      temp_env = gym.make('FrozenLake-v1', is_slippery=False)
      for _ in range(self.num_simulations):
        self._simulate(temp_env, self.root)
       # 选择访问次数最多的动作（最小化成本）
      best_action = max(self.root.children.keys(), key=lambda a: self.root.children[a].visit_count)
      #print('self.root.children.keys()',self.root.children.keys())
      return best_action, self.root
    def _simulate(self, temp_env: gym.Env, root: Node):
      path = []
      node = root
      done = root.done

      # Selection
      while not done and node.children:
          action = self._select_child(node)
          node = node.children[action]
          path.append(node)
          done = node.done

      # Expansion
      if not done:
          for action in range(self.env.action_space.n):
              if action not in node.children:
                  temp_env.reset()
                  temp_env.unwrapped.s = node.state
                  next_state, reward, terminated, truncated, _ = temp_env.step(action)
                  done = terminated or truncated

                  # 使用model预测初始成本
                  state_tensor = torch.tensor([next_state], dtype=torch.float32)
                  initial_cost = self.model(state_tensor).item()

                  child = Node(
                      state=next_state,
                      action_taken=action,
                      parent=node,
                      cost=initial_cost,
                      done=done
                  )
                  node.children[action] = child

      # Evaluation
      if node.done and node.state == 15:
          final_cost = 0.0
      elif node.done and node.state != 15:
          final_cost = 1000000.0
      else:
          with torch.no_grad():
              state_tensor = torch.tensor([node.state], dtype=torch.float32)
              final_cost = 1 + self.gamma * self.target_model(state_tensor).item()

      # Backup
      self._backpropagate(path, final_cost)

      # 存储经验
      if node.parent is not None:
          self._remember(node.parent.state, final_cost)

      # 经验回放
      if len(self.replay_buffer) >= self.batch_size:
          self._replay()
    def _simulate1(self, temp_env: gym.Env, root: Node):
      path = [root]
      node = root
      done = root.done
      current_path_states = set([root.state])
      # 阶段1: Selection
      while not done and node.children:
          action = self._select_child(node)
          if node.children[action] in path:
            continue

          node = node.children[action]
          '''if node.state in current_path_states:
            # 发现循环，给予惩罚
            return self._backpropagate(path, -1.0)
          current_path_states.add(node.state) '''
          path.append(node)
          done = node.done

      '''print('叶子节点',node.state)
      print('叶子节点孩子有么？',node.children)
      print('叶子节点结束没',node.done)  '''
     # 阶段2: Expansion
      while not node.done and len(node.children) < self.env.action_space.n:
          #print('叶子节点扩张没',True)
          action = np.random.choice([a for a in range(4) if a not in node.children])
          #print('action',[a for a in range(4) if a not in node.children])
          temp_env.reset()
          temp_env.unwrapped.s = node.state  # 直接设置状态
          next_state, reward, terminated, truncated, _ = temp_env.step(action)
          done = terminated or truncated

          '''# FrozenLake奖励调整：到达目标+1，其他情况0
          cost = 1.0  # 每步成本为1（最小化步数）
          if terminated and next_state == 15:  # 到达目标
              cost = 0.0'''

          # 使用model预测初始成本

          state_tensor = torch.tensor([next_state], dtype=torch.float32)
          initial_cost = self.model(state_tensor).item()

          child = Node(
              state=next_state,
              action_taken=action,
              parent=node,
              cost=initial_cost,
              done=done
          )
          node.children[action] = child

          #path.append(child)
          #node = child
      # 阶段3: Evaluation & Backup
      if node.done and node.state == 15:
          final_cost = 0.0
      elif node.done and node.state != 15:
        final_cost = 10.0
      else:
        with torch.no_grad():

          state_tensor = torch.tensor([node.state], dtype=torch.float32)
          final_cost = 1 + self.gamma * self.target_model(state_tensor).item()

      # 反向传播累积成本

      self._backpropagate(path, final_cost)
      #self._backpropagate(path, node.done, node.state==15)

      # 存储经验
      if node.parent is not None:
        self._remember(node.parent.state,  final_cost)

      # 经验回放
      if len(self.replay_buffer) >= self.batch_size:
          for i in range(4):
            self._replay()
    def _get_valid_actions(self,state, env_size=4):
      """获取有效动作列表"""
      row, col = state // env_size, state % env_size
      valid_actions = []
      if col > 0: valid_actions.append(0)    # 左
      if row < env_size - 1: valid_actions.append(1)    # 下
      if col < env_size - 1: valid_actions.append(2)    # 右
      if row > 0: valid_actions.append(3)    # 上
      return valid_actions
    def _select_child(self, node: Node) -> int:
        """基于UCT公式选择子节点（最小化成本）"""
        total_visits = node.visit_count
        best_score = float('inf')  # 找最小值
        best_action = -1

        for action in self._get_valid_actions(node.state):
            child = node.children[action]
            #print('child.state',child.state)
            #print('child.cost',child.cost)
            exploit = child.cost

            explore = self.exploration_weight * math.sqrt(math.log(total_visits + 1e-7) / (child.visit_count + 1e-7))
            score = exploit - explore  # 成本越小越好，因此减去探索项


            if score < best_score:
                best_score = score
                best_action = action
        if total_visits < 25:
          if np.random.random() < 0.2:
              return np.random.choice(list(node.children.keys()))
        return best_action
    def _backpropagate(self, path: list[Node], value: float):
      """反向传播更新路径上的节点，使用加权平均更新节点值"""

      for node in reversed(path):

          node.visit_count += 1

          if node.children:
            total_visits = sum(child.visit_count for child in node.children.values())

            weighted_cost = sum(child.cost * child.visit_count for child in node.children.values()) / total_visits
            node.cost = 0.8*weighted_cost



    def _backpropagate1(self, path: list[Node], terminated: bool, reached_goal: bool):
      """
      Backpropagate path information
      """
      path_length = len(path)
      if not path:  # 如果path为空
        return
      # Case 1: Found goal state
      if reached_goal:
          # Update all nodes in path
          for i, node in enumerate(reversed(path)):
              current_length = path_length - i

              # Only update if we found a shorter path
              if current_length < node.shortest_path_length:
                  node.shortest_path_length = current_length
                  node.is_safe_path = True
                  node.reached_goal = True

                  # Record best action (except for last node)
                  if i < len(path) - 1:
                      next_node = path[-(i+1)]
                      node.best_action = next_node.action_taken

      # Case 2: Hit a hole
      elif terminated:
          # Mark path as unsafe
          for node in path:
              node.is_safe_path = False

      # Case 3: Regular state
      else:
          # Get best child info
          node = path[-1]
          if node.children:
              best_child = min(node.children.values(),
                            key=lambda n: n.shortest_path_length)

              # Update if child has path to goal
              if best_child.reached_goal:
                  new_length = best_child.shortest_path_length + 1
                  if new_length < node.shortest_path_length:
                      node.shortest_path_length = new_length
                      node.best_action = best_child.action_taken
                      node.is_safe_path = best_child.is_safe_path
                      node.reached_goal = True
    def observe_uct_decision(self, node: Node, action: int, total_visits: int):
        """观察UCT决策过程"""
        child = node.children[action]
        exploit = child.cost
        explore = self.exploration_weight * math.sqrt(math.log(total_visits + 1e-7) / (child.visit_count + 1e-7))
        score = exploit - explore

        print(f"\n=== UCT Decision Analysis for State {node.state} Action {action} ===")
        print(f"Child State: {child.state}")
        print(f"Visit Count: {child.visit_count}")
        print(f"Cost: {child.cost:.4f}")
        print(f"Exploit Term: {exploit:.4f}")
        print(f"Explore Term: {explore:.4f}")
        print(f"Final UCT Score: {score:.4f}")
        return score

    def observe_state(self, state: int):
        """观察状态信息"""
        print(f"\n=== State Analysis ===")
        print(f"Current State: {state}")
        # 计算到目标的曼哈顿距离
        row, col = state // 4, state % 4
        manhattan_dist = abs(row - 3) + abs(col - 3)
        print(f"Manhattan Distance to Goal: {manhattan_dist}")
        # 显示可能的动作
        possible_actions = []
        for action in range(4):
            if action == 0:  # left
                next_col = max(0, col - 1)
                next_state = row * 4 + next_col
            elif action == 1:  # down
                next_row = min(3, row + 1)
                next_state = next_row * 4 + col
            elif action == 2:  # right
                next_col = min(3, col + 1)
                next_state = row * 4 + next_col
            elif action == 3:  # up
                next_row = max(0, row - 1)
                next_state = next_row * 4 + col
            possible_actions.append((action, next_state))
        print("Possible Actions:")
        for action, next_state in possible_actions:
            print(f"Action {action} -> State {next_state}")
    def _remember(self, state: int, target_cost: float):
        """存储经验"""
        self.replay_buffer.append((state, target_cost))

    def _replay(self):
        """训练网络"""
        batch = np.random.choice(len(self.replay_buffer), self.batch_size, replace=False)
        states, target_costs = zip(*[self.replay_buffer[i] for i in batch])

        states = torch.tensor(states, dtype=torch.float32).unsqueeze(-1)
        target_costs = torch.tensor(target_costs, dtype=torch.float32)

        self.optimizer.zero_grad()
        pred_costs = self.model(states)
        loss = self.loss_fn(pred_costs, target_costs)
        loss.backward()
        self.optimizer.step()

        self.train_step_counter += 1
        if self.train_step_counter % self.target_update_interval == 0:
            self.target_model.load_state_dict(self.model.state_dict())

def train_frozenlake():
    env = gym.make('FrozenLake-v1', is_slippery=False)
    agent = UCTFrozenLake(num_simulations=50)

    for episode in range(1):
        state = env.reset()[0]
        total_cost = 0
        done = False

        while not done:
            action, node = agent.search(state)
            #print(f"Episode {episode+1}, state: {state}")
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            total_cost += 1  # 每步成本+1
            state = next_state
            print(f"Episode {episode+1}, state: {next_state}")
            print('children',[children.visit_count for children in node.children.values()])

            if terminated and next_state == 15:
                total_cost = 0  # 到达目标时总成本为0



if __name__ == "__main__":
    train_frozenlake()


=== State Analysis ===
Current State: 0
Manhattan Distance to Goal: 6
Possible Actions:
Action 0 -> State 0
Action 1 -> State 4
Action 2 -> State 1
Action 3 -> State 0
Episode 1, state: 1
children [2, 3, 40, 8]

=== State Analysis ===
Current State: 1
Manhattan Distance to Goal: 5
Possible Actions:
Action 0 -> State 0
Action 1 -> State 5
Action 2 -> State 2
Action 3 -> State 1
Episode 1, state: 5
children [17, 31, 3, 2]


In [None]:


def _backpropagate(self, path: list[Node], value: float):
    for node in reversed(path):
        node.visit_count += 1
        node.cost += value  # 直接累加成本

In [None]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from collections import deque
from dataclasses import dataclass, field
from typing import Optional
class ValueNetwork(torch.nn.Module):
    def __init__(self, state_size=16):
        super().__init__() # 嵌入层处理离散状态
        self.state_size=state_size
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(state_size, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 1)
        )
        torch.nn.init.uniform_(self.layers[-1].weight, -0.1, 0.1)

    def forward(self, x):
        '''a = torch.zeros(self.state_size)
        a[x]=1'''
        x = x.long()
        x = F.one_hot(x, num_classes=self.state_size).float()   # 输入为状态索引
        return self.layers(x).squeeze(-1)
@dataclass
class Node:
    state: int  # FrozenLake 状态是离散的 (0-15)
    action_taken: Optional[int] = None
    parent: Optional['Node'] = None
    children: dict = field(default_factory=dict)
    value: float = 0.0
    done: bool = False
    visit_count: int = 1

    def __post_init__(self):
        self._id = id(self)

class UCTFrozenLake:

    def __init__(self,
                 num_simulations: int = 100,
                 buffer_size: int = 10000,
                 batch_size: int = 64,
                 gamma: float = 0.99,
                 exploration_weight: float = 1.414):
      self.num_simulations = num_simulations
      self.gamma = gamma
      self.exploration_weight = exploration_weight
      self.env = gym.make('FrozenLake-v1', is_slippery=False)

      # 神经网络
      self.model = ValueNetwork(state_size=16)
      self.target_model = ValueNetwork(state_size=16)
      self.target_model.load_state_dict(self.model.state_dict())
      self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
      self.loss_fn = torch.nn.MSELoss()
      self.replay_buffer = deque(maxlen=buffer_size)
      self.batch_size = batch_size
      self.train_step_counter = 0
      self.target_update_interval = 10

      # 搜索树
      self.root: Optional[Node] = None
    def search(self, current_state: int):
      self.root = Node(state=current_state)
      temp_env = gym.make('FrozenLake-v1', is_slippery=False)
      for _ in range(self.num_simulations):
        self._simulate(temp_env, self.root)
       # 选择访问次数最多的动作（最小化成本）
      best_action = max(self.root.children.keys(), key=lambda a: self.root.children[a].visit_count)
      #print('self.root.children.keys()',self.root.children.keys())
      return best_action, self.root
    def _simulate(self, temp_env: gym.Env, root: Node):
      path = []
      node = root
      done = root.done
      # 阶段1: Selection
      while not done and node.children:
          action = self._select_child(node)
          node = node.children[action]

          path.append(node)
          done = node.done


     # 阶段2: Expansion
      while not node.done and len(node.children) < self.env.action_space.n:
          action = np.random.choice([a for a in range(4) if a not in node.children])
          #print('action',[a for a in range(4) if a not in node.children])
          temp_env.reset()
          temp_env.unwrapped.s = node.state  # 直接设置状态
          next_state, reward, terminated, truncated, _ = temp_env.step(action)
          done = terminated or truncated

          '''# FrozenLake奖励调整：到达目标+1，其他情况0
          cost = 1.0  # 每步成本为1（最小化步数）
          if terminated and next_state == 15:  # 到达目标
              cost = 0.0'''

          # 使用model预测初始成本

          state_tensor = torch.tensor([next_state], dtype=torch.float32)
          initial_value = self.model(state_tensor).item()

          child = Node(
              state=next_state,
              action_taken=action,
              parent=node,
              value=initial_value,
              done=done
          )
          node.children[action] = child

      # 阶段3: Evaluation & Backup
      if node.done and node.state == 15:
        final_value = 1.0
      elif node.done and node.state != 15:
        final_cost = -1
      else:
        with torch.no_grad():
          state_tensor = torch.tensor([node.state], dtype=torch.float32)
          final_cost = self.gamma * self.target_model(state_tensor).item()

      # 反向传播累积成本
      self._backpropagate(path, final_cost)

      # 存储经验
      if node.parent is not None:

        self._remember(node.parent.state,  final_cost)

      # 经验回放
      if len(self.replay_buffer) >= self.batch_size:
          self._replay()
    def _select_child(self, node: Node) -> int:
        """基于UCT公式选择子节点（最小化成本）"""
        total_visits = node.visit_count
        best_score = -float('inf')
        best_action = -1

        for action in node.children:
            child = node.children[action]

            exploit = child.value

            explore = self.exploration_weight * math.sqrt(math.log(total_visits + 1e-7) / (child.visit_count + 1e-7))
            score = exploit + explore


            if score > best_score:
                best_score = score
                best_action = action

        return best_action
    def _backpropagate(self, path: list[Node], value: float):
      """反向传播更新路径上的节点，使用加权平均更新节点值"""
      for node in reversed(path):
          node.visit_count += 1

          if node.children:
            total_visits = sum(child.visit_count for child in node.children.values())

            weighted_value = sum(child.value * child.visit_count for child in node.children.values()) / total_visits
            node.value = weighted_value
    def _remember(self, state: int, target_cost: float):
        """存储经验"""
        self.replay_buffer.append((state, target_cost))

    def _replay(self):
        """训练网络"""
        batch = np.random.choice(len(self.replay_buffer), self.batch_size, replace=False)
        states, target_costs = zip(*[self.replay_buffer[i] for i in batch])

        states = torch.tensor(states, dtype=torch.float32).unsqueeze(-1)
        target_costs = torch.tensor(target_costs, dtype=torch.float32).unsqueeze(-1)

        self.optimizer.zero_grad()
        pred_costs = self.model(states)
        loss = self.loss_fn(pred_costs, target_costs)
        loss.backward()
        self.optimizer.step()

        self.train_step_counter += 1
        if self.train_step_counter % self.target_update_interval == 0:
            self.target_model.load_state_dict(self.model.state_dict())

def train_frozenlake():
    env = gym.make('FrozenLake-v1', is_slippery=False)
    agent = UCTFrozenLake(num_simulations=50)

    for episode in range(10):
        state = env.reset()[0]
        total_cost = 0
        done = False

        while not done:
            action, node = agent.search(state)
            print(f"Episode {episode+1}, state: {state}")
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            total_cost += 1  # 每步成本+1
            state = next_state
            print(f"Episode {episode+1}, state: {next_state}")
            #print('children',[children.visit_count for children in node.children.values()])

            if terminated and next_state == 15:
                total_cost = 0  # 到达目标时总成本为0



if __name__ == "__main__":
    train_frozenlake()

Episode 1, state: 0
Episode 1, state: 1
Episode 1, state: 1
Episode 1, state: 5
Episode 2, state: 0
Episode 2, state: 4
Episode 2, state: 4
Episode 2, state: 5
Episode 3, state: 0
Episode 3, state: 0
Episode 3, state: 0
Episode 3, state: 0
Episode 3, state: 0
Episode 3, state: 0
Episode 3, state: 0
Episode 3, state: 0
Episode 3, state: 0
Episode 3, state: 0
Episode 3, state: 0
Episode 3, state: 0
Episode 3, state: 0
Episode 3, state: 0
Episode 3, state: 0
Episode 3, state: 0
Episode 3, state: 0
Episode 3, state: 0
Episode 3, state: 0
Episode 3, state: 0
Episode 3, state: 0
Episode 3, state: 1
Episode 3, state: 1
Episode 3, state: 2
Episode 3, state: 2
Episode 3, state: 3
Episode 3, state: 3
Episode 3, state: 3
Episode 3, state: 3
Episode 3, state: 3
Episode 3, state: 3
Episode 3, state: 3
Episode 3, state: 3
Episode 3, state: 2
Episode 3, state: 2
Episode 3, state: 2
Episode 3, state: 2
Episode 3, state: 3
Episode 3, state: 3
Episode 3, state: 3
Episode 3, state: 3
Episode 3, state: 3


KeyboardInterrupt: 