In [None]:
import torch

import warnings

# Import your modules
from src import *

warnings.simplefilter("ignore")
import torch._dynamo
torch._dynamo.config.suppress_errors = True
import logging
logging.basicConfig(filename="myfile.txt",level=logging.DEBUG)
!VLLM_CONFIGURE_LOGGING=0


In [None]:
def builder(args):
    torch.manual_seed(args.seed)
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    generator = GetPaths(args, 
                              topk=args.topk, 
                              max_new_tokens=args.max_new_tokens, )

    cot_decoding = Decoder(tokenizer=tokenizer,)

    return generator, cot_decoding, tokenizer



In [None]:
import re
import numpy as np
from datasets import load_dataset
from types import SimpleNamespace
from tqdm import tqdm


# Set up evaluation arguments.
args = SimpleNamespace(
    seed=42,
    model= "google/gemma-2-2b", #"meta-llama/Llama-3.2-1B", Qwen/Qwen2.5-1.5B ...
    dtype="float16",
    max_model_len=512,
    topk=10,                   # Number of paths to generate.
    max_new_tokens=512,        # Maximum tokens to generate.
    pattern=r'-?\d+\.?\d*',    # Regex pattern for arithmetic answers.
    demo_prompt="",            # This will be updated for each sample.
    groundtruth="",            # This will be updated for each sample.
    verbose=False               # Verbose flag to print each output.
)

# Define decoding strategies to test.
# For sampling modes, we pass temperature and top_p via the continue_decoding method.

#Example to test different temperature
strategies = [
    ("topp_sampling_1", {"mode": "sampling", "temperature": 1.0, "top_p": 0.9}),
    ("topp_sampling_0.9", {"mode": "sampling", "temperature": 0.9, "top_p": 0.9}),
    ("topp_sampling_0.7", {"mode": "sampling", "temperature": 0.7, "top_p": 0.9}),
    ("topp_sampling_0.5", {"mode": "sampling", "temperature": 0.5, "top_p": 0.9}),
    
]

#Example to test different p
strategies = [
    ("topp_sampling_0.9", {"mode": "sampling", "temperature": 1.0, "top_p": 0.9}),
    ("topp_sampling_0.8", {"mode": "sampling", "temperature": 1.0, "top_p": 0.8}),
    ("topp_sampling_0.7", {"mode": "sampling", "temperature": 1.0, "top_p": 0.7}),
    ("topp_sampling_0.6", {"mode": "sampling", "temperature": 1.0, "top_p": 0.6}),
    
    
]

#example to test all variants
strategies = [
    ("greedy", {"mode": "greedy"}),
    ("beam", {"mode": "beam", "num_beams": 5}),
    ("topk_sampling", {"mode": "sampling", "temperature": 1.0, "top_p": 1.0}),
    ("topp_sampling", {"mode": "sampling", "temperature": 1.0, "top_p": 0.9})
]

# Initialize generator, CoT decoding, and tokenizer.
generator, cot_decoding, tokenizer = builder(args)

# Load the MultiArith dataset.
dataset = load_dataset("ChilleD/MultiArith", split="train")



In [None]:
# Evaluate each decoding strategy.
with open('save.txt', 'w') as f:

    for strategy_name, params in strategies:
        print(f"\nEvaluating strategy: {strategy_name}")
        print(f"\nEvaluating strategy: {strategy_name}", file=f)
        correct_count = 0
        total = 0

        for sample in tqdm(dataset, desc=f"Evaluating {strategy_name}"):
            question = sample["question"]
            groundtruth = str(sample["final_ans"]).strip() 
            
            # Update args for current sample.
            args.demo_prompt = question
            args.groundtruth = groundtruth
            
            # Initial CoT generation
            topk_tokens, outputs = generator.search_cots(args.demo_prompt)
            cot_paths = cot_decoding.get_score(args.demo_prompt, topk_tokens, outputs, groundtruth)
            best_reasoning = cot_paths.best_path.reasoning_text
            
            # Continue decoding if not using greedy 
            if strategy_name != "greedy":
                # Create a partial prompt.
                partial_prompt = generator.format_prompt(question) + best_reasoning
                continued_outputs = generator.continue_decoding([partial_prompt], **params)
                # Append the continued text to the initial best reasoning.
                continued_text = continued_outputs[0].outputs[0].text
                final_reasoning = best_reasoning + continued_text
            else:
                final_reasoning = best_reasoning
            
            # Extract answer 
            extracted = re.findall(args.pattern, final_reasoning)
            if extracted:
                best_answer = extracted[-1].strip()
            else:
                best_answer = "Did not respond explicitly"
            
            # Check if the extracted answer matches the groundtruth.
            is_correct = (best_answer == groundtruth)
            if is_correct:
                correct_count += 1
            total += 1
            
            # Verbose output for each sample.
            if args.verbose:
                print(f"\nSample {total}:")
                print(f"Question: {question}")
                print(f"Groundtruth: {groundtruth}")
                print(f"Strategy: {strategy_name}")
                print(f"Predicted Answer: {best_answer}")
                print(f"Correct: {is_correct}\n")
        
        # Calculate and print the average accuracy for the strategy.
        accuracy = correct_count / total if total > 0 else 0
        print(f"Average Accuracy for strategy {strategy_name}: {accuracy:.4f}\n")
        print(f"Average Accuracy for strategy {strategy_name}: {accuracy:.4f}\n", file=f)


In [None]:
def compute_greedy_accuracy(generator, args, dataset):
    correct_count = 0
    total = 0

    for sample in dataset:
        question = sample["question"]
        # Groundtruth 
        groundtruth = str(sample["final_ans"]).strip()
        
        # Format the prompt using the generator's formatting method.
        formatted_prompt = generator.format_prompt(question)
        
        # Set up sampling parameters for greedy decoding.
        sampling_params = SamplingParams(
            n=1,
            temperature=0,  # Greedy decoding
            top_p=1,
            max_tokens=args.max_new_tokens,
            logprobs=2,
        )
        
        # Generate a single output.
        outputs = generator.model.generate(formatted_prompt, sampling_params, use_tqdm=False)
        output = outputs[0]
        generated_text = output.outputs[0].text
        
        # Extract the answer candidate using regex.
        answer_candidates = re.findall(args.pattern, generated_text)
        if answer_candidates:
            # Use the last found candidate.
            predicted_answer = answer_candidates[-1].strip()
        else:
            predicted_answer = ""
        
        # Compare the predicted answer to the groundtruth.
        is_correct = (predicted_answer == groundtruth)
        if is_correct:
            correct_count += 1
        total += 1
        
        """print(f"Question: {question}")
        print(f"Groundtruth: {groundtruth}")
        print(f"Predicted: {predicted_answer}")
        print(f"Correct: {is_correct}\n")"""
    
    accuracy = correct_count / total if total > 0 else 0
    print(f"Greedy Decoding Accuracy: {accuracy:.4f}")

In [None]:

compute_greedy_accuracy(generator, args, dataset)
