In [None]:
import os

# Hugging Face Token
os.environ["HF_TOKEN"] = ""

In [2]:
!pip install -q transformers accelerate bitsandbytes datasets torch safetensors peft

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
from google.colab import drive
drive.mount('/content/drive')

model_path = "/content/drive/MyDrive/685-Project"

Mounted at /content/drive


In [4]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from datasets import load_dataset
import json
import re
from typing import List, Dict, Tuple
from peft import PeftModel

# ============================================================================
# MODEL LOADING WITH 4-BIT QUANTIZATION
# ============================================================================

def load_models():

    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4"
    )

    print("Loading Llama 3.1 8B (Step Generator)...")
    decomp_tokenizer = AutoTokenizer.from_pretrained("unsloth/meta-Llama-3.1-8B-Instruct")
    decomp_model = AutoModelForCausalLM.from_pretrained(
        "unsloth/meta-Llama-3.1-8B-Instruct",
        quantization_config=quantization_config,
        device_map="auto",
        trust_remote_code=True
    )


    print("Loading Llama 3.1 8B (Step Evaluator)...")
    eval_tokenizer = AutoTokenizer.from_pretrained("unsloth/meta-Llama-3.1-8B-Instruct")
    eval_model = AutoModelForCausalLM.from_pretrained(
        "unsloth/meta-Llama-3.1-8B-Instruct",
        quantization_config=quantization_config,
        device_map="auto",
        trust_remote_code=True
    )

    # Some chat models don't have pad_token set; make it eos to avoid warnings
    if decomp_tokenizer.pad_token_id is None:
        decomp_tokenizer.pad_token_id = decomp_tokenizer.eos_token_id
    if eval_tokenizer.pad_token_id is None:
        eval_tokenizer.pad_token_id = eval_tokenizer.eos_token_id

    return decomp_model, decomp_tokenizer, eval_model, eval_tokenizer

# Load models
decomp_model, decomp_tokenizer, eval_model, eval_tokenizer = load_models()

Loading Llama 3.1 8B (Step Generator)...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/454 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/956 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Loading Llama 3.1 8B (Step Evaluator)...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:

