In [None]:
# Cell 1: Installations
!pip install -U transformers==4.53.2
!pip install -U accelerate
!pip install -U bitsandbytes

# !pip install flash-attn==2.7.4.post1 \
#   --extra-index-url https://download.pytorch.org/whl/cu124 \
#   --no-build-isolation



In [None]:
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton unsloth_zoo
!pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
!pip install --no-deps unsloth



In [None]:
import unsloth
from unsloth import FastModel
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
from huggingface_hub import hf_hub_download
import json
import re
import math

print("✅ All modules imported.")

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
✅ All modules imported.


In [None]:
# ===================================================================
# 1. DEFINE CUSTOM CLASSIFIER (Required for Phi-4)
# ===================================================================
class GPTSequenceClassifier(nn.Module):
    def __init__(self, base_model, num_labels):
        super().__init__()
        self.base = base_model
        hidden_size = base_model.config.hidden_size
        self.classifier = nn.Linear(hidden_size, num_labels, bias=True)
        self.num_labels = num_labels

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        outputs = self.base(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, **kwargs)
        last_hidden_state = outputs.hidden_states[-1]
        pooled_output = last_hidden_state[:, -1, :]
        logits = self.classifier(pooled_output)
        loss = None
        if labels is not None:
            loss = nn.functional.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
        return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}

In [None]:
# ===================================================================
# 2. LOAD MODELS AND TOKENIZERS
# ===================================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# --- Model 1: Equation Extractor (Gemma-3 with Unsloth) ---
print("\nLoading Equation Extraction Model...")
extractor_adapter_repo = "arvindsuresh-math/gemma-3-1b-equation-line-extractor-aug-10"
base_gemma_model = "unsloth/gemma-3-1b-it-unsloth-bnb-4bit"

gemma_model, gemma_tokenizer = FastModel.from_pretrained(
    model_name=base_gemma_model,
    max_seq_length=350,
    dtype=None,
    load_in_4bit=True,
)
gemma_model = PeftModel.from_pretrained(gemma_model, extractor_adapter_repo)
print("✅ Equation Extraction Model loaded.")


# --- Model 2: Conceptual Error Classifier (Phi-4) ---
print("\nLoading Conceptual Error Classifier Model...")
classifier_adapter_repo = "arvindsuresh-math/phi-4-error-binary-classifier"
base_phi_model = "microsoft/Phi-4-mini-instruct"

DTYPE = torch.float16
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=DTYPE
    )
classifier_backbone_base = AutoModelForCausalLM.from_pretrained(
    base_phi_model,
    quantization_config=quantization_config,
    device_map="auto",
    trust_remote_code=True,
    )

classifier_tokenizer = AutoTokenizer.from_pretrained(
    base_phi_model,
    trust_remote_code=True
    )
classifier_tokenizer.padding_side = "left"
if classifier_tokenizer.pad_token is None:
    classifier_tokenizer.pad_token = classifier_tokenizer.eos_token

classifier_backbone_peft = PeftModel.from_pretrained(
    classifier_backbone_base,
    classifier_adapter_repo
    )
classifier_model = GPTSequenceClassifier(classifier_backbone_peft, num_labels=2)

# Download and load the custom classifier head's state dictionary
classifier_head_path = hf_hub_download(repo_id=classifier_adapter_repo, filename="classifier_head.pth")
classifier_model.classifier.load_state_dict(torch.load(classifier_head_path, map_location=device))

classifier_model.to(device)
classifier_model = classifier_model.to(torch.float16)

classifier_model.eval() # Set model to evaluation mode
print("✅ Conceptual Error Classifier Model loaded.")

Using device: cuda

Loading Equation Extraction Model...
==((====))==  Unsloth 2025.8.4: Fast Gemma3 patching. Transformers: 4.53.2.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Using float16 precision for gemma3 won't work! Using float32.




✅ Equation Extraction Model loaded.

Loading Conceptual Error Classifier Model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✅ Conceptual Error Classifier Model loaded.


In [None]:
# ===================================================================
# 3. PROMPTS
# ===================================================================

