In [1]:
import torch
import time
import argparse
import numpy as np
import requests
import json
import re
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    GenerationConfig,
)

In [2]:
# provided by the authors
# taken from predictions by another model
# manually checked for correctness by us (Wikipedia)
# two added entries (Tyler Hoechlin, Angelina Jolie) due to one double in the original list and one unclear case; now len = 100

year_map = {'Sasha Calle': '1995', 'Annie Murphy': '1986', 'Golshifteh Farahani': '1983', 'Kate Mara': '1983',
            'Josh Hartnett': '1978', 'Jennifer Lawrence': '1990', 'Aaron Taylor-Johnson': '1990',
            'Rebecca Ferguson': '1983', 'Monica Barbaro': '1990', 'Chris Hemsworth': '1983',

            'Wes Anderson': '1969', 'Daniel Portman': '1992', 'Lily-Rose Depp': '1999', "Myha'la Herrold": '1996',
            'Zendaya': '1996', 'Ezra Miller': '1992', 'Olga Kurylenko': '1979', 'Zazie Beetz': '1991',
            'Arnold Schwarzenegger': '1947', 'Emilia Clarke': '1986',

            'Jess Bush': '1992', 'Clara Rugaard': '1997', 'Molly Gordon': '1994', 'Isabel May': '2000',
            'Hailee Steinfeld': '1996', 'Hannah Waddingham': '1974', 'Rory Culkin': '1989', 'Cobie Smulders': '1982',
            'Harrison Ford': '1942',

            'Tom Cruise': '1962', 'Carol Kane': '1952', 'Alexandra Daddario': '1986', 'Gal Gadot': '1985',
            'Tom Holland': '1996', 'Hayley Atwell': '1982', 'Salma Hayek': '1966', 'Ana de Armas': '1988',
            'Will Poulter': '1993', 'Anson Mount': '1973',

            'Paapa Essiedu': '1990', 'Sam Hargrave': '1982', 'Margot Robbie': '1990', 'Nicolas Cage': '1964',
            'Henry Cavill': '1983', 'Juno Temple': '1989', 'Cailee Spaeny': '1998', 'Treat Williams': '1951',

            'Alexander Skarsgård': '1976', 'Rebecca Romijn': '1972', 'Monica Dolan': '1969', 'Anya Taylor-Joy': '1996',
            'Sophia Lillis': '2002', 'Emmanuelle Vaugier': '1976', 'Aaron Paul': '1979', 'Elliot Page': '1987',
            'Robin Tunney': '1972', 'Mike Faist': '1992',

            'Tinatin Dalakishvili': '1991', 'Sarah Snook': '1987', 'Jenna Ortega': '2002', 'Zoe Saldana': '1978',
            'Anjana Vasan': '1987', 'Ben Mendelsohn': '1969', 'Jeremy Allen White': '1991', 'Ayo Edebiri': '1995',
            'Keanu Reeves': '1964', 'Pom Klementieff': '1986',

            'Scarlett Johansson': '1984', 'Tornike Gogrichiani': '1986', 'James Cameron': '1954',
            'Pedro Pascal': '1975', 'Kaley Cuoco': '1985', 'Samuel L. Jackson': '1948', 'Terri Ivens': '1967',
            'Florence Pugh': '1996', 'Shea Whigham': '1969',

            'Kingsley Ben-Adir': '1986', 'Michael Keaton': '1951', 'Julian Sands': '1958', 'Christopher Nolan': '1970',
            'Tom Hanks': '1956', 'Clint Eastwood': '1930', 'Gabriel Macht': '1972', 'Fabiana Udenio': '1964',
            'Tom Bateman': '1989', 'Jack Champion': '2004',

            'Jake Gyllenhaal': '1980', 'Leonardo DiCaprio': '1974', 'Jason Schwartzman': '1980',
            'Grace Caroline Currey': '1996', 'Sydney Sweeney': '1997', 'Emily Rudd': '1993', 'Samuel Blenkin': '1996',
            'James Marsden': '1973', 'Jesse Plemons': '1988', 'Alan Ritchson': '1982',

            'Cillian Murphy': '1976', 'Meghan Markle': '1981', 'Tyler Hoechlin': '1987', 'Angelina Jolie': '1975'}

In [None]:
# Parse command line arguments
parser = argparse.ArgumentParser(
    description="CoT-Decoding with lightweight models")
