# 1. Формулировка задачи

| Характеристика | Описание |
| --- | --- |
| **Суть задачи** | Агент получает матрицу смежности взвешенного графа, начальную и конечную вершины. Ему нужно найти кратчайший путь. |
| **Регулировка сложности** | Параметр `difficulty` (от 1 до 10) маппится на гиперпараметры графа: количество вершин (например, `N = difficulty * 2 + 3`) и плотность ребер. |
| **Верификация (Verifier)** | Классический алгоритм Дейкстры рассчитывает эталонный кратчайший путь и его стоимость. Ответ агента парсится, суммируются веса ребер предложенного маршрута, и проверяется, совпадает ли эта сумма с минимальной, а также существует ли такой путь в принципе. |

# 2. Импорты

In [38]:
%cd /content
!rm -rf HW-2_env
!git clone https://github.com/TebelevGt/HW-2_env.git
%cd HW-2_env/rl-shortest-path-agent

/content
Cloning into 'HW-2_env'...
remote: Enumerating objects: 123, done.[K
remote: Counting objects: 100% (123/123), done.[K
remote: Compressing objects: 100% (62/62), done.[K
remote: Total 123 (delta 37), reused 121 (delta 35), pack-reused 0 (from 0)[K
Receiving objects: 100% (123/123), 1.74 MiB | 10.64 MiB/s, done.
Resolving deltas: 100% (37/37), done.
/content/HW-2_env/rl-shortest-path-agent


In [39]:
from envs import PathEnv, PathVerifier
import re

from trl import GRPOConfig, GRPOTrainer
from unsloth import FastLanguageModel, is_bfloat16_supported
import torch

In [None]:
# 1. Инициализируем нашу среду
env = PathEnv()

# 2. Генерируем 2 задачки 3-го уровня сложности (будет 9 вершин)
print("=== ГЕНЕРАЦИЯ ДАННЫХ ===")
data_samples = env.generate(num_of_questions=2, difficulty=10)

for i, data in enumerate(data_samples):
    print(f"\n--- Задача {i + 1} ---")
    print("Вопрос (Промпт):\n", data.question)
    print("\nЭталонный путь (ответ Дейкстры):", data.answer)
    print("Оптимальная стоимость:", data.metadata['optimal_cost'])

# 3. Проверяем работу Verifier на первой задаче
print("\n=== ТЕСТИРОВАНИЕ VERIFIER ===")
sample = data_samples[0]
correct_path = sample.answer

# Имитируем идеальный ответ от агента (с правильными тегами)
good_llm_reply = f"<think>\nEasy path!\n</think>\n<answer>{correct_path}</answer>"
is_correct = env.verify(sample, good_llm_reply)
print(f"Проверка верного ответа: {'✅ ПРИНЯТО' if is_correct else '❌ ПРОВАЛЕНО'}")

# Имитируем галлюцинацию агента
bad_llm_reply = "<think>\nI am confused\n</think>\n<answer>0, 99, 999</answer>"
is_wrong = env.verify(sample, bad_llm_reply)
print(f"Проверка неверного ответа: {'✅ ПРИНЯТО' if is_wrong else '❌ ПРОВАЛЕНО'} (так и должно быть)")

In [None]:
env.visualize(data_samples[1])

In [None]:

max_seq_length = 1024 # Can increase for longer reasoning traces
lora_rank = 64 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen2.5-1.5B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vllm fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.9, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

In [47]:
# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore


def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

In [None]:
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    use_vllm = True, # use vllm for fast inference!
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 8, # Decrease if out of memory
    max_prompt_length = 256,
    max_completion_length = 200,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 250,
    save_steps = 250,
    max_grad_norm = 0.1,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",
)

In [44]:
import os
from envs.dataset import ShortestPathDataset

# Путь к файлу (убедитесь, что он совпадает с тем, куда вы сохраняли)
# В вашем скрипте create_benchmark_datasets это было "data/train_v1.pkl"
pickle_path = "data/train_v1.pkl"

# 1. Загружаем датасет в переменную
train_dataset = ShortestPathDataset.load(pickle_path)

In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func,
    ],
    args = training_args,
    train_dataset = train_dataset,
)
trainer.train()