# Train a model to schedule events using GRPO

➡️ For a complete walk-through, read [this blog post](PUTLKINK).

Once trained, the model should be able to solve problems like the following.

**Example input**

Task: create an optimized schedule based on the given events. Maximize the total weighted duration of the events.
*(For the detailed prompt, see below).*

Events:
- Event A (01:27 - 01:42)
- Event B (01:15 - 02:30)
- Event C (15:43 - 17:43)

Priorities:
- Event B

**Example output**

```xml
<think>A detailed reasoning</think>
<schedule>
<event>
<name>Event B</name>
<start>01:15</start>
<end>02:30</end>
</event>
<event>
<name>Event C</name>
<start>15:43</start>
<end>17:43</end>
</event>
</schedule>
```







In [None]:
! pip install --upgrade pip
! pip install "unsloth==2025.3.19" vllm wandb
! pip uninstall -y typing_extensions &&  pip install typing_extensions==4.11.0

In [1]:
! rm -rf outputs completion_samples

## Load the original model

In [2]:
from unsloth import FastLanguageModel

max_seq_length = 2048  # Can increase for longer reasoning traces
lora_rank = 32  # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="Qwen/Qwen2.5-Coder-7B-Instruct",
    max_seq_length=max_seq_length,
    load_in_4bit=True,  # False for LoRA 16bit
    fast_inference=True,  # Enable vLLM fast inference
    max_lora_rank=lora_rank,
    gpu_memory_utilization=0.85,  # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],  # Remove QKVO if out of memory
    lora_alpha=lora_rank,
    use_gradient_checkpointing="unsloth",  # Enable long context finetuning
    random_state=3407,
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 04-04 12:31:12 [__init__.py:239] Automatically detected platform cuda.
==((====))==  Unsloth 2025.3.19: Fast Qwen2 patching. Transformers: 4.50.3. vLLM: 0.8.2.
   \\   /|    NVIDIA RTX A6000. Num GPUs = 1. Max memory: 47.438 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/qwen2.5-coder-7b-instruct-bnb-4bit with actual GPU utilization = 84.46%
Unsloth: Your GPU has CUDA compute capability 8.6 with VRAM = 47.44 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 2048. Num Sequences = 320.
Unsloth: vLLM's KV Cache can use up to 34.09

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 04-04 12:31:32 [punica_selector.py:18] Using PunicaWrapperGPU.
INFO 04-04 12:31:32 [model_runner.py:1146] Model loading took 5.3638 GB and 5.440341 seconds
INFO 04-04 12:31:34 [worker.py:267] Memory profiling takes 2.03 seconds
INFO 04-04 12:31:34 [worker.py:267] the current vLLM instance can use total_gpu_memory (47.44GiB) x gpu_memory_utilization (0.84) = 40.07GiB
INFO 04-04 12:31:34 [worker.py:267] model weights take 5.36GiB; non_torch_memory takes 0.06GiB; PyTorch activation peak memory takes 1.76GiB; the rest of the memory reserved for KV Cache is 32.89GiB.
INFO 04-04 12:31:35 [executor_base.py:111] # cuda blocks: 38487, # CPU blocks: 7021
INFO 04-04 12:31:35 [executor_base.py:116] Maximum concurrency for 2048 tokens per request: 300.68x
INFO 04-04 12:31:40 [model_runner.py:1442] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. I

Capturing CUDA graph shapes: 100%|██████████| 43/43 [00:36<00:00,  1.19it/s]

INFO 04-04 12:32:16 [model_runner.py:1570] Graph capturing finished in 36 secs, took 0.85 GiB
INFO 04-04 12:32:16 [llm_engine.py:447] init engine (profile, create kv cache, warmup model) took 44.60 seconds



Unsloth 2025.3.19 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


## Dataset preparation

In [3]:
import datasets

SYSTEM_PROMPT = """You are a precise event scheduler.
1. First, reason through the problem inside <think> and </think> tags. Here you can create drafts, compare alternatives, and check for mistakes.
2. When confident, output the final schedule inside <schedule> and </schedule> tags. Your schedule must strictly follow the rules provided by the user."""