parser.add_argument("--model", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
                    help="Model ID to use (default: TinyLlama-1.1B)")
parser.add_argument("--use_ollama", action="store_true",
                    help="Use Ollama instead of HuggingFace models")
parser.add_argument("--ollama_model", type=str, default="llama2",
                    help="Ollama model to use (default: llama2)")
parser.add_argument("--ollama_url", type=str, default="http://localhost:11434",
                    help="Ollama API URL (default: http://localhost:11434)")
parser.add_argument("--dataset", type=str, default="gsm8k",
                    choices=["gsm8k", "yearparity",
                             "multiarith", "bbh", "custom"],
                    help="Dataset to evaluate on")
parser.add_argument("--split", type=str, default="test",
                    help="Dataset split to evaluate on")
parser.add_argument("--top_k", type=int, default=10,
                    help="Number of alternative tokens to consider (default: 10)")
parser.add_argument("--batch_size", type=int, default=1,
                    help="Batch size for inference")
parser.add_argument("--num_samples", type=int, default=100,
                    help="Number of samples to evaluate")
parser.add_argument("--custom_dataset_path", type=str, default="",
                    help="Path to custom dataset")
parser.add_argument("--decode_method", type=str, default="greedy",
                    choices=["greedy", "cot-decoding"],
                    help="Decoding method to use")
# args = parser.parse_args([
#    "--model", "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
#    "--dataset", "gsm8k",
# args = parser.parse_args([
#    "--model", "Qwen/Qwen2.5-0.5B-Instruct",
#    "--dataset", "gsm8k",
args = parser.parse_args([
    "--model", "Qwen/Qwen2.5-0.5B",
    "--dataset", "gsm8k",
    # args = parser.parse_args([
    #    "--model", "DeepSeek/DeepSeek-R1:1.5B",
    #    "--dataset", "gsm8k",
])
# Add any other arguments you want to pass

In [45]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [46]:
# List of models that are known to work well with RTX 3050 Ti (4GB VRAM)
LIGHTWEIGHT_MODELS = [
    "DeepSeek/DeepSeek-R1:1.5B",
    "Qwen/Qwen2.5-0.5B-Instruct",
    "Qwen/Qwen2.5-0.5B",
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    "microsoft/phi-2",
    "stabilityai/stablelm-3b-4e1t",
    "google/gemma-2b",
    "google/gemma-2b-it",
    "bigcode/starcoder2-3b",
    "mistralai/Mistral-7B-v0.1",  # May require 8-bit quantization
]

In [47]:
# List of recommended Ollama models for RTX 3050 Ti
OLLAMA_MODELS = [
    "deepseek",
    "qwen2.5",
    "llama2",
    "phi",
    "phi3",
    "phi3:mini",
    "mistral",
    "mistral:7b-instruct-v0.2-q4_0",
    "gemma:2b",
    "gemma:2b-instruct",
    "neural-chat",
    "wizard-math:7b-q4_0",
    "stablelm-zephyr"
]

In [48]:
# Verify if the selected model is in the recommended list (if not using Ollama)
if not args.use_ollama and args.model not in LIGHTWEIGHT_MODELS:
    print(
        f"Warning: {args.model} is not in the list of recommended lightweight models.")
    print(f"Recommended models for RTX 3050 Ti: {LIGHTWEIGHT_MODELS}")
    response = input("Do you want to continue? (y/n): ")
    if response.lower() != "y":
        exit()

In [49]:
# Verify if the selected Ollama model is in the recommended list
if args.use_ollama and args.ollama_model not in OLLAMA_MODELS:
    print(
        f"Warning: {args.ollama_model} is not in the list of recommended Ollama models.")
    print(f"Recommended Ollama models for RTX 3050 Ti: {OLLAMA_MODELS}")
    print("You can see available models by running 'ollama list' in terminal")
    response = input("Do you want to continue? (y/n): ")
    if response.lower() != "y":
        exit()

