In [20]:
!pip install unsloth vllm  
!pip install triton==3.1.0  
!pip install -U pynvml
!pip install wandb

Collecting unsloth
  Downloading unsloth-2025.3.19-py3-none-any.whl.metadata (46 kB)
Collecting unsloth_zoo>=2025.3.17 (from unsloth)
  Downloading unsloth_zoo-2025.3.17-py3-none-any.whl.metadata (8.0 kB)
Collecting xformers>=0.0.27.post2 (from unsloth)
  Downloading xformers-0.0.29.post3.tar.gz (8.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.5/8.5 MB[0m [31m23.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting bitsandbytes (from unsloth)
  Downloading bitsandbytes-0.42.0-py3-none-any.whl.metadata (9.9 kB)
Collecting tyro (from unsloth)
  Downloading tyro-0.9.18-py3-none-any.whl.metadata (9.2 kB)
Collecting trl!=0.15.0,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,<=0.15.2,>=0.7.9 (from unsloth)
  Downloading trl-0.15.2-py3-none-any.whl.metadata (11 kB)
Collecting peft!=0.11.0,>=0.7.1 (from unsloth)
  Downloading peft-0.15.1-py3-none-any.whl.metadata (13 kB)
Collecting protobuf<4.0.0 (from unsloth)
  Do

In [None]:
import wandb
wandb.login(key = "")

In [None]:
import sys
sys.path.append('/root/workspace/Align-CodeGemma')

In [None]:
from unsloth import FastLanguageModel, PatchFastRL

# Patch the FastLanguageModel to integrate GRPO-specific modifications.
PatchFastRL("GRPO", FastLanguageModel)

from unsloth import is_bfloat16_supported
import torch

# Set maximum sequence length and LoRA rank (controls the adaptation complexity).
max_seq_length = 8192 # Increase if you need longer reasoning traces.
lora_rank = 64         # Larger rank can improve performance but may slow down training.

# Load the Qwen model in 4-bit mode for reduced memory usage and enable fast inference with vLLM.
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "google/codegemma-7b-it",
    max_seq_length = max_seq_length,
    load_in_4bit = True,           # Set to False if using LoRA in 16-bit precision.
    fast_inference = True,         # Enable vLLM for faster inference.
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.5,  # Adjust GPU memory usage to avoid out-of-memory errors.
)

# Wrap the model with PEFT (Parameter-Efficient Fine-Tuning) using LoRA.
model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank,           # Use a rank greater than 0; common choices include 8, 16, 32, 64, or 128.
    lora_alpha = lora_rank,  # A higher lora_alpha value means that the LoRA layers have a greater influence on the model's output, 
                             # while a lower value reduces this influence
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],                                       # Specify target modules; you can remove QKVO if memory is limited.
    use_gradient_checkpointing = "unsloth",  # Enable gradient checkpointing for long context finetuning.
    random_state = 3407,                     # Set a random seed for reproducibility.
)

In [1]:
import json
from datasets import load_dataset, Dataset, DatasetDict
from prompt_template import format_instruction
import pandas as pd
def load_and_format_json(json_path):
    with open(json_path, "r", encoding="utf-8") as f:
        data_list = json.load(f)  
    data = [
        { 
            "prompt": format_instruction(entry["instruction"]) 
        }
        for entry in data_list
    ]
    df = pd.DataFrame(data)
    hf_dataset = Dataset.from_pandas(df)
    return hf_dataset
train_dataset = load_and_format_json("/Users/pavankumartaddi/Desktop/Align-CodeGemma/datas/train_meta.json")
test_dataset = load_and_format_json("/Users/pavankumartaddi/Desktop/Align-CodeGemma/datas/test_meta.json")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_dataset[0]

