In [1]:
import os

os.environ['HF_HOME'] = 'E:/Models/hf_cache'
os.environ['HUGGINGFACE_HUB_CACHE'] = 'E:/Models/hf_cache'

In [2]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
# ----------------------------------------------------------------------------
# 1. Imports and Setup
# ----------------------------------------------------------------------------
import os
import json
import re
import itertools
import numpy as np
import argparse
from functools import partial
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from datasets import load_dataset

# Global model and tokenizer variables
model = None
tokenizer = None
model_name = "google/gemma-2-2b-it"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
# ----------------------------------------------------------------------------
# 2. Model Interaction Logic for Gemma
# ----------------------------------------------------------------------------

def load_model():
    """Loads the Gemma model and tokenizer into global variables."""
    global model, tokenizer, device
    print(f"Using model: {model_name} on device: {device}")

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    nf4_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=nf4_config,
        device_map="auto"
    )
    print(f"Model '{model_name}' loaded successfully.")

def gemma_call(prompt: str, n: int = 1, stop: list = None, temperature: float = 0.7):
    """
    A wrapper function to call the Gemma model.
    """
    if model is None:
        load_model()
        
    outputs = []
    # Set max_new_tokens for a single reasoning step
    max_new_tokens = 512
    
    stop_sequences = stop if stop else [tokenizer.eos_token]

    for _ in range(n):
        chat = [{"role": "user", "content": prompt}]
        formatted_prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
        inputs = tokenizer.encode(formatted_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
        
        # Ensure do_sample is True for temperature to have an effect
        do_sample_flag = True if temperature > 0 else False
        
        with torch.no_grad():
            output_ids = model.generate(
                inputs,
                max_new_tokens=max_new_tokens,
                temperature=max(temperature, 0.01), # Temp must be > 0 for sampling
                do_sample=do_sample_flag,
                pad_token_id=tokenizer.eos_token_id
            )
        
        response = tokenizer.decode(output_ids[0][len(inputs[0]):], skip_special_tokens=True)
        
        # Manually truncate at the first stop sequence found
        for seq in stop_sequences:
            if seq in response:
                response = response.split(seq)[0]
        outputs.append(response.strip())
        
    return outputs

# The global `gpt` variable used by the core algorithm
gpt = gemma_call

In [6]:
# ----------------------------------------------------------------------------
# 3. Task Definition for AQuA
# ----------------------------------------------------------------------------

class AQuATask:
    """
    Defines the AQuA task, including prompt formatting, data loading, and output testing.
    """
    def __init__(self, file_path='aqua_train.jsonl'):
        self.data = self.load_data(file_path)
        self.value_cache = {}
        # Define stop sequences for each reasoning step if necessary
        self.steps = 4 # A typical problem might take 3-4 main steps
        self.stops = ['\n'] * self.steps

        # Prompts adapted for AQuA
        self.cot_prompt = '''Solve the following algebraic word problem by thinking step-by-step.

Question: A zoo has 37 ostriches and 12 giraffes. The zoo curator buys some more ostriches, and now the ratio of ostriches to giraffes is 7:2. How many ostriches did the curator buy?
Options: (A) 12, (B) 15, (C) 20, (D) 45, (E) 7
Rationale:
Let 'o' be the initial number of ostriches and 'g' be the number of giraffes. o = 37, g = 12.
Let 'x' be the number of new ostriches bought.
The new number of ostriches is 37 + x.
The new ratio of ostriches to giraffes is (37 + x) / 12.
This ratio is given as 7:2, which is 7/2.
So, (37 + x) / 12 = 7/2.
To solve for x, multiply both sides by 12: 37 + x = (7/2) * 12.
37 + x = 7 * 6 = 42.
x = 42 - 37 = 5. The problem asks for the number of ostriches bought, which is x. Wait, I made a mistake in my calculation.
Let me re-check: 37 + x = 42. x = 42 - 37 = 5.  Checking the options, 5 is not there. Let me re-read. Ah, the question might have a typo and meant 5, but let's re-calculate. (37 + x) / 12 = 7/2.  37+x = 42. x=5.  Let's assume there is a mistake in the options and choose the closest one or re-evaluate.
Let's re-read the problem carefully. Okay, maybe I made a mistake somewhere.
Ratio is 7:2. Giraffes are 12. To maintain the ratio, the number of ostriches must be a number 'O' such that O/12 = 7/2.
O = (7/2) * 12 = 7 * 6 = 42.
The new total number of ostriches is 42.
The initial number was 37.
The number of ostriches bought is 42 - 37 = 5.
Wait, 5 is still not in the options. Let me check the problem statement again. Perhaps there is a misunderstanding.
Let's assume the question meant the number of giraffes was different. No, that's fixed.
Let's check my math again. (37+x)/12 = 7/2 -> 2(37+x) = 7*12 -> 74 + 2x = 84 -> 2x = 10 -> x=5. Still 5.
Okay, let's look at the options. Maybe one of the options leads to the ratio.
(A) 37+12=49. 49/12 is not 7/2.
(B) 37+15=52. 52/12 is not 7/2.
(C) 37+20=57. 57/12 is not 7/2.
(E) 37+7=44. 44/12 is not 7/2.
There seems to be an issue with this problem's options. Let me assume my calculation is correct and the intended answer was 5. I will pick the closest option, E=7.  *This example is intentionally difficult to show recovery*.
Let's re-try with a clean problem.
Question: John is 20 years older than his son. In 10 years, John will be twice as old as his son. How old is his son now?
Options: (A) 5, (B) 10, (C) 15, (D) 20, (E) 25
Rationale:
Let John's current age be J and his son's current age be S.
From the problem, J = S + 20.
In 10 years, John's age will be J + 10 and his son's age will be S + 10.
In 10 years, John will be twice as old as his son: J + 10 = 2 * (S + 10).
Now we have a system of two equations. Substitute the first equation into the second one.
(S + 20) + 10 = 2 * (S + 10).
S + 30 = 2S + 20.
Subtract S from both sides: 30 = S + 20.
Subtract 20 from both sides: S = 10.
The son's current age is 10. This matches option B.
Answer: The final answer is B.

Question: {question}
Options: {options}
Rationale:'''

        self.propose_prompt = '''Given a math problem and a partial solution, propose several valid next steps to continue solving the problem.

Question: A store sells shirts for $15 and pants for $25. John bought 3 shirts and some pants. If he paid a total of $120, how many pairs of pants did he buy?
Options: (A) 1, (B) 2, (C) 3, (D) 4, (E) 5
Partial Rationale:
1. Let 's' be the number of shirts and 'p' be the number of pants.
2. We are given s = 3.
3. The cost of one shirt is $15 and one pair of pants is $25.

Possible next steps:
- Calculate the total cost of the shirts: Cost_shirts = 3 * $15 = $45.
- Set up the total cost equation: (3 * 15) + (p * 25) = 120.
- Find the remaining money spent on pants: Money_for_pants = 120 - (3 * 15).
- Define the variables for the costs: Let C_s = 15 and C_p = 25.

Question: {question}
Options: {options}
Partial Rationale:
{partial_solution}

Possible next steps:
'''
        self.value_prompt = '''Evaluate the provided partial rationale for the given math problem. Is the reasoning so far on a correct and promising path to the solution? Respond with only one word: sure, likely, or impossible.

Question: A train travels at 60 km/h. How far does it travel in 3.5 hours?
Partial Rationale:
1. Distance = Speed / Time.
2. So, Distance = 60 / 3.5.
Evaluation: impossible

Reason: The formula used is incorrect. Distance is Speed * Time. The reasoning has already made a fatal error.
---
Question: The sum of two numbers is 25. Their difference is 5. What are the numbers?
Partial Rationale:
1. Let the two numbers be x and y.
2. x + y = 25.
Evaluation: likely

Reason: This is a correct first step, but not yet a full solution. It's on a promising path but could still go wrong.
---
Question: A farmer has 150 eggs. He sells them in cartons of 12. How many full cartons can he sell?
Partial Rationale:
1. Total eggs = 150.
2. Eggs per carton = 12.
3. Number of cartons = Total eggs / Eggs per carton.
4. Number of cartons = 150 / 12.
Evaluation: sure

Reason: The reasoning is direct, correct, and complete. The next step is a simple calculation that will yield the final answer.
---
Question: {question}
Partial Rationale:
{partial_solution}
Evaluation:'''
    
    def load_data(self, file_path):
        if not os.path.exists(file_path):
            print(f"Error: Data file not found at {file_path}")
            print("Please create a dummy 'aqua_train.jsonl' file to run this script.")
            return []
        with open(file_path, 'r') as f:
            return [json.loads(line) for line in f]

    def __len__(self):
        return len(self.data)

    def get_input(self, idx: int) -> str:
        """Returns the formatted question and options."""
        item = self.data[idx]
        question = item['question']
        options = '\n'.join(item['options'])
        return f"Question: {question}\nOptions: {options}"

    def test_output(self, idx: int, output: str):
        """Tests if the output contains the correct letter answer."""
        correct_answer = self.data[idx]['correct']
        # Find the last capital letter (A-E) mentioned in the output
        match = re.findall(r'[A-E]', output)
        if match:
            predicted_answer = match[-1]
            return predicted_answer == correct_answer
        return False

    def cot_prompt_wrap(self, x: str, y: str = '') -> str:
        question, options = x.split('Options: ')
        question_text = question.replace('Question: ', '').strip()
        options_text = options.strip()
        return self.cot_prompt.format(question=question_text, options=options_text) + y

    def propose_prompt_wrap(self, x: str, y: str) -> str:
        # Split the input into question and options parts
        question_part, options_part = x.split('Options: ')
        
        # Clean up the text
        question_text = question_part.replace('Question: ', '').strip()
        options_text = options_part.strip() # Capture the options
        
        # Format the partial solution nicely
        partial_solution_str = '\n'.join(f'{i+1}. {line}' for i, line in enumerate(y.strip().split('\n')) if line)
        
        # **THE FIX**: Add `options=options_text` to the .format() call
        return self.propose_prompt.format(
            question=question_text, 
            options=options_text, 
            partial_solution=partial_solution_str
        )

    def value_prompt_wrap(self, x: str, y: str) -> str:
        question, _ = x.split('Options: ')
        question_text = question.replace('Question: ', '').strip()
        return self.value_prompt.format(question=question_text, partial_solution=y)

    def value_outputs_unwrap(self, x: str, y: str, outputs: list) -> float:
        """Parses 'sure'/'likely'/'impossible' into a numerical value."""
        if not outputs: return 0.0
        output = outputs[0].strip().lower()
        if 'sure' in output:
            return 1.0
        elif 'likely' in output:
            return 0.5
        elif 'impossible' in output:
            return 0.0
        return 0.1 # Default for malformed output

In [7]:
# ----------------------------------------------------------------------------
# 4. Core Tree of Thought (ToT) Algorithm
# ----------------------------------------------------------------------------

def get_value(task, x, y, n_evaluate_sample, cache_value=True):
    value_prompt = task.value_prompt_wrap(x, y)
    if cache_value and value_prompt in task.value_cache:
        return task.value_cache[value_prompt]
    value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)
    value = task.value_outputs_unwrap(x, y, value_outputs)
    if cache_value:
        task.value_cache[value_prompt] = value
    return value

