In [None]:
!pip install unsloth vllm  
!pip install triton==3.1.0  
!pip install -U pynvml
!pip install wandb

In [None]:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-4b-it")

In [None]:
import wandb
wandb.login(key = "")

In [None]:
import sys
sys.path.append('/root/workspace/Align-CodeGemma')

In [None]:
from trl import GRPOConfig, GRPOTrainer, get_peft_config, ModelConfig

# our model we are going to use as policy 
model_config = ModelConfig(
    model_name_or_path="google/gemma-3-4b-it",
    torch_dtype="bfloat16",
    attn_implementation="flash_attention_2",
    use_peft=True,
    load_in_4bit=False,
)

# Hyperparameters
training_args = GRPOConfig(
    output_dir="google/gemma-3-4b-it",
    learning_rate=5e-7,
    lr_scheduler_type="cosine",
    logging_steps=10,
    max_steps=100,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    bf16=True,
    # GRPO specific parameters
    max_prompt_length=256,
    max_completion_length=1024, # max length of the generated output for our solution
    num_generations=2,
    beta=0.001,
    
)
trainer = GRPOTrainer(
    model=model_config.model_name_or_path,
    reward_funcs=[run_tests_and_reward, reward_based_on_jax_usage],
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    peft_config=get_peft_config(model_config),
)

In [1]:
import json
from datasets import load_dataset, Dataset, DatasetDict
from prompt_template import format_instruction
import pandas as pd
def load_and_format_json(json_path):
    with open(json_path, "r", encoding="utf-8") as f:
        data_list = json.load(f)  
    data = [
        { 
            "prompt": format_instruction(entry["instruction"]) 
        }
        for entry in data_list
    ]
    df = pd.DataFrame(data)
    hf_dataset = Dataset.from_pandas(df)
    return hf_dataset
train_dataset = load_and_format_json("/Users/pavankumartaddi/Desktop/Align-CodeGemma/datas/train_meta.json")
test_dataset = load_and_format_json("/Users/pavankumartaddi/Desktop/Align-CodeGemma/datas/test_meta.json")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_dataset[0]

