In [None]:
# !pip install -q pip3-autoremove
# !pip install -q torch torchvision torchaudio xformers --index-url https://download.pytorch.org/whl/cu124
# !pip install -q unsloth vllm

In [None]:
import os
import wandb
os.environ["WANDB_API_KEY"] = 
os.environ["WANDB_PROJECT"] = "Coursework" 
os.environ["WANDB_LOG_MODEL"] = "checkpoint"

In [None]:
from huggingface_hub import login
login()

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import unsloth
from unsloth import FastLanguageModel
import torch
from datasets import load_dataset
from trl import DPOTrainer, DPOConfig
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    DataCollatorForSeq2Seq
)
from peft import LoraConfig
from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, 
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

model, tokenizer = FastLanguageModel.from_pretrained(
    "Qwen/Qwen2.5-0.5B-Instruct",
    quantization_config=bnb_config,
    fast_inference=True
)

model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj"],
    lora_alpha=16,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing=True,
    random_state=42,
)

In [None]:
import re

reasoning_start = "<reasoning>"
reasoning_end   = "</reasoning>" 
solution_start  = "<solution>"
solution_end    = "</solution>"

SYSTEM_PROMPT = \
f"""You are given a problem.
Think about the problem and provide your working out.
Place it between {reasoning_start} and {reasoning_end}.
Then, provide your solution between {solution_start} and {solution_end}"""

def extract_solution(text):
    if solution_start in text and solution_end in text:
        solution = text.split(solution_start)[-1]
        solution = solution.split(solution_end)[0]
        return solution.strip()
    return ""

def extract_hash_answer(text):
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

def correctness_reward_func(prompts, completions, answer, **kwargs):
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_solution(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", 
          f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs):
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_solution(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs):
    pattern = rf"^{re.escape(reasoning_start)}[\s\S]*?{re.escape(reasoning_end)}\s*{re.escape(solution_start)}[\s\S]*?{re.escape(solution_end)}$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.search(pattern, r, re.DOTALL) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs)]:
    pattern = rf"{re.escape(reasoning_start)}[\s\S]*?{re.escape(reasoning_end)}\s*{re.escape(solution_start)}[\s\S]*?{re.escape(solution_end)}"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.search(pattern, r, re.DOTALL) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def reasoning_length_reward_func(completions, **kwargs):
    responses = [completion[0]["content"] for completion in completions]
    rewards = []
    
    for resp in responses:
        if reasoning_start in resp and reasoning_end in resp:
            reason_text = resp.split(reasoning_start)[1].split(reasoning_end)[0].strip()
        else:
            reason_text = ""

        reason_length = len(reason_text)
        min_meaningful_length = 100 
        if reason_length < min_meaningful_length:
            reward = reason_length * 0.5 / min_meaningful_length  - 0.5
        else:
            reward = 0.125
            
        rewards.append(reward)
    return rewards


def count_xml(text):
    count = 0.0
    if reasoning_start in text:
        count += 0.125
        if reasoning_end in text:
            count += 0.125
            post_reasoning = text.split(reasoning_end)[-1]
            if post_reasoning.strip() and not post_reasoning.strip().startswith(solution_start):
                count -= len(post_reasoning) * 0.001
    
    if solution_start in text:
        count += 0.125
        if solution_end in text:
            count += 0.125
            post_solution = text.split(solution_end)[-1]
            if post_solution.strip():
                count -= len(post_solution) * 0.001
    return count

def xmlcount_reward_func(completions, **kwargs):
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

In [None]:
from datasets import load_dataset, Dataset

def get_gsm8k_questions(split = "train"):
    data = load_dataset('openai/gsm8k', 'main')[split]
    data = data.map(lambda x: {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    })
    return data

In [None]:
dataset = get_gsm8k_questions(split='train')

In [None]:
from trl import GRPOConfig, GRPOTrainer

max_prompt_length = 256
max_seq_length = 1024

training_args = GRPOConfig(
    use_vllm=True,
    beta=0.05,
    learning_rate = 5e-6,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 24,
    gradient_accumulation_steps = 1, 
    num_generations = 8,
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    num_train_epochs = 1, 
    max_grad_norm = 0.1,
    report_to = "wandb",
    output_dir = "Qwen2.5-0.5B-Instruct-GRPO",
    gradient_checkpointing=True,
    save_strategy='steps',
    save_steps=250,
    push_to_hub=True,
    hub_model_id='theevolutionisnear/Qwen2.5-0.5B-Instruct-GRPO',
    hub_strategy='checkpoint',
    hub_token=True,
    fp16=True,
)

In [None]:
WANDB_PROJECT = "Coursework"
WANDB_ENTITY = "animavestra888-independent"
WANDB_NAME = "Qwen2.5-0.5B-Instruct-GRPO"

In [None]:
wandb.init(project=WANDB_PROJECT, entity=WANDB_ENTITY, name=WANDB_NAME, settings=wandb.Settings(init_timeout=300))

In [None]:
# from huggingface_hub import snapshot_download

# root_dir = snapshot_download(
#     repo_id="theevolutionisnear/Qwen2.5-0.5B-Instruct-GRPO",
#     revision="0ffec180501348783f125b8558475e61c17809c7",)

# ckpt_dir = f"{root_dir}/last-checkpoint"

In [None]:
# _torch_load = torch.load

# def _load_with_full_pickle(*args, **kwargs):
#     kwargs["weights_only"] = False

#     return _torch_load(*args, **kwargs)
    
# torch.load = _load_with_full_pickle 

# wandb.init(project="Coursework",
#            id="6enu8o1w",
#            resume="must")

In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func,
        reasoning_length_reward_func
    ],
    args = training_args,
    train_dataset = dataset,
)
trainer.train()
#trainer.train(resume_from_checkpoint=ckpt_dir)