In [50]:
def load_model_and_tokenizer(model_name):
    """Load model and tokenizer with appropriate quantization"""
    if args.use_ollama:
        # For Ollama, we don't need to load a model, so we return placeholder objects
        print(f"Using Ollama with model: {args.ollama_model}")

        # Check if Ollama is running
        try:
            response = requests.get(f"{args.ollama_url}/api/tags")
            if response.status_code != 200:
                raise Exception("Ollama server returned non-200 status code")
            print("Ollama server is running and responding")

            # Try to list available models
            models_response = requests.get(f"{args.ollama_url}/api/tags")
            if models_response.status_code == 200:
                available_models = [model["name"]
                                    for model in models_response.json().get("models", [])]
                print(f"Available Ollama models: {available_models}")

                if args.ollama_model not in available_models:
                    print(
                        f"Warning: Model '{args.ollama_model}' not found in Ollama.")
                    print("You may need to pull it first with: ollama pull",
                          args.ollama_model)

        except Exception as e:
            print(f"Error connecting to Ollama server: {e}")
            print(
                "Make sure Ollama is running on your machine. You can start it by running 'ollama serve'")
            exit(1)

        # We don't actually need a real tokenizer for Ollama, but we'll create a dummy object
        class DummyTokenizer:
            def __call__(self, text, return_tensors=None):
                return {"input_ids": torch.tensor([[0]])}  # Dummy value

            def decode(self, token_ids, skip_special_tokens=None):
                return ""  # We won't actually use this

            eos_token_id = 0  # Dummy value

        return None, DummyTokenizer()
    else:
        print(f"Loading model: {model_name}")

        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        # Load model with 8-bit quantization for larger models
        if "7B" in model_name or "13B" in model_name:
            print("Loading with 8-bit quantization for larger model")
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                device_map="auto",
                load_in_8bit=True,
                torch_dtype=torch.float16
            )
        else:
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                device_map="auto",
                torch_dtype=torch.float16
            )

        # Enable model evaluation mode
        model.eval()

        return model, tokenizer

In [None]:
def load_data(dataset_name, split, num_samples, custom_path=""):
    """Load dataset for evaluation"""
    if dataset_name == "custom" and custom_path:
        # Load custom dataset (expected format: JSONL with 'question' and 'answer' fields)
        try:
            import pandas as pd
            if custom_path.endswith('.csv'):
                df = pd.read_csv(custom_path)
            elif custom_path.endswith('.json') or custom_path.endswith('.jsonl'):
                df = pd.read_json(custom_path, lines=True)
            else:
                raise ValueError(
                    "Custom dataset must be in CSV or JSONL format")

            questions = df['question'].tolist()[:num_samples]
            if 'answer' in df.columns:
                answers = df['answer'].tolist()[:num_samples]
            else:
                answers = [""] * len(questions)

            return {'question': questions, 'answer': answers}
        except Exception as e:
            print(f"Error loading custom dataset: {e}")
            exit(1)

    # Load standard datasets
    if dataset_name == "gsm8k":
        dataset = load_dataset("gsm8k", "main", split=split)

    elif dataset_name == "yearparity":
        # print("Dataset")
        problems = []
        # answers = []
        # with adjustments taken from the paper
        for key in year_map:
            text = "Was " + key + " born in an even or odd year?"
            # print(text)
            # questions.append(text)
            # problems.append(text)
            year = int(year_map[key])
            if year % 2 == 0:
                # answers.append('even')
                problems.append([text, "even"])
            else:
                # answers.append('odd')
                problems.append([text, "odd"])

        return problems

    elif dataset_name == "multiarith":
        dataset = load_dataset(
            "EleutherAI/synthetic-instruct-gpt4-TruthfulQA", split=split)
    elif dataset_name == "bbh":
        # You can specify which BBH tasks to use
        dataset = load_dataset(
            "lukaemon/bbh", "sports_understanding", split=split)
    else:
        raise ValueError(f"Dataset {dataset_name} not supported")

    # Limit the number of samples
    dataset = dataset.select(range(min(num_samples, len(dataset))))

    return dataset

In [52]:
def prepare_question(question, model_name):
    """Format question based on model requirements"""
    if args.use_ollama:
        # Adjust prompt based on model
        # Get base model name without tags
        model = args.ollama_model.split(':')[0]

        if model in ["llama2", "llama3"]:
            return f"<s>[INST] Q: {question} [/INST]"
        elif model in ["mistral"]:
            return f"<s>[INST] Q: {question} [/INST]"
        elif model in ["phi", "phi2", "phi3"]:
            return f"Q: {question}\nA:"
        elif model in ["gemma"]:
            return f"<start_of_turn>user\nQ: {question}<end_of_turn>\n<start_of_turn>model\nA:"
        elif "wizard-math" in model:
            return f"USER: Q: {question}\nASSISTANT:"
        else:
            # Generic format for other models
            return f"Q: {question}\nA:"
    else:
        # Different models have different prompt formats
        if "TinyLlama" in model_name:
            return f"<|user|>\nQ: {question}\n<|assistant|>\nA:"
        elif "phi" in model_name:
            return f"Q: {question}\nA:"
        elif "gemma" in model_name:
            return f"<start_of_turn>user\nQ: {question}<end_of_turn>\n<start_of_turn>model\nA:"
        elif "mistral" in model_name:
            return f"[INST] Q: {question} [/INST] A:"
        else:
            # Default format
            return f"Q: {question}\nA:"

