In [None]:
import unsloth
import torch
from unsloth import FastLanguageModel  
from transformers import AutoTokenizer
from peft import LoraConfig, get_peft_model


ckpt = "checkpoint-41000" # add the path to your own checkpoint 


model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=ckpt, 
    max_seq_length=1024,
    load_in_4bit=False,
)


In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    finetune_vision_layers     = False,
    finetune_language_layers   = True,  
    finetune_attention_modules = True,  
    finetune_mlp_modules       = True, 

    r = 8,           
    lora_alpha = 8,  
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
)

In [None]:
import json
from datasets import Dataset

SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

def get_legal_qa_dataset(json_path: str) -> Dataset:
    with open(json_path, "r", encoding="utf-8") as f:
        raw_data = json.load(f)

    formatted = []

    for item in raw_data:
        question = item["question"]
        answer = item["answer"]
        reasoning = item["chain_of_thought"]

        formatted.append({
            "prompt": [
                {
                    "role": "system",
                    "content": SYSTEM_PROMPT.strip()
                },
                {
                    "role": "user",
                    "content": question
                }
            ],
            "answer": f"<reasoning>\n{reasoning.strip()}\n</reasoning>\n<answer>\n{answer.strip()}\n</answer>"
        })

    return Dataset.from_list(formatted)


dataset = get_legal_qa_dataset("taskBased_input.json")


In [None]:
from transformers import TrainingArguments

max_prompt_length = 512  
max_seq_length = 1024     

from trl import GRPOConfig

training_args = GRPOConfig(
    report_to="wandb",           
    output_dir="./results",

    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    num_generations=2,                     
    num_train_epochs=3,

    learning_rate=5e-5,
    weight_decay=0.1,
    warmup_ratio=0.06,
    lr_scheduler_type="cosine",
    optim="adamw_bnb_8bit",

    logging_dir="./logs",
    logging_steps=20,
    save_strategy="epoch" ,  
    max_grad_norm=1.0,
    fp16=False,
    bf16=True
    
)





In [None]:
import wandb 
from sentence_transformers import SentenceTransformer


model_embedder = SentenceTransformer("intfloat/multilingual-e5-large")

def extract_parts(text: str) -> tuple[str, str]:
    reasoning = re.search(r"<reasoning>(.*?)</reasoning>", text, re.DOTALL)
    answer = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
    return (
        reasoning.group(1).strip() if reasoning else "",
        answer.group(1).strip() if answer else ""
    )

def format_reward_func(promt, completions, answer **kwargs) -> list[float]:
    format_pattern = re.compile(r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>", re.DOTALL)
    scores = []
    for completion in completions:
        response = completion[0]['content']
        if format_pattern.search(response):
            scores.append(1.0)  
        else:
            scores.append(-0.5)  
    return scores


def simmilarity_reward_func(promt, completions, answer **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    reasoning_r, answer_r = zip(*[extract_parts(r) for r in responses])
    reasoning_a, answer_a = zip(*[extract_parts(a) for a in answer])

    reasoning_r = [f"query: {r}" for r in reasoning_r]
    reasoning_a = [f"query: {a}" for a in reasoning_a]
    answer_r = [f"query: {r}" for r in answer_r]
    answer_a = [f"query: {a}" for a in answer_a]

    emb_reasoning_r = model_embedder.encode(reasoning_r, convert_to_tensor=True)
    emb_reasoning_a = model_embedder.encode(reasoning_a, convert_to_tensor=True)
    emb_answer_r = model_embedder.encode(answer_r, convert_to_tensor=True)
    emb_answer_a = model_embedder.encode(answer_a, convert_to_tensor=True)

    sim_reasoning = util.pytorch_cos_sim(emb_reasoning_r, emb_reasoning_a).diag()
    sim_answer = util.pytorch_cos_sim(emb_answer_r, emb_answer_a).diag()

    rewards = [float((0.6 * r + 0.4 * a) * 2.0) for r, a in zip(sim_reasoning, sim_answer)]

    wandb.log({
        "reward/reasoning_cosine_mean": float(torch.mean(sim_reasoning)),
        "reward/answer_cosine_mean": float(torch.mean(sim_answer)),
        "reward/final_reward_mean": float(torch.mean(torch.tensor(rewards))),
        "reward/final_reward_std": float(torch.std(torch.tensor(rewards))),
    })

    return rewards




 

In [None]:
from trl import GRPOTrainer, GRPOConfig
import wandb
import re
from sentence_transformers import SentenceTransformer, util

wandb.init(
    project="taskBasedFT",   
    resume=False,                   
    reinit=True                       
)


tokenizer.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
tokenizer.bos_token_id = tokenizer.bos_token_id or tokenizer.cls_token_id or 1
tokenizer.eos_token_id = tokenizer.eos_token_id or tokenizer.sep_token_id or 2

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer, 
    reward_funcs=[
        format_reward_func,
        simmilarity_reward_func
    ],
    args=training_args,
    train_dataset=dataset,
)

trainer.train()