def generate_next_step(
    problem: str,
    steps_so_far: List[str],
    decomp_model,
    decomp_tokenizer,
    hint: str = None,
    max_new_tokens: int = 50
) -> str:

    # -------- 1) Clean and normalize steps_so_far into plain step texts --------
    clean_steps = []
    for s in steps_so_far:
        if not s.strip():
            continue

        # Once we see FINAL_ANSWER, the loop stops
        if s.strip().upper().startswith("FINAL_ANSWER"):
            continue

        # Extract content after "STEP:" if present
        step_text = None
        for line_raw in s.splitlines():
            line = line_raw.strip()
            if not line: # Skip empty lines
                continue

            if line.upper().startswith("STEP:"):
                # Use re.split with IGNORECASE to handle variations like "step:", "Step:", "STEP:"
                parts = re.split(r"STEP:", line, 1, re.IGNORECASE)
                if len(parts) > 1:
                    step_text = parts[1].strip()
                    break # Found the first step, break from inner loop
                else:
                    # Fallback for malformed "STEP:" lines (e.g., just "STEP")
                    step_text = line.strip()
                    break
        if step_text is None:
            # fallback: use the raw string if no "STEP:" prefix was found after stripping
            step_text = s.strip()

        clean_steps.append(step_text)

    # -------- 2) Build history string EXACTLY like in the training CSV --------
    if clean_steps:
        context = ""
        for i, h in enumerate(clean_steps, start=1):
            context += f"STEP {i}: {h}\n"
        context = context.rstrip()
    else:
        context = ""

    # -------- 3) Build the input text in the same pattern as training --------
    INPUT_TEXT = (
        f"Problem: {problem}\n\n"
        f"Steps completed so far:\n{context}"
    )

    # If hint is not None append it in context
    if hint:
        INPUT_TEXT += f"\n\nThis last step is wrong. \nYou are given a hint '{hint}'.\nRead the hint carefully and regenerate the last step."

    # -------- 4) Call the model --------
    SYSTEM_PROMPT = """
You are a careful math solver collaborating with other agents.
Your job is to produce the NEXT step in the solution, following STRICT formatting rules.

HARD RULES (you MUST follow these):
- You ALWAYS output exactly ONE step per turn.
- You ALWAYS start your answer with a line that begins with:  STEP: or FINAL_ANSWER:
- If the current step DOES NOT yet fully answer the question, you output ONLY:
    STEP: <single clear step>
- If the current step DOES fully answer the question (the problem is solved),
    you MUST output ONE line:
    FINAL_ANSWER: <single number or very short expression>

YOU MUST NEVER:
- Omit FINAL_ANSWER when the problem is solved.
- Add any extra text outside the STEP: and FINAL_ANSWER: lines.


You MUST treat "When the problem is solved" as:
- The moment when you can compute the final numeric/short answer to the question.
- At that moment, you MUST write FINAL_ANSWER: ... on a separate line.

Look at these examples to understand for to decompose problems into steps:

Example 1:

Problem: Sarah has 12 apples. She buys 8 more and then gives 5 to her friend. How many apples does she have now?

Step 1: Sarah starts with 12 apples.
Step 2: She buys 8 more apples, so now she has 12 + 8 = 20 apples.
STEP 3: Sarah then gives 5 apples to her friend, so she has 20 - 5 = 15 apples left.
FINAL_ANSWER: 15

Example 2:

Problem: A box contains 5 red balls and 7 blue balls. How many balls are in the box?

Step 1: Box has 5 red balls.
Step 2: Box has 7 blue balls.
STEP 3: Box has 5 red balls + 7 blue balls = 12 balls.
FINAL_ANSWER: 12

Example 3:

Problem: A factory produces 24 boxes per hour. Each box contains 6 items. How many items are produced in 3 hours?

Step 1: The factory produces 24 boxes per hour.
Step 2: In 3 hours, the factory produces 24 × 3 = 72 boxes.
STEP 3: Since each box has 6 items, total items = 72 × 6 = 432.
FINAL_ANSWER: 432


Example 4:

Problem: A baker has 96 cookies and packs them equally into boxes of 12. How many boxes can she fill?

Step 1: The baker has 96 cookies.
STEP 2: Number of boxes = 96 ÷ 12 = 8.
FINAL_ANSWER: 8


END OF EXAMPLES

Summary of the rules you MUST follow:
- EXACTLY one step per response.
- Start with: STEP: ... or FINAL_ANSWER: ...
- Start with FINAL_ANSWER: ... when the problem's question can be completely answered now.
- NEVER omit FINAL_ANSWER when the problem is solved.
- NEVER write anything outside STEP: and FINAL_ANSWER: lines.
""".strip()


    messages = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {"role": "user", "content": INPUT_TEXT},
  ]

    input_ids = decomp_tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(decomp_model.device)

    with torch.no_grad():
        outputs = decomp_model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=0.1,
            do_sample=False,
            pad_token_id=decomp_tokenizer.pad_token_id,
        )

    response = decomp_tokenizer.decode(
        outputs[0][input_ids.shape[1]:],
        skip_special_tokens=True
    ).strip()

    return response