def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
    values = []
    local_value_cache = {}
    for y in ys:  # each partial output
        if y in local_value_cache:  # avoid duplicate candidates
            value = 0
        else:
            value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)
            local_value_cache[y] = value
        values.append(value)
    return values

def get_proposals(task, x, y):
    propose_prompt = task.propose_prompt_wrap(x, y)
    proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n')
    # Append each proposal as a new step to the current thought process
    return [y + prop + '\n' for prop in proposals if prop]

def get_samples(task, x, y, n_generate_sample, stop):
    prompt = task.cot_prompt_wrap(x, y)
    samples = gpt(prompt, n=n_generate_sample, stop=stop)
    return [y + _ for _ in samples]

def solve(args, task, idx, to_print=True):
    global gpt
    gpt = partial(gemma_call, temperature=args.temperature)
    x = task.get_input(idx)
    ys = ['']  # current output candidates, start with empty string
    infos = []
    
    for step in range(task.steps):
        # Generation
        if args.method_generate == 'propose':
            new_ys = [get_proposals(task, x, y) for y in ys]
        else: # 'sample'
            new_ys = [get_samples(task, x, y, args.n_generate_sample, stop=task.stops[step]) for y in ys]

        new_ys = list(itertools.chain(*new_ys))
        ids = list(range(len(new_ys)))
        if not ids: break # stop if no new thoughts are generated
        
        # Evaluation
        values = get_values(task, x, new_ys, args.n_evaluate_sample)
        
        # Selection
        if args.method_select == 'greedy':
            select_ids = sorted(ids, key=lambda i: values[i], reverse=True)[:args.n_select_sample]
        elif args.method_select == 'sample':
            ps = np.array(values) / sum(values) if sum(values) > 0 else np.ones(len(values)) / len(values)
            select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
        
        select_new_ys = [new_ys[select_id] for select_id in select_ids]

        # Log
        if to_print:
            sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
            print(f'-- Step {step+1} --\n'
                  f'-- Top Valued Thoughts --: \n{sorted_new_ys[0] if sorted_new_ys else "None"}\n'
                  f'-- Values --: {[round(v, 2) for v in sorted_values]}\n'
                  f'-- Selected Thoughts --: {select_new_ys}\n' + '-'*20)

        ys = select_new_ys

    if to_print:
        print(f"Final candidate solutions: {ys}")
    return ys, {'steps': infos}