{'prompt': "\nYou are an expert AI assistant specializing in generating highly efficient, well-structured, and optimized code using JAX.  \nFollow these principles:\n1. **Prioritize efficiency**: Use the most optimal algorithms, minimize computational overhead, and leverage JAX's just-in-time (JIT) compilation and automatic differentiation capabilities for performance optimization. Use JAX primitives wherever applicable for better performance, such as `jax.jit`, `jax.grad`, `jax.vmap`, and `jax.pmap`.\n2. **Leverage JAX and standard libraries**: Use JAX's powerful vectorized operations (`jax.numpy`,jax.lax), automatic differentiation (`jax.grad`), and other built-in JAX functions to avoid unnecessary custom implementations. Take full advantage of JAX primitives for parallelism, batching.\n3. **Verify with test cases**: Always include test cases with assertions, including edge cases, to validate the correctness of the solution and to ensure robustness. When relevant, leverage JAX's prim

In [None]:
import re
from openai.types import Completion
from execserver.code_exec_reqs import exec_test_batched
from utils import JAX_LAX_OPERATIONS,JAX_LIBRARIES,JAX_PRIMITIVES,count_jax_usage
from typing import List, Dict, Any
from openai.types.chat import ChatCompletion  

import re
from typing import List

def run_tests_and_reward(
    prompts: List[str],
    completions: List[List[dict]],
    timeout: int = 60,
    tests: str = "",
    timeout_on_client: bool = False
) -> List[int]:
    server = "http://localhost:8000"
    
    # Get all first content strings (direct content, no need for <code> and <test> parsing)
    contents = [completion[0]["content"] if completion and "content" in completion[0] else "" for completion in completions]
    
    codes = []
    for content in contents:
        if not content:
            codes.append(0.0)  # No content at all
            continue
        
        # If content is provided, treat it as the full code + possible tests
        full_code = content.strip() if content else ""

        # Only append full code if it's not empty, otherwise append 0.0
        if full_code:
            codes.append(full_code)
        else:
            codes.append(0.0)  # No code found

    # Send the codes for testing (via exec_test_batched)
    return exec_test_batched(
        server, codes,
        timeout=timeout,
        timeout_on_client=timeout_on_client
    )
    

import re
from typing import List

def format_reward_func(
    prompts: List[str],
    completions: List[List[dict]],
    **kwargs
) -> List[float]:
    # Regex pattern checks for "think", "code", and "test" in the content, allowing for any text between them
    pattern = re.compile(
        r"\bthink\b.*\bcode\b.*\btest\b.*",  # Ensure these keywords appear in the content
        re.DOTALL
    )

    rewards = []
    for completion in completions:
        # Extract and strip the content to remove leading/trailing whitespace
        content = completion[0]["content"] if completion and "content" in completion[0] else ""
        
        # Check if the stripped content matches the pattern and assign reward
        rewards.append(1.0 if pattern.fullmatch(content) else 0.0)
    
    return rewards


def reward_based_on_jax_usage(
    prompts: List[str],
    completions: List[List[dict]]
) -> List[float]:
    codes = []
    for completion in completions:
        content = completion[0]["content"] if completion and "content" in completion[0] else ""
        codes.append(content.strip())  # Strip leading/trailing whitespace

    if not codes:
        return []

    # Calculate JAX usage scores (assuming count_jax_usage is defined elsewhere)
    raw_scores = [count_jax_usage(code) for code in codes]
    return raw_scores


In [None]:
from trl import GRPOConfig, GRPOTrainer

# Configure GRPO training parameters.
# This configuration sets up the training hyperparameters, optimization settings, and inference acceleration via vLLM.
training_args = GRPOConfig(
    use_vllm = True,                     # Enable vLLM to accelerate inference during training.
    learning_rate = 5e-6,                # Set the learning rate for the optimizer.
    adam_beta1 = 0.9,                    # First beta parameter for the AdamW optimizer.
    adam_beta2 = 0.99,                   # Second beta parameter for the AdamW optimizer.
    weight_decay = 0.1,                  # Weight decay to regularize the model and prevent overfitting.
    warmup_ratio = 0.1,                  # Fraction of steps used for learning rate warmup.
    lr_scheduler_type = "cosine",        # Use cosine annealing for the learning rate scheduler.
    optim = "adamw_8bit",                # Use 8-bit AdamW optimizer for memory efficiency.
    logging_steps = 1,                   # Log training information every step.
    bf16 = is_bfloat16_supported(),      # Use bfloat16 precision if supported by the GPU.
    fp16 = not is_bfloat16_supported(),  # Otherwise, fall back to fp16 precision.
    per_device_train_batch_size = 1,     # Batch size per device during training.
    gradient_accumulation_steps = 1,     # Accumulate gradients over this many steps (increase for smoother training if needed).
    num_generations = 8,                 # Number of generations per prompt (reduce if memory issues occur).
    max_prompt_length = 256,             # Maximum length for the input prompt.
    max_completion_length = 4096,         # Maximum length for the generated completion.
    num_train_epochs = 1,               # Uncomment this line to run training for one epoch.
    max_steps = 250,                     # Maximum number of training steps.
    save_steps = 250,                    # Save the model checkpoint every specified number of steps.
    max_grad_norm = 0.1,                 # Maximum gradient norm for gradient clipping.
    report_to = "wandb",                  # Disable reporting to external services like WandB.
    output_dir = "/root/workspace/Align-CodeGemma/src/outputs",              # Directory to save the training outputs and checkpoints.
)

# Instantiate the GRPO trainer with the model, tokenizer, reward functions, and training dataset.
trainer = GRPOTrainer(
    model = model,                       # The language model to be trained.
    processing_class = tokenizer,        # The tokenizer used to preprocess the data.
    reward_funcs = [
         run_tests_and_reward,
         format_reward_func,
        reward_based_on_jax_usage     # Reward function evaluating the correctness of the answer.
    ],
    args = training_args,                # GRPO training configuration.
    train_dataset = train_dataset,  
    eval_dataset = test_dataset           # The training dataset containing prompts and expected answers.
)

# Begin training using the GRPO algorithm.
trainer.train()

# Save the LoRA-adapted model for later use.
model.save_lora("grpo_saved_lora")

In [None]:
text = tokenizer.apply_chat_template([
    {"role" : "user", "content" : "How many r's are in strawberry?"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams

# Set the sampling parameters for text generation.
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)

# Generate a response from the model without applying any LoRA adapter.
output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,  # No LoRA adapter is used here.
)[0].outputs[0].text

# Print the generated output.
print(output)

In [None]:
text = tokenizer.apply_chat_template([
    {"role" : "system", "content" : SYSTEM_PROMPT},
    {"role" : "user", "content" : "How many r's are in strawberry?"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams

# Set sampling parameters for controlled text generation.
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)

# Generate a response using the saved LoRA adapter.
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = model.load_lora("grpo_saved_lora"),  # Load the saved LoRA adapter.
)[0].outputs[0].text

# Print the generated response.
print(output)