<h3 align="center"></h3>

<h1 align="center">Qwen on GRPO for Graphs</h1>

---

<h1 align="center">Training a small graph extractor with RL from distilled R1 Data with Reasoning Process</h1>

This notebook is an inspired work following [this](https://colab.research.google.com/drive/1bfhs1FMLW3FGa8ydvkOZyBNxLYOu0Hev?usp=sharing) which is the further work after the [GRPO demo](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) by [will brown,](https://x.com/willccbb) training llama-1b on the gsm8k math dataset.

During exploration, we saw [Daniel Han Chen@Unsloth's work](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen2.5_(3B)-GRPO.ipynb#scrollTo=SDKIhhvN6lAF) and it was Unsloth's work(LORA) making this GPU-poor on 3B level base model possible.

We've only shown our initial works to:
- distill graph extraction data with CoT reasoning processes from DeepSeek R1.
- RL on Qwen2.5-3B base model to get cheap enough yet with decent level of 1-shot Graph Extraction by reasoning then extract.


这本笔记本是一个灵感之作，参考了[这个](https://colab.research.google.com/drive/1bfhs1FMLW3FGa8ydvkOZyBNxLYOu0Hev?usp=sharing)，这是在[GRPO演示](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb)之后的进一步工作，由[will brown,](https://x.com/willccbb)在gsm8k数学数据集上训练llama-1b。
在探索过程中，我们看到了[Daniel Han Chen@Unsloth的作品](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen2.5_(3B)-GRPO.ipynb#scrollTo=SDKIhhvN6lAF)，而且是Unsloth的作品（LORA）使得这个GPU资源匮乏的3B级基础模型成为可能。
我们仅展示了我们的初步工作：
- 从DeepSeek R1中提取图形提取数据，并使用CoT推理过程。
- 在Qwen2.5-3B基础模型上进行强化学习，以获得足够便宜但具有合理水平的一次性图形提取，通过推理然后提取。

## Setting up the models.

First we install vllm. Notice that you'll have to restart the session afterwards.

In [None]:
#%pip install vllm

Then we install trl and datasets. It has to be in this order for some reason (bug on trl if you do vllm afterwards)

In [None]:
#%pip install trl==0.14.0
#%pip install peft

## Use Unsloth's work due to GPU poor

# Introduction to Triton

[Triton](https://github.com/openai/triton) is an open-source programming language and compiler designed specifically for deep learning workloads on GPUs. It enables easier development of efficient custom GPU kernels, which are critical for optimizing deep learning operations.

In this notebook, we use Triton 3.1.0 to ensure compatibility with the Unsloth library, which helps us efficiently fine-tune large language models. Triton provides the low-level optimization needed for memory-efficient training, especially when using techniques like Low-Rank Adaptation (LoRA) and GRPO (Guided Reinforcement from Preferred Outputs) on consumer GPUs.

By leveraging Triton, we can achieve better performance when training on limited GPU resources, making it possible to fine-tune a 3B parameter model like Qwen2.5 even without enterprise-grade hardware.

## Key Benefits for Our Workflow

- **Memory Optimization**: Triton helps reduce GPU memory usage during training, which is crucial when fine-tuning large models on consumer hardware
- **Computational Efficiency**: Enables faster matrix operations needed for transformer model training
- **Compatibility**: Version 3.1.0 specifically addresses compatibility issues with Unsloth's optimized GRPO implementation
- **Parallel Processing**: Improves throughput for batch processing during reinforcement learning

By installing Triton before other packages, we ensure the optimized computation stack is in place before setting up our training pipeline with Unsloth and vLLM.

In [None]:
%pip install triton==3.1.0 # this is needed due to https://github.com/unslothai/unsloth/issues/1604

Collecting triton==3.1.0
  Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.3 kB)
Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (209.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.5/209.5 MB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: triton
Successfully installed triton-3.1.0


> Let's leverage Unsloth's work

In [None]:
%%capture
# Skip restarting message in Colab
import sys; modules = list(sys.modules.keys())
for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None

%pip install unsloth vllm
%pip install --upgrade pillow

In [None]:
from unsloth import FastLanguageModel, PatchFastRL

PatchFastRL("GRPO", FastLanguageModel)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 02-18 14:56:19 __init__.py:190] Automatically detected platform cuda.


# Qwen模型微调代码解析

这段代码展示了如何使用unsloth库高效地加载和微调Qwen2.5-3B-Instruct模型。下面是详细解析：

## 导入与配置


In [None]:
from unsloth import is_bfloat16_supported
import torch
max_seq_length = 12000 # 可增加以支持更长的推理过程
lora_rank = 64 # 秩越大，模型能力越强，但训练越慢



- `is_bfloat16_supported`：检测硬件是否支持bfloat16格式（一种在AI训练中更高效的浮点格式）
- `max_seq_length`：设置为12000，支持处理长文本和长推理链
- `lora_rank`：低秩适应(LoRA)的秩参数，影响微调的表达能力和速度平衡

## 模型加载


In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "Qwen/Qwen2.5-3B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # 16位LoRA时设为False
    fast_inference = True, # 启用vLLM加速推理
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.3, # 内存不足时可降低
)



- 使用unsloth的`FastLanguageModel`加载Qwen模型，而非标准Hugging Face接口
- `load_in_4bit=True`：启用4位量化，显著减少显存占用
- `fast_inference=True`：使用vLLM优化推理速度
- `gpu_memory_utilization=0.3`：限制GPU内存使用率为30%

## PEFT配置


In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # 推荐值: 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # 内存不足时可移除QKVO
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # 启用长文本微调支持
    random_state = 3407,
)



- 应用参数高效微调(PEFT)，使用LoRA技术
- `target_modules`：指定应用LoRA的具体层
  - `q_proj`, `k_proj`, `v_proj`, `o_proj`：注意力机制的投影层
  - `gate_proj`, `up_proj`, `down_proj`：MLP网络层
- `use_gradient_checkpointing="unsloth"`：启用梯度检查点技术，降低显存需求
- `random_state`：固定随机种子，确保实验可复现

这段代码充分利用了unsloth库提供的优化，实现对大模型的高效微调，即使在有限计算资源条件下也能进行。

In [None]:
from unsloth import is_bfloat16_supported
import torch
max_seq_length = 12000 # Can increase for longer reasoning traces
lora_rank = 64 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "Qwen/Qwen2.5-3B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.3, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

==((====))==  Unsloth 2025.2.12: Fast Qwen2 patching. Transformers: 4.49.0.
   \\   /|    GPU: NVIDIA A100-SXM4-80GB. Max memory: 79.254 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.1+cu121. CUDA: 8.0. CUDA Toolkit: 12.1. Triton: 3.1.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.28.post3. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/qwen2.5-3b-instruct-unsloth-bnb-4bit with actual GPU utilization = 29.84%
Unsloth: Your GPU has CUDA compute capability 8.0 with VRAM = 79.25 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 12000. Num Sequences = 288.
Unsloth: vLLM's KV Cache can use up to 21.23 GB. Also swap space = 6 GB.
INFO 02-18 14:56:38 config.py:542] This model supports multiple tasks: {'score', 'reward', 'embed', 'classify', 'generate'}. Defaulting to 'generate'.
Unsloth: vLLM Bitsandbytes config using k

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 02-18 14:56:43 model_runner.py:1115] Loading model weights took 2.2160 GB
INFO 02-18 14:56:43 punica_selector.py:18] Using PunicaWrapperGPU.
INFO 02-18 14:56:47 worker.py:267] Memory profiling takes 2.97 seconds
INFO 02-18 14:56:47 worker.py:267] the current vLLM instance can use total_gpu_memory (79.25GiB) x gpu_memory_utilization (0.30) = 23.65GiB
INFO 02-18 14:56:47 worker.py:267] model weights take 2.22GiB; non_torch_memory takes 0.09GiB; PyTorch activation peak memory takes 1.61GiB; the rest of the memory reserved for KV Cache is 19.73GiB.
INFO 02-18 14:56:47 executor_base.py:110] # CUDA blocks: 35923, # CPU blocks: 10922
INFO 02-18 14:56:47 executor_base.py:115] Maximum concurrency for 12000 tokens per request: 47.90x
INFO 02-18 14:56:52 model_runner.py:1434] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory erro