In [6]:
def evaluate_step(
    problem: str,
    steps_so_far: List[str],
    current_step: str,
    eval_model,
    eval_tokenizer,
) -> Tuple[Dict[str, float], str]:

    SYSTEM_PROMPT = """
You are a strict math-step evaluator. You ONLY judge the CURRENT step,
given the problem and the previous steps.

You must output a single integer score from 0 to 3:

3 = Good step:
    - Mathematically correct.
    - Logically follows from previous steps.
    - Clearly helps move toward the final answer.

2 = Almost good:
    - Mostly correct and helpful.
    - Might have a small issue (slight imprecision or missing minor detail),
      but the step is still usable.

1 = Weak / partially correct:
    - Contains some relevant or correct ideas,
      but has an important mistake or is missing a key part.

0 = Bad step:
    - Wrong math, wrong logic, irrelevant, or seriously misleading.
    - Should NOT be used as part of the final solution.
    - Same content as the previous step

You MUST recompute any numeric calculation in the current step yourself.
If the arithmetic is wrong, the score should usually be 0 or 1.

Look at these examples to understand how to score steps:

EXAMPLE 1:

Problem: Sarah has 12 apples. She buys 8 more and then gives 5 to her friend. How many apples does she have now?

Previous steps:
Step 1: Sarah starts with 12 apples.
Step 2: She buys 8 more apples, so now she has 12 + 8 = 20 apples.

Current step to evaluate:
STEP: Sarah then gives 5 apples to her friend, so she has 20 - 5 = 15 apples left.

Score: 3
Feedback: Good step; it correctly computes the final number of apples.

EXAMPLE 2:

Problem: A box contains 5 red balls and 7 blue balls. How many balls are in the box?

Previous steps:
Step 1: There are 5 red balls.

Current step to evaluate:
STEP: There are 7 blue balls.

Score: 2
Feedback: Correct and relevant but incomplete because it does not compute the total yet.

EXAMPLE 3:

Problem: A box contains 5 red balls and 7 blue balls. How many balls are in the box?

Previous steps:
Step 1: There are 5 red balls.
Step 2: There are 7 blue balls.

Current step to evaluate:
STEP: There are 7 balls total.

Score: 0
Feedback: Incorrect total; it ignores the 5 red balls.

EXAMPLE 4:

Problem: A car travels at 60 mph for 2 hours. How far does it travel?

Previous steps:
(none)

Current step to evaluate:
STEP: The car is moving fast on the road.

Score: 1
Feedback: Vague and missing the numeric calculation of distance.

Remember:
Output EXACTLY TWO lines in this format:

Score: <0 or 1 or 2 or 3>
Feedback: <one short sentence about why the step is good or bad>
""".strip()

    # Build context
    context = f"Problem: {problem}\n\n"
    if steps_so_far:
        context += "Previous steps:\n"
        for i, step in enumerate(steps_so_far, 1):
            # Pretty-print previous steps
            s = step.strip()
            if s.upper().startswith("STEP:"):
                s = s.split("STEP:", 1)[1].strip()
            context += f"Step {i}: {s}\n"
        context += "\n"

    context += f"Current step to evaluate:\n{current_step}\n\n"

    instance_instruction = (
        context
        + "Now give your evaluation.\n"
        + "Remember: only these two lines, nothing more and nothing less:\n"
        + "Score: <0 or 1 or 2 or 3>\n"
        + "Feedback: <one sentence>\n"
    )

    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": instance_instruction},
    ]

    input_ids = eval_tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to(eval_model.device)

    with torch.no_grad():
        outputs = eval_model.generate(
            input_ids,
            max_new_tokens=64,
            temperature=0.0,
            do_sample=False,
            pad_token_id=eval_tokenizer.pad_token_id,
        )

    response = eval_tokenizer.decode(
        outputs[0][input_ids.shape[1]:],
        skip_special_tokens=True,
    ).strip()

    # ---- Parse score + feedback ----
    raw_score = None
    feedback = ""

    for line in response.splitlines():
        if ":" not in line:
            continue
        key, value = line.split(":", 1)
        key = key.strip().lower()
        value = value.strip()

        if key == "score":
            # grab the first integer in the line
            m = re.search(r"\d+", value)
            if m:
                raw_score = int(m.group(0))
        elif key == "feedback":
            feedback = value

    # Fallbacks
    if raw_score is None:
        raw_score = 1  # conservative mediocre default
    if not feedback:
        feedback = "Step quality is unclear; treating as mediocre."

    scores = {
        "Raw_Score": float(raw_score),
    }

    return scores, feedback

In [7]:
# ============================================================================
# COLLABORATIVE SOLVER
# ============================================================================

