In [1]:
!pip -q install -U transformers accelerate bitsandbytes sentencepiece

import re
import torch
from dataclasses import dataclass
from collections import Counter
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig



[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m20.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
MODEL_ID = "Qwen/Qwen2.5-3B-Instruct"   # strong + usually fits on free-tier T4 in fp16
# If you want a "base" (less instruction-tuned) variant instead:
# MODEL_ID = "Qwen/Qwen2.5-3B"          :contentReference[oaicite:2]{index=2}

USE_4BIT = False  # set True if you are hitting VRAM limits and colab runtime crashes
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32

bnb_cfg = None
if USE_4BIT:
    bnb_cfg = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
    )

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype=(DTYPE if not USE_4BIT else None),
    quantization_config=bnb_cfg,
).eval()

DEVICE = next(model.parameters()).device
print("Device:", DEVICE, "| 4-bit:", USE_4BIT)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/661 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/3.97G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

Device: cuda:0 | 4-bit: False


## Self-Consistency core
### sample m CoT paths -> parse answers -> majority vote)

In [3]:
ANSWER_IS = re.compile(r"The answer is\s*", re.IGNORECASE)
NUMBER = re.compile(r"-?\d+(?:\.\d+)?")

In [4]:
def trim_after_first_answer(text: str) -> str:
    m = ANSWER_IS.search(text)
    if not m:
        return text.strip()

    q_idx = text.find("\nQ:", m.start())
    if q_idx != -1:
        return text[:q_idx].strip()

    nl = text.find("\n", m.start())
    if nl != -1:
        return text[:nl].strip()

    return text.strip()

In [5]:
@dataclass(frozen=True)
class CoTExample:
    q: str
    r: str
    a: str

@dataclass(frozen=True)
class GenCfg:
    m: int = 40
    batch_size: int = 5
    temperature: float = 0.7
    top_k: int = 40
    top_p: float = 1.0
    max_new_tokens: int = 256
    seed: int | None = None

def build_prompt(examples: list[CoTExample], question: str) -> str:
    parts = ['Solve step-by-step and end with exactly: "The answer is <final answer>."','']
    for e in examples:
        parts += [f"Q: {e.q}", f"A: {e.r} The answer is {e.a}.", ""]
    parts += [f"Q: {question}", "A:"]
    return "\n".join(parts)


### We have implemented 2 different parsers for arithmetic and normal semantic tasks like the original paper does

In [6]:
def parse_arithmetic_first(text: str) -> str | None:
    m = ANSWER_IS.search(text)
    if not m:
        return None
    tail = text[m.end():]
    n = NUMBER.search(tail.replace(",", ""))
    return n.group(0) if n else None

def parse_commonsense_first(text: str) -> str | None:
    m = ANSWER_IS.search(text)
    if not m:
        return None
    tail = text[m.end():]
    s = tail.strip().splitlines()[0].strip().rstrip(".").strip()
    return s or None

In [7]:
def majority_vote(parsed: list[str | None]) -> tuple[str | None, dict[str, int]]:
    xs = [x for x in parsed if x is not None]
    if not xs:
        return None, {}
    c = Counter(xs)
    best = max(c.values())
    tied = {k for k, v in c.items() if v == best}
    for x in xs:
        if x in tied:
            return x, dict(c)
    return c.most_common(1)[0][0], dict(c)