Capturing CUDA graph shapes: 100%|██████████| 39/39 [01:02<00:00,  1.60s/it]

INFO 02-18 14:57:54 model_runner.py:1562] Graph capturing finished in 62 secs, took 0.92 GiB
INFO 02-18 14:57:54 llm_engine.py:431] init engine (profile, create kv cache, warmup model) took 70.81 seconds



Unsloth 2025.2.12 patched 36 layers with 36 QKV layers, 36 O layers and 36 MLP layers.


## Defining the RL rewards

Now we have everything ready to set up our RL training set and reward policy.

First we set the general prompt structure (with the reasoning tags).

In [None]:
import re
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer

from tenacity import retry, stop_after_attempt, wait_exponential


Now we import the dataset we generated and filter out bad cases.

In [None]:
import pandas as pd
import json
import re


df = pd.read_csv("polished_rl_training_data.csv")
dataset_df = df[(df['bad_reasoning'] == False) & (df['bad_extraction'] == False)]
data = dataset_df.to_dict(orient='records')


def extract_json_from_answer(api_output: str) -> str:
    """
    Extract the JSON content inside <answer> ... </answer> tags.
    Returns an empty string if not found or if JSON is invalid.
    """
    match = re.search(r"<answer>(.*?)</answer>", api_output, re.DOTALL)
    if match:
        json_str = match.group(1).strip()
        # Validate JSON before returning
        try:
            json.loads(json_str)  # Test if parseable
            return json_str
        except json.JSONDecodeError:
            return ""
    return ""


dataset = []


SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

PROMPT_PREFIX = """\
You are an expert in information extraction and knowledge graph creation. Your task is to:
1. Extract key entities and their relationships from the given text
2. Create a structured knowledge graph
"""

for row in data:
    # Use the prompt column from your CSV as the user prompt.
    user_prompt = PROMPT_PREFIX + row["prompt"]
    # Use the thinking_process and api_output to create the ground-truth answer.
    reasoning = row["thinking_process"]
    final_answer = extract_json_from_answer(row["api_output"])

    # Format the ground-truth answer using the XML_COT_FORMAT.
    ground_truth = XML_COT_FORMAT.format(reasoning=reasoning, answer=final_answer)

    # Create the training example with a list of messages.
    dataset.append({
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": user_prompt}
        ],
        "answer": ground_truth
    })

We reuse some of the formating related reward functions.


# 构建RL奖励函数指南

## 奖励机制

- 严格格式奖励 (strict_format_reward_func)

奖励: 完全符合严格格式可获得0.5分
条件: 必须精确匹配"\<reasoning\>新行...新行\</reasoning\>新行\<answer\>新行...新行\</answer\>新行"模式

- 宽松格式奖励 (soft_format_reward_func)

奖励: 符合基本格式要求可获得0.5分
条件: 只要包含"\<reasoning\>...\</reasoning\>\<answer\>...\</answer\>"结构即可，空白字符要求更宽松

- 组件累积奖励 (count_xml/xmlcount_reward_func)

小奖励1: 正确使用\<reasoning\>\n开始标签 +0.125分

小奖励2: 正确使用\n\</reasoning\>\n结束标签 +0.125分

小奖励3: 正确使用\n\<answer\>\n开始标签 +0.125分

小奖励4: 正确使用\n\</answer\>结束标签 +0.125分


## 奖励函数基本原理

奖励函数是强化学习的核心，它定义了模型行为的"好坏"，引导模型朝着期望的方向优化。您提供的代码展示了三种不同类型的奖励函数设计策略：

## 1. 二元奖励策略



In [None]:
def strict_format_reward_func(completions, **kwargs) -> list[float]:
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]



**特点**：
- 基于严格的正则表达式精确匹配
- 要么得到完整奖励(0.5)，要么得到零奖励
- 优点：明确的学习信号
- 缺点：可能导致训练不稳定，没有中间状态

## 2. 宽松匹配策略



In [None]:
def soft_format_reward_func(completions, **kwargs) -> list[float]:
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    # ...同样返回二元奖励但匹配条件更宽松



**特点**：
- 使用更宽松的匹配条件
- 允许标签间有更灵活的空白字符
- 适用于：需要保持格式但允许一定变化的场景

## 3. 组件累积奖励策略



In [None]:
def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    # ...为每个正确组件增加部分奖励



**特点**：
- 将奖励分解为多个小组件
- 为每个正确的格式元素提供部分奖励
- 包含惩罚机制（减去多余内容的分数）
- 提供更细粒度的学习信号

## 构建有效奖励函数的关键原则

1. **信号清晰性**：奖励应明确指向目标行为
2. **梯度性**：考虑使用连续而非二元奖励，提供改进方向
3. **分解复杂目标**：将复杂目标分解为可测量的子目标
4. **惩罚机制**：适当添加负面奖励阻止不良行为（如`count_xml`中对多余文本的惩罚）
5. **规范化**：确保奖励值在合理范围内（例如代码中奖励范围控制在0-0.5）

## 实际应用建议

- 结合多种奖励函数（如示例代码可能最终会组合使用）
- 平衡探索与利用（奖励太严格会阻碍模型探索）
- 避免奖励欺骗（模型可能找到取巧方式获得高奖励而不达成真正目标）
- 考虑引入人类反馈（RLHF）或参考模型（如DPO）

这些代码示例展示了在语言模型训练中构建格式约束奖励函数的实用方法，可以根据您的具体需求进行调整和扩展。

In [None]:
# Reward functions

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

In [None]:
import os
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

True

从完整RL训练循环看，此代码属于核心奖励计算环节：

模型生成候选回答 → 2. 奖励函数评估(本代码) → 3. 更新模型策略

In [None]:
import re
import json
import os
import pytest


@retry(
    stop=stop_after_attempt(3),
    wait=wait_exponential(multiplier=1, min=4, max=10),
    reraise=True
)
def call_llm_with_retry(client, messages, model="gpt-4o-mini", **kwargs):
    """Wrapper for LLM API calls with retry logic"""
    try:
        completion = client.chat.completions.create(
            model=model,
            messages=messages,
            **kwargs
        )
        return completion
    except Exception as e:
        print(f"[debug] LLM API call failed: {str(e)}")
        raise


def extract_tag_content(text: str, tag: str) -> str:
    """
    Extracts content from the given text that is wrapped in <tag> ... </tag> tags.
    Returns an empty string if not found.
    """
    pattern = rf"<{tag}>(.*?)</{tag}>"
    match = re.search(pattern, text, re.DOTALL)
    return match.group(1).strip() if match else ""


