In [1]:
!pip install -q transformers accelerate bitsandbytes

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.1/60.1 MB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import re
import ast
import math
from typing import Dict, Any, Optional

In [3]:
def load_models():
    """Load both the math solver and filter models with 4-bit quantization"""

    # 4-bit quantization config for memory efficiency
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    print("Loading Math Model (1.5B)...")
    math_model_id = "Qwen/Qwen2.5-Math-1.5B-Instruct"
    math_tokenizer = AutoTokenizer.from_pretrained(math_model_id)
    math_model = AutoModelForCausalLM.from_pretrained(
        math_model_id,
        device_map="auto",
        quantization_config=bnb_config
    )

    print("Loading Filter Model (1.5B)...")
    filter_model_id = "Qwen/Qwen2.5-1.5B-Instruct"
    filter_tokenizer = AutoTokenizer.from_pretrained(filter_model_id)
    filter_model = AutoModelForCausalLM.from_pretrained(
        filter_model_id,
        device_map="auto",
        quantization_config=bnb_config
    )

    print("✅ Both models loaded successfully!")
    return {
        'math_tokenizer': math_tokenizer,
        'math_model': math_model,
        'filter_tokenizer': filter_tokenizer,
        'filter_model': filter_model
    }

# Load the models (this will take a few minutes)
models = load_models()

Loading Math Model (1.5B)...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/656 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.09G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

Loading Filter Model (1.5B)...


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/660 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.09G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

✅ Both models loaded successfully!


In [16]:
# Cell 4: LLM Code Validator Class
class LLMCodeValidator:
    """Uses LLM to validate variable names and function calls in code"""

    def __init__(self, tokenizer, model):
        self.tokenizer = tokenizer
        self.model = model
        self.cache = {}  # Cache validation results

    def validate_variable_name(self, name: str, context: str = "") -> bool:
        """Ask LLM if a variable name is safe and appropriate for math code"""

        # Check cache first
        cache_key = f"var_{name}_{context}"
        if cache_key in self.cache:
            return self.cache[cache_key]

        prompt = f"""Is this variable name safe and appropriate for mathematical Python code?

Variable name: {name}
Context: {context if context else "General mathematical computation"}

Answer with ONLY 'YES' or 'NO'. Consider:
- Is it a valid Python identifier?
- Is it descriptive for calculations (cost, price, total, count, etc.)?
- Does it make sense in the given context?
- Is it NOT a security risk (no __imports__, exec, eval, __builtins__)?

Common math/business variables are ALLOWED: cost, price, amount, total, count, etc.

Answer:"""

        inputs = self.tokenizer(prompt, return_tensors="pt", padding=True).to("cuda")

        outputs = self.model.generate(
            **inputs,
            max_new_tokens=10,
            temperature=0.1,
            do_sample=True,
            pad_token_id=self.tokenizer.eos_token_id
        )

        input_length = inputs['input_ids'].shape[1]
        response = self.tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True).strip().upper()

        is_valid = "YES" in response[:10]
        self.cache[cache_key] = is_valid
        return is_valid

    def validate_function_call(self, func_name: str, context: str = "") -> bool:
        """Ask LLM if a function call is safe for math code"""

        # Check cache first
        cache_key = f"func_{func_name}_{context}"
        if cache_key in self.cache:
            return self.cache[cache_key]

        prompt = f"""Is this function call safe and appropriate for mathematical Python code?

Function: {func_name}()
Context: {context if context else "General mathematical computation"}

Answer with ONLY 'YES' or 'NO'. Consider:
- Is it a safe built-in function (abs, min, max, round, etc.)?
- Is it a math library function?
- Is it NOT dangerous (no exec, eval, open, import, __import__, etc.)?
- Does it make sense for mathematical calculations?

Answer:"""

        inputs = self.tokenizer(prompt, return_tensors="pt", padding=True).to("cuda")

        outputs = self.model.generate(
            **inputs,
            max_new_tokens=10,
            temperature=0.1,
            do_sample=True,
            pad_token_id=self.tokenizer.eos_token_id
        )

        input_length = inputs['input_ids'].shape[1]
        response = self.tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True).strip().upper()

        is_valid = "YES" in response[:10]
        self.cache[cache_key] = is_valid
        return is_valid

    def get_allowed_names_from_code(self, code: str, question: str = "") -> set[str]:
        """Extract and validate all variable names from code using LLM"""

        prompt = f"""Extract all variable names from this Python code and validate them for mathematical computation.

Question context: {question if question else "Mathematical calculation"}

Code:
```python
{code}
```

List ONLY the variable names that are safe and appropriate for math code, separated by commas.
Include standard math variables (x, y, result, etc.) and domain-specific ones (speed, area, etc.).
DO NOT include dangerous names or Python keywords.

Safe variable names:"""

        inputs = self.tokenizer(prompt, return_tensors="pt", padding=True).to("cuda")

        outputs = self.model.generate(
            **inputs,
            max_new_tokens=100,
            temperature=0.2,
            do_sample=True,
            pad_token_id=self.tokenizer.eos_token_id
        )

        input_length = inputs['input_ids'].shape[1]
        response = self.tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True).strip()

        # Parse the response to extract variable names
        names = set()
        for item in response.split(','):
            item = item.strip().strip('`').strip()
            if item and re.match(r'^[A-Za-z_][A-Za-z0-9_]*$', item):
                names.add(item)

        return names



In [5]:
# Cell 5: AST Validation with LLM
def _validate_ast_with_llm(node: ast.AST, validator: LLMCodeValidator, question: str = "") -> None:
    """Validate AST nodes using LLM instead of hardcoded lists"""

    for n in ast.walk(node):
        # Block dangerous constructs completely
        if isinstance(n, (ast.Import, ast.ImportFrom, ast.Global, ast.Nonlocal, ast.Lambda,
                          ast.ClassDef, ast.FunctionDef, ast.While, ast.For, ast.With,
                          ast.AsyncFunctionDef, ast.Try, ast.Raise)):
            raise ValueError(f"Disallowed AST node: {n.__class__.__name__}")

        # Validate function calls using LLM
        if isinstance(n, ast.Call):
            func = n.func
            if isinstance(func, ast.Attribute):
                # Allow math.* functions
                if not (isinstance(func.value, ast.Name) and func.value.id == "math"):
                    raise ValueError("Only calls to math.* are allowed.")
            elif isinstance(func, ast.Name):
                # Use LLM to validate function calls
                if not validator.validate_function_call(func.id, question):
                    raise ValueError(f"LLM rejected function call: {func.id}")

        # Validate variable names using LLM
        if isinstance(n, ast.Name):
            if not validator.validate_variable_name(n.id, question):
                raise ValueError(f"LLM rejected variable name: {n.id}")

In [6]:
def sanitize_math_code(code: str) -> str:
    """
    Convert human-style math expressions into valid Python.
    Fixes:
    - ^ replaced with ** for powers
    - Implicit multiplication (5x → 5*x, 2(x+1) → 2*(x+1))
    - Parenthesis adjacency )x → )*x, )( → )*(
    """

    # Remove markdown code fences
    code = re.sub(r"^```(?:python)?\n", "", code)
    code = re.sub(r"\n```$", "", code)

    # Replace ^ with ** for powers
    code = re.sub(r'([A-Za-z0-9_\)\]\}])\s*\^\s*([A-Za-z0-9_\(\[\{])', r'\1**\2', code)

    # Insert multiplication between number and variable/parenthesis (5x → 5*x, 2(x+1) → 2*(x+1))
    code = re.sub(r'(?<![\w\*\.])(\d+)\s*([A-Za-z_\(])', r'\1*\2', code)

    # Insert multiplication between closing parenthesis and next variable/parenthesis: )x → )*x, )( → )*(
    code = re.sub(r'\)\s*(?=[A-Za-z_\(])', r')*', code)

    # Clean up multiple stars (e.g., ***, ****)
    code = re.sub(r'\*\*+', '**', code)

    return code

In [7]:
def safe_exec_python_with_llm(
    code: str,
    validator: "LLMCodeValidator",
    question: str = "",
    provided_globals: Optional[Dict[str, Any]] = None
) -> Any:
    """
    Safely execute Python code generated by an LLM after validating it.
    Includes math sanitization and AST-level validation.
    """

    # --- Step 1: Clean markdown code fences if present ---
    code = re.sub(r"^```(?:python)?\n", "", code)
    code = re.sub(r"\n```$", "", code)

    # --- Step 2: Sanitize math syntax ---
    cleaned_code = sanitize_math_code(code)

    # Optional: Debugging logs (recommended for dev phase)
    print("========== ORIGINAL LLM CODE ==========")
    print(code)
    print("========== SANITIZED PYTHON CODE ==========")
    print(cleaned_code)

    # --- Step 3: Parse into AST (safe syntax validation) ---
    try:
        parsed = ast.parse(cleaned_code, mode="exec")
    except SyntaxError as e:
        raise SyntaxError(f"Sanitized code still invalid: {e}\nCode:\n{cleaned_code}")

    # --- Step 4: Run custom AST validator (your existing LLM validation) ---
    _validate_ast_with_llm(parsed, validator, question)

    # --- Step 5: Prepare safe execution environment ---
    exec_env = {
        "math": math,
        "__builtins__": {
            "abs": abs, "min": min, "max": max, "round": round,
            "range": range, "len": len, "sum": sum, "pow": pow
        }
    }

    if provided_globals:
        exec_env.update(provided_globals)

    # --- Step 6: Compile and execute ---
    try:
        compiled = compile(parsed, filename="<llm>", mode="exec")
        exec(compiled, exec_env)
    except Exception as e:
        raise RuntimeError(f"Error during code execution: {e}\nCode:\n{cleaned_code}")

    # --- Step 7: Return result variable if present ---
    return exec_env.get("result", None)

