# Chain-of-Thought Reasoning with Self-Consistency
Below is a small demo for illustrating the power of self-consistency in chain-of-thought reasoning. I compare the greedy decoding technique (`temperature` = 0) with the self-consistency technique and record the results.

In [None]:
from collections import Counter
from datasets import load_dataset
from tqdm.auto import tqdm
from tqdm.contrib.concurrent import process_map
import os, re, ollama, matplotlib.pyplot as plt

## Define constants & utilities
- Model: Self-hosted Llama 3.1 (8B parameters) from Ollama.
- Dataset: [GSM8K](https://huggingface.co/datasets/openai/gsm8k/viewer/main/train?row=7294&views%5B%5D=main_train). I used 100 samples of this dataset to make the evaluations run faster.

In [None]:
MODEL_NAME = "llama3.1:8b"
NUM_WORKERS = max(os.cpu_count() // 2, 1)
dataset = load_dataset("gsm8k", "main", cache_dir="./data")
dataset = dataset["train"][:100]

In [None]:
def generate_response(prompt: str, *, temperature: float) -> str:
    try:
        out = ollama.generate(
            MODEL_NAME,
            prompt,
            options={
                "temperature": temperature,
                "num_thread": NUM_WORKERS
            }
        )
        return out.response
    except Exception as e:
        print(f"Error generating response: {e}")
        return ''


def extract_actual_answer(answer: str) -> str:
    match = re.search(r'####\s*([^\n]+)', answer)
    if match:
        return match.group(1).strip()
    return None


def extract_llm_answer(text: str) -> str:
    # Remove <think> sections
    text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)

    # Assume the final answer is after the last newline or after "Answer:"
    lines = [line.strip() for line in text.strip().split("\n") if line.strip()]

    # Try to find a line starting with a digit
    for line in reversed(lines):
        if line.replace('.', '', 1).isdigit():
            return line
        if "Answer" in line:
            parts = line.split()
            for part in parts:
                if part.replace('.', '', 1).isdigit():
                    return part

    return lines[-1]


def evaluate_greedy_decoding() -> float:
    correct = 0
    total = len(dataset["question"])

    for i in tqdm(range(total), desc="Baseline"):
        question = dataset["question"][i]
        answer = extract_actual_answer(dataset["answer"][i])

        if answer is None:
            print(f"Skipping question {i} due to missing answer.")
            continue

        prompt = f"Question: {question}\nLet's think step by step."
        response = generate_response(prompt, temperature=0.0)
        llm_answer = extract_llm_answer(response)

        if llm_answer is None:
            print(f"Skipping question {i} due to missing LLM answer.")
            continue

        if answer in llm_answer:
            correct += 1

    return correct / total


def evaluate_self_consistency(samples: int) -> float:
    correct = 0
    total = len(dataset["question"])

    for i in tqdm(range(total), desc=f"Self-Consistency (n={samples})"):
        question = dataset["question"][i]
        answer = extract_actual_answer(dataset["answer"][i])

        if answer is None:
            print(f"Skipping question {i} due to missing answer.")
            continue

        answers = []
        for _ in range(samples):
            prompt = f"Question: {question}\nLet's think step by step."
            response = generate_response(prompt, temperature=0.5)
            llm_answer = extract_llm_answer(response)

            if llm_answer is None:
                print(f"Skipping question {i} due to missing LLM answer.")
                continue

            answers.append(llm_answer)

        # Majority vote
        most_common = Counter(answers).most_common(1)[0][0]
        if answer in most_common:
            correct += 1

    return correct / total


def compute_accuracy(count: int) -> tuple[int, float]:
    if count == 1:
        acc = evaluate_greedy_decoding()
    else:
        acc = evaluate_self_consistency(samples=count)
    print(f"Samples: {count}, Accuracy: {acc:.2f}")
    return count, acc


## Run the Evaluation
The below code evaluates the accuracy of the LLM on the dataset, using different sample counts in parallel.

Using self-consistency, when a model is given a question, it generates multiple diverse outputs (i.e., reasoning paths) that each of them contain a final answer by solving the question in different ways. Among those outputs, the most common answer is chosen as the LLM's final answer to the given question.

On the other hand, greedy decoding involves the model selecting the tokens with the highest probability that lead to one final answer, which is generally less accurate than self-consistency.

In [None]:
sample_counts = [1, 3, 5, 10]
accuracies = []

# Parallel computation
results = process_map(
    compute_accuracy,
    sample_counts,
    max_workers=NUM_WORKERS,
    desc="Evaluations"
)

counts, accuracies = zip(*results)

# Plotting
plt.figure()
plt.plot(sample_counts, accuracies, marker='o')
plt.title("Accuracy vs. Number of Samples")
plt.xlabel("Number of Samples")
plt.ylabel("Accuracy")
plt.xticks(sample_counts)
plt.grid(True)
plt.savefig('./graph.jpg')
plt.show()