<a href="https://colab.research.google.com/github/MengruiLIU/Masterclass/blob/master/nb/Llama3.1_(8B)-GRPO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    !pip install --no-deps unsloth vllm

In [3]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 1024 # Can increase for longer reasoning traces
lora_rank = 32 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.6, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 03-22 12:57:17 [__init__.py:256] Automatically detected platform cuda.
==((====))==  Unsloth 2025.3.18: Fast Llama patching. Transformers: 4.49.0. vLLM: 0.8.1.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/meta-llama-3.1-8b-instruct-unsloth-bnb-4bit with actual GPU utilization = 59.43%
Unsloth: Your GPU has CUDA compute capability 7.5 with VRAM = 14.74 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 1024. Num Sequences = 160.
Unsloth: vLLM's KV Cache can use up to 2.5

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 03-22 12:58:28 [punica_selector.py:18] Using PunicaWrapperGPU.
INFO 03-22 12:58:28 [model_runner.py:1146] Model loading took 5.7736 GB and 41.423693 seconds
INFO 03-22 12:58:35 [worker.py:267] Memory profiling takes 6.11 seconds
INFO 03-22 12:58:35 [worker.py:267] the current vLLM instance can use total_gpu_memory (14.74GiB) x gpu_memory_utilization (0.59) = 8.76GiB
INFO 03-22 12:58:35 [worker.py:267] model weights take 5.77GiB; non_torch_memory takes 0.03GiB; PyTorch activation peak memory takes 0.74GiB; the rest of the memory reserved for KV Cache is 2.22GiB.
INFO 03-22 12:58:35 [executor_base.py:111] # cuda blocks: 1134, # CPU blocks: 1024
INFO 03-22 12:58:35 [executor_base.py:116] Maximum concurrency for 1024 tokens per request: 17.72x
INFO 03-22 12:58:37 [model_runner.py:1442] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If o

Capturing CUDA graph shapes: 100%|██████████| 23/23 [00:41<00:00,  1.82s/it]

INFO 03-22 12:59:19 [model_runner.py:1570] Graph capturing finished in 42 secs, took 0.53 GiB
INFO 03-22 12:59:19 [llm_engine.py:447] init engine (profile, create kv cache, warmup model) took 50.73 seconds



Unsloth 2025.3.18 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


In [4]:
import re
from datasets import load_dataset, Dataset

# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('csv', data_files='/content/rootcause.csv', encoding='gbk')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()

# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.5 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

In [5]:
max_prompt_length = 256

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 6, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 250,
    save_steps = 250,
    max_grad_norm = 0.1,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",
)

Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 6


In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func,
    ],
    args = training_args,
    train_dataset = dataset,
)
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 2,011 | Num Epochs = 1 | Total steps = 250
O^O/ \_/ \    Batch size per device = 6 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (6 x 1 x 1) = 6
 "-____-"     Trainable parameters = 83,886,080/8,000,000,000 (1.05% trained)


-------------------- Question:
国网辽宁省电力有限公司沈阳市辽中区供电公司沈阳/66kV.航辽/66kV.#1变重过载成因是什么 
Answer:
结论：夏季高温空调多及冬季供暖用户多负荷快速增长导致的主变重载。

原因及解决方案：属于夏季高温带来的降温及冬季供暖设备负荷自然增长，可以考虑通过明显变联网工程解决此问题。 
Response:
</reasoning>可能原因包括:

1.**负荷突然增加**：由于受到某些原因（如节日期间冬 HVAC季用电量增加）、生产线临时加工加班等，过载状况明显。

2.**工况异常**：检修或接线操作不当，实际负荷超过设计值，正常工况难以应对。

3.**线路故障**：由于线路树枝、电缆变性等问题，接地电阻增大，导致线路耐受性下降，容易产生过载现象

4.**变压器降转'," 维保人员忽略过载测试后不及时调圧。

5.**其他原因**：如母线等设备故障、遮蔽、合格参比调試不全等

</answer> 
Extracted:
</reasoning>可能原因包括:

1.**负荷突然增加**：由于受到某些原因（如节日期间冬 HVAC季用电量增加）、生产线临时加工加班等，过载状况明显。

2.**工况异常**：检修或接线操作不当，实际负荷超过设计值，正常工况难以应对。

3.**线路故障**：由于线路树枝、电缆变性等问题，接地电阻增大，导致线路耐受性下降，容易产生过载现象

4.**变压器降转'," 维保人员忽略过载测试后不及时调圧。

5.**其他原因**：如母线等设备故障、遮蔽、合格参比调試不全等