def extract_json_from_answer(api_output: str) -> str:
    """
    Extract the JSON content inside <answer> ... </answer> tags.
    Returns an empty string if not found or if JSON is invalid.
    """
    match = re.search(r"<answer>(.*?)</answer>", api_output, re.DOTALL)
    if match:
        json_str = match.group(1).strip()
        # Validate JSON before returning
        try:
            json.loads(json_str)  # Test if parseable
            return json_str
        except json.JSONDecodeError:
            return ""
    return ""


def llm_compare_extracted_graph(extracted_graph: str, ground_truth: str) -> float:
    """
    Uses a remote LLM to compare the extracted graph and the ground truth.
    Returns 0.0 if either input is invalid JSON.
    """
    try:
        # Validate JSON first
        comp_json = json.loads(extracted_graph)
        truth_json = json.loads(ground_truth)
        print(f"[debug] Valid JSON detected for both graphs")
        print(f"[debug] Completion nodes: {len(comp_json.get('nodes', []))}")
        print(f"[debug] Completion edges: {len(comp_json.get('edges', []))}")

    except json.JSONDecodeError as e:
        print(f"[debug] JSON validation failed: {e}")
        return 0.0

    prompt = (
        "You are evaluating knowledge graph extraction results. Compare the extracted graph with the ground truth graph.\n\n"
        "Scoring Guidelines:\n"
        "2.0 - Perfect or near-perfect match with valid structure\n"
        "1.5 - Good match with valid structure and reasonable relationships\n"
        "1.0 - Valid structure but some relationship issues\n"
        "0.5 - Valid nodes but relationship issues\n"
        "0.0 - Invalid structure or completely wrong\n\n"

        "Key Points to Check:\n"
        "1. Is the JSON structure valid? (nodes and edges arrays)\n"
        "2. Do nodes have valid id, label, and type?\n"
        "3. Do edges have valid source, target, and relation?\n"
        "4. Are the relationships meaningful?\n\n"

        f"Extracted Graph (evaluate this):\n{extracted_graph}\n\n"
        f"Ground Truth Graph:\n{ground_truth}\n\n"

        "Output ONLY a number (2.0, 1.5, 1.0, 0.5, or 0.0) based on the quality of the extracted graph."
    )

    try:
        if os.getenv("AZURE_OPENAI_API_KEY"):
            from openai import AzureOpenAI
            client = AzureOpenAI(
                api_key=os.getenv("AZURE_OPENAI_API_KEY"),
                api_version="2024-02-15-preview",
                azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT") or "",
            )
            model = "gpt-4o-mini"  # or your Azure deployment name
        else:
            from openai import OpenAI
            client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
            model = "gpt-4o-mini"

        print(f"[debug] Calling LLM with model: {model}")

        completion = call_llm_with_retry(
            client,
            messages=[{"role": "user", "content": prompt}],
            model=model,
            temperature=0.1,
            max_tokens=10
        )

        score_text = completion.choices[0].message.content.strip()
        print(f"[debug] Raw LLM response: '{score_text}'")

        # More robust score parsing
        score_match = re.findall(r"[0-2]\.?[05]?", score_text)
        if score_match:
            score = float(score_match[0])
            print(f"[debug] Successfully parsed score: {score}")
            return score
        else:
            print(f"[debug] No valid score found in response: '{score_text}'")
            return 0.0

    except Exception as e:
        print(f"[debug] LLM call failed with error type {type(e)}: {str(e)}")
        # Add basic fallback scoring
        return basic_graph_score(extracted_graph, ground_truth)


def basic_graph_score(completion_graph: str, ground_truth: str) -> float:
    """Fallback scoring when LLM fails"""
    try:
        comp_json = json.loads(completion_graph)
        truth_json = json.loads(ground_truth)

        score = 0.0

        # Check basic structure
        if 'nodes' in comp_json and 'edges' in comp_json:
            score += 0.5

        # Check node structure
        if comp_json['nodes'] and all('id' in n and 'label' in n and 'type' in n for n in comp_json['nodes']):
            score += 0.5

        # Check edge structure
        if comp_json['edges'] and all('source' in e and 'target' in e and 'relation' in e for e in comp_json['edges']):
            score += 0.5

        print(f"[debug] Fallback scoring result: {score}")
        return min(score, 2.0)

    except Exception as e:
        print(f"[debug] Fallback scoring failed: {str(e)}")
        return 0.0


def graph_correctness_reward_func(completions, answer, **kwargs) -> list[float]:
    """
    Evaluates graph correctness against ground truth graph from answer.
    Returns a reward score between 0 and 2 for each example.
    """
    rewards = []
    for comp, ans in zip(completions, answer):
        print("[debug] calling graph_correctness_reward_func")
        # Extract graph JSON from completion
        completion_graph = extract_tag_content(comp[0]["content"], "answer")
        print(f"[debug] gpt completion_graph:{completion_graph}")
        # Extract graph JSON from answer
        ground_truth_graph = extract_tag_content(ans, "answer")

        if not completion_graph or not ground_truth_graph:
            content = comp[0]["content"]
            print(f"[debug] empty completion:\n\n----\n{content}\n\n----\n")
            rewards.append(0.0)
        else:
            score = llm_compare_extracted_graph(completion_graph, ground_truth_graph)
            rewards.append(score)
    print(f"rewards: {rewards}")
    return rewards

def reasoning_reward_func(completions, answer, **kwargs) -> list[float]:
    """
    Evaluates reasoning quality against ground truth reasoning from answer.
    Returns a reward score between 0 and 1 for each example.
    """
    rewards = []
    for comp, ans in zip(completions, answer):
        print("calling reasoning_reward_func")
        # Extract reasoning from completion
        completion_reasoning = extract_tag_content(comp[0]["content"], "reasoning")
        # Extract reasoning from answer
        ground_truth_reasoning = extract_tag_content(ans, "reasoning")

        if not completion_reasoning:
            rewards.append(0.0)
        else:
            score = llm_reasoning_score(completion_reasoning, ground_truth_reasoning)
            rewards.append(score)
    print(f"rewards: {rewards}")
    return rewards


def llm_reasoning_score(reasoning: str, ground_truth: str) -> float:
    """
    Uses a remote LLM to evaluate reasoning quality against ground truth.
    """
    prompt = (
        "You are evaluating the quality of reasoning in knowledge graph extraction. "
        "Compare the reasoning against the ground truth and score based on these criteria:\n\n"

        "Scoring Guide:\n"
        "1.0: Excellent - Systematic analysis that correctly identifies all elements from ground truth\n"
        "0.8: Good - Correct analysis but missing minor details from ground truth\n"
        "0.5: Basic - Identifies main elements but missing significant details\n"
        "0.2: Poor - Very incomplete or vague compared to ground truth\n"
        "0.0: Invalid - Empty, irrelevant, or contradicts ground truth\n\n"

        "Ground Truth:\n" + ground_truth + "\n\n"

        "Reasoning to evaluate:\n" + reasoning + "\n\n"

        "Output only a single number (1.0, 0.8, 0.5, 0.2, or 0.0) based on how well the reasoning matches the ground truth."
    )
    try:
        if os.getenv("AZURE_OPENAI_API_KEY"):
            from openai import AzureOpenAI

            client = AzureOpenAI(
                api_key=os.getenv("AZURE_OPENAI_API_KEY"),
                api_version="2024-02-15-preview",
                azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT") or "",
            )
        else:
            from openai import OpenAI
            client = OpenAI(
                api_key=os.getenv("OPENAI_API_KEY")
            )

        completion = client.chat.completions.create(
            model="gpt-4o-mini",  # adjust as needed
            messages=[{"role": "user", "content": prompt}],
        )
        score_text = completion.choices[0].message.content
        if score_text:
            score_text = score_text.strip()
            match = re.search(r"([\d\.]+)", score_text)
            if match:
                return float(match.group(1))
        return 0.0  # Return 0.0 if no valid score found in response
    except Exception as e:
        print("Remote LLM call for reasoning evaluation failed:", e)
        return 0.0


