<a href="https://colab.research.google.com/github/SreeyaSrikanth/RL-Prompt-Compression/blob/main/phi3_GRPO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
import os
os.environ["UNSLOTH_VLLM_STANDBY"] = "1"

In [None]:
%%capture
import os
!pip install --upgrade -qqq uv
try: import numpy, PIL; get_numpy = f"numpy=={numpy.__version__}"; get_pil = f"pillow=={PIL.__version__}"
except: get_numpy = "numpy"; get_pil = "pillow"
try: import subprocess; is_t4 = "Tesla T4" in str(subprocess.check_output(["nvidia-smi"]))
except: is_t4 = False
get_vllm, get_triton = ("vllm==0.9.2", "triton==3.2.0") if is_t4 else ("vllm==0.10.2", "triton")
!uv pip install -qqq --upgrade \
    unsloth {get_vllm} {get_numpy} {get_pil} torchvision bitsandbytes xformers
!uv pip install -qqq {get_triton}
!uv pip install transformers==4.56.2
!uv pip install --no-deps trl==0.22.2
!uv pip install -qqq sentence-transformers

In [None]:
from unsloth import FastModel
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, TextStreamer
from sentence_transformers import SentenceTransformer, CrossEncoder
import torch
from datasets import load_dataset, Dataset
from sklearn.metrics.pairwise import cosine_similarity
import diskcache
import re
from trl import SFTTrainer

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.




INFO 10-27 15:31:51 [__init__.py:244] Automatically detected platform cuda.
ERROR 10-27 15:32:02 [fa_utils.py:57] Cannot use FA version 2 is not supported due to FA2 is only supported on devices with compute capability >= 8
🦥 Unsloth Zoo will now patch everything to make training faster!


In [None]:
if 'device' not in locals():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Device not set, defaulting to: {device}")

Device not set, defaulting to: cuda


Load up `Phi-3 4k Instruct`, and set parameters

In [None]:
## CELL: Load Model and Tokenizer (Phi-3 Mini 4-bit)

from unsloth import FastLanguageModel # Use FastLanguageModel for Phi-3
import torch

max_seq_length = 2048 # Match Fine_Tuning.ipynb
dtype = None # Auto detection for Phi-3

# Load Phi-3 Mini 4-bit model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit", # Switched back to Phi-3
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = True,  # Enable 4-bit quantization like Fine_Tuning.ipynb
    # token = "hf_...", # Add your Hugging Face token if needed
)
print("Model and Tokenizer Loaded (Phi-3 Mini 4-bit).")

==((====))==  Unsloth 2025.10.10: Fast Mistral patching. Transformers: 4.56.2. vLLM: 0.9.2.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.3.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Model and Tokenizer Loaded (Phi-3 Mini 4-bit).


Add LoRA adapters so we only need to update a small amount of parameters

In [None]:
## CELL: Add LoRA Adapters (Matching Fine_Tuning.ipynb)

