In [1]:
import unsloth
from unsloth import FastLanguageModel
import torch
import textstat
from vllm import SamplingParams
import os
from trl import GRPOConfig, GRPOTrainer
import click
import re
from datasets import Dataset
import wandb
import numpy as np
import json
from datasets import load_dataset

# ---------- Constants ----------
PROMPT = """Solve the following math word problem.

{q}

Think step-by-step. Then, provide the final answer as a single integer in the format "Answer: XXX" with no extra formatting."""


# ---------- Utility Functions ----------
def make_dataset(difficulty_level, dir_path='outputs/gsm8k_platinum/accuracy_subset', subset='train'):
    ds = load_dataset("json", data_files=f'{dir_path}/{difficulty_level}_{subset}.jsonl', split="train")
    def format_prompt(example):
        new_prompt = PROMPT.format(q=example['question'])
        return {'prompt' : new_prompt}
    ds = ds.map(format_prompt)

    ds = ds.map(lambda x: {"answer": x["parsed"]})
    ds = ds.remove_columns("parsed")
    return ds

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 06-29 18:20:48 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 06-29 18:20:48 [__init__.py:239] Automatically detected platform cuda.


In [2]:
def load_model(model_name, adapter=None):
    max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
    dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
    load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

    model_, tokenizer_ = FastLanguageModel.from_pretrained(
        model_name = model_name,
        max_seq_length = max_seq_length,
        dtype = dtype,
        load_in_4bit = load_in_4bit,
        # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
    )
    if adapter is not None:
        model_.load_adapter(adapter)
    _ = FastLanguageModel.for_inference(model_) # Enable native 2x faster inference
    return model_, tokenizer_

In [12]:
def get_answer_batch(tokenizer, model, prompts, num_times_to_repeat: int = 8, apply_template: bool = False):
    """
    Process multiple prompts in a single batch for better GPU utilization
    """
    generation_kwargs = {
        "max_new_tokens": 250,
        "use_cache": True,
        "temperature": 0.9,
        "top_k": None,
        "do_sample": True,
    }

    # Create all formatted prompts at once
    if apply_template:
        all_formatted_prompts = []
        for prompt in prompts:
            formatted_prompt = tokenizer.apply_chat_template(
                [{'role': 'user', 'content': prompt}],
                tokenize=False, add_generation_prompt=True)
            all_formatted_prompts.extend([formatted_prompt] * num_times_to_repeat)
    else:
        all_formatted_prompts = prompts * num_times_to_repeat
    
    # Tokenize in larger batches
    inputs = tokenizer(all_formatted_prompts, return_tensors="pt", padding=True, truncation=True).to("cuda")
    
    with torch.no_grad():  # Disable gradient computation for inference
        outputs = model.generate(**inputs, **generation_kwargs)
    
    outputs = outputs[:, inputs.input_ids.shape[1]:]
    decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    
    # Reshape outputs back to per-prompt format
    result = []
    for i in range(len(prompts)):
        start_idx = i * num_times_to_repeat
        end_idx = start_idx + num_times_to_repeat
        result.append(decoded_outputs[start_idx:end_idx])
    
    return result

ANSWER_PATTERN = re.compile(r"Answer:\s*(-?\d+)")
def parse_llm_answer(text):
    """
    Extracts the final answer from the LLM output.
    Expects the format: "Answer: XXX" where XXX is an integer.
    
    Args:
        text (str): The output from the LLM.
    
    Returns:
        float or None: The extracted float answer, or None if not found.
    """
    match = ANSWER_PATTERN.search(text)
    try:
        if match:
            return int(match.group(1))
    except:
        return None
    return None

In [4]:
model, tokenizer = load_model('unsloth/Qwen3-4B-unsloth-bnb-4bit', 'models/gsm8k/difficulty37_8gen_1k_qwen4b/lora')

==((====))==  Unsloth 2025.6.8: Fast Qwen3 patching. Transformers: 4.52.4. vLLM: 0.8.5.
   \\   /|    NVIDIA A10G. Num GPUs = 1. Max memory: 22.069 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post2. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


In [5]:
ds = make_dataset(37, subset='test')
len(ds)

14

In [8]:
prompts = [x['prompt'] for x in ds][:4]
outputs = get_answer_batch(tokenizer, model, prompts, num_times_to_repeat=1)

  out = torch_matmul(X, W.t(), out = out)
  out = torch_matmul(X, W, out = out)


In [13]:
new_outputs = get_answer_batch(tokenizer, model, prompts, num_times_to_repeat=1, apply_template=False)

  out = torch_matmul(X, W.t(), out = out)
  out = torch_matmul(X, W, out = out)


In [14]:
print(new_outputs[0][0])

0.

Okay, let's see. The problem is about the glee club and the football team ordering and eating pizzas, and we need to find out how many pizzas are left. Alright, let's break it down step by step.

First, the glee club ordered 20 pizzas. They ate 70% of them. So, to find out how many they ate, I need to calculate 70% of 20. Let me write that down. 70% is the same as 0.70 in decimal. So 20 multiplied by 0.70. Let me do that calculation. 20 * 0.7 is 14. So they ate 14 pizzas. That means the number of pizzas left for the glee club is the total ordered minus the ones eaten. So 20 - 14 = 6. So they have 6 left.

Now, the football team ordered twice as many pizzas as the glee club. The glee club ordered 20, so twice that is 20 * 2 = 40. So the football team ordered 40 pizzas. Then they ate 80%


In [17]:
for n in new_outputs:
    print(parse_llm_answer(n[0]))

None
5
70
22