In [8]:
# ----------------------------------------------------------------------------
# 5. Main Execution Block
# ----------------------------------------------------------------------------

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--task', type=str, default='aqua', help='Task to run (currently only "aqua" is supported)')
    parser.add_argument('--task_start_index', type=int, default=0)
    parser.add_argument('--task_end_index', type=int, default=1)
    parser.add_argument('--data_path', type=str, default='aqua_train.jsonl')
    parser.add_argument('--backend', type=str, default=model_name)
    parser.add_argument('--temperature', type=float, default=0.7)
    
    parser.add_argument('--method_generate', type=str, default='propose', choices=['propose', 'sample'])
    parser.add_argument('--method_evaluate', type=str, default='value', choices=['value', 'vote'])
    parser.add_argument('--method_select', type=str, default='greedy', choices=['greedy', 'sample'])
    
    parser.add_argument('--n_generate_sample', type=int, default=1, help='Number of samples for sample-based generation')
    parser.add_argument('--n_evaluate_sample', type=int, default=3, help='Number of samples for value evaluation')
    parser.add_argument('--n_select_sample', type=int, default=5, help='Beam width for selection')
    args = parser.parse_args()
    
    print(f"Running experiment with args: {args}")

    if args.task == 'aqua':
        task = AQuATask(args.data_path)
    else:
        raise ValueError(f"Unknown task: {args.task}")

    if not task.data:
        print("Exiting due to empty dataset.")
        return

    correct_count = 0
    total_count = 0

    for i in range(args.task_start_index, min(args.task_end_index, len(task))):
        print(f"\n\n----- Solving Problem #{i} -----")
        print(task.get_input(i))
        print("---------------------------------")
        
        # Run the ToT solver
        final_ys, _ = solve(args, task, i)
        
        # Check the output
        is_correct = False
        if final_ys:
            # Use the first (highest-valued) candidate solution for evaluation
            best_solution = final_ys[0]
            is_correct = task.test_output(i, best_solution)
            print(f"\n----- Final Answer for Problem #{i} -----")
            print(best_solution)
            print("---------------------------------------")
            print(f"Correctness: {'CORRECT' if is_correct else 'INCORRECT'}")
        else:
            print("No solution was generated.")

        if is_correct:
            correct_count += 1
        total_count += 1
        
        # Calculate and print running accuracy
        accuracy = (correct_count / total_count) * 100
        print(f"Running Accuracy: {accuracy:.2f}% ({correct_count}/{total_count})")

    print("\n\n----- Experiment Summary -----")
    if total_count > 0:
        final_accuracy = (correct_count / total_count) * 100
        print(f"Final Accuracy: {final_accuracy:.2f}%")
        print(f"Total Correct: {correct_count}")
        print(f"Total Attempted: {total_count}")
    else:
        print("No problems were attempted.")
    print("----------------------------")