#### TEST CASES of graph correctness reward function ####

TEST_CASES_GRAPH_CORRECTNESS = [
    {
        "name": "Perfect match",
        "completions": [[{
            "role": "assistant",
            "content": "<answer>{\"nodes\": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\"}, {\"id\": 2, \"label\": \"Tech Inc\", \"type\": \"Organization\"}], \"edges\": [{\"source\": 1, \"target\": 2, \"relation\": \"acquired\"}]}</answer>"
        }]],
        "answer": ['<answer>{"nodes": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\"}, {\"id\": 2, \"label\": \"Tech Inc\", \"type\": \"Organization\"}], "edges": [{\"source\": 1, \"target\": 2, \"relation\": \"acquired\"}]}</answer>'],
        "expected_score": 2.0
    },
    {
        "name": "Very close match (extra valid node)",
        "completions": [[{
            "role": "assistant",
            "content": "<answer>{\"nodes\": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\"}, {\"id\": 2, \"label\": \"Tech Inc\", \"type\": \"Organization\"}, {\"id\": 3, \"label\": \"John Doe\", \"type\": \"Person\"}], \"edges\": [{\"source\": 1, \"target\": 2, \"relation\": \"acquired\"}]}</answer>"
        }]],
        "answer": ['<answer>{"nodes": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\"}, {\"id\": 2, \"label\": \"Tech Inc\", \"type\": \"Organization\"}], "edges": [{\"source\": 1, \"target\": 2, \"relation\": \"acquired\"}]}</answer>'],
        "expected_score": 1.5
    },
    {
        "name": "Partial match (different relationship)",
        "completions": [[{
            "role": "assistant",
            "content": "<answer>{\"nodes\": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\"}, {\"id\": 2, \"label\": \"Tech Inc\", \"type\": \"Organization\"}], \"edges\": [{\"source\": 1, \"target\": 2, \"relation\": \"owns\"}]}</answer>"
        }]],
        "answer": ['<answer>{"nodes": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\"}, {\"id\": 2, \"label\": \"Tech Inc\", \"type\": \"Organization\"}], "edges": [{\"source\": 1, \"target\": 2, \"relation\": \"acquired\"}]}</answer>'],
        "expected_score": 1.0
    },
    {
        "name": "Poor match (wrong direction)",
        "completions": [[{
            "role": "assistant",
            "content": "<answer>{\"nodes\": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\"}, {\"id\": 2, \"label\": \"Tech Inc\", \"type\": \"Organization\"}], \"edges\": [{\"source\": 2, \"target\": 1, \"relation\": \"acquired\"}]}</answer>"
        }]],
        "answer": ['<answer>{"nodes": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\"}, {\"id\": 2, \"label\": \"Tech Inc\", \"type\": \"Organization\"}], "edges": [{\"source\": 1, \"target\": 2, \"relation\": \"acquired\"}]}</answer>'],
        "expected_score": 0.5
    },
    {
        "name": "Malformed JSON",
        "completions": [[{
            "role": "assistant",
            "content": "<answer>{\"nodes\": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\" \"edges\": []}</answer>"
        }]],
        "answer": ['<answer>{"nodes": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\"}], "edges": []}</answer>'],
        "expected_score": 0.0
    },
    {
        "name": "No answer tag",
        "completions": [[{
            "role": "assistant",
            "content": "Some text without proper tags"
        }]],
        "answer": ['<answer>{"nodes": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\"}], "edges": []}</answer>'],
        "expected_score": 0.0
    },
    {
        "name": "Empty vs non-empty",
        "completions": [[{
            "role": "assistant",
            "content": "<answer>{\"nodes\": [], \"edges\": []}</answer>"
        }]],
        "answer": ['<answer>{"nodes": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\"}, {\"id\": 2, \"label\": \"Tech Inc\", \"type\": \"Organization\"}], "edges": [{\"source\": 1, \"target\": 2, \"relation\": \"acquired\"}]}</answer>'],
        "expected_score": 0.0
    },
    {
        "name": "Contradicting node type",
        "completions": [[{
            "role": "assistant",
            "content": "<answer>{\"nodes\": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Person\"}, {\"id\": 2, \"label\": \"Tech Inc\", \"type\": \"Organization\"}], \"edges\": [{\"source\": 1, \"target\": 2, \"relation\": \"acquired\"}]}</answer>"
        }]],
        "answer": ['<answer>{"nodes": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\"}, {\"id\": 2, \"label\": \"Tech Inc\", \"type\": \"Organization\"}], "edges": [{\"source\": 1, \"target\": 2, \"relation\": \"acquired\"}]}</answer>'],
        "expected_score": 0.0
    },
    {
        "name": "Multiple errors (direction + relation)",
        "completions": [[{
            "role": "assistant",
            "content": "<answer>{\"nodes\": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\"}, {\"id\": 2, \"label\": \"Tech Inc\", \"type\": \"Organization\"}], \"edges\": [{\"source\": 2, \"target\": 1, \"relation\": \"owns\"}]}</answer>"
        }]],
        "answer": ['<answer>{"nodes": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\"}, {\"id\": 2, \"label\": \"Tech Inc\", \"type\": \"Organization\"}], "edges": [{\"source\": 1, \"target\": 2, \"relation\": \"acquired\"}]}</answer>'],
        "expected_score": 0.0,
        "error_types": ["wrong_direction", "wrong_relation"]
    },
    {
        "name": "Single relationship type error",
        "completions": [[{
            "role": "assistant",
            "content": "<answer>{\"nodes\": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\"}, {\"id\": 2, \"label\": \"Tech Inc\", \"type\": \"Organization\"}], \"edges\": [{\"source\": 1, \"target\": 2, \"relation\": \"owns\"}]}</answer>"
        }]],
        "answer": ['<answer>{"nodes": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\"}, {\"id\": 2, \"label\": \"Tech Inc\", \"type\": \"Organization\"}], "edges": [{\"source\": 1, \"target\": 2, \"relation\": \"acquired\"}]}</answer>'],
        "expected_score": 1.0
    },
    {
        "name": "Missing half nodes",
        "completions": [[{
            "role": "assistant",
            "content": "<answer>{\"nodes\": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\"}], \"edges\": []}</answer>"
        }]],
        "answer": ['<answer>{"nodes": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\"}, {\"id\": 2, \"label\": \"Tech Inc\", \"type\": \"Organization\"}], "edges": [{\"source\": 1, \"target\": 2, \"relation\": \"acquired\"}]}</answer>'],
        "expected_score": 0.0
    }
]


