In [None]:
import os

# Hugging Face Token
os.environ["HF_TOKEN"] = ""

In [None]:
!pip install -q transformers accelerate bitsandbytes datasets torch safetensors peft

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m46.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
from google.colab import drive
drive.mount('/content/drive')

model_path = "/content/drive/MyDrive/685-Project"

Mounted at /content/drive


In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from datasets import load_dataset
import json
import re
from typing import List, Dict, Tuple
from peft import PeftModel

# ============================================================================
# MODEL LOADING WITH 4-BIT QUANTIZATION
# ============================================================================

def load_models():

    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4"
    )

    print("Loading Llama 3.1 8B (Step Generator)...")
    decomp_tokenizer = AutoTokenizer.from_pretrained("unsloth/meta-Llama-3.1-8B-Instruct")
    decomp_base_model = AutoModelForCausalLM.from_pretrained(
        "unsloth/meta-Llama-3.1-8B-Instruct",
        quantization_config=quantization_config,
        device_map="auto",
        trust_remote_code=True
    )

    decomp_adapter_path = "/content/drive/MyDrive/685-Project/SFT-Decomposer/checkpoint-2138"

    decomp_model = PeftModel.from_pretrained(decomp_base_model, decomp_adapter_path)

    return decomp_model, decomp_tokenizer

# Load models
decomp_model, decomp_tokenizer = load_models()

Loading Llama 3.1 8B (Step Generator)...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/454 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/956 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Loading Llama 3.1 8B (Step Evaluator)...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
def generate_next_step(
    problem: str,
    steps_so_far: List[str],
    decomp_model,
    decomp_tokenizer,
    hint: str = None,
    rejected_step: str = None,
    max_new_tokens: int = 50,
) -> str:
    # -------- 1) Clean and normalize steps_so_far into plain step texts --------
    clean_steps = []
    for s in steps_so_far:
        if not s.strip():
            continue

        # Never include FINAL_ANSWER in the history
        if s.strip().upper().startswith("FINAL_ANSWER"):
            continue

        # Extract content after "STEP:" if present
        step_text = None
        for line in s.splitlines():
            line = line.strip()
            if line.upper().startswith("STEP:"):
                step_text = line.split("STEP:", 1)[1].strip()
                break
        if step_text is None:
            step_text = s.strip()

        clean_steps.append(step_text)

    # -------- 2) Build history string EXACTLY like in the training CSV --------
    if clean_steps:
        context = ""
        for i, h in enumerate(clean_steps, start=1):
            context += f"STEP {i}: {h}\n"
        context = context.rstrip()
    else:
        context = ""

    # -------- 3) Build the input text in the same pattern as training + retry info --------
    INPUT_TEXT = (
        f"Problem: {problem}\n\n"
        f"Steps completed so far:\n{context}"
    )

    # If retrying a specific (rejected) step, show it explicitly
    if rejected_step:
        INPUT_TEXT += f"\n\nLast attempted step (INCORRECT):\n{rejected_step}"

    if hint:
        INPUT_TEXT += (
            f"\n\nHint from evaluator: {hint}\n"
            "Read the hint carefully and rewrite ONLY that incorrect step."
            "Do not change the previous correct steps."
        )

    SYSTEM_PROMPT = """
You are a careful stepwise math solver. Read and understand the problem carefully.
When a last attempted step is marked INCORRECT with a hint, you must generate a corrected version of that step.
""".strip()

    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": INPUT_TEXT},
    ]

    input_ids = decomp_tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to(decomp_model.device)

    with torch.no_grad():
        outputs = decomp_model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=0.1,
            do_sample=False,
            pad_token_id=decomp_tokenizer.pad_token_id,
        )

    response = decomp_tokenizer.decode(
        outputs[0][input_ids.shape[1]:],
        skip_special_tokens=True,
    ).strip()

    return response

In [None]:
import re
from datasets import load_dataset

# Helper: extract gold final answer from GSM8K "answer" field
def extract_gsm8k_final_answer(answer_text: str) -> str | None:
    m = re.search(r"####\s*(.+)", answer_text)
    if not m:
        return None
    return m.group(1).strip()

# Helper: normalize numeric answers for comparison
def normalize_numeric_answer(ans: str | None) -> str | None:
    if ans is None:
        return None

    text = ans.replace(",", "").replace("$", "").strip()
    m = re.search(r"-?\d+(\.\d+)?", text)
    if not m:
        return text if text else None
    return m.group(0)

