# PPO for Tree Reorganization ? 

## Run the es service and make sure you can do the searching

In [6]:
import requests
import numpy as np

def query_knowledge_base(question, k = 2):
    web = "http://localhost:1439"
    data = {
        "query": question,
        "k": k
    }
    for i in range(3):
        try:
            r = requests.get(web, json=data)
            if r.status_code != 200:
                raise Exception(r.text)
            contexts = r.json()
            # return only the 'text' from each of the contexts
            contexts = [context['text'] for context in contexts]
            return contexts
        except Exception as e:
            print(e)

query_knowledge_base("What is the capital of Germany")

HTTPConnectionPool(host='localhost', port=1439): Max retries exceeded with url: / (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x1052d3050>: Failed to establish a new connection: [Errno 61] Connection refused'))
HTTPConnectionPool(host='localhost', port=1439): Max retries exceeded with url: / (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x1052d3440>: Failed to establish a new connection: [Errno 61] Connection refused'))
HTTPConnectionPool(host='localhost', port=1439): Max retries exceeded with url: / (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x104f12780>: Failed to establish a new connection: [Errno 61] Connection refused'))


## Run the together service and make sure you can prompt the model

In [7]:
import os
import json

class ProviderReq:
    def __init__(self, url, cache_path="./cache.jsonl", model=None):
        self.url = url
        self.cache = {}
        self.cache_path = cache_path
        self.model = model  # Default model is set to None, but can be overridden in child classes
        if os.path.exists(self.cache_path):
            with open(self.cache_path, "r") as f:
                for line in f:
                    datum = json.loads(line.strip())
                    self.cache[tuple(datum["input"])] = datum["response"]

    def req2provider(self, prompt, temperature=0, max_tokens=None, stop=None, logprobs=1, use_cache=True):
        assert isinstance(prompt, str)
        input = (prompt, self.model, max_tokens, stop, logprobs)
        if use_cache and temperature == 0 and input in self.cache:
            return self.cache[input], True

        # Retry logic
        for i in range(3):
            try:
                response = self.make_request(prompt, self.model, temperature, max_tokens, stop, logprobs)
                if response.status_code != 200:
                    raise Exception(response.text)
                break
            except Exception as e:
                err_msg = str(e)
                print(e)
                if "reduce your prompt" in err_msg:  # this is because the input string is too long
                    return ['too long'], False

        try:
            response_json = response.json()
            response = self.parse_response(response_json)
        except:
            return ['error'], False

        # Cache the result if temperature is 0
        if temperature == 0:
            self.cache_result(input, response)

        return response, True

    def make_request(self, prompt, model, temperature, max_tokens, stop, logprobs):
        """To be implemented in the subclass"""
        raise NotImplementedError("Subclasses should implement this method")

    def parse_response(self, response_json):
        """To be implemented in the subclass"""
        raise NotImplementedError("Subclasses should implement this method")

    def cache_result(self, input, response):
        """Cache the result if it's not already cached"""
        if input not in self.cache:
            self.cache[input] = response
            with open(self.cache_path, "a") as f:
                f.write("%s\n" % json.dumps({"input": input, "response": response}))

class TogetherReq(ProviderReq):
    # meta-llama/Llama-Vision-Free is free but doesnt support log_probs which are necessary for probtree
    # meta-llama/Meta-Llama-3-8B-Instruct-Turbo : bad answers but provide log_prob
    # meta-llama/Llama-3.2-3B-Instruct-Turbo ? bad answers tpp
    # meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo ? very nice but no log_prob !! 
    # meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo ? niceeee and log probs !! :) 
    # meta-llama/Meta-Llama-3-70B-Instruct-Lite ?
    def __init__(self, cache_path="./cache.jsonl", model="meta-llama/Meta-Llama-3-70B-Instruct-Lite"):
        super().__init__(url="http://127.0.0.1:10001/api/together/completion", cache_path=cache_path, model=model)

    def make_request(self, prompt, model, temperature, max_tokens, stop, logprobs):
        return requests.post(self.url, json={
            "model": model,
            "messages": [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt}
            ],
            "temperature": temperature,
            "max_tokens": max_tokens,
            "stop": stop,
            "logprobs": logprobs,
        })

    def parse_response(self, response_json):
        return response_json['choices']

In [8]:
from dotenv import load_dotenv
load_dotenv()

key_pool = os.getenv('TOGETHER_API_KEY').split(',')
NUM_WORKERS = len(key_pool)  # Match the number of workers to the number of clients

reqor = TogetherReq()
result, tag = reqor.req2provider("Hello, I am from Egypt. Guess my name ?", max_tokens = None, stop = None)
print(result[0])

{'finish_reason': 'eos', 'index': 0, 'logprobs': {'token_ids': [3923, 264, 27387, 8815, 2268, 81122, 1131, 13126, 499, 2351, 505, 15212, 11, 358, 3358, 1935, 264, 8545, 8101, 2195, 11787, 499, 8530, 7086, 330, 83705, 3690, 1, 477, 330, 6219, 81, 44969, 4314, 527, 1633, 4279, 5144, 304, 15212, 11, 719, 358, 1436, 387, 12756, 1022, 2268, 5618, 3041, 757, 264, 13310, 477, 3371, 757, 422, 358, 2846, 12660, 3345, 0, 128009], 'token_logprobs': [-0.26171875, -0.0024719238, -0.19726562, -0.63671875, -0.0013656616, -1.109375, -0.0028076172, -0.51171875, -0.23144531, -0.006713867, 0, -2.3841858e-07, -4.720688e-05, -0.0008621216, -0.47460938, -6.4373016e-06, -0.00037956238, -0.16992188, -0.008361816, -0.5859375, -0.5859375, -0.00020313263, -0.5546875, -0.15820312, -0.83203125, -0.72265625, -0.20117188, -0.13183594, -0.0007019043, -0.11669922, -0.84375, -0.10205078, -0.11621094, -0.8125, -0.3125, -0.859375, -0.029785156, -0.11669922, 0, -1.3113022e-05, -0.020996094, -0.28320312, -0.007080078, -0.5

## Implementing the environment that will be used in Reinforcement Learning

In [4]:
question = "The winner of the the London Riders Championship in 1953 scored how many points in the 1952 Individual Speedway World Championship?"
kb_result = query_knowledge_base(question)
x = {"node": question, "kb_facts": kb_result, "is_answer": False, "children": [], "prob": 1.0}

prompt = f"""
        Given the current reasoning step: {json.dumps(x['node'])}, what is the next logical step? 
        If you are confident that the next step directly answers the question, provide the final answer **only** in the following format:
        "Answer: <answer>"
        **Do not include any additional text, explanations, or reasoning. Only return the final answer in the specified format.**
        Otherwise, provide **up to 3 independent next steps** that could be explored to answer the question. Each step should be provided **only** in the following format:
        "Next step: <next reasoning step 1>"
        "Next step: <next reasoning step 2>"
        **Do not include any additional text or explanations. Only return the response in the specified format.**
        """
print("prompt for expansion: ", prompt)
result, tag = reqor.req2provider(prompt, max_tokens = None, stop = None)

prompt for expansion:  
        Given the current reasoning step: "The winner of the the London Riders Championship in 1953 scored how many points in the 1952 Individual Speedway World Championship?", what is the next logical step? 
        If you are confident that the next step directly answers the question, provide the final answer **only** in the following format:
        "Answer: <answer>"
        **Do not include any additional text, explanations, or reasoning. Only return the final answer in the specified format.**
        Otherwise, provide **up to 3 independent next steps** that could be explored to answer the question. Each step should be provided **only** in the following format:
        "Next step: <next reasoning step 1>"
        "Next step: <next reasoning step 2>"
        **Do not include any additional text or explanations. Only return the response in the specified format.**
        


In [11]:
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from gymnasium.utils.env_checker import check_env

class ReasoningTreeEnv(gym.Env):
    def __init__(self, question):
        """
        Initialize the environment with a question and a knowledge base.
        """
        super(ReasoningTreeEnv, self).__init__()

        # Define the observation and action spaces
        self.observation_space = spaces.Box(
            low=-np.inf, high=np.inf, shape=(250,), dtype=np.float32
        )
        self.action_space = spaces.Discrete(4)  # 4 actions: expand, prune, reorganize, terminate

        self.question = question
        self.tree = self._initialize_tree()  # Start with an initial node
        self.current_state = self.tree
        self.llm = TogetherReq()
        self.current_iteration = 1

    def _get_observation(self):
        """
        Get the current observation (state) of the environment as a fixed-size vector.
        """
        # Step 1: Flatten the current node and its children into a fixed-size vector
        observation = np.zeros(250, dtype=np.float32)  # Increase size to 250

        # Step 2: Add the current node's embedding
        current_node_embedding = self._embed_text(self.current_state["step"])
        observation[:50] = current_node_embedding  # First 50 dimensions for the current node

        # Step 3: Add the children nodes' embeddings and probabilities
        for i, child in enumerate(self.current_state["children"][:4]):  # Limit to 4 children (50 * 4 = 200)
            child_embedding = self._embed_text(child["step"])
            observation[50 + i * 50 : 50 + (i + 1) * 50] = child_embedding  # 50 dimensions per child
            observation[250 - 5 + i] = child["prob"]  # Store probabilities in the last 5 dimensions

        # Step 4: Add metadata (e.g., depth, number of children)
        observation[245] = self._compute_depth(self.current_state)  # Depth of the current node
        observation[246] = len(self.current_state["children"])  # Number of children

        return observation

    def _embed_text(self, text):
        """
        Convert text into a fixed-size embedding (e.g., using a pre-trained model).
        """
        # Example: Use a simple embedding (replace with a real embedding model)
        embedding = np.zeros(50, dtype=np.float32)
        embedding[:min(len(text), 50)] = [ord(c) for c in text[:50]]  # Truncate to 50 characters
        return embedding

    def _initialize_tree(self):
        """
        Initialize the reasoning tree with the question as the root node.
        """
        # Query the knowledge base for relevant facts
        kb_result = query_knowledge_base(self.question)
        return {"step": self.question, "kb_facts": kb_result, "is_answer": False, "children": [], "prob": 1.0}

    def step(self, action):
        """
        Execute an action and return the new observation, reward, done flag, and info.
        """
        if self._is_done():
            # If the episode is already done, return the current state and reward
            print(f"Episode is already done. No further actions can be taken.")
            return self._get_observation(), self._compute_reward(), True, False, {}
    
        if action == 0:  # expand
            print("step chosen: expanding")
            self._expand_tree()
        elif action == 1:  # prune
            print("step chosen: pruning")
            self._prune_tree()
        elif action == 2:  # reorganize
            print("step chosen: reorganizing")
            self._reorganize_tree()
        elif action == 3:  # terminate
            print(f"step chosen: terminating - reward: {self._compute_reward()}, done: True, truncated: False")
            return self._get_observation(), self._compute_reward(), True, False, {}

        reward = self._compute_reward()
        done = self._is_done()
        truncated = False  # Gymnasium requires a "truncated" flag (e.g., for time limits)
        print(f"reward: {reward}, done: {done}, truncated: {truncated}, iteration: {self.current_iteration}")
        print(f"tree now before next iteration ends: \n{json.dumps(self.tree, indent=2)}")

        self.current_iteration += 1
        return self._get_observation(), reward, done, truncated, {}
    
    def _expand_tree(self):
        """
        Expand the tree by using the LLM and knowledge base to generate a new reasoning step.
        """

        # Generate a new reasoning step using the LLM and KB result
        # Note: 
            # For now we are kind of making it as a linear chain as a proof of concept, since handling many possible steps at once
            # and moving back and forth between them is not easy to implement.
        print("tree before expansion: \n", json.dumps(self.tree, indent=2))
        prompt = f"""
        Given the current reasoning tree: {json.dumps(self.tree)}, determine the next logical step or provide the final answer if it has been found.
        If you have already found or reached the answer, provide the final answer **only** in the following format:
        "Answer: <answer>"
        **Do not include any additional text, explanations, or reasoning. Only return the final answer in the specified format.**
        Otherwise, provide the next reasoning step **only** in the following format:
        "Next step: <next reasoning step>"
        **Do not include any additional text or explanations. Only return the response in the specified format.**
        """
        print("prompt for expansion: \n", prompt)
        # prompt = f"""
        # Given the current reasoning step: {self.current_state['step']}, and the following facts: {self.current_state['kb_facts']}, what is the next logical step? 
        # If you have enough information to directly answer the question, provide the final answer **only** in the following format:
        # "Answer: <answer>"
        # Otherwise, provide the next reasoning step **only** in the following format:
        # "Next step: <next reasoning step>"
        # **Important: If the next step contains the final answer to the question, you must use the "Answer:" format.**
        # **Do not include any additional text or explanations. Only return the response in the specified format.**
        # """
        result, tag = self.llm.req2provider(prompt, max_tokens = None, stop = None)
        new_step = result[0]["message"]["content"]

        # Check if the new step contains an answer
        if new_step.lower().startswith("answer:"):
            # Extract the answer and mark the node as an answer node
            answer = new_step[len("Answer:"):].strip()
            is_answer = True
        else:
            # Continue the reasoning process
            answer = new_step
            is_answer = False

        # Add the new step to the tree
        log_probs = np.array(result[0]["logprobs"]["token_logprobs"])
        seq_prob = round(np.exp(np.mean(log_probs)), 2)  # Convert to probability
        # Query the knowledge base for relevant facts
        new_kb_result = query_knowledge_base(new_step)
        new_node = {"step": new_step, "kb_facts": new_kb_result, "children": [], "prob": seq_prob, "is_answer": is_answer}
        self._find_deepest_node(self.tree)["children"].append(new_node)
        # Update current_state to the new node
        self.current_state = new_node
        print(f"result got from the llm regarding expansion:\n{json.dumps(self.tree, indent=2)}")

    def _prune_tree(self):
        """
        Prune low-probability branches from the tree.
        """
        def normalize_log_probs(log_probs):
            """
            Normalize log probabilities to probabilities using softmax.
            """
            if len(log_probs) == 0:  # Check if the list is empty
                return np.array([])  # Return an empty array
            
            probs = np.exp(log_probs - np.max(log_probs))  # Subtract max for numerical stability
            probs = probs / np.sum(probs)  # Normalize to sum to 1
            return probs
        
        # Normalize probabilities of branches
        branch_probs = normalize_log_probs([child["prob"] for child in self.current_state["children"]])
        
        # Prune branches with low normalized probabilities (if there are any)
        if len(branch_probs) > 0:  # Only prune if there are branches to prune
            self.current_state["children"] = [
                child for child, prob in zip(self.current_state["children"], branch_probs) if prob > 0.5  # Example threshold
            ]

    def _reorganize_tree(self):
        """
        Reorganize the tree structure using the LLM.
        """
        # Notes: 
            # should we let llm remove nodes ? maybe dangerous

        # prompt = f"""
        # Given the current reasoning tree: {json.dumps(self.tree)}, reorganize it to improve the clarity and logical flow of the reasoning process. 
        # Focus on:
        # 1. Removing redundant or low-probability branches.
        # 2. Removing irrelevant or incorrect facts from the `kb_facts` of each node.
        # 3. Grouping related reasoning steps.
        # 4. Ensuring the tree leads to a clear and concise answer.
        # 5. **Do not add any new nodes or children **; only remove or reorder existing nodes. **If you add new nodes, the reorganization will be invalid.**
        # 6. **Do not modify any existing attributes** such as `step`, `is_answer`, `prob`, or `children`. Only modify `kb_facts` to improve clarity and relevance.
        # Return **only** the reorganized tree in the same JSON format given. Do not include any additional text or explanations.
        # """

        prompt = f"""
        Given the current reasoning tree: {json.dumps(self.tree)}, reorganize it to improve the clarity and logical flow of the reasoning process. 
        Focus on:
        1. Removing irrelevant or redundant facts from the `kb_facts` of each node.
        2. Rephrasing `step` for clarity **only if the meaning of the question remains unchanged**.
        3. Grouping related reasoning steps.
        4. Reordering nodes to improve logical flow.

        Constraints:
        1. **Do not add any new nodes or children**
        2. **Do not modify critical attributes** such as `is_answer`, `prob`, or `children`.
        3. **Preserve the core question**: Ensure the `step` (question) remains unchanged in meaning.

        Return **only** the reorganized tree in the same JSON format given. Do not include any additional text or explanations.
        """
        
        print("prompt for reorganization:\n", prompt)
        result, tag = self.llm.req2provider(prompt, max_tokens = None, stop = None)
        reorganized_tree_str = result[0]["message"]["content"]
        print(f"result got from the llm regarding reorganization:\n{reorganized_tree_str}")

        try:
            # Replace single quotes with double quotes for valid JSON keys
            # reorganized_tree_str = reorganized_tree_str.replace("'", '"')
            # Replace Python boolean values with JSON boolean values
            reorganized_tree_str = reorganized_tree_str.replace("True", "true").replace("False", "false")
            reorganized_tree = json.loads(reorganized_tree_str)
        except json.JSONDecodeError as e:
            print(f"Failed to parse reorganized tree: {e}. Keeping the current tree.")
            return

        # Validate the reorganized tree
        if not self._validate_tree(reorganized_tree):
            print("Invalid reorganization suggestion. Keeping the current tree.")
            return
        
        # Detect and assign probabilities to new nodes
        self._assign_probabilities_to_new_nodes(reorganized_tree, result)
        
        # Evaluate the impact of reorganization
        print("Computing reward before reorganization...")
        original_reward = self._compute_reward()
        self.tree = reorganized_tree
        self.current_state = self._find_deepest_node(self.tree)
        print("Computing reward after reorganization...")
        new_reward = self._compute_reward()

        print(f"Tree now after reorganizing ends: \n{json.dumps(self.tree, indent=2)}")

        if new_reward > original_reward:
            print("Reorganization improved the tree.")
        else:
            print("Reorganization did not improve the tree.")
    
    def _validate_tree(self, tree):
        """
        Validate the structure of the reorganized tree recursively.
        """
        # if "step" not in tree or "children" not in tree or "prob" not in tree:
            # removed prob from the check for now, since prob got from llm may not be coherent with the tree's prob, and we can add it later by mean of the probs of response
        if "children" not in tree:
            tree["children"] = []
        
        if "step" not in tree:
            return False
        if not isinstance(tree["children"], list):
            return False
        for child in tree["children"]:
            if not self._validate_tree(child):
                return False
        return True
    
    def _assign_probabilities_to_new_nodes(self, reorganized_tree, llm_result):
        """
        Detect new nodes in the reorganized tree and assign probabilities to them.
        """
        # Get the normalized probability from the LLM's response
        log_probs = np.array(llm_result[0]["logprobs"]["token_logprobs"])
        seq_prob = round(np.exp(np.mean(log_probs)), 2)  # Convert to probability

        # Traverse the reorganized tree and assign probabilities to new nodes
        def _traverse_and_assign(node):
            if "prob" not in node:
                # If the node is new (missing prob), assign the normalized probability
                node["prob"] = seq_prob
            if "children" in node:
                for child in node["children"]:
                    _traverse_and_assign(child)

        _traverse_and_assign(reorganized_tree)

    def _find_deepest_node(self, node):
        """
        Recursively find the deepest node in the tree.
        """
        if not node.get("children"):
            # If no children, return the current node
            return node
        else:
            # Recursively find the deepest node in the last child
            return self._find_deepest_node(node["children"][-1])
    
    def _compute_reward(self):
        """
        Compute the reward based on the quality of the reasoning process.
        """
        # Example: Reward for reaching a correct answer or penalize for unnecessary steps
        return 1.0 if self._is_correct() else -0.1

    def _is_correct(self):
        """
        Check if the current state leads to a correct answer.
        """
        # Step 1: Extract the answer from the tree
        answer = self._extract_answer(self.tree)

        # Step 2: Compare the extracted answer to the ground truth
        ground_truth = "14"  # Example ground truth for "What is the capital of France?"
        return ground_truth.lower() in answer.lower()

    def _extract_answer(self, node):
        """
        Extract the answer from the tree by traversing nodes marked as answers.
        """
        if node.get("is_answer", False):  # Check if the node is marked as an answer
            return node["step"]

        # Recursively traverse the tree to find the answer
        for child in node["children"]:
            answer = self._extract_answer(child)
            if answer:
                print(f"Answer found: {answer}")
                return answer

        return ""  # No answer found


    def _compute_depth(self, node):
        """
        Compute the depth of the tree starting from the given node.
        """
        if not node["children"]:
            return 1
        return 1 + max(self._compute_depth(child) for child in node["children"])

    def _is_done(self):
        """
        Check if the reasoning process is complete.
        """
        # Example: Terminate if the tree reaches a certain depth or a correct answer is found
        max_depth = 10  # Maximum depth of the tree
        # max_iterations = 5  # Maximum Iterations to find the answer since depth can increase and decrease
        current_depth = self._compute_depth(self.tree)

        # Check termination conditions
        if current_depth >= max_depth:
            print("Maximum depth reached!")
            return True
        # if self.current_iteration >= max_iterations:
        #     print("Maximum iterations reached!")
        #     return True
        if self._is_correct():
            print("Correct answer found!")
            return True
        if not self.current_state["children"] and self.current_state.get("is_answer", False):
            # Terminate only if it's a leaf node AND contains the correct answer
            print("Answer found at a leaf node!")
            return True

        return False

    def reset(self, seed=None, options=None):
        """
        Reset the environment to the initial state and return the initial observation.
        """
        print("resetting the env")
        super().reset(seed=seed)  # Required by Gymnasium
        self.tree = self._initialize_tree()
        self.current_state = self.tree
        return self._get_observation(), {}

In [12]:
from gymnasium.utils.env_checker import check_env  # For Gymnasium
# from gym.utils.env_checker import check_env  # For OpenAI Gym

# Create an instance of the base environment
# env = ReasoningTreeEnv(question="What is the capital of France?")
env = ReasoningTreeEnv(question="The winner of the the London Riders Championship in 1953 scored how many points in the 1952 Individual Speedway World Championship?")

# Validate the environment using the environment checker
check_env(env)

  logger.warn(
  logger.warn(


resetting the env
resetting the env
resetting the env
resetting the env
resetting the env
resetting the env
resetting the env
resetting the env
step chosen: expanding
tree before expansion: 
 {
  "step": "The winner of the the London Riders Championship in 1953 scored how many points in the 1952 Individual Speedway World Championship?",
  "kb_facts": [
    "There are annual world championship events in the sport of motorcycle speedway for individual riders - the Speedway Grand Prix - and for national teams - the Speedway World Cup. Each has a counterpart for riders under 21: the Speedway World Under 21 Championship and the Team Speedway Junior World Championship. A pairs event, the Speedway World Pairs Championship, ran until 1993.",
    "Jack Ellis Young (31 January 1925 in Adelaide, South Australia \u2013 28 August 1987 in Adelaide) was a Motorcycle speedway rider who won the Speedway World Championship in 1951 and 1952. He also won the London Riders' Championship 1953 and 1954 and w

  logger.warn(


In [13]:
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement
from stable_baselines3.common.monitor import Monitor
import os

# Define the environment
def make_env(question):
    """
    Create and wrap the environment with Monitor for logging.
    """
    env = ReasoningTreeEnv(question)
    env = Monitor(env)  # Wrap the environment with Monitor for logging
    return env

# Create the vectorized environment
question = "The winner of the the London Riders Championship in 1953 scored how many points in the 1952 Individual Speedway World Championship?"
env = DummyVecEnv([lambda: make_env(question)])

# Define the PPO model
model = PPO(
    "MlpPolicy",  # Policy network (MLP for vector observations)
    env,          # Environment
    verbose=1,    # Print training logs
    tensorboard_log="./ppo_reasoning_tree_tensorboard/",  # TensorBoard logging
    learning_rate=3e-4,  # Learning rate
    n_steps=2048,        # Number of steps per update
    batch_size=64,       # Batch size
    n_epochs=1,         # Number of epochs per update
    gamma=0.99,          # Discount factor
    gae_lambda=0.95,     # GAE (Generalized Advantage Estimation) lambda
    clip_range=0.2,      # PPO clip range
    ent_coef=0.01,       # Entropy coefficient (encourages exploration)
)

early_stopping_callback = StopTrainingOnNoModelImprovement(
    max_no_improvement_evals=10,  # Stop if no improvement after 10 evaluations
    min_evals=5,  # Minimum number of evaluations before stopping
    verbose=1,
)

# Create a callback for evaluation
eval_callback = EvalCallback(
    env,  # Environment to evaluate on
    best_model_save_path="./ppo_reasoning_tree_best_model/",  # Save the best model
    log_path="./ppo_reasoning_tree_eval_logs/",  # Log evaluation results
    eval_freq=1000,  # Evaluate every 1000 steps
    deterministic=True,  # Use deterministic actions for evaluation
    render=False,  # Do not render the environment during evaluation
    callback_on_new_best=early_stopping_callback,  # Use early stopping
)

# Train the model
total_timesteps = 20  # Total number of training steps
model.learn(
    total_timesteps=total_timesteps,
    callback=eval_callback,  # Use the evaluation callback
    tb_log_name="ppo_reasoning_tree",  # TensorBoard log name
)

# Save the final model
model.save("ppo_reasoning_tree_final_model")

# Close the environment
env.close()

Using cpu device
resetting the env
Logging to ./ppo_reasoning_tree_tensorboard/ppo_reasoning_tree_43
step chosen: expanding
tree before expansion: 
 {
  "step": "The winner of the the London Riders Championship in 1953 scored how many points in the 1952 Individual Speedway World Championship?",
  "kb_facts": [
    "There are annual world championship events in the sport of motorcycle speedway for individual riders - the Speedway Grand Prix - and for national teams - the Speedway World Cup. Each has a counterpart for riders under 21: the Speedway World Under 21 Championship and the Team Speedway Junior World Championship. A pairs event, the Speedway World Pairs Championship, ran until 1993.",
    "Jack Ellis Young (31 January 1925 in Adelaide, South Australia \u2013 28 August 1987 in Adelaide) was a Motorcycle speedway rider who won the Speedway World Championship in 1951 and 1952. He also won the London Riders' Championship 1953 and 1954 and was a nine time South Australian Champion be

KeyboardInterrupt: 

In [None]:
for episode in range(1000):  # Number of episodes
    state = env.reset()
    done = False
    total_reward = 0

    while not done:
        action, _ = model.predict(state)  # Select action using the trained policy
        state, reward, done = env.step(action)
        total_reward += reward

    print(f"Episode {episode + 1}, Total Reward: {total_reward}")

In [None]:
test_env = DummyVecEnv([lambda: ReasoningTreeEnv(question="What is the capital of Germany?")])
state = test_env.reset()
done = False

while not done:
    action, _ = model.predict(state)
    state, reward, done = test_env.step(action)

print(f"Final State: {state}, Reward: {reward}")

## Trying array of thoughts for easier

In [70]:
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from gymnasium.utils.env_checker import check_env

class ReasoningArrayEnv(gym.Env):
    def __init__(self, question):
        """
        Initialize the environment with a question and an empty array of thoughts.
        """
        super(ReasoningArrayEnv, self).__init__()

        # Define the observation and action spaces
        self.observation_space = spaces.Box(
            low=-np.inf, high=np.inf, shape=(250,), dtype=np.float32
        )
        self.action_space = spaces.Discrete(4)  # 4 actions: expand, prune, reorganize, terminate

        self.question = question
        self.thoughts = self._initialize_array()  # Array of structured thoughts
        print(self.thoughts)
        self.llm = TogetherReq()
        self.current_iteration = 1

    def _initialize_array(self):
        """
        Initialize the reasoning aree with the question as the first thought.
        """
        # Query the knowledge base for relevant facts
        kb_result = query_knowledge_base(self.question)
        return [{"step": self.question, "kb_facts": kb_result, "prob": 1.0}]

    def _get_observation(self):
        """
        Get the current observation (state) of the environment as a fixed-size vector.
        """
        # Step 1: Initialize the observation vector
        observation = np.zeros(250, dtype=np.float32)

        # Step 2: Embed each thought and concatenate into the observation vector
        for i, thought in enumerate(self.thoughts[:5]):  # Limit to 5 thoughts (50 * 5 = 250)
            thought_embedding = self._embed_thought(thought)
            observation[i * 50 : (i + 1) * 50] = thought_embedding

        return observation

    def _embed_thought(self, thought):
        """
        Convert a thought (dictionary) into a fixed-size embedding.
        """
        # Example: Use a simple embedding (replace with a real embedding model)
        embedding = np.zeros(50, dtype=np.float32)

        # Embed the step text
        step_embedding = self._embed_text(thought["step"])
        embedding[:25] = step_embedding[:25]  # First 25 dimensions for the step

        # Embed the kb_facts (concatenate all facts into a single string)
        kb_facts_str = " ".join(thought["kb_facts"])
        kb_facts_embedding = self._embed_text(kb_facts_str)
        embedding[25:50] = kb_facts_embedding[:25]  # Next 25 dimensions for the kb_facts

        return embedding

    def _embed_text(self, text):
        """
        Convert text into a fixed-size embedding (e.g., using a pre-trained model).
        """
        # Example: Use a simple embedding (replace with a real embedding model)
        embedding = np.zeros(25, dtype=np.float32)
        embedding[:min(len(text), 25)] = [ord(c) for c in text[:25]]  # Truncate to 25 characters
        return embedding

    def step(self, action):
        """
        Execute an action and return the new observation, reward, done flag, and info.
        """
        if self._is_done():
            # If the episode is already done, return the current state and reward
            print(f"Episode is already done. No further actions can be taken.")
            return self._get_observation(), self._compute_reward(), True, False, {}

        if action == 0:  # expand
            print("step chosen: expanding")
            self._expand_array()
        elif action == 1:  # prune
            print("step chosen: pruning")
            self._prune_array()
        elif action == 2:  # reorganize
            print("step chosen: reorganizing")
            self._reorganize_array()
        elif action == 3:  # terminate
            print(f"step chosen: terminating - reward: {self._compute_reward()}, done: True, truncated: False")
            return self._get_observation(), self._compute_reward(), True, False, {}

        reward = self._compute_reward()
        done = self._is_done()
        truncated = False  # Gymnasium requires a "truncated" flag (e.g., for time limits)
        print(f"reward: {reward}, done: {done}, truncated: {truncated}, iteration: {self.current_iteration}")
        print(f"thoughts now before next iteration ends: \n{self.thoughts}")

        self.current_iteration += 1
        return self._get_observation(), reward, done, truncated, {}

    def _expand_array(self):
        """
        Expand the array of thoughts by generating a new reasoning step.
        """
        # Generate a new reasoning step using the LLM
        # prompt = f"""
        # Given the current reasoning steps: {[thought['step'] for thought in self.thoughts]}, determine the next logical step or provide the final answer if it has been found.
        
        # If you have already found or reached the answer to the original question in first thought, provide the final answer **only** in the following format:
        # "Answer: <answer>"
        # **Do not include any additional text, explanations, or reasoning. Only return the final answer in the specified format.**
        
        # Otherwise, provide the next reasoning step **only** in the following format:
        # "Next step: <next reasoning step>"
        # **Do not include any additional text or explanations. Only return the response in the specified format.**
        # """
        prompt = f"""
        Given the original question: "{self.question}", and the current reasoning steps: {[thought['step'] for thought in self.thoughts]}, determine if the answer is already present in the reasoning steps.

        If the answer is present, provide it **only** in the following format:
        "Answer: <answer>"
        **Do not include any additional text, explanations, or reasoning. Only return the final answer in the specified format.**

        If the answer is not present, provide the next logical reasoning step **only** in the following format:
        "Next step: <next reasoning step>"
        **Do not include any additional text or explanations. Only return the response in the specified format.**
        """

        print("prompt for expansion: \n", prompt)
        result, tag = self.llm.req2provider(prompt, max_tokens=None, stop=None)
        new_step = result[0]["message"]["content"]

         # Add the new step to the tree
        log_probs = np.array(result[0]["logprobs"]["token_logprobs"])
        seq_prob = round(np.exp(np.mean(log_probs)), 2)  # Convert to probability

        # Check if the new step contains an answer
        if new_step.lower().startswith("answer:"):
            # Extract the answer and mark it as the final step
            answer = new_step[len("Answer:"):].strip()
            new_thought = {
                "step": f"Answer: {answer}",
                "prob": seq_prob,
                "kb_facts": []  # No additional facts needed for the answer
            }
        else:
            # Add the new reasoning step to the array
            new_thought = {
                "step": new_step,
                "prob": seq_prob,
                "kb_facts": query_knowledge_base(new_step)  # Query KB for relevant facts
            }

        self.thoughts.append(new_thought)
        print(f"Array now after expansion ends: \n{json.dumps(self.thoughts, indent=2)}")


    def _prune_array(self):
        """
        Prune low-probability or irrelevant steps from the array.
        """
        # Example: Remove steps with low probability
        if len(self.thoughts) > 0:
            self.thoughts = [thought for thought in self.thoughts if thought["prob"] > 0.5]  # Example threshold

    def _reorganize_array(self):
        """
        Reorganize the array of thoughts using the LLM.
        """
        # Convert the array to a JSON string with double quotes
        prompt = f"""
        Given the original question: "{self.question}", and the current reasoning steps: {json.dumps(self.thoughts, ensure_ascii=False)}, reorganize them to improve the clarity and logical flow of the reasoning process. 
        Focus on:
        1. Removing irrelevant or redundant facts from the `kb_facts` of each step.
        2. Correcting the `step` if it contradicts the `kb_facts` by rephrasing it to align with the facts.
        3. Rephrasing `step` for clarity **only if the meaning remains unchanged**.
        4. Grouping related reasoning steps.
        5. Reordering steps to improve logical flow.

        Constraints:
        1. **Do not add any new steps**; only remove or reorder existing steps.
        2. **Do not modify critical attributes** such as `prob` or `kb_facts` unless necessary for clarity.
        3. **Preserve the core question**: Ensure the `step` (question) remains unchanged in meaning.

        Return **only** the reorganized array of steps in the same JSON format given. Do not include any additional text or explanations.
        """
        
        print("prompt for reorganization:\n", prompt)
        result, tag = self.llm.req2provider(prompt, max_tokens=None, stop=None)
        reorganized_array_str = result[0]["message"]["content"]
        print(f"result got from the llm regarding reorganization:\n{reorganized_array_str}")

        try:
            # Replace single quotes with double quotes for valid JSON keys
            # reorganized_array_str = reorganized_array_str.replace("'", '"')
            # Replace Python boolean values with JSON boolean values
            reorganized_array_str = reorganized_array_str.replace("True", "true").replace("False", "false")
            reorganized_array = json.loads(reorganized_array_str)
        except json.JSONDecodeError as e:
            print(f"Failed to parse reorganized array: {e}. Keeping the current array.")
            return

        # Validate the reorganized array
        if not self._validate_array(reorganized_array):
            print("Invalid reorganization suggestion. Keeping the current array.")
            return
        
        # Detect and assign probabilities to new steps
        self._assign_probabilities_to_new_steps(reorganized_array, result)
        
        # Evaluate the impact of reorganization
        print("Computing reward before reorganization...")
        original_reward = self._compute_reward()
        self.thoughts = reorganized_array
        print("Computing reward after reorganization...")
        new_reward = self._compute_reward()

        print(f"Array now after reorganizing ends: \n{json.dumps(self.thoughts, indent=2)}")

        if new_reward > original_reward:
            print("Reorganization improved the array.")
        else:
            print("Reorganization did not improve the array.")

    def _validate_array(self, array):
        """
        Validate the structure of the reorganized array.
        """
        if not isinstance(array, list):
            return False
        for step in array:
            if not isinstance(step, dict):
                return False
            if "step" not in step or "prob" not in step or "kb_facts" not in step:
                return False
        return True

    def _assign_probabilities_to_new_steps(self, reorganized_array, llm_result):
        """
        Assign probabilities to new steps in the reorganized array.
        """
        # Get the normalized probability from the LLM's response
        log_probs = np.array(llm_result[0]["logprobs"]["token_logprobs"])
        seq_prob = round(np.exp(np.mean(log_probs)), 2)  # Convert to probability

        # Assign probabilities to new steps
        for step in reorganized_array:
            if "prob" not in step:
                step["prob"] = seq_prob

    def _compute_reward(self):
        """
        Compute the reward based on the quality of the reasoning process.
        """
        # Example: Reward for reaching a correct answer or penalize for unnecessary steps
        return 1.0 if self._is_correct() else -0.1
    
    def _is_correct(self):
        """
        Check if the current array of thoughts contains the correct answer.
        """
        # Step 1: Extract the answer from the array
        answer = self._extract_answer()

        # Step 2: Compare the extracted answer to the ground truth
        ground_truth = "14"  # Example ground truth for "What is the capital of France?"
        return ground_truth.lower() in answer.lower()

    def _extract_answer(self):
        """
        Extract the answer from the array of thoughts.
        """
        for thought in self.thoughts:
            if thought["step"].lower().startswith("answer:"):
                return thought["step"][len("Answer:"):].strip()
        return ""  # No answer found

    def _is_done(self):
        """
        Check if the reasoning process is complete.
        """
        # Example: Terminate if the array contains the correct answer or reaches a maximum length
        max_steps = 10  # Maximum number of steps
        if len(self.thoughts) >= max_steps:
            print("Maximum steps reached!")
            return True
        if self._is_correct():
            print("Correct answer found!")
            return True
        return False

    def reset(self, seed=None, options=None):
        """
        Reset the environment to the initial state and return the initial observation.
        """
        print("resetting the env")
        super().reset(seed=seed)  # Required by Gymnasium
        self.thoughts = self._initialize_array()  # Reset the array of thoughts
        return self._get_observation(), {}

In [71]:
from gymnasium.utils.env_checker import check_env  # For Gymnasium
# from gym.utils.env_checker import check_env  # For OpenAI Gym

# Create an instance of the base environment
# env = ReasoningTreeEnv(question="What is the capital of France?")
env = ReasoningArrayEnv(question="The winner of the the London Riders Championship in 1953 scored how many points in the 1952 Individual Speedway World Championship?")

# Validate the environment using the environment checker
check_env(env)

[{'step': 'The winner of the the London Riders Championship in 1953 scored how many points in the 1952 Individual Speedway World Championship?', 'kb_facts': ['There are annual world championship events in the sport of motorcycle speedway for individual riders - the Speedway Grand Prix - and for national teams - the Speedway World Cup. Each has a counterpart for riders under 21: the Speedway World Under 21 Championship and the Team Speedway Junior World Championship. A pairs event, the Speedway World Pairs Championship, ran until 1993.', "Jack Ellis Young (31 January 1925 in Adelaide, South Australia – 28 August 1987 in Adelaide) was a Motorcycle speedway rider who won the Speedway World Championship in 1951 and 1952. He also won the London Riders' Championship 1953 and 1954 and was a nine time South Australian Champion between 1948 and 1964."], 'prob': 1.0}]
resetting the env
resetting the env
resetting the env
resetting the env
resetting the env
resetting the env
resetting the env
res

In [72]:

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement
from stable_baselines3.common.monitor import Monitor
import os

# Define the environment
def make_env(question):
    """
    Create and wrap the environment with Monitor for logging.
    """
    env = ReasoningArrayEnv(question)
    env = Monitor(env)  # Wrap the environment with Monitor for logging
    return env

# Create the vectorized environment
question = "The winner of the the London Riders Championship in 1953 scored how many points in the 1952 Individual Speedway World Championship?"
env = DummyVecEnv([lambda: make_env(question)])

# Define the PPO model
model = PPO(
    "MlpPolicy",  # Policy network (MLP for vector observations)
    env,          # Environment
    verbose=1,    # Print training logs
    tensorboard_log="./ppo_reasoning_tree_tensorboard/",  # TensorBoard logging
    learning_rate=3e-4,  # Learning rate
    n_steps=2048,        # Number of steps per update
    batch_size=64,       # Batch size
    n_epochs=1,         # Number of epochs per update
    gamma=0.99,          # Discount factor
    gae_lambda=0.95,     # GAE (Generalized Advantage Estimation) lambda
    clip_range=0.2,      # PPO clip range
    ent_coef=0.01,       # Entropy coefficient (encourages exploration)
)

early_stopping_callback = StopTrainingOnNoModelImprovement(
    max_no_improvement_evals=10,  # Stop if no improvement after 10 evaluations
    min_evals=5,  # Minimum number of evaluations before stopping
    verbose=1,
)

# Create a callback for evaluation
eval_callback = EvalCallback(
    env,  # Environment to evaluate on
    best_model_save_path="./ppo_reasoning_tree_best_model/",  # Save the best model
    log_path="./ppo_reasoning_tree_eval_logs/",  # Log evaluation results
    eval_freq=1000,  # Evaluate every 1000 steps
    deterministic=True,  # Use deterministic actions for evaluation
    render=False,  # Do not render the environment during evaluation
    callback_on_new_best=early_stopping_callback,  # Use early stopping
)

# Train the model
total_timesteps = 20  # Total number of training steps
model.learn(
    total_timesteps=total_timesteps,
    callback=eval_callback,  # Use the evaluation callback
    tb_log_name="ppo_reasoning_tree",  # TensorBoard log name
)

# Save the final model
model.save("ppo_reasoning_tree_final_model")

# Close the environment
env.close()

[{'step': 'The winner of the the London Riders Championship in 1953 scored how many points in the 1952 Individual Speedway World Championship?', 'kb_facts': ['There are annual world championship events in the sport of motorcycle speedway for individual riders - the Speedway Grand Prix - and for national teams - the Speedway World Cup. Each has a counterpart for riders under 21: the Speedway World Under 21 Championship and the Team Speedway Junior World Championship. A pairs event, the Speedway World Pairs Championship, ran until 1993.', "Jack Ellis Young (31 January 1925 in Adelaide, South Australia – 28 August 1987 in Adelaide) was a Motorcycle speedway rider who won the Speedway World Championship in 1951 and 1952. He also won the London Riders' Championship 1953 and 1954 and was a nine time South Australian Champion between 1948 and 1964."], 'prob': 1.0}]
Using cpu device
resetting the env
Logging to ./ppo_reasoning_tree_tensorboard/ppo_reasoning_tree_56
step chosen: pruning
reward: