<a href="https://colab.research.google.com/github/Arpit1118/Post-Training-LLMs-with-RL/blob/main/LLM_using_RL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Core libraries
import torch
import random
import numpy as np

# Transformers from Hugging Face
from transformers import AutoTokenizer, AutoModelForCausalLM

# Evaluation and logging
import time
import traceback
import logging

# Set seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Using Qwen2.5 model
MODEL_NAME = "Qwen/Qwen2.5-0.5B"

# Load tokenizer and model
tokenizer =  AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
device = torch.device("cpu")
model.to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/681 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/988M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/138 [00:00<?, ?B/s]

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((896,), eps=1e-06)
    (rotary_emb): Qwen2RotaryEmbe

In [2]:
import sympy as sp

class MathSolver:
    def __init__(self, variable='x'):
        self.x = sp.Symbol(variable)

    def solve_equation(self, equation_str):
        """
        Solves an equation like 'x**2 - 4 = 0' or 'sin(x) = 0'
        Returns symbolic and numeric solutions
        """
        try:
            if '=' in equation_str:
                lhs, rhs = equation_str.split('=')
                expr = sp.sympify(lhs) - sp.sympify(rhs)
            else:
                expr = sp.sympify(equation_str)

            roots = sp.solve(expr, self.x)
            numeric = [sp.N(r) for r in roots]

            return {
                "success": True,
                "symbolic": roots,
                "numeric": numeric,
                "error": None
            }
        except Exception as e:
            return {
                "success": False,
                "symbolic": None,
                "numeric": None,
                "error": str(e)
            }

    def evaluate_expression(self, expr_str):
        """
        Evaluates a basic math expression like '2 + 3 * 4'
        """
        try:
            result = sp.sympify(expr_str).evalf()
            return {
                "success": True,
                "result": result,
                "error": None
            }
        except Exception as e:
            return {
                "success": False,
                "result": None,
                "error": str(e)
            }

# Example REPL
if __name__ == "__main__":
    solver = MathSolver()
    print("Math Solver Ready. Type 'exit' to quit.")
    while True:
        inp = input(">>> ")
        if inp.lower() in ['exit', 'quit']:
            break
        if '=' in inp:
            out = solver.solve_equation(inp)
            if out["success"]:
                print(f"Symbolic: {out['symbolic']}")
                print(f"Numeric: {out['numeric']}")
            else:
                print("Error:", out["error"])
        else:
            out = solver.evaluate_expression(inp)
            if out["success"]:
                print("Result:", out["result"])
            else:
                print("Error:", out["error"])


Math Solver Ready. Type 'exit' to quit.
>>> x**4 - 5*x**2 + 4 = 0
Symbolic: [-2, -1, 1, 2]
Numeric: [-2.00000000000000, -1.00000000000000, 1.00000000000000, 2.00000000000000]
>>> 2*cos(x)**2 - 1 = 0
Symbolic: [pi/4, 3*pi/4, 5*pi/4, 7*pi/4]
Numeric: [0.785398163397448, 2.35619449019234, 3.92699081698724, 5.49778714378214]
>>> diff(sin(x)*e**x, x)
Result: e**x*log(e)*sin(x) + e**x*cos(x)
>>> sqrt(x + 2) - x = 0
Symbolic: [2]
Numeric: [2.00000000000000]
>>> x**(1/3) + x**(1/2) - 4 = 0
Symbolic: [(-1/3 + 1/(9*(2*sqrt(78)/9 + 53/27)**(1/3)) + (2*sqrt(78)/9 + 53/27)**(1/3))**6]
Numeric: [5.16124243978033]
>>> log(x) + x = 3
Symbolic: [LambertW(exp(3))]
Numeric: [2.20794003156932]
>>> exit


In [3]:
import re
from functools import lru_cache

class LLMMathWrapper:
    def __init__(self, model, tokenizer, math_solver):
        self.model = model
        self.tokenizer = tokenizer
        self.math_solver = math_solver

    def is_math_prompt(self, prompt):
        prompt = prompt.strip()

        # Removed punctuation for better word analysis
        cleaned = re.sub(r'[^\w\s]', '', prompt.lower())

        # Quick pass for math symbols or math keywords
        if re.search(r'[\d\+\-\*/\^=()]', prompt):
            math_keywords = {"solve", "evaluate", "simplify", "integrate", "differentiate", "factor"}
            if any(kw in cleaned for kw in math_keywords):
                return True

            words = cleaned.split()
            if len(words) <= 5 or sum(w.isdigit() or w in {'x', 'y'} for w in words) >= len(words) // 2:
                return True

        return False

    @lru_cache(maxsize=128)  # Optional: speeds up repeated LLM queries
    def generate_with_llm(self, prompt):
        inputs = self.tokenizer.encode(prompt, return_tensors="pt")
        outputs = self.model.generate(inputs, max_length=100, num_return_sequences=1)
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

    def run(self, prompt):
        if self.is_math_prompt(prompt):
            if '=' in prompt:
                result = self.math_solver.solve_equation(prompt)
                if result["success"]:
                    return f"Symbolic: {result['symbolic']}\nNumeric: {result['numeric']}"
                else:
                    return f"Math Error: {result['error']}"
            else:
                result = self.math_solver.evaluate_expression(prompt)
                if result["success"]:
                    return f"Result: {result['result']}"
                else:
                    return f"Math Error: {result['error']}"
        else:
            return self.generate_with_llm(prompt)


In [4]:
wrapper = LLMMathWrapper(model, tokenizer, solver)
print(wrapper.run("Explain how photosynthesis works."))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Both `max_new_tokens` (=2048) and `max_length`(=100) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Explain how photosynthesis works. Photosynthesis is the process by which plants, algae, and some bacteria convert light energy into chemical energy in the form of glucose. The process occurs in the chloroplasts of plant cells, where light energy is absorbed by the chlorophyll molecules in the thylakoid membranes. The absorbed light energy is used to split water molecules into hydrogen and oxygen, which are released into the atmosphere as oxygen. The remaining hydrogen is used to produce glucose, which is the primary source of energy for the plant. The process of photosynthesis is a complex and energy-intensive process that requires a lot of energy to produce glucose, which is why plants are often referred to as the "sun" of the ecosystem.


In [5]:
print(wrapper.run("If a bottle costs $3 and you buy 4, how much total?"))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Both `max_new_tokens` (=2048) and `max_length`(=100) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


If a bottle costs $3 and you buy 4, how much total? To determine the total cost of buying 4 bottles of soda, we need to know the cost of one bottle. However, since the problem does not specify the cost of one bottle, we will assume that the cost of one bottle is $3. Here is the step-by-step reasoning:

1. Identify the cost of one bottle of soda.
   - The cost of one bottle is $3.

2. Determine the number of bottles being bought.
   - You are buying 4 bottles.

3. Calculate the total cost by multiplying the cost of one bottle by the number of bottles.
   - Total cost = Cost of one bottle × Number of bottles
   - Total cost = $3 × 4
   - Total cost = $12

Therefore, the total cost of buying 4 bottles of soda is \boxed{12}.


In [6]:
print(wrapper.run("What is the capital of France"))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Both `max_new_tokens` (=2048) and `max_length`(=100) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


What is the capital of France?
The capital of France is Paris.


In [7]:
print(wrapper.run("sqrt(x) + sqrt(x + 1) = 2"))

Symbolic: [9/16]
Numeric: [0.562500000000000]


In [8]:
print(wrapper.run("(x - sqrt(3))*(x + sqrt(3)) = 0"))

Symbolic: [-sqrt(3), sqrt(3)]
Numeric: [-1.73205080756888, 1.73205080756888]


In [9]:
print(wrapper.run("x^2 - 4 = 0"))

Symbolic: [-2, 2]
Numeric: [-2.00000000000000, 2.00000000000000]