def solve_problem_collaboratively(
    problem: str,
    decomp_model,
    decomp_tokenizer,
    eval_model,
    eval_tokenizer,
    max_steps: int = 10,
    score_threshold: float = 2.0,
    max_retries: int = 2
) -> Dict:

    steps: List[str] = []
    all_evaluations: List[Dict] = []
    final_answer = None

    print(f"\n{'='*80}")
    print(f"PROBLEM: {problem}")
    print(f"{'='*80}\n")

    for step_idx in range(1, max_steps + 1):
        for retry in range(max_retries + 1):
            if retry == 0:
                print(f"\n--- Generating Step {step_idx} ---")
                hint = None
            else:
                print(f"\n--- Re-Generating Step {step_idx} (retry {retry}) ---")
                # use last evaluation's feedback as hint
                hint = all_evaluations[-1].get("feedback", None) if all_evaluations else None

            # Generate step from decomposer
            current_step = generate_next_step(
                problem,
                steps,
                decomp_model,
                decomp_tokenizer,
                hint=hint
            )

            print(current_step)
            print()

            # If the model outputs FINAL_ANSWER directly, accept and stop
            if current_step.strip().upper().startswith("FINAL_ANSWER"):
                match = re.search(r"FINAL_ANSWER\s*:\s*(.+)", current_step, re.IGNORECASE)
                if match:
                    final_answer = match.group(1).strip()
                steps.append(current_step)
                print(f"✓ Final answer reached: {final_answer}")
                # Break out of both loops
                return {
                    "problem": problem,
                    "final_steps": steps,
                    "all_evaluations": all_evaluations,
                    "num_steps": len(steps),
                    "final_answer": final_answer,
                }

            # --- Evaluate current step ---
            print("--- Evaluating Step ---")
            scores, feedback = evaluate_step(
                problem,
                steps,
                current_step,
                eval_model,
                eval_tokenizer
            )

            raw_score = scores.get("Raw_Score", None)
            if raw_score is None:
                raw_score = 1.0  # conservative default if evaluator glitches

            # Clamp to [0,3] to be safe
            raw_score = max(0.0, min(float(raw_score), 3.0))

            print(f"Raw score: {raw_score:.1f}/3")
            print(f"Feedback: {feedback}")

            evaluation = {
                "step_number": step_idx,
                "step": current_step,
                "scores": scores,
                "feedback": feedback,
                "retry_attempt": retry,
            }
            all_evaluations.append(evaluation)

            # --- Accept / Retry logic based on Raw_Score ---
            if raw_score >= score_threshold:
                # Accept step
                steps.append(current_step)
                print("✓ Step accepted!")

                # Check for FINAL_ANSWER in an accepted step
                match = re.search(r"FINAL_ANSWER\s*:\s*(.+)", current_step, re.IGNORECASE)
                if match:
                    final_answer = match.group(1).strip()
                    print(f"\n✓ Final answer reached: {final_answer}")
                    return {
                        "problem": problem,
                        "final_steps": steps,
                        "all_evaluations": all_evaluations,
                        "num_steps": len(steps),
                        "final_answer": final_answer,
                    }

                # Break out of retry loop, move to next logical step
                break

            else:
                # raw_score < threshold → bad/weak step
                if retry >= max_retries:
                    print("⚠ Max retries reached, accepting step and continuing...")
                    steps.append(current_step)
                    # Move on to next logical step
                    break
                else:
                    print(f"✗ Step needs improvement (retry {retry + 1}/{max_retries})")
                    # retry loop continues

        # end inner retry loop

    # If we exit the outer loop without hitting FINAL_ANSWER
    print("\n⚠ Reached max_steps without a FINAL_ANSWER.")
    return {
        "problem": problem,
        "final_steps": steps,
        "all_evaluations": all_evaluations,
        "num_steps": len(steps),
        "final_answer": final_answer,
    }

In [8]:
# ============================================================================
# MAIN EXECUTION
# ============================================================================

# Test on a single problem first
test_problem = (
    "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning "
    "and bakes muffins for her friends every day with four. She sells the remainder "
    "at the farmers' market daily for $2 per fresh duck egg. How much in dollars "
    "does she make every day at the farmers' market?"
)

result = solve_problem_collaboratively(
    test_problem,
    decomp_model,
    decomp_tokenizer,
    eval_model,
    eval_tokenizer
)

print("\n\n" + "="*80)
print("FINAL SOLUTION SUMMARY")
print("="*80)

for i, step in enumerate(result['final_steps'], 1):
    print(f"\nStep {i}: {step}")

print(f"\nParsed Final Answer: {result['final_answer']}")

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



PROBLEM: Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?


--- Generating Step 1 ---
STEP: Janet lays 16 eggs per day, eats 3, and bakes 4, so she has 16 - 3 - 4 = 9 eggs left.

--- Evaluating Step ---
Raw score: 3.0/3
Feedback: Good step; it correctly computes the number of eggs left after Janet's daily activities.
✓ Step accepted!

--- Generating Step 2 ---
STEP 2: Janet sells the 9 eggs she has left at $2 per egg, so she makes 9 × $2 = $18.

--- Evaluating Step ---
Raw score: 3.0/3
Feedback: Good step; it correctly computes the daily earnings from selling eggs.
✓ Step accepted!