EXTRACTOR_SYSTEM_PROMPT = \
"""[ROLE]
You are an expert at parsing mathematical solutions.

[TASK]
You are given a single line from a mathematical solution. Your task is to extract the calculation from this line.

**This is a literal transcription task. Follow these rules with extreme precision:**
- **RULE 1: Transcribe EXACTLY.** Do not correct mathematical errors. If a line implies `2+2=5`, your output for that line must be `2+2=5`.
- **RULE 2: Isolate the Equation.** Your output must contain ONLY the equation, with no surrounding text, units, or currency symbols. Always use `*` for multiplication.

[RESPONSE FORMAT]
Your response must ONLY contain the extracted equation, wrapped in <eq> and </eq> tags.
If the line contains no calculation, respond with empty tags: <eq></eq>.
"""

CLASSIFIER_SYSTEM_PROMPT = \
"""You are a mathematics tutor.
You will be given a math word problem and a solution written by a student.
Carefully analyze the problem and solution LINE-BY-LINE and determine whether there are any errors in the solution."""

In [None]:
# ===================================================================
# 3. HELPERS
# ===================================================================

# --- Helper Functions ---
def extract_equation_from_response(response: str) -> str | None:
    """Extracts content from between <eq> and </eq> tags."""
    match = re.search(r'<eq>(.*?)</eq>', response, re.DOTALL)
    return match.group(1) if match else None

def sanitize_equation_string(expression: str) -> str:
    """
    Enhanced version with your requested simplified parenthesis logic.
    """
    if not isinstance(expression, str):
        return ""

    # Your requested parenthesis logic:
    if expression.count('(') > expression.count(')') and expression.startswith('('):
        expression = expression[1:]
    elif expression.count(')') > expression.count('(') and expression.endswith(')'):
        expression = expression[:-1]

    sanitized = expression.replace(' ', '')
    sanitized = sanitized.replace('x', '*').replace('×', '*')
    sanitized = re.sub(r'/([a-zA-Z]+)', '', sanitized)
    sanitized = re.sub(r'[^\d.()+\-*/=]', '', sanitized)
    return sanitized

def evaluate_equations(eq_dict: dict, sol_dict: dict):
    """
    Evaluates extracted equations and returns a more detailed dictionary for
    building clearer explanations.
    """
    for key, eq_str in eq_dict.items():
        if not eq_str or "=" not in eq_str:
            continue
        try:
            sanitized_eq = sanitize_equation_string(eq_str)

            if not sanitized_eq or "=" not in sanitized_eq:
                continue

            lhs, rhs_str = sanitized_eq.split('=', 1)

            if not lhs or not rhs_str:
                continue

            lhs_val = eval(lhs, {"__builtins__": None}, {})
            rhs_val = eval(rhs_str, {"__builtins__": None}, {})

            if not math.isclose(lhs_val, rhs_val, rel_tol=1e-2):
                correct_rhs_val = round(lhs_val, 4)
                correct_rhs_str = f"{correct_rhs_val:.4f}".rstrip('0').rstrip('.')

                # Return a more detailed dictionary for better explanations
                return {
                    "error": True,
                    "line_key": key,
                    "line_text": sol_dict.get(key, "N/A"),
                    "original_flawed_calc": eq_str, # The raw model output
                    "sanitized_lhs": lhs,           # The clean left side
                    "original_rhs": rhs_str,        # The clean right side
                    "correct_rhs": correct_rhs_str, # The correct right side
                }
        except Exception:
            continue

    return {"error": False}

In [None]:
# ===================================================================
# 4. PIPELINE COMPONENTS
# ===================================================================

def run_conceptual_check(question: str, solution: str, model, tokenizer) -> dict:
    """
    STAGE 1: Runs the Phi-4 classifier with memory optimizations.
    """
    input_text = f"{CLASSIFIER_SYSTEM_PROMPT}\n\n### Problem:\n{question}\n\n### Answer:\n{solution}"
    inputs = tokenizer(
        input_text,
        return_tensors="pt",
        truncation=True,
        max_length=512).to(device)

    # Use inference_mode and disable cache for better performance and memory management
    with torch.inference_mode():
        outputs = model(**inputs, use_cache=False)

        # Explicitly cast logits to float32 for stable downstream processing
        logits = outputs["logits"].to(torch.float32)
        probs = torch.softmax(logits, dim=-1).squeeze().tolist()

    is_flawed_prob = probs[1]
    prediction = "flawed" if is_flawed_prob > 0.5 else "correct"

    return {
        "prediction": prediction,
        "probabilities": {"correct": probs[0], "flawed": probs[1]}
    }


