<a href="https://colab.research.google.com/github/Arpit1118/Post-Training-LLMs-with-RL/blob/main/LLM_using_RL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Core libraries
import torch
import random
import numpy as np

# Transformers from Hugging Face
from transformers import AutoTokenizer, AutoModelForCausalLM

# Evaluation and logging
import time
import traceback
import logging

# Set seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Using Qwen2.5 model
MODEL_NAME = "Qwen/Qwen2.5-0.5B"

# Load tokenizer and model
tokenizer =  AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [2]:
import sympy as sp

class MathSolver:
    def __init__(self, variable='x'):
        self.x = sp.Symbol(variable)

    def solve_equation(self, equation_str):
        """
        Solves an equation like 'x**2 - 4 = 0' or 'sin(x) = 0'
        Returns symbolic and numeric solutions
        """
        try:
            if '=' in equation_str:
                lhs, rhs = equation_str.split('=')
                expr = sp.sympify(lhs) - sp.sympify(rhs)
            else:
                expr = sp.sympify(equation_str)

            roots = sp.solve(expr, self.x)
            numeric = [sp.N(r) for r in roots]

            return {
                "success": True,
                "symbolic": roots,
                "numeric": numeric,
                "error": None
            }
        except Exception as e:
            return {
                "success": False,
                "symbolic": None,
                "numeric": None,
                "error": str(e)
            }

    def evaluate_expression(self, expr_str):
        """
        Evaluates a basic math expression like '2 + 3 * 4'
        """
        try:
            result = sp.sympify(expr_str).evalf()
            return {
                "success": True,
                "result": result,
                "error": None
            }
        except Exception as e:
            return {
                "success": False,
                "result": None,
                "error": str(e)
            }

# Example REPL
if __name__ == "__main__":
    solver = MathSolver()
    print("Math Solver Ready. Type 'exit' to quit.")
    while True:
        inp = input(">>> ")
        if inp.lower() in ['exit', 'quit']:
            break
        if '=' in inp:
            out = solver.solve_equation(inp)
            if out["success"]:
                print(f"Symbolic: {out['symbolic']}")
                print(f"Numeric: {out['numeric']}")
            else:
                print("Error:", out["error"])
        else:
            out = solver.evaluate_expression(inp)
            if out["success"]:
                print("Result:", out["result"])
            else:
                print("Error:", out["error"])


Math Solver Ready. Type 'exit' to quit.
>>> exit


In [3]:
import re
from functools import lru_cache

class LLMMathWrapper:
    def __init__(self, model, tokenizer, math_solver):
        self.model = model
        self.tokenizer = tokenizer
        self.math_solver = math_solver

    def is_math_prompt(self, prompt):
        prompt = prompt.strip()

        # Removed punctuation for better word analysis
        cleaned = re.sub(r'[^\w\s]', '', prompt.lower())

        # Quick pass for math symbols or math keywords
        if re.search(r'[\d\+\-\*/\^=()]', prompt):
            math_keywords = {"solve", "evaluate", "simplify", "integrate", "differentiate", "factor"}
            if any(kw in cleaned for kw in math_keywords):
                return True

            words = cleaned.split()
            if len(words) <= 5 or sum(w.isdigit() or w in {'x', 'y'} for w in words) >= len(words) // 2:
                return True

        return False

    @lru_cache(maxsize=128)  # Optional: speeds up repeated LLM queries
    def generate_with_llm(self, prompt):
        inputs = self.tokenizer.encode(prompt, return_tensors="pt")
        outputs = self.model.generate(inputs, max_length=100, num_return_sequences=1)
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

    def run(self, prompt):
        if self.is_math_prompt(prompt):
            if '=' in prompt:
                result = self.math_solver.solve_equation(prompt)
                if result["success"]:
                    return f"Symbolic: {result['symbolic']}\nNumeric: {result['numeric']}"
                else:
                    return f"Math Error: {result['error']}"
            else:
                result = self.math_solver.evaluate_expression(prompt)
                if result["success"]:
                    return f"Result: {result['result']}"
                else:
                    return f"Math Error: {result['error']}"
        else:
            return self.generate_with_llm(prompt)


In [4]:
wrapper = LLMMathWrapper(model, tokenizer, solver)
print(wrapper.run("What is the capital of India"))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Both `max_new_tokens` (=2048) and `max_length`(=100) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


What is the capital of India?
What is the capital of India?
Do those questions have the same meaning?
Pick from:
 (A). no
 (B). yes
(B).


In [5]:
print(wrapper.run("What is the capital of France"))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Both `max_new_tokens` (=2048) and `max_length`(=100) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


What is the capital of France?
The capital of France is Paris.


In [6]:
print(wrapper.run("sqrt(x) + sqrt(x + 1) = 2"))

Symbolic: [9/16]
Numeric: [0.562500000000000]


In [7]:
print(wrapper.run("(x - sqrt(3))*(x + sqrt(3)) = 0"))

Symbolic: [-sqrt(3), sqrt(3)]
Numeric: [-1.73205080756888, 1.73205080756888]