@pytest.mark.parametrize("test_case", TEST_CASES_GRAPH_CORRECTNESS, ids=[case["name"] for case in TEST_CASES_GRAPH_CORRECTNESS])
def test_graph_correctness_scores(test_case):
    """Test graph correctness reward function with various cases."""
    prompts = [[{"role": "user", "content": "Dummy prompt"}]]
    completions = test_case["completions"]
    answer = test_case["answer"]

    scores = graph_correctness_reward_func(prompts, answer)
    assert len(scores) == 1, "Expected single score"
    assert scores[0] == pytest.approx(test_case["expected_score"], abs=0.1)


def test_extract_json_from_answer():
    """Test JSON extraction from answer tags."""
    # Valid JSON
    valid_input = "<answer>{\"test\": true}</answer>"
    assert extract_json_from_answer(valid_input) == "{\"test\": true}"

    # Invalid JSON
    invalid_input = "<answer>{invalid json}</answer>"
    assert extract_json_from_answer(invalid_input) == ""

    # No tags
    no_tags = "{\"test\": true}"
    assert extract_json_from_answer(no_tags) == ""

    # Empty tags
    empty_tags = "<answer></answer>"
    assert extract_json_from_answer(empty_tags) == ""


def test_llm_compare_extracted_graph():
    """Test LLM comparison function."""
    # Valid JSON comparison
    valid_json1 = '{"nodes": [], "edges": []}'
    valid_json2 = '{"nodes": [], "edges": []}'
    assert llm_compare_extracted_graph(valid_json1, valid_json2) >= 0.0
    assert llm_compare_extracted_graph(valid_json1, valid_json2) <= 2.0

    # Invalid JSON handling
    invalid_json = '{invalid}'
    assert llm_compare_extracted_graph(invalid_json, valid_json1) == 0.0
    assert llm_compare_extracted_graph(valid_json1, invalid_json) == 0.0


@pytest.mark.parametrize("score", [0.0, 0.5, 1.0, 1.5, 2.0])
def test_score_ranges(score):
    """Test that scores fall within valid ranges."""
    assert score >= 0.0 and score <= 2.0, "Score must be between 0 and 2"
    assert score % 0.5 == 0, "Score must be in increments of 0.5"


TEST_CASES_REASONING = [
    {
        "name": "Perfect reasoning (1.0)",
        "completions": [[{
            "role": "assistant",
            "content": "<reasoning>I'll analyze the text systematically:\n\n1. First, identify key entities:\n- Acme Corp (Organization)\n- Tech Inc (Organization)\n\n2. Analyze relationships:\n- The text indicates an acquisition relationship\n- Acme Corp is the acquirer\n- Tech Inc is the acquired company\n\n3. Structure verification:\n- Both entities are organizations (correct type)\n- Relationship direction is from acquirer to acquired\n- 'acquired' is the appropriate relationship type\n\nThis forms a clear acquisition relationship between two organizations.</reasoning><answer>{...}</answer>"
        }]],
        "answer": [
            "<reasoning>I'll analyze the text systematically:\n\n1. Entity identification:\n- Acme Corp (Organization)\n- Tech Inc (Organization)\n\n2. Relationship analysis:\n- Acquisition relationship identified\n- Acme Corp is acquirer\n- Tech Inc is acquired\n\n3. Verification:\n- Entity types confirmed\n- Relationship direction verified\n- Relationship type confirmed\n\nThis represents a clear business acquisition.</reasoning><answer>{\"nodes\": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\"}, {\"id\": 2, \"label\": \"Tech Inc\", \"type\": \"Organization\"}], \"edges\": [{\"source\": 1, \"target\": 2, \"relation\": \"acquired\"}]}</answer>"
        ],
        "expected_score": 1.0
    },
    {
        "name": "Good reasoning with minor omissions (0.8)",
        "completions": [[{
            "role": "assistant",
            "content": "<reasoning>Analyzing the text:\n1. Entity identification:\n- Acme Corp and Tech Inc are organizations\n\n2. Relationship analysis:\n- Acme Corp acquired Tech Inc\n\nThe graph should represent this acquisition.</reasoning><answer>{...}</answer>"
        }]],
        "answer": [
            "<reasoning>1. Entities:\n- Acme Corp (Organization)\n- Tech Inc (Organization)\n\n2. Relationships:\n- Acquisition relationship\n- Direction: Acme Corp → Tech Inc\n\n3. Verification:\n- Types and relationships confirmed\n- Transaction details noted</reasoning><answer>{\"nodes\": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\"}, {\"id\": 2, \"label\": \"Tech Inc\", \"type\": \"Organization\"}], \"edges\": [{\"source\": 1, \"target\": 2, \"relation\": \"acquired\"}]}</answer>"
        ],
        "expected_score": 0.8
    },
    {
        "name": "Basic reasoning (0.5)",
        "completions": [[{
            "role": "assistant",
            "content": "<reasoning>Found that Acme Corp bought Tech Inc. They are both companies.</reasoning><answer>{...}</answer>"
        }]],
        "answer": [
            "<reasoning>1. Entities:\n- Acme Corp (Organization)\n- Tech Inc (Organization)\n\n2. Relationship:\n- Acquisition between companies\n- Full ownership transfer involved\n- Transaction completed last quarter</reasoning><answer>{\"nodes\": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\"}, {\"id\": 2, \"label\": \"Tech Inc\", \"type\": \"Organization\"}], \"edges\": [{\"source\": 1, \"target\": 2, \"relation\": \"acquired\"}]}</answer>"
        ],
        "expected_score": 0.5
    },
    {
        "name": "Poor reasoning (0.2)",
        "completions": [[{
            "role": "assistant",
            "content": "<reasoning>Some business activity happened between companies.</reasoning><answer>{...}</answer>"
        }]],
        "answer": [
            "<reasoning>1. Entities:\n- Acme Corp (Organization)\n- Tech Inc (Organization)\n\n2. Relationship:\n- Business acquisition occurred\n- Clear direction of acquisition</reasoning><answer>{\"nodes\": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\"}, {\"id\": 2, \"label\": \"Tech Inc\", \"type\": \"Organization\"}], \"edges\": [{\"source\": 1, \"target\": 2, \"relation\": \"acquired\"}]}</answer>"
        ],
        "expected_score": 0.2
    },
    {
        "name": "Incorrect reasoning (0.0)",
        "completions": [[{
            "role": "assistant",
            "content": "<reasoning>Tech Inc has acquired Acme Corp.</reasoning><answer>{...}</answer>"
        }]],
        "answer": [
            "<reasoning>1. Entities:\n- Acme Corp (Organization)\n- Tech Inc (Organization)\n\n2. Relationship:\n- Acme Corp acquired Tech Inc\n- Direction is important</reasoning><answer>{\"nodes\": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\"}, {\"id\": 2, \"label\": \"Tech Inc\", \"type\": \"Organization\"}], \"edges\": [{\"source\": 1, \"target\": 2, \"relation\": \"acquired\"}]}</answer>"
        ],
        "expected_score": 0.0
    },
    {
        "name": "Missing reasoning (0.0)",
        "completions": [[{
            "role": "assistant",
            "content": "<answer>{...}</answer>"
        }]],
        "answer": [
            "<reasoning>1. Entities:\n- Acme Corp (Organization)\n- Tech Inc (Organization)\n\n2. Relationship:\n- Clear acquisition relationship\n- Acme Corp is acquirer</reasoning><answer>{\"nodes\": [{\"id\": 1, \"label\": \"Acme Corp\", \"type\": \"Organization\"}, {\"id\": 2, \"label\": \"Tech Inc\", \"type\": \"Organization\"}], \"edges\": [{\"source\": 1, \"target\": 2, \"relation\": \"acquired\"}]}</answer>"
        ],
        "expected_score": 0.0
    }
]