def run_computational_check(solution: str, model, tokenizer, batch_size: int = 32) -> dict:
    """
    STAGE 2: Splits a solution into lines and performs a batched computational check.
    (Corrected to handle PEMDAS/parentheses)
    """
    lines = [line.strip() for line in solution.strip().split('\n') if line.strip() and "FINAL ANSWER:" not in line.upper()]
    if not lines:
        return {"error": False}

    # Create a batch of prompts, one for each line
    prompts = []
    for line in lines:
        messages = [{"role": "user", "content": f"{EXTRACTOR_SYSTEM_PROMPT}\n\n### Solution Line:\n{line}"}]
        prompts.append(tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True))

    # Run batched inference
    tokenizer.padding_side = "left"
    inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
    tokenizer.padding_side = "left"
    outputs = model.generate(**inputs, max_new_tokens=64, use_cache=True, pad_token_id=tokenizer.pad_token_id)
    tokenizer.padding_side = "left"
    decoded_outputs = tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)

    # Evaluate each line's extracted equation
    for i, raw_output in enumerate(decoded_outputs):
        equation = extract_equation_from_response(raw_output)
        if not equation or "=" not in equation:
            continue

        try:
            # Sanitize the full equation string, preserving parentheses
            sanitized_eq = sanitize_equation_string(equation)
            if "=" not in sanitized_eq:
                continue

            lhs, rhs_str = sanitized_eq.split('=', 1)

            # Evaluate the sanitized LHS, which now correctly includes parentheses
            lhs_val = eval(lhs, {"__builtins__": None}, {})

            # Compare with the RHS
            if not math.isclose(lhs_val, float(rhs_str), rel_tol=1e-2):
                return {
                    "error": True,
                    "line_text": lines[i],
                    "correct_calc": f"{lhs} = {round(lhs_val, 4)}"
                }
        except Exception:
            continue # Skip lines where evaluation fails

    return {"error": False}


def analyze_solution(question: str, solution: str):
    """
    Main orchestrator that runs the full pipeline and generates the final explanation.
    """
    print("--- Running Analysis ---")

    # STAGE 1: Conceptual Check (Fast)
    conceptual_result = run_conceptual_check(question, solution, classifier_model, classifier_tokenizer)
    confidence = conceptual_result['probabilities'][conceptual_result['prediction']]
    print(f"Stage 1 (Conceptual): Model predicts solution is likely '{conceptual_result['prediction']}' (Confidence: {confidence:.2%}).")
    print("Now, let's check the calculations...")

    # STAGE 2: Computational Check (Slower, Batched)
    computational_result = run_computational_check(solution, gemma_model, gemma_tokenizer)

    # FINAL VERDICT LOGIC
    if computational_result["error"]:
        classification = "computational_error"
        explanation = (
            f"A calculation error was found.\n"
            f"On the line: \"{computational_result['line_text']}\"\n"
            f"The correct calculation should be: {computational_result['correct_calc']}"
        )
    else:
        # If calculations are fine, the final verdict is the conceptual one.
        if conceptual_result['prediction'] == 'correct':
            classification = 'correct'
            explanation = "All calculations are correct and the overall logic appears to be sound."
        else: # This now correctly corresponds to 'flawed'
            classification = 'conceptual_error' # Produce the user-facing label
            explanation = "All calculations are correct, but there appears to be a conceptual error in the logic or setup of the solution."
    final_verdict = {
        "classification": classification,
        "explanation": explanation
    }

    print("\n--- Final Verdict ---")
    print(f"Classification: {final_verdict['classification']}")
    print(f"Explanation: {final_verdict['explanation']}")
    return final_verdict

In [None]:
import pandas as pd

def prepare_test_samples(df: pd.DataFrame) -> pd.DataFrame:
    """
    Transforms the raw test DataFrame into a format suitable for pipeline evaluation.

    For each row in the input, it creates two output rows:
    1. A 'Correct' sample using the 'correct_answer'.
    2. A flawed sample (either 'comp' or 'concep') using the 'wrong_answer'.

    Args:
        df: The input DataFrame with columns 'question', 'correct_answer',
            'wrong_answer', and 'error_type'.

    Returns:
        A new DataFrame with columns 'question', 'answer', and 'error_type'.
    """
    prepared_data = []
    for _, row in df.iterrows():
        # Append the correct version of the solution
        prepared_data.append({
            "question": row["question"],
            "answer": row["correct_answer"],
            "error_type": "correct"
        })
        # Append the flawed version of the solution
        error = "conceptual_error" if row["error_type"] == "concep" else "computational_error"
        prepared_data.append({
            "question": row["question"],
            "answer": row["wrong_answer"],
            "error_type": error
        })
    return pd.DataFrame(prepared_data)