In [None]:
def extract_answer(response_text, dataset, decode_method):
    import re
    """Extract the final answer from the generated response"""
    # This is a simple implementation - you might need to adjust based on your model's output

    # different answer method for Year Parity-Task:
    if dataset == "yearparity":
        # in case of greedy decoding, use first occurrence of "odd" or "even"
        index = 0
        # in case of CoT-decoding, use last occurrence of "odd" or "even"
        if decode_method == "cot_decoding":
            index = -1

        # identifying answer
        found = re.findall(r"\W*even\W*|\W*odd\W*",
                           response_text, flags=re.IGNORECASE)

        # evaluation
        if found:
            # to lowercase to match with expected answer
            return re.sub(r"\W", "", found[index]).lower()

        else:
            return "No answer found"

    try:
        # Look for the answer after "The answer is" or similar phrases
        phrases = ["The answer is", "answer is", "final answer is", "= "]
        for phrase in phrases:
            if phrase in response_text:
                answer_part = response_text.split(phrase)[-1].strip()
                # Extract the first number
                import re
                numbers = re.findall(r"[-+]?\d*\.\d+|\d+", answer_part)
                if numbers:
                    return numbers[0]

        # If no clear answer format is found, return the last number in the text
        numbers = re.findall(r"[-+]?\d*\.\d+|\d+", response_text)
        if numbers:
            return numbers[-1]

        return "No answer found"
    except:
        return "Error extracting answer"

In [54]:
def ollama_generate(prompt, model_name, temperature=0.0, max_tokens=200):
    """Generate text using Ollama API"""
    url = f"{args.ollama_url}/api/generate"

    payload = {
        "model": model_name,
        "prompt": prompt,
        "temperature": temperature,
        "max_tokens": max_tokens,
        "stream": False
    }

    try:
        response = requests.post(url, json=payload)
        if response.status_code == 200:
            return response.json().get("response", "")
        else:
            print(f"Error from Ollama API: {response.status_code}")
            print(response.text)
            return f"Error: {response.status_code}"
    except Exception as e:
        print(f"Exception when calling Ollama API: {e}")
        return f"Error: {str(e)}"

In [55]:
def ollama_get_top_logprobs(prompt, model_name, top_k=10):
    """Get top logprobs for the next token using Ollama API"""
    url = f"{args.ollama_url}/api/generate"

    # We'll use a trick to get logprobs: generate just 1 token and request logprobs
    payload = {
        "model": model_name,
        "prompt": prompt,
        "temperature": 0.0,
        "top_k": top_k,
        "top_p": 1.0,
        "max_tokens": 1,
        "stream": False,
        "options": {
            "num_ctx": 2048,
            "top_k_return": top_k
        }
    }

    try:
        response = requests.post(url, json=payload)
        if response.status_code == 200:
            # Extract the logprobs from the response
            # Note: This is Ollama-specific and might change with API updates
            if "top_k_return" in response.json():
                return response.json()["top_k_return"]
            else:
                # Fallback if logprobs aren't available
                # We'll just return a dummy list with the first token
                first_token = response.json().get("response", "")
                return [(first_token, 0.0)]
        else:
            print(f"Error from Ollama API: {response.status_code}")
            print(response.text)
            return [(f"Error: {response.status_code}", 0.0)]
    except Exception as e:
        print(f"Exception when calling Ollama API: {e}")
        return [(f"Error: {str(e)}", 0.0)]