def test_reasoning_scores():
    """Test reasoning reward function with various cases."""
    print("\nTesting reasoning scores:")
    for test_case in TEST_CASES_REASONING:
        print(f"\nTest case: {test_case['name']}")
        scores = reasoning_reward_func(test_case["completions"], test_case["answer"])
        print(f"Expected score: {test_case['expected_score']}")
        print(f"Actual score: {scores[0]}")
        assert abs(scores[0] - test_case['expected_score']) <= 0.3, \
            f"Score {scores[0]} too far from expected {test_case['expected_score']}"


def run_all_tests():
    """Run all tests directly without pytest"""
    print("Running all tests...\n")

    # Run graph correctness tests
    print("Testing graph correctness scores:")
    for test_case in TEST_CASES_GRAPH_CORRECTNESS:
        print(f"\nTest case: {test_case['name']}")
        completions = test_case["completions"]
        answer = test_case["answer"]

        scores = graph_correctness_reward_func(completions, answer)
        print(f"Expected score: {test_case['expected_score']}")
        print(f"Actual score: {scores[0]}")

    # Run JSON extraction tests
    print("\nTesting JSON extraction:")
    test_extract_json_from_answer()
    print("JSON extraction tests completed")

    # Run LLM compare tests
    print("\nTesting remote LLM compare:")
    test_llm_compare_extracted_graph()
    print("Remote LLM compare tests completed")

    # Run score range tests
    print("\nTesting score ranges:")
    for score in [0.0, 0.5, 1.0, 1.5, 2.0]:
        test_score_ranges(score)
    print("Score range tests completed")

    # reasoning tests
    test_reasoning_scores()

# uncomment to run tests/tune the LLM involved reward func
#run_all_tests()


## Train Without Unsloth

We now set the training arguments:

In [None]:
# model_name = "Qwen/Qwen2.5-3b-Instruct"

# output_dir="outputs/Qwen-3b-GRPO-graph-extraction"
# run_name="Qwen-GRPO-graph-reasoning-extraction"


# training_args = GRPOConfig(
#     output_dir=output_dir,
#     run_name=run_name,
#     learning_rate=5e-6,
#     adam_beta1=0.9,
#     adam_beta2=0.99,
#     weight_decay=0.1,
#     warmup_ratio=0.1,
#     lr_scheduler_type='cosine',
#     logging_steps=1,
#     bf16=True,
#     per_device_train_batch_size=2,
#     gradient_accumulation_steps=4,
#     num_generations=2,
#     max_prompt_length=2040,
#     max_completion_length=2040,
#     num_train_epochs=1,
#     save_steps=100,
#     max_grad_norm=0.1,
#     log_on_each_node=False,
#     use_vllm=True,
#     vllm_gpu_memory_utilization=0.5,
#     vllm_device="cuda:1",
#     # vllm_max_model_len=7000, 0.15.0+
#     # report_to="none",
#     # log_level="debug", # debug
#     report_to="none", #I'm disabling Wandb.
# )

# model = AutoModelForCausalLM.from_pretrained(
#     model_name,
#     torch_dtype=torch.bfloat16,
#     device_map=None
# ).to("cuda")

# tokenizer = AutoTokenizer.from_pretrained(model_name)
# tokenizer.pad_token = tokenizer.eos_token

## Train With Unsloth


In [None]:
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    use_vllm = True, # use vLLM for fast inference!
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    logging_steps = 1,
    bf16 = is_bfloat16_supported(),
    fp16 = not is_bfloat16_supported(),
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 4, # Increase to 4 for smoother training
    num_generations = 2, # Decrease if out of memory
    max_prompt_length=9000,
    max_completion_length=9100,
    num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 250,
    save_steps = 100,
    max_grad_norm = 0.1,
    #report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",
)

Unsloth: We know expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 2


In [None]:
%pip install wandb -qU
import wandb

In [None]:
wandb.login()

## Train

And launch the actual training:

In [None]:
from tqdm.auto import tqdm
import time

print("Dataset size:", len(dataset))
print("Starting training...")

# use peft at your own risk; not working for me with multi-GPU training
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        # int_reward_func,
        # correctness_reward_func,
        graph_correctness_reward_func,
        reasoning_reward_func,
        ],
    args=training_args,
    train_dataset=dataset,
    #peft_config=peft_config
)
trainer.train()

# try:
#     start_time = time.time()
#     trainer.train()
#     print(f"Training completed in {time.time() - start_time:.2f} seconds")
# except Exception as e:
#     print(f"Error during training: {str(e)}")
#     raise

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Dataset size: 59
Starting training...


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 59 | Num Epochs = 18
O^O/ \_/ \    Batch size per device = 2 | Gradient Accumulation steps = 4
\        /    Total batch size = 8 | Total steps = 250
 "-____-"     Number of trainable parameters = 119,734,272
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