In [None]:
df = pd.read_csv("/content/final-test-with-wrong-answers.csv")

In [None]:
# 1. Prepare the full set of test samples (will have 2N rows)
test_df = prepare_test_samples(df)

# 2. Create a balanced sample DataFrame with 5 items of each type
num_samples_per_type = 40

# Ensure error types are strings for grouping
test_df['error_type'] = test_df['error_type'].astype(str)

sampled_df = test_df.groupby('error_type').sample(n=num_samples_per_type, random_state=42).reset_index(drop=True)

print(f"Prepared a balanced test set with {len(sampled_df)} samples total.")
print("="*80)

Prepared a balanced test set with 120 samples total.


In [None]:
pd.set_option('display.max_colwidth', None)
pd.set_option('display.width', 0)

test_df.head()

Unnamed: 0,question,answer,error_type
0,A community garden is rectangular and measures 15 meters in length and 10 meters in width. The gardeners decide to dedicate 20% of the garden's area to growing tomatoes. How many square meters are dedicated to growing tomatoes?,The total area of the garden is 15 * 10 = 150 square meters.\nThe area for tomatoes is 20% of the total area: 150 * 0.20 = 30 square meters.\nFINAL ANSWER: 30,correct
1,A community garden is rectangular and measures 15 meters in length and 10 meters in width. The gardeners decide to dedicate 20% of the garden's area to growing tomatoes. How many square meters are dedicated to growing tomatoes?,The total area of the garden is 2*(15 + 10) = 50 square meters.\nThe area for tomatoes is 20% of the total area: 50 * 0.20 = 10 square meters.\nFINAL ANSWER: 10,conceptual_error
2,"Sarah is building a model rocket. The body of the rocket requires a tube that is 30 cm long. She has a pipe that is 1.2 meters long. After cutting the piece for the rocket body, how many centimeters of pipe does she have left?",Convert the pipe length to centimeters: 1.2 * 100 = 120 cm.\nSubtract the length of the rocket body to find the remaining pipe: 120 - 30 = 90 cm.\nFINAL ANSWER: 90,correct
3,"Sarah is building a model rocket. The body of the rocket requires a tube that is 30 cm long. She has a pipe that is 1.2 meters long. After cutting the piece for the rocket body, how many centimeters of pipe does she have left?",Convert the pipe length to centimeters: 1.2 * 100 = 120 cm.\nSubtract the length of the rocket body to find the remaining pipe: 120 - 30 = 80 cm.\nFINAL ANSWER: 80,computational_error
4,"A bakery sells muffins for $3 each and cookies for $2 each. On Monday, they sold 45 muffins and 60 cookies. On Tuesday, they sold 30 muffins and 75 cookies. What was the total revenue from muffins and cookies over both days?",Monday's muffin revenue was 45 * 3 = $135.\nMonday's cookie revenue was 60 * 2 = $120.\nTuesday's muffin revenue was 30 * 3 = $90.\nTuesday's cookie revenue was 75 * 2 = $150.\nThe total revenue over both days is 135 + 120 + 90 + 150 = $495.\nFINAL ANSWER: 495,correct


In [None]:
import time

# 3. Run the pipeline on the sampled data
for index, row in sampled_df.iterrows():
    print(f"\n--- TESTING SAMPLE #{index+1} (True Label: {row['error_type']}) ---")
    print(f"Question: {row['question']}")
    print(f"Solution:\n{row['answer']}")
    print(f"True classification:\n{row['error_type']}")
    print("-" * 50)
    print(f"Beginning pipeline now...")
    start_time = time.time()
    analyze_solution(row['question'], row['answer'])
    end_time = time.time()
    print("-"*50)
    print(f"Time taken: {end_time - start_time:.2f} seconds")

    print("\n" + "="*80)


--- TESTING SAMPLE #1 (True Label: computational_error) ---
Question: A train travels at a constant speed of 80 kilometers per hour. How long will it take to travel a distance of 360 kilometers?
Solution:
The time taken is the distance divided by the speed: 360 / 80 = 4 hours.
FINAL ANSWER: 4
True classification:
computational_error
--------------------------------------------------
Beginning pipeline now...
--- Running Analysis ---
Stage 1 (Conceptual): Model predicts solution is likely 'correct' (Confidence: 99.46%).
Now, let's check the calculations...

