<a href="https://colab.research.google.com/github/ImaginationX4/HybridZero-/blob/main/AlphaZero_CartPole.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy

##1.Network

In [None]:
class Network(nn.Module):
  def __init__(self, input_size=4, hidden_size=64, output_size=2):
      super(Network, self).__init__()
      self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
      self.fc1 = nn.Linear(input_size, hidden_size)
      self.fc2 = nn.Linear(hidden_size, hidden_size)
      #self.policy_head = nn.Linear(hidden_size, output_size)
      #torch.nn.init.uniform_(self.policy_head.weight, -0.1, 0.1)
      self.value_head = nn.Linear(hidden_size, 1)
      self.net = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(256, output_size)
        )
      self.v_net = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(256, 1)
            #nn.Tanh()
        )

  def forward(self, x):
    if not isinstance(x, torch.Tensor):
        x = torch.FloatTensor(x)

    x = x.to(self.device)
    policy = self.net(x)
    value = self.v_net(x)
    #x = F.leaky_relu(self.fc1(x))
    #x = F.leaky_relu(self.fc2(x))
    #policy_logits = 1.0 * self.policy_head(x)
    #policy = torch.softmax(policy_logits, dim=-1)
    #policy = policy_logits


    #value =  torch.tanh(self.value_head(x))
    #value = (value + 1) * 20
    return policy, value

  def save(self, filepath):
    torch.save(self.state_dict(), filepath)
    print(f"Model saved to {filepath}")

  def load(self, filepath):
    self.load_state_dict(torch.load(filepath, map_location=self.device))
    print(f"Model loaded from {filepath}")

In [None]:
class Net_w(nn.Module):
    def __init__(self, input_size=4, hidden_size=64, output_size=2):
        super(Net_w, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(),
            nn.LayerNorm(256),  # 加入BN层帮助训练
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(256, output_size),
        )

    def forward(self, x):
        if not isinstance(x, torch.Tensor):
            x = torch.FloatTensor(x)
        logits = self.net(x)
        # 用temperature参数调节softmax的平滑程度
        temperature = 1
        return logits / temperature

In [None]:
!pip install gymnasium