{'prompt': "\nYou are an expert AI assistant specializing in generating highly efficient, well-structured, and optimized code using JAX.  \nFollow these principles:\n1. **Prioritize efficiency**: Use the most optimal algorithms, minimize computational overhead, and leverage JAX's just-in-time (JIT) compilation and automatic differentiation capabilities for performance optimization. Use JAX primitives wherever applicable for better performance, such as `jax.jit`, `jax.grad`, `jax.vmap`, and `jax.pmap`.\n2. **Leverage JAX and standard libraries**: Use JAX's powerful vectorized operations (`jax.numpy`,jax.lax), automatic differentiation (`jax.grad`), and other built-in JAX functions to avoid unnecessary custom implementations. Take full advantage of JAX primitives for parallelism, batching.\n3. **Verify with test cases**: Always include test cases with assertions, including edge cases, to validate the correctness of the solution and to ensure robustness. When relevant, leverage JAX's prim

In [None]:
import re
from openai.types import Completion
from execserver.code_exec_reqs import exec_test_batched
import utils
from typing import List, Dict, Any
from openai.types.chat import ChatCompletion  

import re
from typing import List

import re

def run_tests_and_reward(
    prompts: List[str],
    completions: List[List[dict]],
    timeout: int = 60,
    tests: str = "",
    timeout_on_client: bool = False
) -> List[int]:
    import utils  # Ensure utils is imported

    server = "http://localhost:8000"
    codes = []

    for completion_group in completions:
        if not completion_group or "content" not in completion_group[0]:
            codes.append(0.0)
            continue

        content = completion_group[0]["content"]

        # Extract all Python code blocks using the utils function
        python_code_blocks = utils.find_code_blocks(content, tag="python")

        if not python_code_blocks:
            codes.append(0.0)
            continue

        # Join code blocks and clean triple quotes
        full_code = "\n\n".join(python_code_blocks)
        full_code = full_code.replace('\\"""', '"""')

        codes.append(full_code)

    return exec_test_batched(
        server, codes,
        timeout=timeout,
        timeout_on_client=timeout_on_client
    )

    
    

import re
from typing import List



def reward_based_on_jax_usage(
    prompts: List[str],
    completions: List[List[dict]]
) -> List[float]:
    codes = []
    for completion in completions:
        content = completion[0]["content"] if completion and "content" in completion[0] else ""
        codes.append(utils.find_code_blocks(content, tag="python"))  # Strip leading/trailing whitespace

    if not codes:
        return []

    # Calculate JAX usage scores (assuming count_jax_usage is defined elsewhere)
    raw_scores = [utils.count_jax_usage(code) for code in codes]
    return raw_scores


In [None]:
from trl import GRPOConfig, GRPOTrainer

# Configure GRPO training parameters.
# This configuration sets up the training hyperparameters, optimization settings, and inference acceleration via vLLM.
training_args = GRPOConfig(
    use_vllm = True,                     # Enable vLLM to accelerate inference during training.
    learning_rate = 5e-6,                # Set the learning rate for the optimizer.
    adam_beta1 = 0.9,                    # First beta parameter for the AdamW optimizer.
    adam_beta2 = 0.99,                   # Second beta parameter for the AdamW optimizer.
    weight_decay = 0.1,                  # Weight decay to regularize the model and prevent overfitting.
    warmup_ratio = 0.1,                  # Fraction of steps used for learning rate warmup.
    lr_scheduler_type = "cosine",        # Use cosine annealing for the learning rate scheduler.
    optim = "adamw_8bit",                # Use 8-bit AdamW optimizer for memory efficiency.
    logging_steps = 1,                   # Log training information every step.
    bf16 = is_bfloat16_supported(),      # Use bfloat16 precision if supported by the GPU.
    fp16 = not is_bfloat16_supported(),  # Otherwise, fall back to fp16 precision.
    per_device_train_batch_size = 1,     # Batch size per device during training.
    gradient_accumulation_steps = 1,     # Accumulate gradients over this many steps (increase for smoother training if needed).
    num_generations = 8,                 # Number of generations per prompt (reduce if memory issues occur).
    max_prompt_length = 256,             # Maximum length for the input prompt.
    max_completion_length = 4096,         # Maximum length for the generated completion.
    num_train_epochs = 1,               # Uncomment this line to run training for one epoch.
    max_steps = 250,                     # Maximum number of training steps.
    save_steps = 250,                    # Save the model checkpoint every specified number of steps.
    max_grad_norm = 0.1,                 # Maximum gradient norm for gradient clipping.
    report_to = "wandb",                  # Disable reporting to external services like WandB.
    output_dir = "/root/workspace/Align-CodeGemma/src/outputs",              # Directory to save the training outputs and checkpoints.
)

# Instantiate the GRPO trainer with the model, tokenizer, reward functions, and training dataset.
trainer = GRPOTrainer(
    model = model,                       # The language model to be trained.
    processing_class = tokenizer,        # The tokenizer used to preprocess the data.
    reward_funcs = [
         run_tests_and_reward,
         format_reward_func,
        reward_based_on_jax_usage     # Reward function evaluating the correctness of the answer.
    ],
    args = training_args,                # GRPO training configuration.
    train_dataset = train_dataset,  
    eval_dataset = test_dataset           # The training dataset containing prompts and expected answers.
)

# Begin training using the GRPO algorithm.
trainer.train()

# Save the LoRA-adapted model for later use.
model.save_lora("grpo_saved_lora")

In [None]:
text = tokenizer.apply_chat_template([
    {"role" : "user", "content" : "How many r's are in strawberry?"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams

# Set the sampling parameters for text generation.
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)

# Generate a response from the model without applying any LoRA adapter.
output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,  # No LoRA adapter is used here.
)[0].outputs[0].text

# Print the generated output.
print(output)

In [None]:
text = tokenizer.apply_chat_template([
    {"role" : "system", "content" : SYSTEM_PROMPT},
    {"role" : "user", "content" : "How many r's are in strawberry?"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams

# Set sampling parameters for controlled text generation.
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)

# Generate a response using the saved LoRA adapter.
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = model.load_lora("grpo_saved_lora"),  # Load the saved LoRA adapter.
)[0].outputs[0].text

# Print the generated response.
print(output)