In [56]:
def greedy_decode(model, tokenizer, input_text, max_new_tokens=200):
    """Standard greedy decoding"""
    if args.use_ollama:
        return ollama_generate(input_text, args.ollama_model, temperature=0.0, max_tokens=max_new_tokens)
    else:
        inputs = tokenizer(input_text, return_tensors="pt").to(device)

        with torch.no_grad():
            outputs = model.generate(
                inputs["input_ids"],
                max_new_tokens=max_new_tokens,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id
            )

        response = tokenizer.decode(
            outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
        return response

In [None]:
def get_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")
    elif torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")


def calculate_confidence(logits: List[torch.Tensor], answer_ids: torch.Tensor) -> float:
    """
    Calculate the confidence score (Δ) as specified in the paper.

    Args:
        logits: List of logits for each decoding step
        answer_ids: Tensor of token ids for the answer

    Returns:
        Confidence score (Δ)
    """
    confidence_sum = 0.0
    valid_tokens = 0
    for t, token_id in enumerate(answer_ids):
        if t >= len(logits):
            break
        token_logits = logits[t]
        probs = torch.softmax(token_logits, dim=-1)
        if probs.size(-1) > 1:
            top_2_probs, _ = torch.topk(probs, min(2, probs.size(-1)))
            if top_2_probs.size(-1) > 1:
                confidence_sum += (top_2_probs[-1]
                                   [0] - top_2_probs[-1][1]).item()
            else:
                confidence_sum += 1.0  # Max confidence if there's only one token
        else:
            confidence_sum += 1.0  # Max confidence if there's only one token
        valid_tokens += 1

    return confidence_sum / valid_tokens if valid_tokens > 0 else 0.0


def aggregate_paths_based_on_scores(paths: List[Tuple[str, float]]) -> Tuple[str, float]:
    """Aggregate multiple paths based on their confidence scores."""
    answer_scores = {}
    for answer, delta in paths:
        answer_scores[answer] = answer_scores.get(answer, 0) + delta
    best_answer = max(answer_scores, key=answer_scores.get)
    return best_answer, answer_scores[best_answer]


def cot_decode(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    messages: List[Dict[str, str]],
    k: int = 10,
    num_beams: int = 1,
    max_new_tokens: int = 512,
    temperature: float = 1.0,
    top_p: float = 1.0,
    repetition_penalty: float = 1.0,
    length_penalty: float = 1.0,
    no_repeat_ngram_size: int = 0,
    early_stopping: bool = False,
    aggregate_paths: bool = False,
) -> Tuple[str, float]:
    """
    Implement CoT-decoding for a given chat input.

    Args:
        model: The Hugging Face transformer model.
        tokenizer: The associated tokenizer.
        messages: List of chat messages in the format [{"role": "user", "content": "..."}]
        k: The number of alternative tokens to consider at the first step.
        num_beams: Number of beams for beam search.
        max_new_tokens: Maximum number of new tokens to generate.
        temperature: Sampling temperature.
        top_p: Nucleus sampling probability.
        repetition_penalty: Repetition penalty factor.
        length_penalty: Length penalty factor.
        no_repeat_ngram_size: Size of n-grams to avoid repeating.
        early_stopping: Whether to stop generation when all beams are finished.
        aggregate_paths: Whether to aggregate multiple paths.

    Returns:
        A tuple containing the best path (or aggregated result) and its confidence score.
    """
    device = get_device()
    model.to(device)

    # Use the chat template to format the input
    if tokenizer.chat_template:
        input_text = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True)
    else:
        # Fallback for tokenizers without chat templates
        input_text = "\n".join(
            [f"{msg['role']}: {msg['content']}" for msg in messages])
        input_text += "\nassistant:"

    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
    attention_mask = torch.ones_like(input_ids).to(device)

    # Set pad_token_id if it's not set
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    # Get the top-k tokens for the first decoding step
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        first_token_logits = outputs.logits[0, -1, :]
        top_k_logits, top_k_indices = torch.topk(first_token_logits, k)

    paths = []
    for idx in top_k_indices:
        # Generate sequence starting with the selected token
        start_ids = torch.cat(
            [input_ids, idx.unsqueeze(0).unsqueeze(0)], dim=-1)
        start_mask = torch.cat([attention_mask, torch.ones(
            (1, 1), dtype=torch.long, device=device)], dim=-1)

        output = model.generate(
            start_ids,
            attention_mask=start_mask,
            max_new_tokens=max_new_tokens,
            num_beams=num_beams,
            temperature=temperature,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            length_penalty=length_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
            early_stopping=early_stopping,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            output_scores=True,
            return_dict_in_generate=True,
        )

        generated_sequence = output.sequences[0]
        answer_ids = generated_sequence[len(input_ids[0]):]
        print(f"answer_ids: {answer_ids}")
        answer_text = tokenizer.decode(answer_ids, skip_special_tokens=True)
        print(f"answer_text: {answer_text}")

        # Calculate confidence score (Δ)
        confidence = calculate_confidence(output.scores, answer_ids)
        paths.append((answer_text, confidence))

    if aggregate_paths:
        return aggregate_paths_based_on_scores(paths)
    else:
        return max(paths, key=lambda x: x[1])

