### Installation

In [None]:
import os
import subprocess
import sys

def install_packages():
    """Install required packages for Kaggle environment"""
    packages = [
        "unsloth",
        "vllm==0.8.5.post1",
        "bitsandbytes", 
        "accelerate",
        "xformers==0.0.29.post3",
        "peft",
        "trl==0.15.2",
        "triton",
        "cut_cross_entropy",
        "unsloth_zoo",
        "sentencepiece",
        "protobuf",
        "datasets>=3.4.1",
        "huggingface_hub",
        "hf_transfer",
        "transformers==4.51.3"
    ]
    
    for package in packages:
        try:
            subprocess.check_call([sys.executable, "-m", "pip", "install", package])
            print(f"✓ Installed {package}")
        except subprocess.CalledProcessError as e:
            print(f"✗ Failed to install {package}: {e}")

if not os.path.exists('/kaggle/working/packages_installed.flag'):
    install_packages()
    with open('/kaggle/working/packages_installed.flag', 'w') as f:
        f.write('installed')

In [None]:
import re, requests

url = "https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt"
response = requests.get(url)
filtered = re.sub(r"(transformers|numpy|xformers)[^\n]*\n", "", response.text)

with open("vllm_requirements.txt", "w") as f:
    f.write(filtered)

!pip install -r vllm_requirements.txt

### Unsloth

Goal: To convert `Qwen3-4B-Base` into a reasoning model via GRPO by using OpenR1's Math dataset.

We first pre fine-tune the model to make GRPO skip trying to match formatting - this speeds GRPO up.

In [None]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 4096 
lora_rank = 32

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-4B-Base",
    max_seq_length = max_seq_length,
    load_in_4bit = True, 
    fast_inference = True, 
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.5, 
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, 
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = lora_rank*2, 
    use_gradient_checkpointing = "unsloth", 
    random_state = 3407,
)

### GRPO chat template
Since we're using a base model, we should set a chat template. You can make your own chat template as well!
1. DeepSeek uses `<think>` and `</think>`, but this is **not** necessary - you can customize it however you like!
2. A `system_prompt` is recommended to at least guide the model's responses.

In [None]:
reasoning_start = "<start_working_out>"
reasoning_end   = "<end_working_out>"  
solution_start  = "<SOLUTION>"
solution_end    = "</SOLUTION>"

system_prompt = \
f"""You are given a math problem.

1. Carefully analyze the problem.
2. Show all your working out and reasoning steps clearly and concisely.
3. Ensure all reasoning is placed between {reasoning_start} and {reasoning_end}.
4. Finally, provide your final answer between {solution_start} and {solution_end}.

Stick strictly to this format."""
system_prompt

We create a simple chat template below. Notice `add_generation_prompt` includes prepending `<start_working_out>` to guide the model to start its reasoning process.

In [None]:
chat_template = \
    "{% if messages[0]['role'] == 'system' %}"\
        "{{ messages[0]['content'] + eos_token }}"\
        "{% set loop_messages = messages[1:] %}"\
    "{% else %}"\
        "{{ '{system_prompt}' + eos_token }}"\
        "{% set loop_messages = messages %}"\
    "{% endif %}"\
    "{% for message in loop_messages %}"\
        "{% if message['role'] == 'user' %}"\
            "{{ message['content'] }}"\
        "{% elif message['role'] == 'assistant' %}"\
            "{{ message['content'] + eos_token }}"\
        "{% endif %}"\
    "{% endfor %}"\
    "{% if add_generation_prompt %}{{ '{reasoning_start}' }}"\
    "{% endif %}"

chat_template = chat_template\
    .replace("'{system_prompt}'",   f"'{system_prompt}'")\
    .replace("'{reasoning_start}'", f"'{reasoning_start}'")
tokenizer.chat_template = chat_template

In [None]:
tokenizer.apply_chat_template([
    {"role" : "user", "content" : "What is 1+1?"},
    {"role" : "assistant", "content" : f"{reasoning_start}I think it's 2.{reasoning_end}{solution_start}2{solution_end}"},
    {"role" : "user", "content" : "What is 2+2?"},
], tokenize = False, add_generation_prompt = True)