USER_PROMPT = """Task: create an optimized schedule based on the given events.

Rules:
- The schedule MUST be in strict chronological order. Do NOT place priority events earlier unless their actual start time is earlier.
- Event start and end times are ABSOLUTE. NEVER change, shorten, adjust, or split them.
- Priority events (weight = 2) carry more weight than normal events (weight = 1), but they MUST still respect chronological order.
- Maximize the sum of weighted event durations.
- No overlaps allowed. In conflicts, include the event with the higher weighted time.
- Some events may be excluded if needed to meet these rules.


You must use this format:  

<think>...</think>
<schedule>
<event>
<name>...</name>
<start>...</start>
<end>...</end>
</event>
...
</schedule>

---

"""

ds = datasets.load_dataset("anakin87/events-scheduling", split="train")
ds

ds = ds.map(
    lambda x: {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": USER_PROMPT + x["prompt"]},
        ]
    }
)


ds[0]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

{'events': [['Backstage tour VIP pass', '01:06', '01:21'],
  ['Sum 41 concert', '14:57', '15:27'],
  ['Korean BBQ food trucks', '16:21', '17:51'],
  ['Fireworks', '08:48', '10:48'],
  ['Jam session', '14:31', '15:31']],
 'priority_events': ['Sum 41 concert'],
 'optimal_score': 285,
 'prompt': [{'content': 'You are a highly precise event scheduler. Your goal is to create an optimized schedule following strict constraints.\n1. First, generate a detailed reasoning process inside <think> and </think> tags.  \n2. Then, generate the schedule inside <schedule> and </schedule> tags.',
   'role': 'system'},
  {'content': "Task: create an optimized schedule based on the given events.\n\nRules:\n- The schedule MUST be in chronological order.\n- Event start and end times are ABSOLUTE. Do NOT change, shorten, adjust, or split them.\n- Priority events (weight = 2) are more important than normal events (weight = 1).\n- Maximize the sum of weighted event durations.\n- No two events can overlap. If two

## Reward functions

We use 3 reward functions:
1. Format reward: ensure the output is in the correct format. (10 points)
2. Sorted events reward: ensure the events are sorted in chronological order. (20 points)
3. Score reward: ratio between the total weighted duration of the events and the optimal score computed with dynamic programming. (70 points)

In [4]:
import random, os
from datetime import datetime
import re


def minutes_to_time(minutes):
    """Convert minutes since midnight to time string.

    Args:
        minutes (int): Minutes since midnight

    Returns:
        str: Time string in "HH:MM" format
    """
    return f"{minutes // 60:02d}:{minutes % 60:02d}"


def time_to_minutes(time_str):
    """Convert time string to minutes since midnight.

    Args:
        time_str (str): Time string in "HH:MM" format

    Returns:
        int: Minutes since midnight
    """
    hours, mins = map(int, time_str.split(":"))
    return hours * 60 + mins


overall_pattern = r"<think>.+</think>.*<schedule>.*(<event>.*<name>.+</name>.*<start>\d{2}:\d{2}</start>.*<end>\d{2}:\d{2}</end>.*</event>)+.*</schedule>"
overall_regex = re.compile(overall_pattern, re.DOTALL)

capture_pattern = r"""
    <event>\s*
        <name>([^<]+)</name>\s*
        <start>(\d{2}:\d{2})</start>\s*
        <end>(\d{2}:\d{2})</end>\s*
    </event>
"""

capture_regex = re.compile(capture_pattern, re.VERBOSE)


def get_events(content):
    """Extract event information from XML-like content.

    Args:
        content (str): XML-like string containing event data

    Returns:
        list: List of tuples (name, start_time, end_time)
    """
    return [
        (match.group(1), match.group(2), match.group(3))
        for match in capture_regex.finditer(content)
    ]


def format_reward(prompts, completions, **kwargs):
    responses = [completion[0]["content"] for completion in completions]

    return [
        0.0 if not overall_regex.match(response) else 10.0 for response in responses
    ]


def score_reward(
    prompts, completions, events, priority_events, optimal_score, **kwargs
):
    scores = []
    responses = [completion[0]["content"] for completion in completions]

    for content, valid_events, priorities, opt_score in zip(
        responses, events, priority_events, optimal_score
    ):
        scheduled_events = get_events(content)

        # Get valid scheduled events
        existing_events = {
            ev for ev in scheduled_events if [ev[0], ev[1], ev[2]] in valid_events
        }

        # penalize choosing nonexistent events or less than 2 events (not a valid schedule)
        if len(existing_events) < len(scheduled_events) or len(existing_events) < 2:
            scores.append(0.0)
            continue

        # Convert to minutes
        existing_events_minutes = [
            (ev[0], time_to_minutes(ev[1]), time_to_minutes(ev[2]))
            for ev in existing_events
        ]

        # remove overlapping events and remove both events - to penalize overlaps
        overlapping_events = set()
        for j in range(len(existing_events_minutes)):
            for k in range(j + 1, len(existing_events_minutes)):
                if (
                    existing_events_minutes[j][1] <= existing_events_minutes[k][2]
                    and existing_events_minutes[j][2] >= existing_events_minutes[k][1]
                ):
                    overlapping_events.add(existing_events_minutes[j])
                    overlapping_events.add(existing_events_minutes[k])

        existing_events_minutes = [
            ev for ev in existing_events_minutes if ev not in overlapping_events
        ]

        # Calculate score
        score = sum(
            2 * (ev[2] - ev[1]) if ev[0] in priorities else ev[2] - ev[1]
            for ev in existing_events_minutes
        )

        scores.append((score / opt_score) * 70)

    # Log samples
    if any(score > 0 for score in scores) and random.random() < 0.10:
        os.makedirs("completion_samples", exist_ok=True)
        log_file = os.path.join("completion_samples", "completion_samples.txt")
        with open(log_file, "a") as f:
            f.write("\n\n==============\n")
            f.write(f"\n{datetime.now().time()}\n")
            f.write(f"{prompts[0]}\n")
            f.write(f"{scores}\n")
            f.write(f"{completions}")

    return scores


def sorted_events_reward(completions, **kwargs):
    scores = []
    responses = [completion[0]["content"] for completion in completions]

    for response in responses:
        scheduled_events = get_events(response)

        # not a valid schedule: should be discarded
        if len(scheduled_events) < 2:
            scores.append(0.0)
            continue

        scheduled_events_minutes = [
            (ev[0], time_to_minutes(ev[1]), time_to_minutes(ev[2]))
            for ev in scheduled_events
        ]

        if all(
            scheduled_events_minutes[i][1] < scheduled_events_minutes[i + 1][1]
            for i in range(len(scheduled_events_minutes) - 1)
        ):
            scores.append(20.0)
        else:
            scores.append(0)

    return scores

## Training configuration

In [None]:
tokenized_prompts = [
    tokenizer.apply_chat_template(prompt, tokenize=True, add_generation_prompt=True)
    for prompt in ds["prompt"]
]
exact_max_prompt_length = max(
    [len(tokenized_prompt) for tokenized_prompt in tokenized_prompts]
)

In [5]:
max_prompt_length = 448

new_model_id = "anakin87/qwen-scheduler-7b-grpo"


from trl import GRPOConfig, GRPOTrainer

training_args = GRPOConfig(
    learning_rate=8e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.01,
    lr_scheduler_type="cosine",
    optim="paged_adamw_8bit",
    logging_steps=1,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=1,
    num_generations=8,  # Decrease if out of memory
    max_prompt_length=max_prompt_length,
    max_completion_length=max_seq_length - max_prompt_length,
    max_grad_norm=0.1,
    report_to="wandb",
    output_dir="outputs",
    overwrite_output_dir=True,
    push_to_hub=True,
    hub_model_id=new_model_id,
    hub_strategy="every_save",
    save_strategy="steps",
    save_steps=50,
    save_total_limit=1,
    num_train_epochs=3,
)

In [None]:
import wandb

wandb.init(project="GRPO-reboost")

## Training!

In [None]:
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        format_reward,
        sorted_events_reward,
        score_reward,
    ],
    args=training_args,
    train_dataset=ds,
)
trainer.train()

[Weights & Biases report](https://wandb.ai/stefanofiorucci/GRPO-reboost/reports/Qwen-Scheduler-GRPO--VmlldzoxMjI1MTA4MA?accessToken=p9whiiwc1ourpt1ae5hcs84un0ri117ty84m3c56kogvkm5drp5hnk9tanvlvrsn)