NameError: name 'List' is not defined

In [None]:
def calculate_confidence(logits: List[torch.Tensor], answer_ids: torch.Tensor) -> float:
    """
    Calculate the confidence score (Δ) as specified in the paper.

    Args:
        logits: List of logits for each decoding step
        answer_ids: Tensor of token ids for the answer

    Returns:
        Confidence score (Δ)
    """
    confidence_sum = 0.0
    valid_tokens = 0
    for t, token_id in enumerate(answer_ids):
        if t >= len(logits):
            break
        token_logits = logits[t]
        probs = torch.softmax(token_logits, dim=-1)
        if probs.size(-1) > 1:
            top_2_probs, _ = torch.topk(probs, min(2, probs.size(-1)))
            if top_2_probs.size(-1) > 1:
                confidence_sum += (top_2_probs[-1]
                                   [0] - top_2_probs[-1][1]).item()
            else:
                confidence_sum += 1.0  # Max confidence if there's only one token
        else:
            confidence_sum += 1.0  # Max confidence if there's only one token
        valid_tokens += 1

    return confidence_sum / valid_tokens if valid_tokens > 0 else 0.0

NameError: name 'List' is not defined

In [None]:
def cot_decode(model, tokenizer, input_text, top_k=10, max_new_tokens=200):
    """Implementation of CoT-decoding as per the paper"""
    if args.use_ollama:
        # Step 1: Get the top-k tokens (or completions) for the first position
        top_completions = []

        try:
            # First, try to get top logprobs if the API supports it
            top_logprobs = ollama_get_top_logprobs(
                input_text, args.ollama_model, top_k)

            if len(top_logprobs) > 0 and not top_logprobs[0][0].startswith("Error"):
                # Use the returned logprobs to create different starting points
                for token, logprob in top_logprobs:
                    # Generate a completion starting with this token
                    output = ollama_generate(
                        input_text + token,
                        args.ollama_model,
                        temperature=0.0,
                        max_tokens=max_new_tokens-1
                    )
                    # top_completions.append((token + full_completion, logprob))
                    generated_sequence = output.sequences[0]
                    answer_ids = generated_sequence[len(input_ids[0]):]
                    # print(f"answer_ids: {answer_ids}")
                    answer_text = tokenizer.decode(
                        answer_ids, skip_special_tokens=True)
                    # print(f"answer_text: {answer_text}")

                    # Calculate confidence score (Δ)
                    confidence = calculate_confidence(
                        output.scores, answer_ids)
                    paths.append((answer_text, confidence))

                    if aggregate_paths:
                        return aggregate_paths_based_on_scores(paths)
                    else:
                        return max(paths, key=lambda x: x[1])

            else:
                # Fallback: Generate multiple samples with higher temperature
                print("Logprobs not available, using temperature sampling as fallback")
                for _ in range(top_k):
                    completion = ollama_generate(
                        input_text,
                        args.ollama_model,
                        temperature=0.7,
                        max_tokens=max_new_tokens
                    )
                    # Assign a dummy logprob (we'll use length as a proxy for confidence)
                    dummy_logprob = len(completion) / 100  # Simple heuristic
                    top_completions.append((completion, dummy_logprob))
        except Exception as e:
            print(f"Error in CoT decoding with Ollama: {e}")
            # Fallback to standard greedy decoding
            completion = ollama_generate(
                input_text, args.ollama_model, max_tokens=max_new_tokens)
            return completion

        # Step 3: Select the generation that contains reasoning
        def contains_reasoning(text):
            # Simple heuristic: check if the text has calculations/steps
            reasoning_indicators = [
                "First", "Step", "Let's", "I'll", "=", "+", "-", "*", "/",
                "calculate", "step", "think", "reason", "therefore", "if", "because"
            ]
            return any(indicator in text for indicator in reasoning_indicators) and len(text) > 50

        reasoning_generations = [
            gen for gen in top_completions if contains_reasoning(gen[0])]

        if reasoning_generations:
            # Return the most probable reasoning path
            best_generation = max(reasoning_generations, key=lambda x: x[1])
            return best_generation[0]
        else:
            # Fall back to the most probable generation
            return max(top_completions, key=lambda x: x[1])[0]
    else:
        # Original HuggingFace implementation
        inputs = tokenizer(input_text, return_tensors="pt").to(device)
        input_length = inputs["input_ids"].shape[1]

        # Step 1: Get the top-k tokens for the first position
        with torch.no_grad():
            outputs = model(inputs["input_ids"])
            logits = outputs.logits[:, -1, :]
            top_k_logits, top_k_indices = torch.topk(logits, top_k)

        # Step 2: For each of the top-k first tokens, perform greedy decoding to get full paths
        all_generations = []

        for i in range(top_k):
            first_token = top_k_indices[0, i].unsqueeze(0).unsqueeze(0)

            # Concatenate the input with the first token
            current_input = torch.cat(
                [inputs["input_ids"], first_token], dim=1)

            # Perform greedy decoding for the rest of the sequence
            with torch.no_grad():
                outputs = model.generate(
                    current_input,
                    max_new_tokens=max_new_tokens-1,  # -1 because we already have the first token
                    do_sample=False,
                    pad_token_id=tokenizer.eos_token_id
                )

            generated_text = tokenizer.decode(
                outputs[0][input_length:], skip_special_tokens=True)
            all_generations.append((generated_text, top_k_logits[0, i].item()))

        # Step 3: Select the generation that contains reasoning (heuristic: longer and has calculation steps)
        def contains_reasoning(text):
            # Simple heuristic: check if the text has calculations/steps
            reasoning_indicators = [
                "First", "Step", "Let's", "I'll", "=", "+", "-", "*", "/",
                "calculate", "step", "think", "reason", "therefore", "if", "because"
            ]
            return any(indicator in text for indicator in reasoning_indicators) and len(text) > 50

        reasoning_generations = [
            gen for gen in all_generations if contains_reasoning(gen[0])]

        if reasoning_generations:
            # Return the most probable reasoning path
            best_generation = max(reasoning_generations, key=lambda x: x[1])
            return best_generation[0]
        else:
            # Fall back to the most probable generation
            return max(all_generations, key=lambda x: x[1])[0]