### Pre fine-tuning for formatting
We now use a subset of NVIDIA's [Open Math Reasoning dataset](https://huggingface.co/datasets/nvidia/OpenMathReasoning) which was filtered to only include high quality DeepSeek R1 traces.

In [None]:
from datasets import load_dataset
import pandas as pd
import numpy as np

dataset = load_dataset("unsloth/OpenMathReasoning-mini", split = "cot")
dataset = dataset.to_pandas()[
    ["expected_answer", "problem", "generated_solution"]
]

is_number = pd.to_numeric(pd.Series(dataset["expected_answer"]), errors = "coerce").notnull()
dataset = dataset.iloc[np.where(is_number)[0]]

dataset

In [None]:
def format_dataset(x):
    expected_answer = x["expected_answer"]
    problem = x["problem"]

    thoughts = x["generated_solution"]
    thoughts = thoughts.replace("<think>", "").replace("</think>", "")

    thoughts = thoughts.strip()
    final_prompt = \
        reasoning_start + thoughts + reasoning_end + \
        solution_start + expected_answer + solution_end
    return [
        {"role" : "system",    "content" : system_prompt},
        {"role" : "user",      "content" : problem},
        {"role" : "assistant", "content" : final_prompt},
    ]

dataset["Messages"] = dataset.apply(format_dataset, axis = 1)

In [None]:
tokenizer.apply_chat_template(dataset["Messages"][0], tokenize = False)

Let's truncate the pre fine-tuning dataset to `max_seq_length*0.3` since we don't want too long reasoning traces.

Note this might take 2 minutes!

In [None]:
dataset["N"] = dataset["Messages"].apply(lambda x: len(tokenizer.apply_chat_template(x)))

dataset = dataset.loc[dataset["N"] <= max_seq_length*0.30].copy()
dataset.shape

We then tokenize the messages and convert it to a Hugging Face compatible dataset format:

In [None]:
from datasets import Dataset

dataset["text"] = tokenizer.apply_chat_template(dataset["Messages"].values.tolist(), tokenize = False)
dataset = Dataset.from_pandas(dataset)
dataset

Let's now pre fine-tune the model so it follows our custom GRPO formatting!

In [None]:
from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 4,
        gradient_accumulation_steps = 4, 
        warmup_steps = 5,
        num_train_epochs = 2, 
        learning_rate = 2e-4,
        logging_steps = 5,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        report_to = "none", 
    ),
)

In [None]:
trainer.train()

In [None]:
text = tokenizer.apply_chat_template(
    dataset[0]["Messages"][:2],
    tokenize = False,
    add_generation_prompt = True, 
)

from transformers import TextStreamer
_ = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    temperature = 0,
    max_new_tokens = 1024,
    streamer = TextStreamer(tokenizer, skip_prompt = False),
)

Yes it did follow the formatting! Great! Let's remove some items before the GRPO step

In [None]:
del dataset
torch.cuda.empty_cache()
import gc
gc.collect()

# Let's begin with GRPO training

In [None]:
import wandb
import os
from datasets import load_dataset
import re
import numpy as np
from transformers import AutoTokenizer
from vllm import SamplingParams
from trl import GRPOConfig, GRPOTrainer

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wandb_api_key = user_secrets.get_secret("wandb_api_key")
os.environ["WANDB_API_KEY"] = wandb_api_key

wandb.init(
    project="grpo-math-reaoning_hard",  
    name="hard_labels_final",    
    config={
        "reward_type": "hard_labels",
        "model_name": "model",  
        "dataset": "yobro4619/math-reasoning-dataset",
        "learning_rate": 5e-6,
        "batch_size": 1,
        "gradient_accumulation_steps": 2,
        "num_generations": 4,
        "num_train_epochs": 3,
        "temperature": 0.7,
        "min_p": 0.1,
    },
    tags=["hard_labels", "grpo", "math_reasoning", "baseline"],
    notes="Baseline experiment using hard reward labels for mathematical reasoning"
)