[debug] calling graph_correctness_reward_func
[debug] gpt completion_graph:{"nodes": [
    {"id": 1, "label": "Cresomycin", "type": "Antibiotic"},
    {"id": 2, "label": "Red light", "type": "Light therapy"},
    {"id": 3, "label": "Omalizumab", "type": "Drug"},
    {"id": 4, "label": "Donor heart", "type": "Medical procedure"},
    {"id": 5, "label": "Headgear", "type": "Medical device"},
    {"id": 6, "label": "PFAS", "type": "Chemical"},
    {"id": 7, "label": "Water scarcity", "type": "Environmental impact"},
    {"id": 8, "label": "Human papillomavirus (HPV)", "type": "Infection"},
    {"id": 9, "label": "ATM overturning circulation (AMOC)", "type": "Climate phenomenon"},
    {"id": 10, "label": "H5N1", "type": "Virus"},
    {"id": 11, "label": "Amazon rainforest", "type": "Ecosystem"},
    {"id": 12, "label": "Chlormequat", "type": "Chemical"},
    {"id": 13, "label": "Niacin", "type": "Vitamin"},
    {"id": 14, "label": "Puberty blockers", "type": "Medical treatment"},
    {"id"

Step,Training Loss,reward,reward_std,completion_length,kl,rewards / xmlcount_reward_func,rewards / soft_format_reward_func,rewards / strict_format_reward_func,rewards / graph_correctness_reward_func,rewards / reasoning_reward_func
1,-0.0,-2.40875,2.135109,1308.75,0.0,-3.50875,0.0,0.0,0.375,0.725
2,-0.0,-3.3415,1.513562,1606.75,0.0,-4.3415,0.0,0.0,0.3125,0.6875
3,0.0,-1.389625,0.636926,789.75,0.000517,-2.302125,0.0,0.0,0.25,0.6625
4,0.0,-1.7365,1.25264,1211.75,0.000455,-2.849,0.0,0.0,0.3125,0.8
5,0.0,0.00225,0.403404,399.5,0.000614,-0.87275,0.0,0.0,0.375,0.5
6,0.0,-0.37925,0.923128,641.375,0.000727,-1.67925,0.0,0.0,0.5,0.8
7,3046679552.0,-2.118,1.287288,1252.375,76166987776.00044,-3.218,0.0,0.0,0.25,0.85
8,0.0,-1.6665,1.185818,946.75,0.000525,-2.554,0.0,0.0,0.25,0.6375
9,0.0,-0.953375,0.798147,795.25,0.000472,-2.178375,0.0,0.0,0.375,0.85
10,0.0,-1.0485,0.497096,839.25,0.000567,-2.2235,0.0,0.0,0.4375,0.7375


[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
[debug] gpt completion_graph:{"nodes":[{"id":1,"label":"Cresomycin","type":"antibiotic"},{"id":2,"label":"Red light therapy","type":"therapy"},{"id":3,"label":"Omalizumab","type":"molecule"},{"id":4,"label":"Donor heart","type":"transplant"},{"id":5,"label":"Headgear for gamma stimulation","type":"device"},{"id":6,"label":"PFAS","type":"chemical substance"},{"id":7,"label":"Water scarcity","type":"environmental issue"},{"id":8,"label":"HPV infection","type":"infection"},{"id":9,"label":"ATM","type":"phenomenon"},{"id":10,"label":"H5N1 bird flu virus","type":"virus"},{"id":11,"label":"Amazon rainforest","type":"ecosystem"},{"id":12,"label":"Chlormequat","type":"chemical substance"},{"id":13,"label":"Niacin","type":"vitamin"},{"id":14,"label":"Gender dysphoria","type":"disease"},{"id":15,"label":"Puberty blockers","type":"treatment"},{"id":16,"label":"Cross-sex hormones","type":"hormone"}],"edges":[{"source":1,"target":1,"relation":"discovered"},{

TrainOutput(global_step=250, training_loss=12186718.208415123, metrics={'train_runtime': 40722.1695, 'train_samples_per_second': 0.049, 'train_steps_per_second': 0.006, 'total_flos': 0.0, 'train_loss': 12186718.208415123})

In [None]:
# load hf token from dot env
load_dotenv()
hf_token = os.getenv("HF_TOKEN")

# upload model

model.save_pretrained_merged("qwen2.5-3b-graph-extraction", tokenizer, save_method="merged_16bit",)
model.push_to_hub_merged("weygu/qwen2.5-3b-graph-extraction", tokenizer, save_method="merged_16bit", token=hf_token)


Unsloth: Kaggle/Colab has limited disk space. We need to delete the downloaded
model which will save 4-16GB of disk space, allowing you to save on Kaggle/Colab.
Unsloth: Will remove a cached repo with size 2.4G


Unsloth: Merging 4bit and LoRA weights to 16bit...
Unsloth: Will use up to 102.11 out of 167.06 RAM for saving.
Unsloth: Saving model... This might take 5 minutes ...


100%|██████████| 36/36 [00:00<00:00, 95.84it/s]


Unsloth: Saving tokenizer... Done.
Done.


Unsloth: You are pushing to hub, but you passed your HF username = weygu.
We shall truncate weygu/qwen2.5-3b-graph-extraction to qwen2.5-3b-graph-extraction


Unsloth: Merging 4bit and LoRA weights to 16bit...
Unsloth: Will use up to 102.06 out of 167.06 RAM for saving.
Unsloth: Saving model... This might take 5 minutes ...


100%|██████████| 36/36 [00:00<00:00, 116.33it/s]


Unsloth: Saving tokenizer...

  0%|          | 0/1 [00:00<?, ?it/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

 Done.


README.md:   0%|          | 0.00/31.0 [00:00<?, ?B/s]

  0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.21G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

Done.
Saved merged model to https://huggingface.co/weygu/qwen2.5-3b-graph-extraction


## The trained model

Let's give a try of our model!



### base-model without RL

In [None]:
prompt = """\nYou are an expert in information extraction and knowledge graph creation. Your task is to:
1. Extract key entities and their relationships from the given text, no more than 10 entities.
2. Create a structured knowledge graph

Article Summary:
## 2022 in science

> URL Source: https://en.wikipedia.org/wiki/2022_in_science

This is the article about all science highlights in 2022.

Chunk Text:

15 June -Astronomers identify J1144 as the fastest-growing black hole of the last nine billion years, consuming matter equivalent to one Earth every second, as well as being the most luminous quasi-stellar object of that period.[255][256][257]
Researchers report Lac-Phe as the most significantly induced circulating metabolite in two animal models of exercise which – including via chronic administration – reduces food intake and suppresses obesity.[258][259]
20 June - A study suggests global food miles CO2 emissions are 3.5–7.5 times higher than previously estimated, with transport accounting for about 19% of total food-system emissions,[260][261] albeit shifting towards plant-based diets remains substantially more important.[262]
Researchers demonstrate an MRI-ML-based approach that can diagnose early Alzheimer's disease with high accuracy and may help identify unknown related changes in the brain.[263][264]
21 June – The inability to stand on one leg for 10 seconds in mid to later life is linked to a near-doubling in the risk of death from any cause within the next 10 years.[265][266]
22 June - A study concludes that the spread of breast cancer accelerates during sleep.[267][268]
Agilicious, an open-source and open-hardware versatile standardized quadrotor drone, currently tailored toward agility, is released.[269][270]
The world's first quantum computer integrated circuit is demonstrated.[271][272]
23 June - The largest known bacterium, and an organism that has encapsulated DNA despite being identified as a prokaryote and not an eukaryote, with an average length of 10 mm, T. magnifica is reported.[273][274]
A review shows prevalence of long COVID conditions – like mood symptoms, fatigue and sleep disorders – in people age 0–18 years appears to be at ~25% overall.[275][276]
Two studies about aging-related characteristics of long-lived animals like turtles are published, identifying potentially causal protective traits and suggesting many of the species have "slow or negligible senescence" (or aging).[277][278][279]
Researchers report the controlled growth of diverse foods in the dark via solar energy and electrocatalysis-based artificial photosynthesis as a potential way to increase energy efficiency of food production and reduce its environmental impacts.[280][281]

Instructions: Analyze the above text, explain your reasoning inside <reasoning> tags, and output a structured knowledge graph in JSON format inside <answer> tags. The JSON should have two keys: 'nodes' and 'edges'. For example:
{"nodes": [{"id": 1, "label": "EntityName", "type": "EntityType"}, ...],
"edges": [{"source": 1, "target": 2, "relation": "RELATION_TYPE"}, ...]}
Ensure your output strictly follows this format.
"""

text = tokenizer.apply_chat_template([
    {"role": "user", "content": prompt}
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams

sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 9192,
)
output_base = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,
)[0].outputs[0].text

print(output_base)


Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s][A
Processed prompts: 100%|██████████| 1/1 [00:04<00:00,  4.79s/it, est. speed input: 162.14 toks/s, output: 91.72 toks/s]

<answer>
{
  "nodes": [
    { "id": 1, "label": "J1144", "type": "Astronomical Object" },
    { "id": 2, "label": "Earth", "type": "Planetary Body" },
    { "id": 3, "Lac-Phe", "Metabolite" },
    { "id": 4, "Alzheimer's Disease", "Disease" },
    { "id": 5, "Agilicious", "Drone" },
    { "id": 6, "Quantum Computer", "Device" },
    { "id": 7, "T. magnifica", "Microorganism" },
    { "id": 8, "Long COVID Conditions", "Health Condition" },
    { "id": 9, "Turtles", "Species" },
    { "id": 10, "Solar Energy", "Energy Source" }
  ],
  "edges": [
    { "source": 1, "target": 2, "relation": "Black hole consuming matter" },
    { "source": 3, "target": 1, "relation": "Induces circulating metabolite" },
    { "source": 4, "target": 1, "relation": "Diagnose disease" },
    { "source": 5, "target": 3, "relation": "Developed for agility" },
    { "source": 6, "target": 3, "relation": "Demonstrated" },
    { "source": 7, "target": 6, "relation": "Reported" },
    { "source": 8, "target": 7, "rel




### RL Model

In [None]:
model.save_lora("grpo_saved_lora")

config.json:   0%|          | 0.00/1.42k [00:00<?, ?B/s]

In [None]:
text = tokenizer.apply_chat_template([
    {"role" : "system", "content" : SYSTEM_PROMPT},
    {"role" : "user", "content" : prompt},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams

sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 9192,
)
output_rl = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text

answer = extract_tag_content(output_rl, "answer")
reasoning = extract_tag_content(output_rl, "reasoning")

print("reasoning:\n")
print(reasoning)

print("answer:\n")
print(answer)


Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s][A
Processed prompts: 100%|██████████| 1/1 [00:05<00:00,  5.58s/it, est. speed input: 140.43 toks/s, output: 76.86 toks/s]

reasoning:

The text contains various scientific developments and breakthroughs in 2022, which can be categorized into different entities. These include astronomical discoveries, research on metabolic compounds, food systems, health conditions, technology, and biological findings. The key entities extracted should be related to these categories.
answer:

{"nodes":[{"id":1,"label":"J1144","type":"Astronomical Object"},{"id":2,"label":"Lac-Phe","type":"Metabolite"},{"id":3,"label":"Global Food Miles CO2 Emissions","type":"Environmental Factor"},{"id":4,"label":"Alzheimer's Disease","type":"Disease"},{"id":5,"label":"Stand on One Leg","type":"Physical Test"},{"id":6,"label":"Quantum Computer Circuit","type":"Technology"},{"id":7,"label":"T. magnifica","type":"Organism"},{"id":8,"label":"Long COVID Conditions","type":"Health Condition"},{"id":9,"label":"Aging-Related Characteristics","type":"Aging Characteristics"},{"id":10,"label":"Diverse Foods","type":"Food"},{"id":11,"label":"Solar Ene




## Rough Evaluation

Then, let's roughly evaluate our outcome!

Where, we also try the same task with gpt-4o-mini, and I put its result here:

```json
{ "nodes": [ {"id": 1, "label": "J1144", "type": "Black Hole"}, {"id": 2, "label": "Lac-Phe", "type": "Metabolite"}, {"id": 3, "label": "Food miles CO2 emissions", "type": "Research Finding"}, {"id": 4, "label": "MRI-ML-based approach", "type": "Method"}, {"id": 5, "label": "Alzheimer's disease", "type": "Disease"}, {"id": 6, "label": "Breast cancer", "type": "Disease"}, {"id": 7, "label": "Agilicious", "type": "Drone"}, {"id": 8, "label": "Quantum computer integrated circuit", "type": "Technology"}, {"id": 9, "label": "T. magnifica", "type": "Bacterium"}, {"id": 10, "label": "Long COVID conditions", "type": "Health Condition"} ], "edges": [ {"source": 1, "target": 5, "relation": "is related to"}, {"source": 2, "target": 5, "relation": "induces"}, {"source": 3, "target": 5, "relation": "affects"}, {"source": 4, "target": 5, "relation": "diagnoses"}, {"source": 6, "target": 5, "relation": "is related to"}, {"source": 7, "target": 6, "relation": "is a technological development for"}, {"source": 8, "target": 6, "relation": "is a technological milestone for"}, {"source": 9, "target": 5, "relation": "is a research subject of"}, {"source": 10, "target": 5, "relation": "is a health condition linked to"} ] }
```

Then we leverage a smart instruct LLM(gpt-4o) for evluation, and we could see our 3b-sized model based on the great qwen2.5 with lora GPRO RL on only 60+ level of DeepSeek R1 distilled data, to get a reasoning-based cheap graph extraction model beats gpt-4o-mini :).

I know it's not scrupulously evaluated but this is just a quick experiment to explore the path to extremely cheap expert model for Graph Indexing(GraphRAG, Knowledge Distilling for Retrieval).



### 1. Accuracy (Correctness of Nodes and Edges)
- **Model A(base:qwen2.5-3b):**
  - Correctly identifies key scientific entities but has some incorrect edges (e.g., black hole consuming a metabolite).
- **Model B(our RL model):**
  - More accurate in distinguishing different scientific topics but has some speculative relationships (e.g., standing on one leg linked to quantum computers).
- **Model C(GPT-4o-mini):**
  - Includes well-defined entities but introduces **incorrect general relationships** (e.g., black hole related to Alzheimer’s disease, Lac-Phe inducing Alzheimer’s).

### 2. Completeness (Coverage of Relevant Information)
- **Model A(base:qwen2.5-3b):**  
  - Misses important topics like **food miles CO2 emissions** and **artificial photosynthesis**.
- **Model B(our RL model):**  
  - Covers the broadest range of topics, including **food miles CO2 emissions** and **aging-related characteristics**.
- **Model C(GPT-4o-mini):**  
  - More balanced but **lacks artificial photosynthesis** and **aging-related characteristics**.

### 3. Coherence (Logical Consistency of Edges)
- **Model A(base:qwen2.5-3b):**  
  - Some edges link **unrelated** entities.
- **Model B(our RL model):**  
  - Edges are **slightly better** but still contain **weak logical connections**.
- **Model C(GPT-4o-mini):**  
  - Most edges **overgeneralize relationships** (e.g., "is related to" for multiple diseases and unrelated concepts).

### 4. Relevance (Focus on Key Scientific Highlights)
- **Model A(base:qwen2.5-3b):**  
  - Covers fewer topics and **misses some key advancements**.
- **Model B(our RL model):**  
  - Covers the **widest range** but has **some speculative relationships**.
- **Model C(GPT-4o-mini):**  
  - Focuses heavily on **disease-related links** while **missing technology and energy advancements**.

---

## Final Verdict
| **Criterion**     | **qwen2.5-3b** | **qwen2.5-3b-GPRO** | **gpt-4o-mini** |
|------------------|------------|------------|------------|
| **Accuracy**      | ⭐⭐☆☆☆ | ⭐⭐⭐☆☆ | ⭐⭐☆☆☆ |
| **Completeness**  | ⭐⭐☆☆☆ | ⭐⭐⭐⭐☆ | ⭐⭐⭐☆☆ |
| **Coherence**     | ⭐⭐☆☆☆ | ⭐⭐⭐☆☆ | ⭐⭐☆☆☆ |
| **Relevance**     | ⭐⭐⭐☆☆ | ⭐⭐⭐⭐☆ | ⭐⭐⭐☆☆ |

- **Model B is the best overall**, as it covers more scientific advancements and includes key concepts.
- **Model C introduces general but weak relationships**, making it less reliable.
- **Model A is the most conservative**, but it lacks completeness.
- **None of the models achieve perfect relationships**, and **edge refinement is needed** to improve logical connections.

### **Recommendation**
- **Model B is the best starting point**, but edges need **improvement in logical structure**.
- **Hybrid approach**: Combine **Model B's completeness** with **Model A's cautious approach** to improve accuracy.
- **Improve edge extraction** using **semantic similarity** or **causal inference techniques** to avoid incorrect connections.