In [None]:
def evaluate(model, tokenizer, dataset, dataset_name, decode_method, top_k):
    """Evaluate the model on the dataset"""
    results = []
    start_time = time.time()

    # print(dataset)
    for i, item in enumerate(dataset):
        # specifically for GSM8K
        if dataset_name == "gsm8k":
            # print("GSM8K")
            lines = item['answer'].split('\n')

            if lines:
                reference_answer = lines[-1].replace("#### ", '')
                question = item['question']

        elif isinstance(dataset, list):
            # for year parity
            question = item[0]
            reference_answer = item[1]

        elif isinstance(dataset, dict):
            # For custom datasets
            # print("Isinstance", item)
            question = dataset["question"][i]
            reference_answer = dataset['answer'][i] if dataset['answer'][i] else "No reference"
        else:
            # For HuggingFace datasets
            if "question" in item:
                question = item["question"]
            elif "input" in item:
                question = item["input"]
            else:
                raise ValueError("Dataset structure not supported")

            if "answer" in item:
                reference_answer = item["answer"]
            else:
                reference_answer = "No reference"

        # Prepare the input
        input_text = prepare_question(question, args.model)

        # Get the prediction
        if decode_method == "greedy":
            response = greedy_decode(model, tokenizer, input_text)
        else:  # cot-decoding
            response = cot_decode(model, tokenizer, input_text, top_k)

        # Extract the final answer
        predicted_answer = extract_answer(
            response, dataset_name, decode_method)

        # Save result
        results.append({
            "question": question,
            "full_response": response,
            "predicted_answer": predicted_answer,
            "reference_answer": reference_answer
        })

        # Print progress
        if (i + 1) % 5 == 0:
            elapsed = time.time() - start_time
            print(f"Processed {i+1}/{len(dataset)} examples ({elapsed:.2f}s)")
            # Print the last example
            print(f"Question: {question}")
            print(f"Response: {response[:100]}...")
            print(f"Predicted: {predicted_answer}")
            print("-" * 50)

    return results