model = FastLanguageModel.get_peft_model(
    model,
    r = 64,          # Matched
    lora_alpha = 128,# Matched
    target_modules = [ # Explicitly list modules like in Fine_Tuning.ipynb
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_dropout = 0,   # Matched
    bias = "none",      # Matched
    use_gradient_checkpointing = "unsloth", # Use Unsloth's optimized version
    random_state = 3407, # Matched
    use_rslora = False,
    loftq_config = None,
)
print("LoRA Adapters Added (Matching Fine_Tuning.ipynb Rank/Alpha/Modules).")
model.print_trainable_parameters() # Optional: See trainable parameters

Unsloth 2025.10.10 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


LoRA Adapters Added (Matching Fine_Tuning.ipynb Rank/Alpha/Modules).
trainable params: 119,537,664 || all params: 3,940,617,216 || trainable%: 3.0335


Prepare the GSM8K dataset

In [None]:
dataset = load_dataset("openai/gsm8k", "main", split = "train")
dataset

Dataset({
    features: ['question', 'answer'],
    num_rows: 7473
})

In [None]:
dataset = dataset.select(range(len(dataset) // 3))
dataset

Dataset({
    features: ['question', 'answer'],
    num_rows: 2491
})

In [None]:
print("New dataset example:")
print(f"Original Prompt: {dataset[0]['question']}")
print(f"Original Output: {dataset[0]['answer']}")

New dataset example:
Original Prompt: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
Original Output: Natalia sold 48/2 = <<48/2=24>>24 clips in May.
Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.
#### 72


In [None]:
def extract_hash_answer(text):
    parts = text.split("####")
    if len(parts) > 1:
        text_to_search = parts[-1].strip()
    else:
        text_to_search = text
    numbers = re.findall(r'[-+]?\d{1,3}(?:,\d{3})*(?:\.\d+)?|\d+\.\d+|\.\d+|\d+', text_to_search)
    if numbers:
        try: return float(numbers[-1].replace(',', ''))
        except ValueError: return None
    return None
extract_hash_answer(dataset[0]["answer"])

72.0

In [None]:
def extract_reasoning(text):
    parts = text.split("####")
    if len(parts) > 1:
        return parts[0].strip()
    else:
        return text.strip() if text else ""
extract_reasoning(dataset[0]["answer"])

'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.'

In [None]:
system_prompt = (
    "You are an expert prompt compressor. Your task is to rewrite the given prompt "
    "to be as short as possible while ensuring that a large language model "
    "can still generate the same, high-quality response. Retain all key constraints "
    "and entities from the original prompt."
)
print(system_prompt)

You are an expert prompt compressor. Your task is to rewrite the given prompt to be as short as possible while ensuring that a large language model can still generate the same, high-quality response. Retain all key constraints and entities from the original prompt.


In [None]:
dataset = dataset.map(lambda x: {
    "prompt" : [
        {"role": "system", "content": system_prompt},
        {"role": "user",   "content": x["question"]},
    ],
    "original_prompt": x["question"],
    "original_output": x["answer"],
    "ground_truth_answer": extract_hash_answer(x["answer"]),
})
dataset[0]

Map:   0%|          | 0/2491 [00:00<?, ? examples/s]

{'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
 'answer': 'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72',
 'prompt': [{'content': 'You are an expert prompt compressor. Your task is to rewrite the given prompt to be as short as possible while ensuring that a large language model can still generate the same, high-quality response. Retain all key constraints and entities from the original prompt.',
   'role': 'system'},
  {'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
   'role': 'user'}],
 'original_prompt': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?

Create SFT Dataset and SFT Train the model

In [None]:
from datasets import Dataset

# Your provided dataset
warmup_samples = [
    {"question": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. In June, she sold 5 fewer clips than in May. How many clips did Natalia sell in total?", "compressed_question": "Natalia sold 48 clips in April, half as many in May, and 5 fewer than May in June. Total clips sold?"},
    {"question": "A bakery made 100 donuts in the morning. They sold 3/4 of them. In the afternoon, they made another 50 donuts. How many donuts did the bakery have at the end of the day?", "compressed_question": "A bakery had 100 donuts, sold 3/4, then made 50 more. How many donuts are left?"},
    {"question": "John is reading a book that is 450 pages long. He read 1/3 of the book on Monday and 100 pages on Tuesday. How many pages does he have left to read?", "compressed_question": "A 450-page book: John read 1/3 on Monday and 100 pages on Tuesday. How many pages are left?"},
    {"question": "A rectangular garden is 15 meters long and 8 meters wide. A fence is built around it. If the fencing material costs $12 per meter, what is the total cost of the fence?", "compressed_question": "A 15m by 8m rectangular garden needs a fence. Fencing costs $12/meter. What's the total cost?"},
    {"question": "Sarah has a budget of $200 for shopping. She buys a pair of shoes for $75 and a dress for $50. She then finds a handbag that is 20% off its original price of $60. Can she afford the handbag?", "compressed_question": "Sarah's budget is $200. She buys $75 shoes and a $50 dress. Can she afford a $60 handbag with a 20% discount?"},
    {"question": "A train travels at a speed of 80 kilometers per hour. It leaves station A at 9:00 AM and is scheduled to arrive at station B at 1:30 PM. How far is station B from station A?", "compressed_question": "A train travels at 80 km/h from 9:00 AM to 1:30 PM. What is the distance traveled?"},
    {"question": "A recipe for a cake requires 250 grams of flour, 150 grams of sugar, and 100 grams of butter. If you want to make 3 cakes, how much of each ingredient do you need in total?", "compressed_question": "A cake needs 250g flour, 150g sugar, and 100g butter. How much of each ingredient for 3 cakes?"},
    {"question": "Mark is saving for a new bike that costs $500. He already has $150 saved. If he saves $25 every week, how many weeks will it take him to save enough money for the bike?", "compressed_question": "A bike costs $500. Mark has $150 and saves $25 weekly. How many weeks until he can afford it?"},
    {"question": "There are 30 students in a class. 2/5 of them are boys. On a particular day, 1/3 of the boys are absent. How many boys are present in the class on that day?", "compressed_question": "A class has 30 students, 2/5 are boys. If 1/3 of the boys are absent, how many are present?"},
    {"question": "A water tank has a capacity of 5000 liters. It is currently 60% full. If water is being added to the tank at a rate of 100 liters per minute, how long will it take to fill the tank completely?", "compressed_question": "A 5000L tank is 60% full. If filled at 100L/min, how long until it's full?"},
    {"question": "A bookstore is having a sale where all books are 15% off. If a book originally costs $20, what is the sale price? If you pay with a $50 bill, how much change do you get?", "compressed_question": "A $20 book is 15% off. What is the sale price and change from $50?"},
    {"question": "A farmer has 120 chickens and cows in total. The number of chickens is three times the number of cows. How many chickens and how many cows does the farmer have?", "compressed_question": "A farmer has 120 chickens and cows. There are three times as many chickens as cows. How many of each?"},
    {"question": "A car's fuel tank holds 50 liters of gasoline. The car consumes 8 liters of gasoline per 100 kilometers. If the tank is full, how far can the car travel before it runs out of fuel?", "compressed_question": "A car has a 50L fuel tank and consumes 8L/100km. What is the car's maximum range on a full tank?"},
    {"question": "The sum of three consecutive integers is 147. What are the three integers?", "compressed_question": "The sum of three consecutive integers is 147. Find the integers."},
    {"question": "A company's profit was $500,000 in 2022. In 2023, the profit increased by 12%. What was the profit in 2023?", "compressed_question": "A company's profit was $500,000. It increased by 12% the next year. What was the new profit?"},
    {"question": "A library has 2500 books. 40% are fiction, 30% are non-fiction, and the rest are reference books. How many reference books are there in the library?", "compressed_question": "A library with 2500 books has 40% fiction and 30% non-fiction. How many are reference books?"},
    {"question": "A pizza is cut into 8 equal slices. Tom eats 3 slices, and Jane eats 2 slices. What fraction of the pizza is left?", "compressed_question": "A pizza has 8 slices. Tom eats 3 and Jane eats 2. What fraction remains?"},
    {"question": "A swimming pool is 25 meters long, 10 meters wide, and 2 meters deep. What is the volume of the pool in cubic meters?", "compressed_question": "A swimming pool is 25m long, 10m wide, and 2m deep. Calculate its volume."},
    {"question": "An airplane flies at an altitude of 35,000 feet. A submarine is at a depth of 1,500 feet below sea level. What is the vertical distance between the airplane and the submarine?", "compressed_question": "What is the vertical distance between an airplane at 35,000 feet altitude and a submarine 1,500 feet deep?"},
    {"question": "A factory produces 600 widgets per day. Due to a machine breakdown, the production is reduced by 25%. How many widgets are produced on that day?", "compressed_question": "A factory that produces 600 widgets per day has a 25% reduction in output. How many widgets are made?"},
    {"question": "If a shirt costs $45 and is on sale for 30% off, how much does it cost after the discount?", "compressed_question": "What is the price of a $45 shirt after a 30% discount?"},
    {"question": "A movie starts at 6:45 PM and lasts for 2 hours and 20 minutes. What time does the movie end?", "compressed_question": "A movie starts at 6:45 PM and runs for 2 hours 20 minutes. When does it end?"},
    {"question": "A garden has 12 rows of tomato plants. Each row has 8 plants. If each plant produces 5 tomatoes, what is the total number of tomatoes produced?", "compressed_question": "A garden has 12 rows with 8 tomato plants each. Each plant yields 5 tomatoes. What is the total tomato yield?"},
    {"question": "David weighs 80 kg. He goes on a diet and loses 15% of his weight. What is his new weight?", "compressed_question": "David, who weighs 80 kg, loses 15% of his weight. What is his new weight?"},
    {"question": "A car rental company charges $50 per day plus $0.20 per mile. If you rent a car for 3 days and drive 200 miles, what is the total rental cost?", "compressed_question": "Car rental costs $50/day plus $0.20/mile. What's the total cost for 3 days and 200 miles?"}
]

# System prompt needs to be defined (assuming it was in a previous cell)
# Example:
# system_prompt = (
#     "You are an expert prompt compressor..."
# )

# Convert list to Hugging Face Dataset
sft_dataset = Dataset.from_list(warmup_samples)

# Function to format data into chat template
def create_chat_prompt(example):
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": example["question"]},
        {"role": "assistant", "content": example["compressed_question"]}, # The expected output
    ]
    # This creates a single string following the model's chat format
    # Requires 'tokenizer' to be loaded already
    return {"text": tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)}

# Apply the formatting - Corrected function name here
sft_dataset_formatted = sft_dataset.map(create_chat_prompt)

print(f"--- Example of formatted SFT data ---\n{sft_dataset_formatted[0]['text']}")
print("\nDataset Prepared and Formatted.")

Map:   0%|          | 0/25 [00:00<?, ? examples/s]

--- Example of formatted SFT data ---
<|system|>
You are an expert prompt compressor. Your task is to rewrite the given prompt to be as short as possible while ensuring that a large language model can still generate the same, high-quality response. Retain all key constraints and entities from the original prompt.<|end|>
<|user|>
Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. In June, she sold 5 fewer clips than in May. How many clips did Natalia sell in total?<|end|>
<|assistant|>
Natalia sold 48 clips in April, half as many in May, and 5 fewer than May in June. Total clips sold?<|end|>
<|endoftext|>

Dataset Prepared and Formatted.


In [None]:
## CELL: Configure Training Arguments & SFTTrainer (Matching Fine_Tuning.ipynb)

from transformers import TrainingArguments
from trl import SFTTrainer
import torch # Ensure torch is imported

# --- Define Training Arguments (Matching Fine_Tuning.ipynb) ---
sft_training_args = TrainingArguments(
    output_dir="outputs_phi3_sft",         # Changed dir name slightly
    per_device_train_batch_size=2,         # Matched
    gradient_accumulation_steps=4,         # Matched (Effective batch size 8)
    learning_rate=2e-4,                    # Matched
    warmup_steps=10,                       # Matched
    num_train_epochs=10,                   # Kept 10 for small dataset (25 examples)
    logging_steps=25,                      # Matched
    optim="adamw_8bit",                    # Matched
    weight_decay=0.01,                     # Matched
    lr_scheduler_type="linear",            # Matched
    seed=3407,                             # Matched
    save_strategy="epoch",                 # Matched
    save_total_limit=2,                    # Matched
    fp16=not torch.cuda.is_bf16_supported(), # Matched (Handles T4 correctly)
    bf16=torch.cuda.is_bf16_supported(),   # Matched (Handles T4 correctly)
    remove_unused_columns=True,            # Good practice
    dataloader_pin_memory=False,           # Matched
    report_to="none",                      # Matched
)
print("Training Arguments Configured (Matched to Fine_Tuning.ipynb).")

# --- Initialize SFTTrainer (Matching Fine_Tuning.ipynb) ---
sft_trainer = SFTTrainer(
    model=model,                           # Correct model (Phi-3 + LoRA)
    tokenizer=tokenizer,                   # Correct tokenizer
    train_dataset=sft_dataset_formatted,   # Your formatted warmup dataset
    dataset_text_field="text",             # Matched
    max_seq_length=max_seq_length,         # Using 2048 as set above
    dataset_num_proc=2,                    # Matched
    args=sft_training_args,                # Use the matched arguments
    packing=False,                         # Keep False for this format
)
print("SFTTrainer Initialized.")

# --- Run Training ---
print("\n--- Starting SFT (Warmup) ---")
sft_trainer.train()
print("--- SFT (Warmup) Complete ---")

Training Arguments Configured (Matched to Fine_Tuning.ipynb).


Unsloth: Tokenizing ["text"] (num_proc=6):   0%|          | 0/25 [00:00<?, ? examples/s]

SFTTrainer Initialized.

--- Starting SFT (Warmup) ---


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 25 | Num Epochs = 10 | Total steps = 40
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 119,537,664 of 3,940,617,216 (3.03% trained)


Step,Training Loss
25,0.4897


--- SFT (Warmup) Complete ---


In [None]:
## CELL: Testing the SFT Model

from transformers import TextStreamer

# Enable inference mode for the model (important after training)
FastLanguageModel.for_inference(model)

print("\n--- Testing SFT-warmed-up Model for Prompt Compression ---")
test_prompt = "A restaurant sold 80 pizzas on Friday. On Saturday, they sold 110 pizzas. On Sunday, they sold 130 pizzas. What is the average number of pizzas sold per day over the weekend?"

# Assuming system_prompt was defined earlier
messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user",   "content": test_prompt},
]

# Use apply_chat_template for Phi-3
inputs = tokenizer.apply_chat_template(
    messages,
    tokenize = True,
    add_generation_prompt = True, # Signal for generation
    return_tensors = "pt",
).to("cuda")

print(f"\nOriginal Question for Test: {test_prompt}\n")
print("Generated Compressed Question:")

# Generate the compressed prompt
_ = model.generate(
    input_ids = inputs,
    max_new_tokens = 128, # Max length for compressed output
    use_cache = True,
    temperature = 0.7,    # You can adjust generation parameters
    top_p = 0.9,
    streamer = TextStreamer(tokenizer, skip_prompt = True),
    pad_token_id=tokenizer.eos_token_id, # Set pad token ID
)

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



--- Testing SFT-warmed-up Model for Prompt Compression ---

Original Question for Test: A restaurant sold 80 pizzas on Friday. On Saturday, they sold 110 pizzas. On Sunday, they sold 130 pizzas. What is the average number of pizzas sold per day over the weekend?

Generated Compressed Question:
A restaurant sold 80 pizzas on Friday, 110 on Saturday, and 130 on Sunday. What is the average number of pizzas sold per day over the weekend?<|end|>


Load up TinyLlama Evaluator, Similarity and Reasoning Embedder for Reward Function

In [None]:
# 1. Load Frozen LLM (TinyLlama)
evaluator_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
print(f"Loading frozen LLM: {evaluator_name}")
evaluator_tokenizer = AutoTokenizer.from_pretrained(evaluator_name)
evaluator_model = AutoModelForCausalLM.from_pretrained(
    evaluator_name,
    dtype=torch.bfloat16,
    device_map="cuda:0" # Assumes you're running on a single GPU
)
evaluator_model.eval() # Set to evaluation mode
print("Frozen LLM loaded.")

# 2. Load Cross-Encoder for semantic similarity
cross_encoder_model_name = 'cross-encoder/ms-marco-MiniLM-L-6-v2'
print(f"Loading cross encoder model: {cross_encoder_model_name}")
cross_encoder = CrossEncoder(cross_encoder_model_name, device="cuda:0")
print("cross encoder model loaded.")

# 3. Load Embedder for reasoning consistency
embedder_model_name = "all-MiniLM-L6-v2"
print(f"Loading embedder model: {embedder_model_name}")
embedder = SentenceTransformer(embedder_model_name, device="cuda:0")
print("Embedder loaded.")

Loading frozen LLM: TinyLlama/TinyLlama-1.1B-Chat-v1.0
Frozen LLM loaded.
Loading cross encoder model: cross-encoder/ms-marco-MiniLM-L-6-v2
cross encoder model loaded.
Loading embedder model: all-MiniLM-L6-v2
Embedder loaded.


Create the reward function

In [None]:
try:
    cache
except NameError:
    print("Initializing diskcache...")
    cache = diskcache.Cache('./evaluator_cache_tinyllama')
    print("Diskcache initialized.")

In [None]:
@cache.memoize()
def get_evaluator_output_with_reasoning(prompt, max_new_tokens=256):
    """
    Gets output from TinyLlama, explicitly asking for step-by-step reasoning.
    Returns the full output string including reasoning and final answer.
    """
    reasoning_prompt_template = (
        "<|user|>\n"
        "Think step-by-step to solve the following math problem. Show your work.\n"
        "Problem: {problem}\n\n"
        "Provide your final answer after ####."
        "</s>\n<|assistant|>"
    )
    chat_prompt = reasoning_prompt_template.format(problem=prompt)
    inputs = evaluator_tokenizer(chat_prompt, return_tensors="pt").to(evaluator_model.device)
    with torch.no_grad():
        outputs = evaluator_model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            pad_token_id=evaluator_tokenizer.eos_token_id,
        )
    response_ids = outputs[0][inputs['input_ids'].shape[1]:]
    response_text = evaluator_tokenizer.decode(response_ids, skip_special_tokens=True)
    return response_text.strip()

In [None]:
def compute_reward_metrics(
    orig_prompt, comp_prompt, evaluator_out_original, evaluator_out_comp, ground_truth_answer
    ):
    metrics = {}
    with torch.no_grad():
        # Embeddings for reasoning similarity
        reasoning_orig = extract_reasoning(evaluator_out_original)
        reasoning_comp = extract_reasoning(evaluator_out_comp)
        emb_reason_orig = None
        emb_reason_comp = None
        if reasoning_orig and reasoning_comp:
             emb_reason_orig = embedder.encode([reasoning_orig], normalize_embeddings=True, device=device)
             emb_reason_comp = embedder.encode([reasoning_comp], normalize_embeddings=True, device=device)

        # Embeddings for optional prompt similarity
        emb_orig_prompt = None
        emb_comp_prompt = None
        if orig_prompt and comp_prompt:
             emb_orig_prompt = embedder.encode([orig_prompt], normalize_embeddings=True, device=device)
             emb_comp_prompt = embedder.encode([comp_prompt], normalize_embeddings=True, device=device)


    # 1. Semantic Similarity (r_sem) - Cross-Encoder
    r_sem = 0.0
    if orig_prompt and comp_prompt:
        try:
             cross_enc_score = cross_encoder.predict([(orig_prompt, comp_prompt)])
             r_sem = float(cross_enc_score[0])
        except Exception as e:
             print(f"CrossEncoder prediction failed: {e}") # Debugging
             r_sem = 0.0
    metrics['r_sem'] = numpy.clip(r_sem if not numpy.isnan(r_sem) else 0.0, 0.0, 1.0) # Clip assuming positive score

    # (Alternate) Semantic Similarity (r_sem) - Bi-Encoder
    #r_sem = 0.0
    #if emb_orig_prompt is not None and emb_comp_prompt is not None:
    #     sim_matrix_prompt = cosine_similarity(emb_orig_prompt, emb_comp_prompt)
    #     r_sem = float(sim_matrix_prompt[0][0])
    #metrics['r_sem'] = numpy.clip(r_sem if not numpy.isnan(r_sem) else 0.0, -1.0, 1.0)

    # 2. Compression Ratio (r_comp)
    orig_len = len(orig_prompt.split())
    comp_len = len(comp_prompt.split())
    r_comp = 1.0 - (comp_len / max(1, orig_len))
    metrics['r_comp'] = max(r_comp, -1.0)

    # 3. Correctness Score (r_correct)
    parsed_evaluator_answer = extract_hash_answer(evaluator_out_comp)
    r_correct = 0.0
    if parsed_evaluator_answer is not None and ground_truth_answer is not None:
        if abs(parsed_evaluator_answer - ground_truth_answer) < 1e-2:
            r_correct = 1.0
    metrics['r_correct'] = r_correct

    # 4. Reasoning Score (r_reason)
    r_reason = 0.0
    if emb_reason_orig is not None and emb_reason_comp is not None:
        sim_matrix = cosine_similarity(emb_reason_orig, emb_reason_comp)
        r_reason = float(sim_matrix[0][0])
    metrics['r_reason'] = numpy.clip(r_reason if not numpy.isnan(r_reason) else 0.0, -1.0, 1.0)

    # 5. FINAL WEIGHTED REWARD
    final_reward = (0.5 * metrics['r_correct'] +    # Weight 0.4 (Correctness)
                    0.2 * metrics['r_reason'] +     # Weight 0.2 (CoT)
                    0.15 * metrics['r_comp'] +       # Weight 0.2 (Compression)
                    0.15 * metrics['r_sem'])         # Weight 0.2 (CrossEncoder Prompt Sim)
    metrics['final_reward'] = float(final_reward if not numpy.isnan(final_reward) else 0.0)

    return metrics

print("Updated compute_all_reward_metrics_v4 function defined with new final reward weighting.")

Updated compute_all_reward_metrics_v4 function defined with new final reward weighting.


In [None]:
def final_weighted_reward_func(prompts, completions, **kwargs):
    """
    Calculates the final weighted reward using metrics.
    Averages scores across generations for each prompt in the batch.
    """
    avg_final_rewards = []
    batch_size = len(completions)
    ground_truth_answers = kwargs.get("ground_truth_answer")

    for i in range(batch_size):
        # Extract original prompt string correctly from the list of dicts
        original_prompt = prompts[i][-1]["content"] if isinstance(prompts[i], list) else prompts[i]
        ground_truth_answer = ground_truth_answers[i]
        generated_prompts = [comp["content"] for comp in completions[i]]

        # Get evaluator output (with reasoning) for the original prompt
        evaluator_out_original = get_evaluator_output_with_reasoning(original_prompt)

        generation_rewards = []
        for comp_prompt in generated_prompts:
            # Get evaluator output (with reasoning) for the compressed prompt
            evaluator_out_comp = get_evaluator_output_with_reasoning(comp_prompt)

            # Compute metrics
            metrics = compute_reward_metrics(
                original_prompt, comp_prompt, evaluator_out_original, evaluator_out_comp, ground_truth_answer
            )
            generation_rewards.append(metrics['final_reward'])

        # Average the final weighted rewards for this prompt, handle empty list
        avg_reward = float(numpy.mean(generation_rewards)) if generation_rewards else 0.0
        avg_final_rewards.append(avg_reward)

    return avg_final_rewards

Set up GRPO Trainer and configurations

In [None]:
max_prompt_length = 512  # Max length of original prompt + system prompt
max_completion_length = 256 # Max length of the *compressed* prompt

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_torch_fused",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 4,
    num_generations = 4,
    max_prompt_length = max_prompt_length,
    max_completion_length = max_completion_length,
    #num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 25, # Keep this low for testing
    save_steps = 50,
    max_grad_norm = 0.1,
    report_to = "none", # Stays as "none"
    output_dir = "outputs",
)

Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 4