In [None]:
def solve_problem_decomposer_only(
    problem: str,
    decomp_model,
    decomp_tokenizer,
    max_steps: int = 10,
) -> Dict:

    steps: List[str] = []
    final_answer = None

    print(f"\n{'='*80}")
    print(f"PROBLEM: {problem}")
    print(f"{'='*80}\n")

    for step_idx in range(1, max_steps + 1):
        print(f"\n--- Generating Step {step_idx} ---")

        # No hint (since no evaluator)
        current_step = generate_next_step(
            problem,
            steps,
            decomp_model,
            decomp_tokenizer,
            hint=None
        )

        print(current_step)
        print()

        # If the model outputs FINAL_ANSWER directly, accept and stop
        if current_step.strip().upper().startswith("FINAL_ANSWER"):
            match = re.search(r"FINAL_ANSWER\s*:\s*(.+)", current_step, re.IGNORECASE)
            if match:
                final_answer = match.group(1).strip()
            steps.append(current_step)
            print(f"✓ Final answer reached: {final_answer}")
            break

        # Otherwise, treat it as a normal step and continue
        steps.append(current_step)

    # If we exit the loop without seeing FINAL_ANSWER
    if final_answer is None:
        print("\n⚠ Reached max_steps without a FINAL_ANSWER.")

    return {
        "problem": problem,
        "final_steps": steps,
        "all_evaluations": [],
        "num_steps": len(steps),
        "final_answer": final_answer,
    }

In [None]:
def test_on_gsm8k_decomp_only(
    decomp_model,
    decomp_tokenizer,
    num_samples: int = 5
):
    print("Loading GSM8K dataset...")
    dataset = load_dataset("gsm8k", "main", split=f"test[:{num_samples}]")

    results = []
    num_correct = 0
    num_total = 0

    for idx in range(min(num_samples, len(dataset))):
        problem = dataset[idx]["question"]
        ground_truth = dataset[idx]["answer"]

        result = solve_problem_decomposer_only(
            problem,
            decomp_model,
            decomp_tokenizer,
        )

        result["ground_truth"] = ground_truth
        results.append(result)

        model_final = result.get("final_answer", None)
        gold_final = extract_gsm8k_final_answer(ground_truth)

        norm_model = normalize_numeric_answer(model_final)
        norm_gold = normalize_numeric_answer(gold_final)

        is_correct = (norm_model is not None
                      and norm_gold is not None
                      and norm_model == norm_gold)

        num_total += 1
        if is_correct:
            num_correct += 1

        print(f"\n{'='*80}")
        print(f"Completed problem {idx+1}/{num_samples}")
        print(f"Model final answer: {model_final}  (normalized: {norm_model})")
        print(f"Ground truth: {gold_final}        (normalized: {norm_gold})")
        print(f"Correct? {is_correct}")
        print(f"{'='*80}\n")

        print("\n\n" + "="*80)
        print("FINAL SOLUTION SUMMARY")
        print("="*80)
        for j, step in enumerate(result["final_steps"], 1):
            print(f"\nStep {j}: {step}")

    # ---- After all problems, compute and print accuracy ----
    if num_total > 0:
        accuracy = num_correct / num_total
    else:
        accuracy = 0.0

    print("\n" + "="*80)
    print(f"Evaluated {num_total} GSM8K problems.")
    print(f"Correct: {num_correct}")
    print(f"Accuracy: {accuracy * 100:.2f}%")
    print("="*80 + "\n")

In [None]:
test_on_gsm8k_decomp_only(
    decomp_model,
    decomp_tokenizer,
    num_samples=50
)

Loading GSM8K dataset...

PROBLEM: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?


--- Generating Step 1 ---
STEP: First find the number of eggs Janet uses for breakfast and baking: 3 eggs + 4 eggs = <<3+4=7>>7 eggs


--- Generating Step 2 ---
STEP: Then subtract that number from the number of eggs laid to find the number of eggs sold: 16 eggs - 7 eggs = <<16-7=9>>9 eggs


--- Generating Step 3 ---
STEP: Then multiply the number of eggs sold by the price per egg to find Janet's daily earnings: 9 eggs * $2/egg = $<<9*2=18>>18


--- Generating Step 4 ---
FINAL_ANSWER: 18

✓ Final answer reached: 18

Completed problem 1/50
Model final answer: 18  (normalized: 18)
Ground truth: 18        (normalized: 18)
Correct? True



FINAL SOLUTION SUMMARY

Step 1: STEP: 