In [None]:
!pip install unsloth vllm  
!pip install triton==3.1.0  
!pip install -U pynvml

In [None]:
from unsloth import FastLanguageModel, PatchFastRL

# Patch the FastLanguageModel to integrate GRPO-specific modifications.
PatchFastRL("GRPO", FastLanguageModel)

from unsloth import is_bfloat16_supported
import torch

# Set maximum sequence length and LoRA rank (controls the adaptation complexity).
max_seq_length = 1024  # Increase if you need longer reasoning traces.
lora_rank = 64         # Larger rank can improve performance but may slow down training.

# Load the Qwen model in 4-bit mode for reduced memory usage and enable fast inference with vLLM.
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "google/codegemma-7b-it",
    max_seq_length = max_seq_length,
    load_in_4bit = True,           # Set to False if using LoRA in 16-bit precision.
    fast_inference = True,         # Enable vLLM for faster inference.
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.5,  # Adjust GPU memory usage to avoid out-of-memory errors.
)

# Wrap the model with PEFT (Parameter-Efficient Fine-Tuning) using LoRA.
model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank,           # Use a rank greater than 0; common choices include 8, 16, 32, 64, or 128.
    lora_alpha = lora_rank,  # A higher lora_alpha value means that the LoRA layers have a greater influence on the model's output, 
                             # while a lower value reduces this influence
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],                                       # Specify target modules; you can remove QKVO if memory is limited.
    use_gradient_checkpointing = "unsloth",  # Enable gradient checkpointing for long context finetuning.
    random_state = 3407,                     # Set a random seed for reproducibility.
)

In [15]:
import json
from datasets import load_dataset, Dataset, DatasetDict
from prompt_template import format_instruction
import pandas as pd
def load_and_format_json(json_path):
    with open(json_path, "r", encoding="utf-8") as f:
        data_list = json.load(f)  
    data = [
        { 
            "prompt": format_instruction(entry["instruction"]) 
        }
        for entry in data_list
    ]
    df = pd.DataFrame(data)
    hf_dataset = Dataset.from_pandas(df)
    return hf_dataset
dataset = load_and_format_json("/Users/pavankumartaddi/Desktop/Align-CodeGemma/outputs/test_meta.json")

In [14]:
from typing import List
import re
from openai.types import Completion
from execserver.code_exec_reqs import run_coverage_batched
from utils import JAX_LAX_OPERATIONS,JAX_LIBRARIES,JAX_PRIMITIVES,count_jax_usage
def run_tests_and_reward(completions: List[Completion], timeout=60, tests="", timeout_on_client=False) -> List[int]:
    server = "http://localhost:8000"
    codes = []
    for completion in completions:
        for choice in completion["choices"]:
            codes.append(choice["text"])
    coverage_results = run_coverage_batched(server, codes, tests, timeout, timeout_on_client)
    rewards = [1 if result and result > 0 else 0 for result in coverage_results]
    return rewards
def format_reward_func(completions, **kwargs):
    pattern = r"^<response>\s*<think>.*?</think>\s*<code>.*?</code>\s*<test>.*?</test>\s*</response>$"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, content, re.DOTALL) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]
def reward_based_on_jax_usage(completions: List[Completion]) -> List[float]:
    codes = []
    for completion in completions:
        for choice in completion["choices"]:
            codes.append(choice["text"])
    max_possible_score = len(JAX_LIBRARIES) + len(JAX_PRIMITIVES) + len(JAX_LAX_OPERATIONS)
    rewards = [
        count_jax_usage(code) / max_possible_score if max_possible_score > 0 else 0.0
        for code in codes
    ]
    return rewards

In [None]:
from trl import GRPOConfig, GRPOTrainer

# Configure GRPO training parameters.
# This configuration sets up the training hyperparameters, optimization settings, and inference acceleration via vLLM.
training_args = GRPOConfig(
    use_vllm = True,                     # Enable vLLM to accelerate inference during training.
    learning_rate = 5e-6,                # Set the learning rate for the optimizer.
    adam_beta1 = 0.9,                    # First beta parameter for the AdamW optimizer.
    adam_beta2 = 0.99,                   # Second beta parameter for the AdamW optimizer.
    weight_decay = 0.1,                  # Weight decay to regularize the model and prevent overfitting.
    warmup_ratio = 0.1,                  # Fraction of steps used for learning rate warmup.
    lr_scheduler_type = "cosine",        # Use cosine annealing for the learning rate scheduler.
    optim = "adamw_8bit",                # Use 8-bit AdamW optimizer for memory efficiency.
    logging_steps = 1,                   # Log training information every step.
    bf16 = is_bfloat16_supported(),      # Use bfloat16 precision if supported by the GPU.
    fp16 = not is_bfloat16_supported(),  # Otherwise, fall back to fp16 precision.
    per_device_train_batch_size = 1,     # Batch size per device during training.
    gradient_accumulation_steps = 1,     # Accumulate gradients over this many steps (increase for smoother training if needed).
    num_generations = 8,                 # Number of generations per prompt (reduce if memory issues occur).
    max_prompt_length = 256,             # Maximum length for the input prompt.
    max_completion_length = 200,         # Maximum length for the generated completion.
    # num_train_epochs = 1,               # Uncomment this line to run training for one epoch.
    max_steps = 250,                     # Maximum number of training steps.
    save_steps = 250,                    # Save the model checkpoint every specified number of steps.
    max_grad_norm = 0.1,                 # Maximum gradient norm for gradient clipping.
    report_to = "none",                  # Disable reporting to external services like WandB.
    output_dir = "outputs",              # Directory to save the training outputs and checkpoints.
)

# Instantiate the GRPO trainer with the model, tokenizer, reward functions, and training dataset.
trainer = GRPOTrainer(
    model = model,                       # The language model to be trained.
    processing_class = tokenizer,        # The tokenizer used to preprocess the data.
    reward_funcs = [
        xmlcount_reward_func,            # Reward function based on XML tag counts.
        soft_format_reward_func,         # Reward function checking for soft adherence to XML formatting.
        strict_format_reward_func,       # Reward function checking for strict XML formatting.
        int_reward_func,                 # Reward function that provides rewards for integer outputs.
        correctness_reward_func,         # Reward function evaluating the correctness of the answer.
    ],
    args = training_args,                # GRPO training configuration.
    train_dataset = dataset,             # The training dataset containing prompts and expected answers.
)

# Begin training using the GRPO algorithm.
trainer.train()

# Save the LoRA-adapted model for later use.
model.save_lora("grpo_saved_lora")

In [None]:
text = tokenizer.apply_chat_template([
    {"role" : "user", "content" : "How many r's are in strawberry?"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams

# Set the sampling parameters for text generation.
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)

# Generate a response from the model without applying any LoRA adapter.
output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,  # No LoRA adapter is used here.
)[0].outputs[0].text

# Print the generated output.
print(output)

In [None]:
text = tokenizer.apply_chat_template([
    {"role" : "system", "content" : SYSTEM_PROMPT},
    {"role" : "user", "content" : "How many r's are in strawberry?"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams

# Set sampling parameters for controlled text generation.
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)

# Generate a response using the saved LoRA adapter.
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = model.load_lora("grpo_saved_lora"),  # Load the saved LoRA adapter.
)[0].outputs[0].text

# Print the generated response.
print(output)