In [None]:
# Evaluation for GSM8K dataset
# taken from the original implementation of the paper

def _is_float(s):
    try:
        float(s)
        return True
    except:
        return False


def is_correct(target, ans):
    if _is_float(target) and _is_float(ans):
        if abs(float(target) - float(ans)) <= 1e-5:
            return True
    elif str(target) == str(ans):
        return True
    return False

In [None]:
def save_results(results, dataset_name, args):
    """Save evaluation results to a file"""
    import json
    import os
    from datetime import datetime

    # Create results directory if it doesn't exist
    os.makedirs("results", exist_ok=True)

    # Create filename with timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    if args.use_ollama:
        model_name = args.ollama_model.replace(':', '-')
    else:
        model_name = args.model.split("/")[-1]

    filename = f"results/cot_decoding_{model_name}_{args.dataset}_{args.decode_method}_{timestamp}.json"

    # Calculate accuracy if reference answers are available
    correct = 0
    total = 0
    # specific evaluation for GSM8K taken from the paper:
    # print(dataset_name, type(dataset_name))
    if dataset_name == "gsm8k":
        print("Evaluation Method GSM8K")
        for result in results:
            if result["reference_answer"] != "No reference":
                total += 1
                if is_correct(result["predicted_answer"], result["reference_answer"]):
                    correct += 1
    else:
        for result in results:
            if result["reference_answer"] != "No reference":
                total += 1
                # Very simple accuracy check - this should be improved for real evaluation
                if str(result["predicted_answer"]) in str(result["reference_answer"]):
                    correct += 1

    accuracy = correct / total if total > 0 else "N/A"

    # Save metadata and results
    output = {
        "metadata": {
            "model": args.ollama_model if args.use_ollama else args.model,
            "dataset": args.dataset,
            "decode_method": args.decode_method,
            "top_k": args.top_k,
            "num_samples": args.num_samples,
            "answers_found": total,
            "accuracy": accuracy,
            "timestamp": timestamp
        },
        "results": results
    }

    with open(filename, "w") as f:
        json.dump(output, f, indent=2)

    print(f"Results saved to {filename}")
    print(f"Accuracy: {accuracy}")

In [63]:
def main():
    # Check GPU info
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(
            f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

    # Load model and tokenizer
    model, tokenizer = load_model_and_tokenizer(args.model)
    print(args.num_samples)
    # Load dataset
    if args.dataset == "custom":
        if not args.custom_dataset_path:
            print("Error: --custom_dataset_path must be provided for custom datasets")
            exit(1)
        dataset = load_data(args.dataset, args.split,
                            args.num_samples, args.custom_dataset_path)
    else:
        dataset = load_data(args.dataset, args.split, args.num_samples)

    print(f"Loaded {len(dataset)} examples from {args.dataset}")

    # Evaluate
    results = evaluate(model, tokenizer, dataset, args.dataset,
                       args.decode_method, args.top_k)

    # Save results
    save_results(results, args.dataset, args)


if __name__ == "__main__":
    main()

GPU: NVIDIA GeForce RTX 4060 Laptop GPU
GPU Memory: 8.59 GB
Loading model: Qwen/Qwen2.5-0.5B
100
Loaded 100 examples from gsm8k
Processed 5/100 examples (19.37s)
Question: Every day, Wendi feeds each of her chickens three cups of mixed chicken feed, containing seeds, mealworms and vegetables to help keep them healthy.  She gives the chickens their feed in three separate meals. In the morning, she gives her flock of chickens 15 cups of feed.  In the afternoon, she gives her chickens another 25 cups of feed.  How many cups of feed does she need to give her chickens in the final meal of the day if the size of Wendi's flock is 20 chickens?
Response:  Wendi feeds each chicken 15 + 25 = 40 cups of feed. Wendi needs 40 / 3 = 13.333333 cups of feed for...
Predicted: 13.333333
--------------------------------------------------
Processed 10/100 examples (39.33s)
Question: Eliza's rate per hour for the first 40 hours she works each week is $10. She also receives an overtime pay of 1.2 times her r

In [64]:
dm = ['apple', 'cherry', 'banana']

for i, y in enumerate(dm):
    print(i, y)

0 apple
1 cherry
2 banana


In [None]:
# ollama run stablelm-zephyr
# ollama run wizard-math:7b-q4_0
# ollama run mistral:v0.1