In [8]:
# Cell 7: Code Fixing Utilities
def fix_result_assignment(code: str) -> str:
    """Fix code that assigns to variables other than 'result'"""
    if 'result =' in code:
        return code

    result_variables = [
        'compound_interest', 'simple_interest', 'interest',
        'total_cost', 'total', 'final_amount', 'amount',
        'area', 'volume', 'perimeter', 'circumference',
        'answer', 'solution', 'final_answer',
        'time', 'distance', 'speed', 'velocity',
        'profit', 'loss', 'percentage',
        'average', 'mean', 'sum_val'
    ]

    lines = code.split('\n')
    modified_lines = []

    for line in lines:
        line_modified = line
        for var in result_variables:
            if re.match(rf'^{var}\s*=', line.strip()):
                line_modified = re.sub(rf'^{var}\s*=', 'result =', line.strip())
                break
        modified_lines.append(line_modified)

    final_code = '\n'.join(modified_lines)

    if 'result =' not in final_code:
        lines = final_code.split('\n')
        for i in range(len(lines) - 1, -1, -1):
            line = lines[i].strip()
            if '=' in line and not line.startswith('#'):
                var_name = line.split('=')[0].strip()
                final_code += f'\nresult = {var_name}'
                break

    return final_code

In [9]:
# Cell 8: Code Extraction Utilities
def extract_python_code(text: str) -> str:
    """Extract Python code from model response"""
    python_match = re.search(r'```python\s*(.*?)\s*```', text, re.DOTALL)
    if python_match:
        code = python_match.group(1).strip()
        code = re.sub(r'\nprint\([^)]*\)\s*', '\n', code)
        code = re.sub(r'^print\([^)]*\)\s*', '', code)
        code = fix_result_assignment(code)
        return code.strip()

    lines = text.split('\n')
    python_lines = []

    for line in lines:
        stripped = line.strip()
        if not stripped or stripped.startswith('**') or stripped.startswith('##'):
            continue
        if re.match(r'^\d+\.\s', stripped) or stripped.startswith('- '):
            continue
        if any(phrase in stripped.lower() for phrase in ['final answer', 'therefore', 'the answer is']):
            break

        is_python_line = any([
            stripped.startswith('#'),
            '=' in stripped and not stripped.startswith('*'),
            re.match(r'^[a-zA-Z_]\w*\s*=', stripped),
            'math.' in stripped,
        ])

        if is_python_line and not stripped.startswith('print('):
            python_lines.append(line)

    code = '\n'.join(python_lines).strip()
    return fix_result_assignment(code)

def extract_units_from_question(question: str) -> Optional[str]:
    """Extract units from the question"""
    unit_patterns = [
        r'(\w+/\w+)',
        r'(dollars?|USD|\$)',
        r'(hours?|minutes?|seconds?)',
        r'(meters?|kilometers?|feet|inches?)',
        r'(grams?|kilograms?|pounds?)',
        r'(percent|%)',
        r'(degrees?|°[CF]?)',
        r'(years?|months?|days?)',
    ]

    for pattern in unit_patterns:
        match = re.search(pattern, question, re.IGNORECASE)
        if match:
            return match.group(1)
    return None


In [10]:
# Cell 9: Question Filtering
def filter_question(question: str, filter_tokenizer, filter_model) -> str:
    """Clean and extract only relevant math details from the question"""
    prompt = f"""Extract only the essential mathematical information from this question, removing unnecessary story details. Keep numbers, units, and the core question.

Original: {question}

Cleaned math problem:"""

    inputs = filter_tokenizer(prompt, return_tensors="pt", padding=True).to("cuda")

    outputs = filter_model.generate(
        **inputs,
        max_new_tokens=150,
        temperature=0.2,
        do_sample=True,
        pad_token_id=filter_tokenizer.eos_token_id
    )

    input_length = inputs['input_ids'].shape[1]
    cleaned = filter_tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)

    return cleaned.strip()