--- Final Verdict ---
Classification: computational_error
Explanation: A calculation error was found.
On the line: "The time taken is the distance divided by the speed: 360 / 80 = 4 hours."
The correct calculation should be: 360/80 = 4.5
--------------------------------------------------
Time taken: 2.60 seconds


--- TESTING SAMPLE #2 (True Label: computational_error) ---
Question: An author writes a book. The publisher pays a roya

Of course. Analyzing the results in a confusion matrix is the perfect way to get a clear, quantitative understanding of your pipeline's performance across the different error types.

I have gone through the 120 test samples you provided and tallied the results.

### 3x3 Confusion Matrix

The rows represent the **True Label** for the solution, and the columns represent the **Final Classification** predicted by your pipeline.

|                    | **Predicted: Correct** | **Predicted: Computational Error** | **Predicted: Conceptual Error** |
| :----------------- | :--------------------: | :--------------------------------: | :-----------------------------: |
| **True: Correct**  |           35           |                 3                  |                2                |
| **True: Comp. Err**|           1            |                 38                 |                1                |
| **True: Concep. Err**|           2            |                 8                  |               30                |

---

### Analysis and Key Observations

Overall, the pipeline is performing very well, with an accuracy of **(35 + 38 + 30) / 120 = 103 / 120 = 85.8%**. More importantly, the *types* of errors it makes are highly systematic and reveal clear insights into the behavior of each stage.

#### 1. Stage 2 (Computational Check) is Near-Perfect
The diagonal value for `Computational Error` is **38 out of 40**. This is an outstanding result. Your fine-tuned Gemma-3 model and the evaluation logic are extremely effective at their specific task. The two misclassifications were:
*   **1 case predicted as `Correct`**: (Sample #10) A very complex multi-step problem where the computational error was subtle.
*   **1 case predicted as `Conceptual`**: (Sample #7) The line `15:20, which simplifies to 3:5` contains no standard operators, so the computational check passed, and the conceptual model correctly identified it as flawed. This is reasonable behavior.

#### 2. The PEMDAS Bug is the Main Cause of "Correct -> Flawed" Errors
This is the most critical takeaway. The pipeline incorrectly classified 5 correct solutions.
*   **3 were classified as `Computational Error`**: These are Samples #105, #108, and #110. As we diagnosed, this is **100% due to the PEMDAS bug** where the pipeline incorrectly evaluates expressions with parentheses. **Fixing the code will move these 3 samples into the "Correct/Correct" box.**
*   **2 were classified as `Conceptual Error`**: These are Samples #85 and #99. These are false positives from the Phi-4 model. It appears to be overly cautious on very long, multi-step problems, even when they are logically sound.

#### 3. The Pipeline Hierarchy is Working as Designed
This is the most interesting observation. The single largest source of error is the **8 cases** where a solution with a `True Label` of `Conceptual Error` was classified as a `Computational Error`.
*   **Why this happens:** This is not a failure but a feature of your pipeline's design. In all 8 of these cases (e.g., Sample #42, #43, #53), the student's conceptual mistake led them to write a line that was *also* arithmetically incorrect.
*   **Example:** In Sample #43, the conceptual error is adding `80 + 50`. The student then makes a computational error by writing `80 + 50 = 110`. Your pipeline correctly identifies the objective computational mistake (`80+50` is `130`, not `110`) and stops, reporting the most definitive error it found.

#### 4. The Conceptual Classifier is Good but Imperfect
*   **False Negatives:** The classifier missed 2 conceptual errors and labeled them as `Correct` (Samples #41 and #64). These were subtle errors (using 6 days in a week, subtracting in the wrong order) that the model failed to catch.
*   **Overall Performance:** Out of the 40 conceptual error cases, it correctly identified **38** of them as flawed in some way (30 as conceptual, 8 as computational). This is a strong result.

### Summary & Actionable Insights

1.  **Your #1 Priority is Fixing the PEMDAS Bug:** This single code change will correct 3 out of the 5 errors made on `Correct` samples, instantly boosting your pipeline's accuracy and reliability.
2.  **The Two-Stage Logic is Sound:** The pipeline is effectively using the fast conceptual check to filter solutions and the highly accurate computational check to find objective errors. The error hierarchy is behaving as expected.
3.  **The Next Step for Improvement is Data:** To reduce the remaining errors (the 2 false positives and 2 false negatives from the conceptual model), the best path forward is to add these specific failure cases to the fine-tuning dataset for the Phi-4 classifier and retrain it.