In [1]:
! nvidia-smi

Fri Mar 14 15:39:41 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.08             Driver Version: 535.161.08   CUDA Version: 12.4     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          On  | 00000000:4E:00.0 Off |                    0 |
| N/A   35C    P0              92W / 400W |      0MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
import re
import torch
from datasets import load_dataset,Dataset
from transformers import AutoTokenizer,AutoModelForCausalLM

from trl import GRPOTrainer,GRPOConfig


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
system_prompt = """你的名字是Brench-AI，是由Brench创造出的深度推理AI助手,专注于各种推理问题的解答和分析，拥有强大的推理能力和分析能力以及反思能力，可以帮助用户解决各种推理性问题。
Your name is Brench-AI, a deep reasoning AI assistant created by Brench, focusing on the answer and analysis of various reasoning problems. You focus on the solution and analysis of various reasoning problems. At the same time, you have strong reasoning, analytical and reflective abilities, which can help users solve various reasoning problems.
Please respond reasoning question in the following format:
<think>
...
</think>
<answer>
...
</answer>
"""

cot_format = """
<think>
{think_content}
</think>
<answer>
{answer_content}
</answer>
"""

# Dataset Preparation

## Answer Extraction

In [4]:
def extract_tag_answer(content: str) -> str:
    answer = content.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()


In [5]:
def extract_data_answer(content:str) -> str:
    idx = content.rfind(r"\boxed")
    if idx < 0:
        return None
    
    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(content):
        if content[i] == '{':
            num_left_braces_open += 1
        if content[i] == '}':
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1
    
    if right_brace_idx is None:
        boxed_ans = None
    else:
        boxed_ans = content[idx:right_brace_idx+1]
    
    return boxed_ans

def remove_boxed(s:str) -> str:
    left = r"\boxed{"
    try:
        assert s[:len(left)] == left
        assert s[-1] == '}'
        return s[len(left):-1]
    except Exception:
        return None

def extract_boxed_answer(content:str) -> str:
    boxed_ans = extract_data_answer(content)
    if boxed_ans is None:
        return None
    answer = remove_boxed(boxed_ans)
    if answer is None:
        return None
    return answer


In [6]:

def get_logic_questions() -> Dataset:
    data  = load_dataset("parquet",data_files="./data/R1-Zero-GRPO-750/data/train-00000-of-00001.parquet")['train']
    data = data.map(lambda x: {
        'prompt': [
            {'role': 'system', 'content': system_prompt},
            {'role': 'user', 'content': x['problem']}
        ],
        'answer': extract_boxed_answer(x['solution'])
    })
    return data

# Load Dataset and Maping

In [7]:
dataset = get_logic_questions()

Map: 100%|██████████| 750/750 [00:00<00:00, 7853.82 examples/s]


In [8]:
dataset

Dataset({
    features: ['problem', 'level', 'solution', 'type', 'prompt', 'answer'],
    num_rows: 750
})

In [9]:
dataset[99]

{'problem': 'If the system of equations  \\begin{align*}\n3x+y&=a,\\\\\n2x+5y&=2a,\n\\end{align*} has a solution $(x,y)$ when $x=2$, compute $a$.',
 'level': 'Level 3',
 'solution': 'Substituting in $x=2$, we obtain the equations\n\n\\begin{align*}\ny+6&=a,\\\\\n5y+4&=2a.\n\\end{align*}\n\nMultiplying the first equation by $5$ and subtracting it from the second equation, we find\n\n$$-26=-3a\\Rightarrow a=\\boxed{\\frac{26}{3}}.$$',
 'type': 'Algebra',
 'prompt': [{'content': '你的名字是Brench-AI，是由Brench创造出的深度推理AI助手,专注于各种推理问题的解答和分析，拥有强大的推理能力和分析能力以及反思能力，可以帮助用户解决各种推理性问题。\nYour name is Brench-AI, a deep reasoning AI assistant created by Brench, focusing on the answer and analysis of various reasoning problems. You focus on the solution and analysis of various reasoning problems. At the same time, you have strong reasoning, analytical and reflective abilities, which can help users solve various reasoning problems.\nPlease respond reasoning question in the following format:\n<think>\n...\n</thi

- # Reward Function Definition

In [11]:
def correctness_reward_func(prompts,completions,answer,**kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    query = prompts[0][-1]['content']
    extracted_responses = [extract_tag_answer(r) for r in responses]
    recording_item = {
        'Question': query,
        'Answer': answer[0],
        'Response': responses[0],
        'Extracted': extracted_responses[0]
    }
    print('-'*20, f"Question:\n{query}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if res == ans else 0.0 for res, ans in zip(extracted_responses, answer)]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_cot(content) -> float:
    count = 0.0
    if content.count("<think>\n") == 1:
        count += 0.125
    if content.count("\n</think>\n") == 1:
        count += 0.125
    if content.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(content.split("\n</answer>\n")[-1])*0.001
    if content.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(content.split("\n</answer>")[-1]) - 1)*0.001
    return count

def cotcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_cot(c) for c in contents]

In [23]:
model_path = "./models/Qwen2.5-1.5B-Instruct"

output_dir = "./outputs/Qwen2.5-1.5B-R1-GRPO-DEMO"
run_name = "Qwen2.5-1.5B-R1-GRPO-DEMO-TEST"

training_args = GRPOConfig(
    output_dir=output_dir,
    run_name = run_name,
    learning_rate=5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    num_generations=4,
    max_prompt_length=1024,
    max_completion_length=8192,
    num_train_epochs=1,
    save_steps=100,
    max_grad_norm=0.1,
    log_on_each_node=False,
    use_vllm=False,
    vllm_gpu_memory_utilization=.3,
    vllm_device="cuda:0",
    report_to="wandb",
    
)

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map=None
).to("cuda")

tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

In [27]:
import wandb
wandb.init(project="Qwen2-R1-ZERO-GRPO-TEST") 
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        cotcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        correctness_reward_func],
    args=training_args,
    train_dataset=dataset,
    #peft_config=peft_config
)
trainer.train()