In [8]:
def generate_many(prompt: str, cfg: GenCfg) -> list[str]:
    if cfg.seed is not None:
        torch.manual_seed(cfg.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(cfg.seed)

    inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    input_len = inputs.input_ids.shape[1]

    texts: list[str] = []
    remaining = cfg.m

    while remaining > 0:
        b = min(cfg.batch_size, remaining)
        with torch.inference_mode():
            out = model.generate(
                **inputs,
                do_sample=True,
                temperature=cfg.temperature,
                top_k=cfg.top_k,
                top_p=cfg.top_p,
                num_return_sequences=b,
                max_new_tokens=cfg.max_new_tokens,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )

        new_tokens = out[:, input_len:].detach().cpu()
        texts.extend(tokenizer.batch_decode(new_tokens, skip_special_tokens=True))

        del out, new_tokens
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        remaining -= b

    return [t.strip() for t in texts]

In [9]:
def self_consistency(
    question: str,
    *,
    examples: list[CoTExample],
    parser,
    cfg: GenCfg,
) -> tuple[str | None, dict[str, int], list[str], list[str | None]]:
    prompt = build_prompt(examples, question)
    comps = generate_many(prompt, cfg)
    comps = [trim_after_first_answer(c) for c in comps]
    parsed = [parser(c) for c in comps]
    final, counts = majority_vote(parsed)
    return final, counts, comps, parsed

### Run tests


### Sample few shot CoT examples for the specific task

In [10]:
ARITH_EX = [
    CoTExample(
        q="There are 3 cars in the lot. 2 more arrive. How many cars are there now?",
        r="3 + 2 = 5.",
        a="5",
    ),
    CoTExample(
        q="A shop sells pens for £2 each. You buy 4 pens and pay with £10. How much change do you get?",
        r="4 * 2 = 8. 10 - 8 = 2.",
        a="2",
    ),
]

COMMON_EX = [
    CoTExample(
        q="If it is raining outside, what should you take to stay dry?",
        r="An umbrella keeps rain off you.",
        a="umbrella",
    ),
    CoTExample(
        q="Why do people wear sunscreen at the beach?",
        r="Sunscreen protects skin from UV rays.",
        a="to protect their skin from uv rays",
    ),
]

EXAMPLES_BY_TASK = {
    "arithmetic": ARITH_EX,
    "commonsense": COMMON_EX,
}


## Test Prompts

In [11]:
# choose category: "arithmetic" or "commonsense" and prompt the question for that specific prompt
TASK = "arithmetic"

question = "How many minutes are in 3.75 hours?"

In [13]:
# config
cfg = GenCfg(m=40, temperature=0.7)

In [14]:
examples = ARITH_EX if TASK == "arithmetic" else COMMON_EX
parser = parse_arithmetic_first if TASK == "arithmetic" else parse_commonsense_first

final, counts, comps, parsed = self_consistency(
    question,
    examples=examples,
    parser=parser,
    cfg=cfg,
)

items = sorted(counts.items(), key=lambda kv: (-kv[1], kv[0]))
top = items[0] if items else (None, 0)
second = items[1] if len(items) > 1 else (None, 0)
margin = top[1] - second[1]

print(f"TASK: {TASK}")
print(f"Q: {question}")
print(f"FINAL: {final} | top_count={top[1]} second={second[1]} margin={margin}")
print("Top votes:", items[:10])

# show 3 paths total + 2 winner paths
print("\n--- first 3 samples ---")
for i in range(min(3, len(comps))):
    print(f"\n[{i+1}] parsed={parsed[i]}\n{comps[i]}")

print("\n--- 2 winner samples ---")
shown = 0
for c, p in zip(comps, parsed):
    if p == final:
        shown += 1
        print(f"\n[w{shown}] parsed={p}\n{c}")
        if shown == 2:
            break


TASK: arithmetic
Q: How many minutes are in 3.75 hours?
FINAL: 225 | top_count=38 second=1 margin=37
Top votes: [('225', 38), ('270', 1)]

--- first 3 samples ---

[1] parsed=225
60 * 3.75 = 225. The answer is 225.

[2] parsed=225
60 * 3.75 = 225. The answer is 225.

[3] parsed=225
1 hour = 60 minutes
3.75 hours = 3.75 * 60 = 225 minutes The answer is 225.

--- 2 winner samples ---

[w1] parsed=225
60 * 3.75 = 225. The answer is 225.

[w2] parsed=225
60 * 3.75 = 225. The answer is 225.


In [15]:
# =========================
# QUICK SWITCH: arithmetic tests (pick ONE line, set TASK="arithmetic")
# =========================
TASK = "arithmetic"
question = "An item costs £80. It is discounted by 15%, then an extra £7 is taken off. What is the final price?"
# question = "A car travels 315 km in 4.5 hours. What is its average speed in km/h?"
# question = "You have 2.4 liters of juice. You pour out 0.35 liters, then add 0.9 liters. How many liters do you have now?"
# question = "A box has 96 chocolates. You give away 3/8 of them. How many chocolates remain?"
# question = "How many minutes are in 3.75 hours?"
# question = "A worker earns £14/hour for 6 hours and £21/hour for 2 overtime hours. What is the total pay?"
# question = "The average of 5 numbers is 12. The sum of 4 of them is 38. What is the fifth number?"
# question = "A store has 2,400 items. It sells 12.5% of them. How many items remain?"
# question = "If you round 19.95 to the nearest whole number, what do you get?"


In [16]:
# =========================
# QUICK SWITCH: commonsense tests (pick ONE line, set TASK="commonsense")
TASK = "commonsense"
question = "What tool do you typically use to tighten a screw?"
# question = "Where would you store ice cream to keep it from melting?"
# question = "If you want to cut a piece of cardboard, what tool is most appropriate?"
# question = "Why do people wear a seatbelt in a car?"
# question = "What do you use to write on a whiteboard?"
# question = "If your phone battery is low, what should you connect it to?"
# question = "What appliance do you use to heat up leftover soup quickly?"
# question = "If you spill water on the floor, what should you use to clean it up?"
# question = "What do you use to unlock a standard door lock?"
# question = "If it’s very bright outside, what might you wear to protect your eyes?"