Step,Training Loss,reward,reward_std,completion_length,kl,rewards / xmlcount_reward_func,rewards / soft_format_reward_func,rewards / strict_format_reward_func,rewards / int_reward_func,rewards / correctness_reward_func
1,0.0,-0.089,1.520131,349.166687,0.0,-0.589,0.0,0.0,0.5,0.0
2,0.0,0.5155,0.037967,399.5,0.0,0.0155,0.0,0.0,0.5,0.0
3,0.0,0.521,0.051439,378.833344,9e-06,0.021,0.0,0.0,0.5,0.0
4,0.0,0.5,0.0,529.0,7e-06,0.0,0.0,0.0,0.5,0.0
5,0.0,0.5,0.0,281.833344,1e-05,0.0,0.0,0.0,0.5,0.0
6,0.0,0.5,0.0,387.666687,9e-06,0.0,0.0,0.0,0.5,0.0
7,0.0,0.5,0.0,503.333344,7e-06,0.0,0.0,0.0,0.5,0.0
8,0.0,0.5,0.0,255.5,8e-06,0.0,0.0,0.0,0.5,0.0
9,0.0,0.5,0.0,470.0,7e-06,0.0,0.0,0.0,0.5,0.0
10,0.0,-0.1555,1.605641,486.333344,7e-06,-0.6555,0.0,0.0,0.5,0.0


-------------------- Question:
国网福建省电力有限公司古田县供电公司鹤大线重过载成因是什么 
Answer:
特殊转供电运行方式造成重过载，优化特殊运行方式，缩短特殊运行方式时间 
Response:
我找不到关于国网福建省电力有限公司古田县供电公司鹤大线重过载成因的信息，但是可以告诉你关于电网的常见过载原因：

1. **负荷过大**：如果线路负荷超过设计负荷能力，尤其是在夏季或冬季高峰期，电网负荷会显著增加，导致线路更容易过载。
2. **设备老化**：高温、腐蚀或磨损等因素会导致电力设备老化，使其电压着度不稳，电容器湿度不正常，提高了线路的感应值，因此易过载。
3. **变压器失效**：变压器失压、摇摆、调压会导致电压波动，影响电网的稳定，使得线路更容易过载。
4. **线路故障**：线路的故障，如断路、水泥烧毁、树杈接地等，也会影响电压下降为线路过载。
5. **环境气候**：极端高温、干旱、高温、风力等因素都可能导致电网质量下降，从而导致线路过载。
6. **负荷构成**：负荷偏向于电力大型客户或设备 خر男性，尤其是尖峰负荷，导致线路过载。
7. **电力互相**：恶劣的气象，局域外设站设定风力改进等 非常严重影响区域电网的稳定，按换在线路过载。

这只是常见的原因。对于鹤大线重过载的具体原因，你可以向国网福建省电力有限公司古田县供电公司咨询他们的专家，他们会为你提供更加准确的答案。 
Extracted:
我找不到关于国网福建省电力有限公司古田县供电公司鹤大线重过载成因的信息，但是可以告诉你关于电网的常见过载原因：

1. **负荷过大**：如果线路负荷超过设计负荷能力，尤其是在夏季或冬季高峰期，电网负荷会显著增加，导致线路更容易过载。
2. **设备老化**：高温、腐蚀或磨损等因素会导致电力设备老化，使其电压着度不稳，电容器湿度不正常，提高了线路的感应值，因此易过载。
3. **变压器失效**：变压器失压、摇摆、调压会导致电压波动，影响电网的稳定，使得线路更容易过载。
4. **线路故障**：线路的故障，如断路、水泥烧毁、树杈接地等，也会影响电压下降为线路过载。
5. **环境气候**：极端高温、干旱、高温、风力等因素都可能导致电网质量下降，从而导致线路过载。
6. **负荷构成**：负荷偏向于电力大型客户或设备 خر男

In [None]:
text = tokenizer.apply_chat_template([
    {"role" : "user", "content" : "Calculate pi."},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,
)[0].outputs[0].text

output

In [None]:
model.save_lora("grpo_saved_lora")

In [None]:
text = tokenizer.apply_chat_template([
    {"role" : "system", "content" : SYSTEM_PROMPT},
    {"role" : "user", "content" : "Calculate pi."},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text

output

In [None]:
# Merge to 16bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_16bit", token = "")

# Merge to 4bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_4bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_4bit", token = "")

# Just LoRA adapters
if False: model.save_pretrained_merged("model", tokenizer, save_method = "lora",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "lora", token = "")

In [None]:
# Save to 8bit Q8_0
if False: model.save_pretrained_gguf("model", tokenizer,)
# Remember to go to https://huggingface.co/settings/tokens for a token!
# And change hf to your username!
if False: model.push_to_hub_gguf("hf/model", tokenizer, token = "")

# Save to 16bit GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "f16")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "f16", token = "")

# Save to q4_k_m GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "q4_k_m")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "q4_k_m", token = "")

# Save to multiple GGUF options - much faster if you want multiple!
if False:
    model.push_to_hub_gguf(
        "hf/model", # Change hf to your username!
        tokenizer,
        quantization_method = ["q4_k_m", "q8_0", "q5_k_m",],
        token = "",
    )