<a href="https://colab.research.google.com/github/Arpit1118/Post-Training-LLMs-with-RL/blob/main/LLM_using_RL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Core libraries
import torch
import random
import numpy as np

# Transformers from Hugging Face
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Evaluation and logging
import time
import traceback
import logging

# Set seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Using GPT-2 model (medium or large)
MODEL_NAME = "gpt2-medium"  # or "gpt2-large"

# Load tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)
model = GPT2LMHeadModel.from_pretrained(MODEL_NAME)
model.eval()


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1024)
    (wpe): Embedding(1024, 1024)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-23): 24 x GPT2Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=3072, nx=1024)
          (c_proj): Conv1D(nf=1024, nx=1024)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=4096, nx=1024)
          (c_proj): Conv1D(nf=1024, nx=4096)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1024, out_features=50257, bias=False)
)

In [2]:
import sympy as sp

class MathSolver:
    def __init__(self, variable='x'):
        self.x = sp.Symbol(variable)

    def solve_equation(self, equation_str):
        """
        Solves an equation like 'x**2 - 4 = 0' or 'sin(x) = 0'
        Returns symbolic and numeric solutions
        """
        try:
            if '=' in equation_str:
                lhs, rhs = equation_str.split('=')
                expr = sp.sympify(lhs) - sp.sympify(rhs)
            else:
                expr = sp.sympify(equation_str)

            roots = sp.solve(expr, self.x)
            numeric = [sp.N(r) for r in roots]

            return {
                "success": True,
                "symbolic": roots,
                "numeric": numeric,
                "error": None
            }
        except Exception as e:
            return {
                "success": False,
                "symbolic": None,
                "numeric": None,
                "error": str(e)
            }

    def evaluate_expression(self, expr_str):
        """
        Evaluates a basic math expression like '2 + 3 * 4'
        """
        try:
            result = sp.sympify(expr_str).evalf()
            return {
                "success": True,
                "result": result,
                "error": None
            }
        except Exception as e:
            return {
                "success": False,
                "result": None,
                "error": str(e)
            }

# Example REPL
if __name__ == "__main__":
    solver = MathSolver()
    print("Math Solver Ready. Type 'exit' to quit.")
    while True:
        inp = input(">>> ")
        if inp.lower() in ['exit', 'quit']:
            break
        if '=' in inp:
            out = solver.solve_equation(inp)
            if out["success"]:
                print(f"Symbolic: {out['symbolic']}")
                print(f"Numeric: {out['numeric']}")
            else:
                print("Error:", out["error"])
        else:
            out = solver.evaluate_expression(inp)
            if out["success"]:
                print("Result:", out["result"])
            else:
                print("Error:", out["error"])


Math Solver Ready. Type 'exit' to quit.
>>> x**3 - 6*x**2 + 11*x - 6 = 0
Symbolic: [1, 2, 3]
Numeric: [1.00000000000000, 2.00000000000000, 3.00000000000000]
>>> tan(x) - sqrt(3) = 0
Symbolic: [pi/3]
Numeric: [1.04719755119660]
>>> diff(sin(x)*cos(x), x)
Result: -sin(x)**2 + cos(x)**2
>>> integrate(x*exp(x), x)
Result: (x - 1.0)*exp(x)
>>> log(x**2 + 1) = 1
Symbolic: [-sqrt(-1 + E), sqrt(-1 + E)]
Numeric: [-1.31083249443209, 1.31083249443209]
>>> (x + 1)**2 = x**2 + 2*x + 1
Symbolic: []
Numeric: []
>>> sqrt(x) + sqrt(x + 1) = 2
Symbolic: [9/16]
Numeric: [0.562500000000000]
>>> (x - sqrt(3))*(x + sqrt(3)) = 0
Symbolic: [-sqrt(3), sqrt(3)]
Numeric: [-1.73205080756888, 1.73205080756888]
>>> quit


In [3]:
import re

def is_math_prompt(self, prompt):
    # Strip to remove accidental spacing
    prompt = prompt.strip().lower()

    # If it's very short and numeric, it's likely a math query
    if prompt.isdigit():
        return True

    # If it looks like a math expression with operators or equal signs
    math_pattern = r"[0-9xX\+\-\*/\^=()]"
    if re.search(math_pattern, prompt) and not re.search(r'[a-z]{3,}', prompt):  # ignore full sentences
        return True

    # If it ends with a question mark but has math symbols, still likely not math
    if '?' in prompt and not re.search(r'\d|\+|\-|\*|/', prompt):
        return False

    return False


In [4]:
class GPT2MathWrapper:
    def __init__(self, model, tokenizer, math_solver):
        self.model = model
        self.tokenizer = tokenizer
        self.math_solver = math_solver

    def is_math_prompt(self, prompt):
        import re
        prompt = prompt.strip().lower()
        if prompt.isdigit():
            return True
        math_pattern = r"[0-9xX\+\-\*/\^=()]"
        if re.search(math_pattern, prompt) and not re.search(r'[a-z]{3,}', prompt):
            return True
        if '?' in prompt and not re.search(r'\d|\+|\-|\*|/', prompt):
            return False
        return False

    def run(self, prompt):
        if self.is_math_prompt(prompt):
            if '=' in prompt:
                result = self.math_solver.solve_equation(prompt)
                if result["success"]:
                    return f"Symbolic: {result['symbolic']}\nNumeric: {result['numeric']}"
                else:
                    return f"Math Error: {result['error']}"
            else:
                result = self.math_solver.evaluate_expression(prompt)
                if result["success"]:
                    return f"Result: {result['result']}"
                else:
                    return f"Math Error: {result['error']}"
        else:
            # Use GPT-2 for general text
            inputs = self.tokenizer.encode(prompt, return_tensors="pt")
            outputs = self.model.generate(inputs, max_length=100, num_return_sequences=1)
            return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