wandb.define_metric("training/step")
wandb.define_metric("training/epoch")

wandb.define_metric("accuracy/*", step_metric="training/step")
wandb.define_metric("rewards/*", step_metric="training/step")
wandb.define_metric("errors/*", step_metric="training/step")
wandb.define_metric("convergence/*", step_metric="training/step")
wandb.define_metric("train/*", step_metric="training/step")

wandb.define_metric("accuracy/exact_answer_accuracy", summary="max")
wandb.define_metric("accuracy/total_correct_rate", summary="max")
wandb.define_metric("rewards/avg_total_reward", summary="max")
wandb.define_metric("train/loss", summary="min")

### Data Loading
<a name="Data"></a>

I'm using Hugging Face's [AM-DeepSeek-R1-Distilled](https://huggingface.co/datasets/yobro4619/math-reasoning-dataset). I have take a [small subset](https://huggingface.co/datasets/yobro4619/math-reasoning-dataset) of this due to time and compute restrictions.

In [None]:
from datasets import load_dataset
dataset = load_dataset("yobro4619/math-reasoning-dataset", split = "train")
dataset

In [None]:
def extract_hash_answer(text):
    # if "####" not in text: return None
    # return text.split("####")[1].strip()
    return text
extract_hash_answer(dataset[0]["solution"])

In [None]:
dataset = dataset.map(lambda x: {
    "prompt" : [
        {"role": "system", "content": system_prompt},
        {"role": "user",   "content": x["prompt"]},
    ],
    "answer": extract_hash_answer(x["solution"]),
})

REGEX PATTERNS FOR ANSWER EXTRACTION

In [None]:
reasoning_end = r"<end_working_out>"
solution_start = r"<SOLUTION>"
solution_end_regex = r"</SOLUTION>[\s]{0,}" + \
    "(?:" + re.escape(tokenizer.eos_token) + ")?"

match_format = re.compile(
    rf"{solution_start}(.+?){solution_end_regex}"
    rf"[\s]{{0,}}",
    flags=re.MULTILINE | re.DOTALL
)

match_fractions = re.compile(
    solution_start + r".*?(\\frac\{[^}]+\}\{[^}]+\})",
    flags=re.MULTILINE | re.DOTALL
)

match_numbers = re.compile(
    solution_start + r"(?!.*\\frac).*?[\s]{0,}([-]?[\d\.\,]{1,})(?!\})",
    flags=re.MULTILINE | re.DOTALL
)

match_choice = re.compile(
    solution_start + r".*?[\s]{0,}([A-Z])(?:\s|</SOLUTION>)",
    flags=re.MULTILINE | re.DOTALL
)

In [None]:
match_format.findall(
    "Let me think!<end_working_out>"\
    f"<SOLUTION>\n2\n</SOLUTION>",
)

In [None]:
def extract_answer(response_text):
    """
    Extract answer from model response STRICTLY from within <SOLUTION></SOLUTION> tags
    Priority order:
    1. Try fraction answer (check first to avoid conflicts with numbers)
    2. Try numerical answer
    3. Try multiple choice answer
    4. Return complete content inside SOLUTION tags
    """
    
    solution_matches = match_format.findall(response_text)
    if not solution_matches:
        return ""  
    
    solution_content = solution_matches[0].strip()
    
    solution_content = re.sub(re.escape(tokenizer.eos_token), "", solution_content).strip()
    
    
    # 1. Try to find fraction answer
    fraction_pattern = re.compile(r'(\\frac\{[^}]+\}\{[^}]+\})')
    fraction_matches = fraction_pattern.findall(solution_content)
    if fraction_matches:
        return fraction_matches[0].strip()
    
    # 2. Try to find numerical answer (standalone numbers)
    number_pattern = re.compile(r'^\s*([-]?[\d\.\,]+)\s*$|(?:^|\s)([-]?[\d\.\,]+)(?:\s|$)')
    number_matches = number_pattern.findall(solution_content)
    if number_matches:
        for match_tuple in number_matches:
            for match in match_tuple:
                if match:
                    return match.strip()
    
    # 3. Try to find multiple choice answer (single letter A, B, C, D, etc.)
    choice_pattern = re.compile(r'(?:^|\s)([A-Z])(?:\s|$)')
    choice_matches = choice_pattern.findall(solution_content)
    if choice_matches:
        return choice_matches[0].strip()
    
    # 4. If no specific pattern found, return the complete content
    return solution_content

In [None]:
class ConvergenceMonitor:
    def __init__(self, patience=10, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.best_accuracy = 0.0
        self.best_step = 0
        self.patience_counter = 0
        self.accuracy_history = []
        self.loss_history = []
        self.converged = False
        
    def update(self, accuracy, loss, step):
        self.accuracy_history.append(accuracy)
        self.loss_history.append(loss)
        
        if accuracy > self.best_accuracy + self.min_delta:
            self.best_accuracy = accuracy
            self.best_step = step
            self.patience_counter = 0
        else:
            self.patience_counter += 1
            
        # Check convergence
        if self.patience_counter >= self.patience:
            self.converged = True
            
        window_size = min(5, len(self.accuracy_history))
        if len(self.accuracy_history) >= window_size:
            recent_avg = np.mean(self.accuracy_history[-window_size:])
            older_avg = np.mean(self.accuracy_history[-2*window_size:-window_size]) if len(self.accuracy_history) >= 2*window_size else 0
            improvement_rate = (recent_avg - older_avg) if older_avg > 0 else recent_avg
        else:
            improvement_rate = 0
            
        wandb.log({
            "convergence/best_accuracy": self.best_accuracy,
            "convergence/best_step": self.best_step,
            "convergence/patience_counter": self.patience_counter,
            "convergence/steps_since_improvement": step - self.best_step,
            "convergence/improvement_rate": improvement_rate,
            "convergence/converged": 1 if self.converged else 0,
            "training/step": step
        })
        
        return self.converged

convergence_monitor = ConvergenceMonitor(patience=15, min_delta=0.01)

PRINTED_TIMES = 0
PRINT_EVERY_STEPS = 2
current_step = 0

REWARD FUNCTIONS WITH WANDB LOGGING

In [None]:
def match_format_exactly(completions, **kwargs):
    """Hard reward: 3 points for exact format match"""

    global current_step
    current_step += 1
    
    scores = []
    correct_formats = 0
    
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        if match_format.search(response) is not None: 
            score += 3.0
            correct_formats += 1
        scores.append(score)
    
    format_match_rate = correct_formats / len(completions)
    avg_format_score = np.mean(scores)
    
    wandb.log({
        "rewards/format_match_rate": format_match_rate,
        "rewards/avg_format_score": avg_format_score,
        "rewards/format_score_std": np.std(scores),
        "rewards/total_format_matches": correct_formats,
        "training/step": current_step
    })
    
    return scores

def match_format_approximately(completions, **kwargs):
    """Partial credit for format elements"""
    scores = []
    reasoning_end_correct = 0
    solution_start_correct = 0
    solution_end_correct = 0
    total_completions = len(completions)
    
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        
        # Check each format element
        if response.count(reasoning_end) == 1:
            score += 0.5
            reasoning_end_correct += 1
        else:
            score -= 1.0
            
        if response.count(solution_start) == 1:
            score += 0.5
            solution_start_correct += 1
        else:
            score -= 1.0
            
        if response.count(solution_end) == 1:
            score += 0.5
            solution_end_correct += 1
        else:
            score -= 1.0
            
        scores.append(score)
    
    wandb.log({
        "rewards/reasoning_end_accuracy": reasoning_end_correct / total_completions,
        "rewards/solution_start_accuracy": solution_start_correct / total_completions,
        "rewards/solution_end_accuracy": solution_end_correct / total_completions,
        "rewards/avg_approximate_format_score": np.mean(scores),
        "rewards/approximate_format_score_std": np.std(scores),
        "training/step": current_step
    })
    
    return scores

def check_answer(prompts, completions, answer, **kwargs):
    """Hard reward for answer correctness with some continuous elements"""
    question = prompts[0][-1]["content"]
    responses = [completion[0]["content"] for completion in completions]

    extracted_responses = [
        guess.group(1)
        if (guess := match_format.search(r)) is not None else None \
        for r in responses
    ]

    scores = []
    exact_matches = 0
    approximate_matches = 0
    close_matches = 0
    extraction_failures = 0
    wrong_answers = 0
    
    for guess, true_answer in zip(extracted_responses, answer):
        score = 0
        if guess is None:
            scores.append(-2.0)
            extraction_failures += 1
            continue
            
        if guess == true_answer:
            score += 5.0
            exact_matches += 1
        elif guess.strip() == true_answer.strip():
            score += 3.5
            approximate_matches += 1
        else:
            try:
                ratio = float(guess) / float(true_answer)
                if ratio >= 0.9 and ratio <= 1.1:
                    score += 2.0
                    close_matches += 1
                elif ratio >= 0.8 and ratio <= 1.2:
                    score += 1.5
                    close_matches += 1
                else:
                    score -= 2.5
                    wrong_answers += 1
            except:
                score -= 4.5
                wrong_answers += 1
        scores.append(score)
    
    total_completions = len(completions)
    exact_accuracy = exact_matches / total_completions
    total_correct_rate = (exact_matches + approximate_matches + close_matches) / total_completions
    
    wandb.log({
        "accuracy/exact_answer_accuracy": exact_matches / total_completions,
        "accuracy/approximate_answer_accuracy": approximate_matches / total_completions,
        "accuracy/close_answer_accuracy": close_matches / total_completions,
        "accuracy/total_correct_rate": (exact_matches + approximate_matches + close_matches) / total_completions,
        "errors/extraction_failure_rate": extraction_failures / total_completions,
        "errors/wrong_answer_rate": wrong_answers / total_completions,
        "rewards/avg_answer_score": np.mean(scores),
        "rewards/answer_score_std": np.std(scores),
        "rewards/max_answer_score": np.max(scores),
        "rewards/min_answer_score": np.min(scores),
        "training/step": current_step
    })
    
    avg_loss = -np.mean(scores)  
    is_converged = convergence_monitor.update(exact_accuracy, avg_loss, current_step)
    
    if is_converged and not hasattr(check_answer, 'convergence_logged'):
        wandb.log({
            "convergence/convergence_step": current_step,
            "convergence/final_accuracy": exact_accuracy,
            "training/step": current_step
        })
        check_answer.convergence_logged = True
        print(f"🎯 Training converged at step {current_step} with accuracy {exact_accuracy:.4f}")
    
    return scores

In [None]:
def check_numbers(prompts, completions, answer, **kwargs):
    """Numerical answer checking with detailed logging"""
    question = prompts[0][-1]["content"]
    responses = [completion[0]["content"] for completion in completions]

    extracted_responses = [
        guess.group(1)
        if (guess := match_numbers.search(r)) is not None else None \
        for r in responses
    ]

    scores = []
    numerical_matches = 0
    numerical_extraction_failures = 0
    
    global PRINTED_TIMES
    global PRINT_EVERY_STEPS
    
    if PRINTED_TIMES % PRINT_EVERY_STEPS == 0:
        columns = ["Question", "True Answer", "Model Response", "Extracted Answer"]
        data = [[
            question[:100] + "..." if len(question) > 100 else question,
            answer[0],
            responses[0][:150] + "..." if len(responses[0]) > 150 else responses[0],
            str(extracted_responses[0]) if extracted_responses[0] else "None"
        ]]
        
        example_table = wandb.Table(columns=columns, data=data)
        wandb.log({
            "examples/model_outputs": example_table,
            "examples/step": PRINTED_TIMES,
            "training/step": current_step
        })
        
        print('*'*50)
        print(f"Step {current_step} - Example {PRINTED_TIMES}")
        print(f"Question: {question[:100]}...")
        print(f"True Answer: {answer[0]}")
        print(f"Model Response: {responses[0][:100]}...")
        print(f"Extracted: {extracted_responses[0]}")
        print('*'*50)
    
    PRINTED_TIMES += 1

    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:
            scores.append(-2.5)
            numerical_extraction_failures += 1
            continue
            
        try:
            true_answer_float = float(true_answer.strip())
            guess_float = float(guess.strip().replace(",", ""))
            
            if guess_float == true_answer_float:
                scores.append(3.5)
                numerical_matches += 1
            else:
                scores.append(-1.5)
        except:
            scores.append(0)
            continue
    
    total_completions = len(completions)
    wandb.log({
        "accuracy/numerical_exact_accuracy": numerical_matches / total_completions,
        "errors/numerical_extraction_failure_rate": numerical_extraction_failures / total_completions,
        "rewards/avg_numerical_score": np.mean(scores),
        "rewards/numerical_score_std": np.std(scores),
        "training/step": current_step
    })
    
    return scores

<a name="Train"></a>
### Train the model

Now set up GRPO Trainer and all configurations!

In [None]:
max_prompt_length = 512
max_completion_length = max_seq_length - max_prompt_length

In [None]:
vllm_sampling_params = SamplingParams(
    min_p=0.1,
    top_p=1.0,
    top_k=-1,
    seed=3407,
    stop=[tokenizer.eos_token],
    include_stop_str_in_output=True,
)

# Training configuration with WandB integration
training_args = GRPOConfig(
    vllm_sampling_params=vllm_sampling_params,
    temperature=0.7,
    learning_rate=5e-6,
    weight_decay=0.01,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="adamw_8bit",
    logging_steps=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    num_generations=4,
    max_prompt_length=max_prompt_length,
    max_completion_length=max_completion_length,
    num_train_epochs=2,
    save_steps=6,
    report_to="wandb",  
    output_dir="outputs_hard_labels",
    run_name="hard_labels_baseline",
)

trainer = GRPOTrainer(
    model=model,  
    processing_class=tokenizer,  
    reward_funcs=[
        match_format_exactly,
        match_format_approximately,
        check_answer,  
        check_numbers,
    ],
    args=training_args,
    train_dataset=dataset,
)

In [None]:
trainer.train()

In [None]:
model.save_lora("/kaggle/working/grpo_saved_lora")

<a name="Inference"></a>
### Inference

Now let's try the model we just trained! First, let's first try the model without any GRPO trained:

In [None]:
text = "What is the value of 2+2?"

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.7,
    top_k = 50,
    max_tokens = 1024,
)
output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,
)[0].outputs[0].text

output

In [None]:
from safetensors import safe_open

tensors = {}
with safe_open("grpo_saved_lora/adapter_model.safetensors", framework = "pt") as f:
    # Verify both A and B are non zero
    for key in f.keys():
        tensor = f.get_tensor(key)
        n_zeros = (tensor == 0).sum() / tensor.numel()
        assert(n_zeros.item() != tensor.numel())

print("LoRA adapters verified successfully!")

Now we load the LoRA and test:

In [None]:
messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user",   "content": "Find $\\csc 330^\\circ."},
]

text = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt = True, 
    tokenize = False,
)
from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 1.0,
    top_k = 50,
    max_tokens = 2048,
)
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = model.load_lora("/kaggle/working/grpo_saved_lora"),
)[0].outputs[0].text

output

In [None]:
HF_MODEL_NAME = "yobro4619/hard_labels_final"
HF_TOKEN = user_secrets.get_secret("hf_token")

print("Saving and pushing LoRA adapters...")
model.save_pretrained_merged("local_model_lora", tokenizer, save_method="lora")
model.push_to_hub_merged(HF_MODEL_NAME, tokenizer, save_method="lora", token=HF_TOKEN)