trainer.save_model(output_dir)



Error in callback <bound method _WandbInit._resume_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7f81ee6cb520>> (for pre_run_cell), with arguments args (<ExecutionInfo object at 7f83537dde70, raw_cell="import wandb
wandb.init(project="Qwen2-R1-ZERO-GRP.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://icube%2Bicube/mnt/bn/brench-volume-lq1/graduation_design/GRPO-R1-Training-RL/grpo_test.ipynb#X22sdnNjb2RlLXJlbW90ZQ%3D%3D>,),kwargs {}:


TypeError: _WandbInit._resume_backend() takes 1 positional argument but 2 were given

0,1
train/completion_length,▁
train/epoch,▁
train/global_step,▁
train/grad_norm,▁
train/kl,▁
train/learning_rate,▁
train/loss,▁
train/reward,▁
train/reward_std,▁
train/rewards/correctness_reward_func,▁

0,1
train/completion_length,424.0
train/epoch,0.00133
train/global_step,1.0
train/grad_norm,0.0
train/kl,0.0
train/learning_rate,0.0
train/loss,0.0
train/reward,0.0
train/reward_std,0.0
train/rewards/correctness_reward_func,0.0


wandb: ⭐️ View project at https://ml.bytedance.net/experiment/tracking/detail?Id=project_20250314_3653a3fe
wandb: 🚀 View run at https://ml.bytedance.net/experiment/tracking/detail?Id=project_20250314_3653a3fe&selectedTrial=run_20250314_2db14e9c


Detected kernel version 5.4.143, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


-------------------- Question:
Find the greatest integer value of $b$ for which the expression $\frac{9x^3+4x^2+11x+7}{x^2+bx+8}$ has a domain of all real numbers. 
Answer:
5 
Response:
To have a domain of all real numbers, the denominator must not equal zero for any real number $x$. Therefore, we need to find the values of $b$ such that the quadratic equation $x^2 + bx + 8 = 0$ has no real solutions. 

This is achieved when the discriminant $\Delta$ of the quadratic equation $x^2 + bx + 8 = 0$ is less than zero. The discriminant $\Delta$ is given by $b^2 - 4ac$, where $a = 1$, $b = b$, and $c = 8$. So, 

$$
\Delta = b^2 - 4 \cdot 1 \cdot 8 = b^2 - 32.
$$

For the quadratic equation to have no real solutions, we need $\Delta < 0$, which gives us:

$$
b^2 - 32 < 0 \Rightarrow b^2 < 32.
$$

Taking the square root of both sides, we get $|b| < \sqrt{32}$. This means $-\sqrt{32} < b < \sqrt{32}$. Since $b$ is an integer, we take the integer values within this range. Because $\sqrt{32} \appr

Step,Training Loss
1,0.0
2,0.0
3,0.0
4,0.0
5,0.0
6,0.0
7,0.0
8,0.0
9,0.0
10,0.0


-------------------- Question:
For which positive integer values of $k$ does $kx^2+20x+k=0$ have rational solutions? Express your answers separated by commas and in increasing order. 
Answer:
6, 8\text{, and }10 
Response:
To determine for which positive integer values of \( k \) the quadratic equation \( kx^2 + 20x + k = 0 \) has rational solutions, we start by using the discriminant of the quadratic equation. The discriminant \(\Delta\) for a quadratic equation \( ax^2 + bx + c = 0 \) is given by \(\Delta = b^2 - 4ac\). For the quadratic equation \( kx^2 + 20x + k = 0 \), the coefficients are \( a = k \), \( b = 20 \), and \( c = k \). The discriminant is:

\[
\Delta = 20^2 - 4 \cdot k \cdot k = 400 - 4k^2
\]

For the quadratic equation to have rational solutions, the discriminant must be a perfect square. Therefore, we need:

\[
400 - 4k^2 = m^2
\]

for some integer \( m \). Rearranging this equation gives:

\[
400 - m^2 = 4k^2 \implies (20 - m)(20 + m) = 4k^2
\]

Since \( k^2 \) is