In [11]:
# Cell 10: Math Solver with Explanation
def solve_with_explanation(question: str, math_tokenizer, math_model) -> Dict[str, Any]:
    """Solve math problem with detailed step-by-step explanation"""
    input_text = f"""Solve this math problem step by step and provide Python code.

Problem: {question}

Provide:
1. A clear plan
2. Step-by-step explanation
3. Python code that calculates the answer (store result in 'result' variable)

Solution:"""

    inputs = math_tokenizer(input_text, return_tensors="pt", padding=True).to("cuda")

    outputs = math_model.generate(
        **inputs,
        max_new_tokens=400,
        do_sample=True,
        temperature=0.1,
        pad_token_id=math_tokenizer.eos_token_id
    )

    input_length = inputs['input_ids'].shape[1]
    response = math_tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)

    python_code = extract_python_code(response)

    steps = []
    for line in response.split('\n'):
        if re.match(r'^\d+\.', line.strip()) or line.strip().startswith('-'):
            steps.append(line.strip())

    if not steps:
        steps = ["Calculate using the given values", "Apply mathematical operations", "Compute final result"]

    return {
        "python_code": python_code,
        "plan": "Solve mathematically step by step",
        "steps": steps[:5],
        "units": extract_units_from_question(question),
        "precision": 2,
        "raw_text": response,
        "explanation": response
    }

In [12]:
# Cell 11: Main Math Solver Pipeline
class MathSolverPipeline:
    def __init__(self, models_dict):
        self.math_tokenizer = models_dict['math_tokenizer']
        self.math_model = models_dict['math_model']
        self.filter_tokenizer = models_dict['filter_tokenizer']
        self.filter_model = models_dict['filter_model']

        # Initialize LLM validator
        self.validator = LLMCodeValidator(self.filter_tokenizer, self.filter_model)
        print("✅ LLM Code Validator initialized!")

    def solve(self, question: str) -> Dict[str, Any]:
        """Complete pipeline: filter → solve → validate → execute"""
        print("\n🔹 Original Question:", question)

        # Step 1: Filter question
        cleaned_question = filter_question(question, self.filter_tokenizer, self.filter_model)
        print("\n✨ Cleaned Question:", cleaned_question)

        # Step 2: Solve with explanation
        try:
            result_dict = solve_with_explanation(cleaned_question, self.math_tokenizer, self.math_model)

            # Step 3: Execute code with LLM validation
            python_code = result_dict.get("python_code", "").strip()
            print("\n🔍 Validating code with LLM...")

            result = safe_exec_python_with_llm(python_code, self.validator, cleaned_question)

            # Apply precision
            precision = result_dict.get("precision", 2)
            if isinstance(precision, int):
                try:
                    result = round(float(result), precision)
                except Exception:
                    pass

            result_dict["result"] = result

            # Display results
            print("\n📋 PLAN:", result_dict["plan"])
            print("\n📝 STEPS:")
            for i, step in enumerate(result_dict["steps"], 1):
                print(f"   {i}. {step}")
            print("\n💻 PYTHON CODE:")
            print(result_dict["python_code"])
            print(f"\n✅ FINAL ANSWER: {result} {result_dict['units'] or ''}")

            return result_dict

        except Exception as e:
            print(f"\n❌ Error: {e}")
            raise


In [None]:
# Cell 12: Initialize Models and Solver
models = load_models()
solver = MathSolverPipeline(models)
print("\n" + "="*60)
print("✅ Math Solver with LLM Validation Ready!")
print("="*60)
question = input("Enter your math question: ")
solver.solve(question)

Loading Math Model (1.5B)...
Loading Filter Model (1.5B)...
✅ Both models loaded successfully!
✅ LLM Code Validator initialized!

✅ Math Solver with LLM Validation Ready!
Enter your math question: Ravi, who loves cricket and has a pet parrot named Coco, went to a stationery shop on a rainy Thursday afternoon. He bought 8 pencils that each cost ₹5 and 3 notebooks that each cost ₹20. The shopkeeper, who has been running the shop for 15 years, gave him a discount of ₹10. How much did Ravi pay in total?

🔹 Original Question: Ravi, who loves cricket and has a pet parrot named Coco, went to a stationery shop on a rainy Thursday afternoon. He bought 8 pencils that each cost ₹5 and 3 notebooks that each cost ₹20. The shopkeeper, who has been running the shop for 15 years, gave him a discount of ₹10. How much did Ravi pay in total?

✨ Cleaned Question: Ravi bought 8 pencils at ₹5 each and 3 notebooks at ₹20 each. After applying a discount of ₹10, how much did he pay in total? To solve this, cal