In [10]:
# Create a namespace object to simulate argparse
class Args:
    task = 'aqua'
    task_start_index = 0
    task_end_index = 10  # Set how many problems you want to test
    data_path = 'E:/Data/AQuA Dataset/train.json'
    backend = model_name
    temperature = 0.7
    method_generate = 'propose'
    method_evaluate = 'value'
    method_select = 'greedy'
    n_generate_sample = 1
    n_evaluate_sample = 3
    n_select_sample = 5 # This is your beam width (b)

args = Args()

In [11]:
print(f"Running experiment with args: {vars(args)}")

if args.task == 'aqua':
    task = AQuATask(args.data_path)
else:
    raise ValueError(f"Unknown task: {args.task}")

if not task.data:
    print("Exiting due to empty dataset. Make sure 'aqua_train.jsonl' exists.")
else:
    correct_count = 0
    total_count = 0

    for i in range(args.task_start_index, min(args.task_end_index, len(task))):
        print(f"\n\n----- Solving Problem #{i} -----")
        print(task.get_input(i))
        print("---------------------------------")

        # Run the ToT solver
        final_ys, _ = solve(args, task, i)

        # Check the output
        is_correct = False
        if final_ys:
            # Use the first (highest-valued) candidate solution for evaluation
            best_solution = final_ys[0]
            is_correct = task.test_output(i, best_solution)
            print(f"\n----- Final Answer for Problem #{i} -----")
            print(best_solution)
            print("---------------------------------------")
            print(f"Correctness: {'CORRECT' if is_correct else 'INCORRECT'}")
        else:
            print("No solution was generated.")

        if is_correct:
            correct_count += 1
        total_count += 1

        # Calculate and print running accuracy
        accuracy = (correct_count / total_count) * 100
        print(f"Running Accuracy: {accuracy:.2f}% ({correct_count}/{total_count})")

    print("\n\n----- Experiment Summary -----")
    if total_count > 0:
        final_accuracy = (correct_count / total_count) * 100
        print(f"Final Accuracy: {final_accuracy:.2f}%")
        print(f"Total Correct: {correct_count}")
        print(f"Total Attempted: {total_count}")
    else:
        print("No problems were attempted.")
    print("----------------------------")

Running experiment with args: {}


----- Solving Problem #0 -----
Question: Two friends plan to walk along a 43-km trail, starting at opposite ends of the trail at the same time. If Friend P's rate is 15% faster than Friend Q's, how many kilometers will Friend P have walked when they pass each other?
Options: A)21
B)21.5
C)22
D)22.5
E)23
---------------------------------
Using model: google/gemma-2-2b-it on device: cuda




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Model 'google/gemma-2-2b-it' loaded successfully.
-- Step 1 --
-- Top Valued Thoughts --: 
**1. Define variables and relationships:**

-- Values --: [0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
-- Selected Thoughts --: ['**1. Define variables and relationships:**\n', '**5.  Substitute and solve:**\n', 'Here are some possible next steps to solve the question, building upon the partial rationale provided:\n', '   *  Let `x` be the speed of Friend Q.\n', '   *  Let `y` be the speed of Friend P.\n']
--------------------
-- Step 2 --
-- Top Valued Thoughts --: 
**1. Define variables and relationships:**
    *   Let 't' be the time in hours they walk before passing each other.

-- Values --: [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 