And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!

You might have to wait 150 to 200 steps for any action. You'll probably get 0 reward for the first 100 steps. Please be patient!

| Step | Training Loss | reward    | reward_std | completion_length | kl       |
|------|---------------|-----------|------------|-------------------|----------|
| 1    | 0.000000      | 0.125000  | 0.000000   | 200.000000        | 0.000000 |
| 2    | 0.000000      | 0.072375  | 0.248112   | 200.000000        | 0.000000 |
| 3    | 0.000000      | -0.079000 | 0.163776   | 182.500000        | 0.000005 |


In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [final_weighted_reward_func],
    args = training_args,
    train_dataset = dataset,
)
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 2,491 | Num Epochs = 1 | Total steps = 120
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 4 x 1) = 16
 "-____-"     Trainable parameters = 119,537,664 of 3,940,617,216 (3.03% trained)


Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,kl,rewards / final_weighted_reward_func / mean,rewards / final_weighted_reward_func / std


OutOfMemoryError: CUDA out of memory. Tried to allocate 124.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 50.12 MiB is free. Process 2424 has 14.69 GiB memory in use. Of the allocated memory 13.64 GiB is allocated by PyTorch, and 917.40 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Test the model with a sample input

In [None]:
messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user",   "content": (
        "If a shirt costs $45 and is on sale for 30% off, how much does it cost after the discount?"
    )},
]

text = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt = True,
    tokenize = False,
)
from transformers import TextStreamer
_ = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    max_new_tokens = 512, # Should match your max_completion_length
    temperature = 1.0, top_p = 0.95, top_k = 64,
    streamer = TextStreamer(tokenizer, skip_prompt = True),
)

Save the final model as LoRA adapters, either use Huggingface's `push_to_hub` for an online save or `save_pretrained` for a local save.

In [None]:
import pandas as pd

# Access the log history from the trainer's state
log_history = trainer.state.log_history

# Convert the log history (list of dictionaries) into a Pandas DataFrame
# Exclude the final entry which might just contain runtime stats
metrics_df = pd.DataFrame(log_history[:-1])

# Display the DataFrame
print("--- Training Metrics History ---")
print(metrics_df.to_string()) # .to_string() helps display all rows/columns

# You can also access the final summary metrics from the train() output if you captured it
# train_result = trainer.train()
# print("\n--- Final Training Summary ---")
# print(train_result.metrics)

In [None]:
import os
import shutil
from google.colab import files

# --- 1. Save Adapters Locally in Colab ---
output_directory = "grpo_phi3_adapters" # Choose a directory name

print(f"Saving model adapters and tokenizer to ./{output_directory}...")
model.save_pretrained(output_directory)
tokenizer.save_pretrained(output_directory)
print("Saving complete.")

# --- 2. Zip the Saved Directory ---
zip_filename = f"{output_directory}.zip"
print(f"Zipping the directory into {zip_filename}...")
shutil.make_archive(output_directory, 'zip', output_directory)
print("Zipping complete.")

# --- 3. Download the Zip File ---
print(f"Starting download of {zip_filename}. Please wait...")
files.download(zip_filename)
print("Download initiated.")

In [None]:
model.save_pretrained("gemma-3")  # Local saving
tokenizer.save_pretrained("gemma-3")
# model.push_to_hub("HF_ACCOUNT/gemma-3", token = "...") # Online saving
# tokenizer.push_to_hub("HF_ACCOUNT/gemma-3", token = "...") # Online saving