In [10]:
def reward_func(prompts, completions, completion_ids=None, **kwargs):
    """
    Compute reward for each (prompt, completion) pair.
    Returns a list of floats of the same length as completions.
    """
    rewards = []
    math_solver = MathSolver()

    for prompt, output in zip(prompts, completions):
        # Basic check if it's a math prompt
        is_math = '=' in prompt or any(op in prompt for op in ['+', '-', '*', '/', '^'])

        if is_math:
            # Try solving it
            try:
                if '=' in prompt:
                    math_result = math_solver.solve_equation(prompt)
                    if math_result["success"]:
                        # If numeric answer is in output, reward it
                        correct_answer = str(math_result["numeric"][0])
                        if correct_answer in output:
                            rewards.append(1.0)
                        else:
                            rewards.append(-1.0)
                    else:
                        rewards.append(-0.5)
                else:
                    eval_result = math_solver.evaluate_expression(prompt)
                    if eval_result["success"] and str(eval_result["result"]) in output:
                        rewards.append(1.0)
                    else:
                        rewards.append(-1.0)
            except Exception:
                rewards.append(-1.0)
        else:
            # For non-math prompts, reward any decent length response
            rewards.append(0.5 if len(output.strip()) > 10 else -0.5)

    return rewards


In [11]:
train_dataset = [
    {"prompt": "Solve x^2 - 4 = 0"},
    {"prompt": "Tell me a joke"}
]


In [12]:
!pip install trl

Collecting trl
  Downloading trl-0.23.0-py3-none-any.whl.metadata (11 kB)
Downloading trl-0.23.0-py3-none-any.whl (564 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m564.7/564.7 kB[0m [31m42.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: trl
Successfully installed trl-0.23.0


In [14]:
from trl import GRPOConfig, GRPOTrainer

USE_GPU = False  # Assuming CPU-only

config = GRPOConfig(
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    num_generations=4,
    num_train_epochs=1,
    learning_rate=5e-6,
    logging_steps=2,
    no_cuda=not USE_GPU,
    report_to = "none"
)

# Instantiate GRPOTrainer
grpo_trainer = GRPOTrainer(
    model=model,
    args=config,
    reward_funcs=reward_func,
    train_dataset=train_dataset
)

grpo_trainer.train()


  ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


Step,Training Loss


TrainOutput(global_step=1, training_loss=0.0, metrics={'train_runtime': 893.512, 'train_samples_per_second': 0.002, 'train_steps_per_second': 0.001, 'total_flos': 0.0, 'train_loss': 0.0})

In [15]:
test_set = [
    {"prompt": "x^2 - 4 = 0"},
    {"prompt": "2 + 2 * 3?"},
    {"prompt": "Tell me a joke"},
    {"prompt": "Explain how photosynthesis works."}
]


In [16]:
def evaluate_model(model, tokenizer, test_set):
    model.eval()  # Put model in evaluation mode

    wrapper = LLMMathWrapper(model, tokenizer, MathSolver())
    prompts = [item["prompt"] for item in test_set]

    completions = []
    for prompt in prompts:
        try:
            completions.append(wrapper.run(prompt))
        except Exception as e:
            completions.append(f"[Error in generation: {e}]")

    rewards = reward_func(prompts, completions)

    return list(zip(prompts, completions, rewards))


In [17]:
print("Evaluating model AFTER GRPO training...")
results = evaluate_model(model, tokenizer, test_set)

for prompt, output, reward in results:
    print("="*30)
    print(f"Prompt: {prompt}")
    print(f"Output: {output}")
    print(f"Reward: {reward}")


Both `max_new_tokens` (=2048) and `max_length`(=100) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Evaluating model AFTER GRPO training...


Both `max_new_tokens` (=2048) and `max_length`(=100) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Prompt: x^2 - 4 = 0
Output: Symbolic: [-2, 2]
Numeric: [-2.00000000000000, 2.00000000000000]
Reward: 1.0
Prompt: 2 + 2 * 3?
Output: Math Error: Sympify of expression 'could not parse '2 + 2 * 3?'' failed, because of exception being raised:
SyntaxError: invalid syntax (<string>, line 1)
Reward: -1.0
Prompt: Tell me a joke
Output: Tell me a joke about a cat.

Why did the cat go to the vet?

Because it was feeling a bit "catatonic"!
Reward: 0.5
Prompt: Explain how photosynthesis works.
Output: Explain how photosynthesis works. Photosynthesis is the process by which plants, algae, and some bacteria convert light energy into chemical energy in the form of glucose. The process occurs in the chloroplasts of plant cells, where light energy is absorbed by the chlorophyll molecules in the thylakoid membranes. The absorbed light energy is used to split water molecules into hydrogen and oxygen, which are released into the atmosphere as oxygen. The remaining hydrogen is used to produce glucose, whi