Collecting gymnasium
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Downloading gymnasium-1.0.0-py3-none-any.whl (958 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m958.1/958.1 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Installing collected packages: farama-notifications, gymnasium
Successfully installed farama-notifications-0.0.4 gymnasium-1.0.0


### how to set env


In [None]:
import gymnasium as gym
import numpy as np

# 创建环境
env = gym.make('CartPole-v1')
a,_=env.reset()
print('observation',a)
# 1. 保存当前状态
current_state = env.unwrapped.state.copy()  # [x, x_dot, theta, theta_dot]
print('current_state',current_state)
# 2. 设置到特定状态
# 比如，将小车放在中间，静止，杆子稍微倾斜
desired_state = np.array([0.0,  # 位置 cart position (x)
                         0.0,  # 速度 cart velocity (x_dot)
                         0.1,  # 杆子角度 pole angle (theta) - 略微倾斜
                         0.0]) # 角速度 pole angular velocity (theta_dot)

env.unwrapped.state = desired_state

# 3. 从这个状态开始执行动作
observation, reward, terminated, truncated, info = env.step(1)  # 1: 向右推

# 4. 如果需要回到之前保存的状态
env.unwrapped.state = current_state.copy()

observation [ 0.00363835 -0.01579056  0.0161241  -0.00074579]
current_state [ 0.00363835 -0.01579056  0.0161241  -0.00074579]


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
a= Network().to(device)
env = gym.make('CartPole-v1')
observations,_ = env.reset()
print('observations',observations)
observations = torch.FloatTensor(observations)

policy,value = a(observations)
print(policy.cpu().detach().numpy())
print('value',value)

observations [-0.00716547  0.04397519 -0.00126024  0.00752538]
[ 0.37035656 -0.3788243 ]
value tensor([0.7007], grad_fn=<ViewBackward0>)


## 2.MCTS

In [None]:
import math
import numpy as np
import torch
from dataclasses import dataclass, field
from typing import List, Tuple, Dict, Optional



@dataclass
class Node:

    prior: float  #  P(s,a)
    action_taken: Optional[int]
    visit_count: int = 0  # N(s,a)
    value_sum: float = 0  # Q=value_sum/visit_count
    parent: Optional['Node'] = None
    children: Dict[int, 'Node'] = field(default_factory=dict)
    done = False
    has_children = False

    def __post_init__(self):
        if self.children is None:
            self.children = {}
    @property
    def Q_value(self) -> float:

        if self.visit_count == 0:
            return 0.0
        return self.value_sum / self.visit_count


class MCTS:
    def __init__(self, model, num_simulations: int = 100, epoch_of_training=0):
        self.model = model
        self.num_simulations = num_simulations
        self.c_puct =max(10 * (1 - epoch_of_training/10), 1) # UCB exploration constant
        self.epsilon = max(0.25 * (1 - epoch_of_training/10), 0)

    def search(self, root_state) -> np.ndarray:
      root = Node(prior=1.0, action_taken=None)

      for _ in range(self.num_simulations):
        node = root
        env = gym.make('CartPole-v1')
        env.reset()
        state = root_state
        env.unwrapped.state = state
        search_path = [node]
        rewards = []
        # 1. Selection
        while node.has_children and not node.done:


          action, node = self.select_child(node)
          observation, reward, terminated, truncated, info = env.step(action)
          rewards.append(reward)
          node.done = terminated or truncated
          state = observation
          search_path.append(node)

        # 2. Expansion and Evaluation
        value = 0
        if not node.done:

          policy, value = self.evaluate_state(state, node.done)


          self.expand_node(node, policy)

        # 3. Backup
        self.backup(search_path, value, rewards)

      visits = np.array([child.visit_count for child in root.children.values()])
      total_visits = np.sum(visits)
      action_probs = visits / total_visits


      '''print("\nMCTS搜索结果：")
      for action, prob in enumerate(action_probs):
          print(f"动作 {action}: 访问次数 = {visits[action]}, 概率 = {prob:.3f}")'''



      return root

    def select_child(self, node: Node) -> Tuple[int, Node]:


      best_score = -float('inf')
      best_action = -1
      best_child = None


      sqrt_total_count = math.sqrt(node.visit_count)

      def calculate_ucb(node: Node, child: Node, c_puct: float = self.c_puct) -> float:
        pb_c = np.log((node.visit_count + c_puct + 1) / c_puct) + c_puct
        pb_c *= np.sqrt(node.visit_count) / (child.visit_count + 1)

        prior_score = pb_c * child.prior
        value_score = child.Q_value

        return value_score + prior_score #+ depth_factor
      for action, child in node.children.items():
        # UCB score = Q + U
        # Q
        # U = c_puct * P * sqrt(N(s)) / (1 + N(s,a))
        Q = child.Q_value
        U = self.c_puct * child.prior * sqrt_total_count / (1 + child.visit_count)
        #print('Q',Q)
        #print('U',U)
        ucb_score = calculate_ucb(node, child) #Q + U
        #exploration_bonus = 20.0 * math.sqrt(2.0 * math.log(node.visit_count + 1) / (child.visit_count + 1))
        #ucb_score += exploration_bonus
        if ucb_score > best_score:
            best_score = ucb_score
            best_action = action
            best_child = child

      return best_action, best_child


    def evaluate_state(self, state: np.ndarray, done) -> Tuple[np.ndarray, float]:
        if done:

          return np.zeros(2), -1.0
        with torch.no_grad():


          state_t = torch.FloatTensor(state)
          policy, value = self.model(state_t)

          policy = F.softmax(policy, dim=0)

          return policy.cpu().detach().numpy(), value.item()

    def expand_node(self, node: Node, policy: np.ndarray):

      noise = np.random.dirichlet([0.3] * len(policy))
      #policy = 0.75 * policy + 0.25 * noise


      policy = (1 - self.epsilon) * policy + self.epsilon * noise

      for action, prob in enumerate(policy):

        child = Node(
            prior=prob,
            action_taken=action,
            parent=node
        )

        node.children[action] = child

        node.has_children = True

    def backup(self, search_path: List[Node], value: float, rewards: List[float]):
      # 从叶子节点开始，向上传播
      G = value  # 最后一个状态的值估计
      #print('search_path len',len(search_path))

      if len(rewards)<500 and len(rewards)>0:
        rewards[-1]=-1

      #print('rewards',rewards)

      #print("\nBackup过程:")
      for node in reversed(search_path):
        # 如果有对应的奖励，用实际奖励更新G
        if rewards:  # 还有未使用的奖励
          r = rewards.pop()  # 获取当前步骤的奖励
          survival_bonus = 0.01
          r+= survival_bonus
          G = r + 0.99 * G  # G = 即时奖励 + 折扣 * 未来价值
          #print('G',G)

        # 更新节点统计
        old_value = node.value_sum / max(1, node.visit_count)
        node.visit_count += 1

        decay = 0.95
        node.value_sum = decay * node.value_sum + G
        new_value = node.value_sum / node.visit_count

        '''print(f"节点更新:")
        print(f"  旧值: {old_value:.3f}")
        print(f"  新值: {new_value:.3f}")
        print(f"  访问次数: {node.visit_count}")
        print(f"  使用的G值: {G:.3f}")'''

    def backup1(self, search_path: List[Node], value: float):
      for idx, node in enumerate(reversed(search_path)):
        if idx == 0:
          discount = 1.0
        else:
          discount = 0.99 ** idx
        node.visit_count += 1
        node.value_sum += value * discount


    def get_action_probs(self, root: Node, temperature: float = 0.0) -> np.ndarray:

        counts = np.array([child.visit_count for child in root.children.values()])
        if temperature == 0:
            probs = np.zeros_like(counts)
            probs[np.argmax(counts)] = 1
            return probs
        else:
            counts = counts ** (1.0 / temperature)
            probs = counts / np.sum(counts)
            return probs

In [None]:
model = Network().to(device)

mcts = MCTS(model,100,100)
observations,_ = env.reset()
#print(observations)
root=mcts.search(observations)
print('root value',root.Q_value)
print('root visit', root.visit_count)
print('root children left visit',root.children[0])
print('root children right',root.children[1])

root value 1.2455006518617993
root visit 100
root children left visit Node(prior=0.567622721195221, action_taken=0, visit_count=84, value_sum=187.68598482206252, parent=Node(prior=1.0, action_taken=None, visit_count=100, value_sum=124.55006518617994, parent=None, children={0: ..., 1: Node(prior=0.4323772192001343, action_taken=1, visit_count=15, value_sum=19.36107507304724, parent=..., children={0: Node(prior=0.5361707210540771, action_taken=0, visit_count=13, value_sum=11.237434869016916, parent=..., children={0: Node(prior=0.5694295763969421, action_taken=0, visit_count=11, value_sum=4.198807600505871, parent=..., children={0: Node(prior=0.47154173254966736, action_taken=0, visit_count=3, value_sum=-3.2907117472940683, parent=..., children={0: Node(prior=0.3857184052467346, action_taken=0, visit_count=1, value_sum=-1.7807540452480315, parent=..., children={0: Node(prior=0.36542534828186035, action_taken=0, visit_count=0, value_sum=0, parent=..., children={}), 1: Node(prior=0.63457459

##3.Replay Buffer

In [None]:
import numpy as np
from dataclasses import dataclass
from typing import List, Optional, Dict
from collections import deque

class GameHistory:
    def __init__(self):
      self.observations = []  # List[np.ndarray]
      self.actions = []       # List[int]
      self.rewards = []       # List[float]
      self.mcts_policies = []    # List[float]
      self.values = []


    def store(self, observation, action, reward, action_probs,value):
      self.observations.append(observation)
      self.actions.append(action)
      self.rewards.append(reward)
      self.mcts_policies.append(action_probs)
      self.values.append(value)


    def clear(self):
      self.observations.clear()
      self.actions.clear()
      self.rewards.clear()
      self.mcts_policies.clear()




In [None]:
class ReplayBuffer:
    def __init__(self, batch_size, minimum_size, capacity=500):
        self.batch_size = batch_size
        self.minimum_size = minimum_size
        self.capacity = capacity
        self.game_history = GameHistory()

    def store(self, observation, action, reward, action_probs, value):

        if len(self.game_history.observations) >= self.capacity:
            # del the oldest data
            self.game_history.observations.pop(0)
            self.game_history.actions.pop(0)
            self.game_history.rewards.pop(0)
            self.game_history.mcts_policies.pop(0)
            self.game_history.values.pop(0)

        self.game_history.store(observation, action, reward, action_probs,value)



    def sample_batch(self, batch_size=32):
      device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
      if batch_size is None:
          batch_size = self.batch_size



      indices = np.random.choice(len(self.game_history.observations), batch_size)


      observations = torch.tensor([self.game_history.observations[i] for i in indices], dtype=torch.float32)
      actions = torch.tensor([self.game_history.actions[i] for i in indices], dtype=torch.long)
      rewards = torch.tensor([self.game_history.rewards[i] for i in indices], dtype=torch.float32)
      policies = torch.tensor([self.game_history.mcts_policies[i] for i in indices], dtype=torch.float32)
      values = torch.tensor([self.game_history.values[i] for i in indices], dtype=torch.float32)


      return observations.to(device), actions.to(device), rewards.to(device), policies.to(device), values.to(device)

    def __len__(self):
        return len(self.game_history.observations)

    def clear(self):
        self.game_history.clear()

##4.Self play

In [None]:
def self_play(replay_buffer, model, epoch_of_training=0):
    env = gym.make('CartPole-v1')
    model.eval()


    for episode in range(1):

      observation,_ = env.reset()
      done = False
      episode_reward = 0


      trajectory = []
      while not done:
        mcts = MCTS(model,epoch_of_training=epoch_of_training)
        root = mcts.search(observation)
        #temperature = min(1.0, 1.0 - episode/num_episodes)
        action_probs = mcts.get_action_probs(root, temperature=1)

        action = np.argmax(action_probs)

        next_observation, reward, terminated, truncated, info = env.step(action)
        episode_reward += reward
        done = terminated or truncated

        trajectory.append({
            'state': observation,
            'action': action,
            'reward': reward,
            'action_probs': action_probs
        })
        observation = next_observation

      returns = []
      G = 0
      gamma = 0.99
      for t in reversed(trajectory):
        G = t['reward'] + gamma * G
        returns.insert(0, G)

      for t, G in zip(trajectory, returns):
        replay_buffer.store(
            t['state'],
            t['action'],
            t['reward'],
            t['action_probs'],
            G

        )


      print('Self-play')
      #print('observation',observation)
      #print('action',action)
      #print('action_probs',action_probs)
      #print('returns',returns)


      print('Episode {}: episode_reward = {}'.format(episode, episode_reward))

      #print("-----------------------------")
      #print("-----------------------------")
      return episode_reward

In [None]:
model = Network().to(device)
#model.load('best_model.pth')
replay_buffer_test=ReplayBuffer(batch_size=32, minimum_size=100)
self_play(replay_buffer_test, model)

Self-play
Episode 0: episode_reward = 15.0


15.0

In [None]:
observations, actions, rewards, policies, values = replay_buffer_test.sample_batch(10)
policy_logits, pred_values = model(observations)
print('values',values)
print('pred_values',pred_values.squeeze())
print('policies',policies)
print('policy_logits',policy_logits)

values tensor([ 8.6483,  6.7935, 10.4662,  7.7255,  4.9010, 11.3615,  5.8520,  8.6483,
         9.5618,  8.6483])
pred_values tensor([-0.0732, -0.0706, -0.0375, -0.0339, -0.4599, -0.0564, -0.3290, -0.0732,
        -0.0515, -0.0732], grad_fn=<SqueezeBackward0>)
policies tensor([[0.4949, 0.5051],
        [0.5859, 0.4141],
        [0.5051, 0.4949],
        [0.6566, 0.3434],
        [0.5657, 0.4343],
        [0.4949, 0.5051],
        [0.7475, 0.2525],
        [0.4949, 0.5051],
        [0.6364, 0.3636],
        [0.4949, 0.5051]])
policy_logits tensor([[ 0.1690, -0.7911],
        [ 0.1853, -0.7749],
        [ 0.6352,  0.2304],
        [ 0.0768, -0.2188],
        [ 0.3572, -1.3648],
        [ 0.0677, -0.2579],
        [ 0.2771, -1.2215],
        [ 0.1690, -0.7911],
        [ 0.0685, -0.2479],
        [ 0.1690, -0.7911]], grad_fn=<AddmmBackward0>)


  observations = torch.tensor([self.game_history.observations[i] for i in indices], dtype=torch.float32)


##5.Train


In [None]:
def train(model, optimizer, replay_buffer, num_epochs, save_path='best_model.pth'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    #model = model.to(device)
    best_reward = 0

    for epoch in range(num_epochs):
      print(f"Epoch {epoch}/{num_epochs}")
      # 1. data collection
      episode_reward = self_play(replay_buffer, model, epoch)
      if episode_reward > 450:
        epoch_training  = 100
      else:
        epoch_training  = 50

      #scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

      for i in range(epoch_training):
        # 2. data from replay buffer and transform into torch
        observations, actions, rewards, policies, values = replay_buffer.sample_batch()

        # 3. clean grad
        optimizer.zero_grad()

        # 3.1 Loss calculation
        model.train()

        policy_logits, pred_values = model(observations)

        #policy_loss = -torch.sum(policies * torch.log_softmax(policy_logits, dim=1), dim=1).mean()
        value_loss = F.mse_loss(pred_values.squeeze(), values)
        probs = F.softmax(policy_logits, dim=1)

        policy_loss = -torch.sum(policies * torch.log(probs))
        #temperature = 1
        #scaled_logits = policy_logits / temperature
        #policy_loss = -torch.mean(torch.sum(policies * F.log_softmax(scaled_logits, dim=1), dim=1))
        alpha = value_loss.mean().item()
        beta = policy_loss.mean().item()

        # 3.2 sum the loss
        total_loss = value_loss + policy_loss
        policy_entropy =  -(F.softmax(policy_logits, dim=1) * F.log_softmax(policy_logits, dim=1)).sum(1).mean()
        total_loss += 0.01 * policy_entropy
        best_loss  = total_loss.item()

        # 4. backward and optimize
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
      #scheduler.step()
      if epoch % 3 == 0:
        eval_reward = evaluate(model)
        print(f"Epoch {epoch}: Evaluation reward = {eval_reward:.2f}")
        # 5. Save best model
        if eval_reward > best_reward:
          best_reward = eval_reward
          model.save(save_path)
          print(f"Epoch {epoch}: New best model saved with loss: {best_loss:.4f}")


      print(f"Epoch {epoch}: "
            f"Total Loss = {total_loss:.4f}, "
            f"Policy Loss = {policy_loss.mean().item():.4f}, "
            f"Value Loss = {value_loss.mean().item():.4f} "
            )


    print(f"Training completed. Best model loaded with loss: {best_loss:.4f}")
    return model

###evaluate

In [None]:
def evaluate(model, num_episodes=2):
    """
    评估当前模型在CartPole环境的表现

    Args:
        model: 训练的网络模型
        num_episodes: 评估的回合数
    Returns:
        mean_reward: 平均回合奖励
    """
    model.eval()  # 设置为评估模式
    rewards = []
    env = gym.make('CartPole-v1')

    for episode in range(num_episodes):
        obs, _ = env.reset()
        episode_reward = 0
        done = False

        while not done:
            with torch.no_grad():  # 不需要梯度
                mcts = MCTS(model)
                root = mcts.search(obs)  # 只使用最优动作
                action = np.argmax(mcts.get_action_probs(root))

                # 执行动作
                obs, reward, terminated, truncated, _ = env.step(action)
                episode_reward += reward
                done = terminated or truncated

        rewards.append(episode_reward)
        print('++++++Evaluation++++++')
        print(f"Evaluation episode {episode}: Reward = {episode_reward}")

    mean_reward = np.mean(rewards)
    std_reward = np.std(rewards)

    return mean_reward

##6.Let's start Training


In [None]:
import torch.optim as optim
import torch.nn.functional as F
num_epochs = 10
lr= 0.001
model = Network().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr)



replay_buffer = ReplayBuffer(batch_size=32, minimum_size=100)


# training
model = train(model, optimizer, replay_buffer, num_epochs)

Epoch 0/10
Self-play
Episode 0: episode_reward = 14.0
++++++Evaluation++++++
Evaluation episode 0: Reward = 137.0
++++++Evaluation++++++
Evaluation episode 1: Reward = 113.0
Epoch 0: Evaluation reward = 125.00
Model saved to best_model.pth
Epoch 0: New best model saved with loss: 24.2939
Epoch 0: Total Loss = 24.2939, Policy Loss = 21.7788, Value Loss = 2.5083 
Epoch 1/10
Self-play
Episode 0: episode_reward = 165.0
Epoch 1: Total Loss = 1003.1149, Policy Loss = 22.0261, Value Loss = 981.0819 
Epoch 2/10
Self-play
Episode 0: episode_reward = 115.0
Epoch 2: Total Loss = 206.3548, Policy Loss = 22.1152, Value Loss = 184.2327 
Epoch 3/10
Self-play
Episode 0: episode_reward = 112.0
++++++Evaluation++++++
Evaluation episode 0: Reward = 116.0
++++++Evaluation++++++
Evaluation episode 1: Reward = 146.0
Epoch 3: Evaluation reward = 131.00
Model saved to best_model.pth
Epoch 3: New best model saved with loss: 26.6919
Epoch 3: Total Loss = 26.6919, Policy Loss = 22.0592, Value Loss = 4.6258 
Epoc

##7.Test thoughts

In [None]:
observations, actions, rewards, policies, values = replay_buffer.sample_batch(10)
policy_logits, pred_values = model(observations)
policy_logits = F.softmax(policy_logits, dim=1)

print('values',values)
print('pred_values',pred_values.squeeze())
print('policies',policies)
print('policy_logits',policy_logits)

values tensor([47.4403, 42.4645, 97.3435, 92.6692, 31.0551, 93.0285, 97.8705, 29.6552,
        33.7718, 73.9915])
pred_values tensor([49.4657, 61.3272, 71.8766, 90.1167, 30.4016, 91.6647, 86.4833, 28.7126,
        33.6460, 73.7598], grad_fn=<SqueezeBackward0>)
policies tensor([[0.6263, 0.3737],
        [0.3636, 0.6364],
        [0.7576, 0.2424],
        [0.3535, 0.6465],
        [0.4949, 0.5051],
        [0.3838, 0.6162],
        [0.7475, 0.2525],
        [0.4848, 0.5152],
        [0.5657, 0.4343],
        [0.1919, 0.8081]])
policy_logits tensor([[0.5609, 0.4391],
        [0.5036, 0.4964],
        [0.4927, 0.5073],
        [0.4337, 0.5663],
        [0.4975, 0.5025],
        [0.4831, 0.5169],
        [0.4655, 0.5345],
        [0.4579, 0.5421],
        [0.5901, 0.4099],
        [0.3668, 0.6332]], grad_fn=<SoftmaxBackward0>)


In [None]:
self_play(replay_buffer, model)

In [None]:

class Net_w(nn.Module):
    def __init__(self, input_size=4, hidden_size=64, output_size=2):
        super(Net_w, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),  # 加入BN层帮助训练
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Linear(256, output_size),
        )

    def forward(self, x):
        if not isinstance(x, torch.Tensor):
            x = torch.FloatTensor(x)
        logits = self.net(x)
        # 用temperature参数调节softmax的平滑程度
        temperature = 1
        return logits / temperature

In [None]:
#self_play(replay_buffer, model)
observations, actions, rewards, policies, values = replay_buffer.sample_batch(10)


net_policy = Net_w()

print(net_policy(observations))
optimizer_policy = torch.optim.Adam(net_policy.parameters(), lr = 0.0001)

observations, actions, rewards, policies, values = replay_buffer.sample_batch()
for i in range(100):

  net_policy.train()
  optimizer_policy.zero_grad()
  logits = net_policy(observations)
  probs = F.softmax(logits, dim=1)
  #


  policy_net_loss = -torch.sum(policies * torch.log(probs))
  print('policy_net_loss',policy_net_loss.item())
  policy_net_loss.backward()

  torch.nn.utils.clip_grad_norm_(net_policy.parameters(), max_norm=1.0)
  optimizer_policy.step()


#policy_logits, pred_values = model(observations)
policy_logits = F.softmax(net_policy(observations), dim=1)

#print('values',values)
#print('pred_values',pred_values.squeeze())
print('policies',policies)
print('policy_logits',policy_logits)