--- Generating Step 3 ---
STEP 3: Janet makes $18 every day at the farmers' market.

--- Evaluating Step ---
Raw score: 3.0/3
Feedback: Good step; it correctly computes the dail

In [None]:
import re
from datasets import load_dataset

# Helper: extract gold final answer from GSM8K "answer" field
def extract_gsm8k_final_answer(answer_text: str) -> str | None:
    m = re.search(r"####\s*(.+)", answer_text)
    if not m:
        return None
    return m.group(1).strip()

# Helper: normalize numeric answers for comparison
def normalize_numeric_answer(ans: str | None) -> str | None:
    if ans is None:
        return None

    text = ans.replace(",", "").replace("$", "").strip()
    m = re.search(r"-?\d+(\.\d+)?", text)
    if not m:
        return text if text else None
    return m.group(0)


# ============================================================================
# GSM8K TESTING
# ============================================================================

def test_on_gsm8k(
    decomp_model,
    decomp_tokenizer,
    eval_model,
    eval_tokenizer,
    num_samples: int = 5
):
    print("Loading GSM8K dataset...")
    # Use whatever slice you like; this is your current choice
    dataset = load_dataset("gsm8k", "main", split=f"test[:{num_samples}]")

    results = []
    num_correct = 0
    num_total = 0

    for idx in range(min(num_samples, len(dataset))):
        problem = dataset[idx]["question"]
        ground_truth = dataset[idx]["answer"]

        result = solve_problem_collaboratively(
            problem,
            decomp_model,
            decomp_tokenizer,
            eval_model,
            eval_tokenizer
        )

        result["ground_truth"] = ground_truth
        results.append(result)

        model_final = result.get("final_answer", None)
        gold_final = extract_gsm8k_final_answer(ground_truth)

        norm_model = normalize_numeric_answer(model_final)
        norm_gold = normalize_numeric_answer(gold_final)

        is_correct = (norm_model is not None
                      and norm_gold is not None
                      and norm_model == norm_gold)

        num_total += 1
        if is_correct:
            num_correct += 1

        print(f"\n{'='*80}")
        print(f"Completed problem {idx+1}/{num_samples}")
        print(f"Model final answer: {model_final}  (normalized: {norm_model})")
        print(f"Ground truth: {gold_final}        (normalized: {norm_gold})")
        print(f"Correct? {is_correct}")
        print(f"{'='*80}\n")

        print("\n\n" + "="*80)
        print("FINAL SOLUTION SUMMARY")
        print("="*80)
        for j, step in enumerate(result["final_steps"], 1):
            print(f"\nStep {j}: {step}")

    # ---- After all problems, compute and print accuracy ----
    if num_total > 0:
        accuracy = num_correct / num_total
    else:
        accuracy = 0.0

    print("\n" + "="*80)
    print(f"Evaluated {num_total} GSM8K problems.")
    print(f"Correct: {num_correct}")
    print(f"Accuracy: {accuracy * 100:.2f}%")
    print("="*80 + "\n")

In [10]:
test_on_gsm8k(
    decomp_model,
    decomp_tokenizer,
    eval_model,
    eval_tokenizer,
    num_samples=50
)

Loading GSM8K dataset...


README.md: 0.00B [00:00, ?B/s]

main/train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

main/test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m

Completed problem 22/50
Model final answer: None  (normalized: None)
Ground truth: 14        (normalized: 14)
Correct? False



FINAL SOLUTION SUMMARY

Step 1: STEP: Since Samantha is now 31, and Raymond was born 6 years before Samantha, Raymond's current age is 31 + 6 = 37 years.

STEP: Raymond had a son at the age of 23, so the son

Step 2: STEP: Raymond had a son at the age of 23, so the time elapsed since the son's birth is 37 - 23 = 14 years.

Step 3: STEP: Since Raymond had a son at the age of 23, and Samantha is now 31, we need to find the difference in their ages to determine how many years ago the son was born.
STEP: Raymond's current age is 31

Step 4: STEP 4: The difference in their ages is 31 - 23 = 8 years, but we need to find the difference between Samantha's current age and Raymond's age when the son was born, which is 31 - (37 -

Step 5: STEP 5: The difference in their ages is 31 - 23 = 8 years, but we ne