<a href="https://colab.research.google.com/github/Papa-Panda/Paper_reading/blob/main/Marco_o1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# https://github.com/AIDC-AI/Marco-o1
# learning how to do MCTS (Monte Carlo Tree Search)

In [2]:
import torch
import torch.nn as nn
import random

# Define a mock reward model to score reasoning steps
class RewardModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 1)  # Input: Reasoning embeddings, Output: Reward score

    def forward(self, x):
        return self.linear(x)

# Initialize the reward model
reward_model = RewardModel()

# Mock function to simulate reasoning step generation
def generate_reasoning_step(state):
    """
    Simulates a reasoning step by slightly modifying the current state.
    Each step is represented as a tensor embedding.
    """
    return state + torch.randn_like(state) * 0.1

# Monte Carlo Tree Node
class TreeNode:
    def __init__(self, state, parent=None):
        self.state = state
        self.parent = parent
        self.children = []
        self.visits = 0
        self.value = 0

    def add_child(self, child):
        self.children.append(child)

# MCTS Algorithm
class MCTS:
    def __init__(self, reward_model, exploration_weight=1.0):
        self.reward_model = reward_model
        self.exploration_weight = exploration_weight

    def select_node(self, node):
        """Select the child node with the highest UCB score."""
        best_score = float('-inf')
        best_child = None
        for child in node.children:
            # Calculate Upper Confidence Bound (UCB)
            exploit = child.value / (child.visits + 1e-5)
            explore = self.exploration_weight * torch.sqrt(torch.log(torch.tensor(node.visits + 1.0)) / (torch.tensor(child.visits + 1e-5)))
            score = exploit + explore
            if score > best_score:
                best_score = score
                best_child = child
        return best_child

    def expand_node(self, node):
        """Expand the current node by generating new reasoning steps."""
        for _ in range(3):  # Generate 3 possible reasoning steps
            new_state = generate_reasoning_step(node.state)
            child_node = TreeNode(new_state, parent=node)
            node.add_child(child_node)

    def simulate(self, node):
        """Simulate a rollout from the current node."""
        current_state = node.state.clone()
        for _ in range(5):  # Perform 5 random reasoning steps
            current_state = generate_reasoning_step(current_state)
        # Evaluate the trajectory with the reward model
        reward = self.reward_model(current_state).item()
        return reward

    def backpropagate(self, node, reward):
        """Backpropagate the reward up the tree."""
        while node is not None:
            node.visits += 1
            node.value += reward
            node = node.parent

    def search(self, root, num_simulations=50):
        """Perform MCTS starting from the root node."""
        for _ in range(num_simulations):
            node = root
            # Selection
            while node.children:
                node = self.select_node(node)
            # Expansion
            self.expand_node(node)
            # Simulation
            reward = self.simulate(node)
            # Backpropagation
            self.backpropagate(node, reward)
        # Return the best child
        return max(root.children, key=lambda child: child.visits)

# Initial state (randomly initialized embedding)
initial_state = torch.randn(10)
root_node = TreeNode(initial_state)

# Perform MCTS
mcts = MCTS(reward_model)
best_node = mcts.search(root_node)

print("Best reasoning trajectory:", best_node.state)


Best reasoning trajectory: tensor([-1.3798,  1.3225,  0.6134,  0.0936, -0.7026, -0.9549, -0.4759,  0.3135,
         0.4071,  0.4607])
