In [None]:
from unsloth import FastLanguageModel
import torch
# max_seq_length = 4096 # Can increase for longer reasoning traces
max_seq_length = 6144 # Can increase for longer reasoning traces
lora_rank = 8 # Larger rank = smarter, but slower

# Use v0 Engine
import os
os.environ["VLLM_USE_V1"] = "0"

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "<MODEL>", 
    fix_tokenizer    = False,           # <— let it use the HF tokenizer directly
    max_seq_length = max_seq_length,
    load_in_4bit = False, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.6, # Reduce if out of memory
)


In [None]:

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

<a name="Data"></a>
### Data Prep


In [None]:
from huggingface_hub import login

import wandb

wandb.login(key="<WANDB_TOKEN>")

login(token="<HF_TOKEN>")

In [None]:
import pandas as pd
from datasets import Dataset

df = pd.read_csv("data/qwq_sft_data.csv")

def to_messages(row):
    return {
        "messages": [
            {"role": "user", "content": row["prompt"]},
            {"role": "assistant", "content": row["response"]}
        ]
    }

dataset = Dataset.from_pandas(df)
dataset = dataset.map(to_messages, remove_columns=["prompt", "response"])


In [None]:
print(dataset[0])

In [None]:
tokenizer.chat_template = \
"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{% for message in messages %}{% if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{% endif %}{% endfor %}{{bos_token}}{{ns.system_prompt}}{% for message in messages %}{% if message['role'] == 'user' %}{% set ns.is_tool = false %}{{'<｜User｜>' + message['content']}}{% endif %}{% if message['role'] == 'assistant' and message['content'] is none %}{% set ns.is_tool = false %}{% for tool in message['tool_calls']%}{% if not ns.is_first %}{{'<｜Assistant｜><｜tool▁calls▁begin｜><｜tool▁call▁begin｜>' + tool['type'] + '<｜tool▁sep｜>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<｜tool▁call▁end｜>'}}{% set ns.is_first = true %}{% else %}{{'\\n' + '<｜tool▁call▁begin｜>' + tool['type'] + '<｜tool▁sep｜>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<｜tool▁call▁end｜>'}}{{'<｜tool▁calls▁end｜><｜end▁of▁sentence｜>'}}{% endif %}{% endfor %}{% endif %}{% if message['role'] == 'assistant' and message['content'] is not none %}{% if ns.is_tool %}{{'<｜tool▁outputs▁end｜>' + message['content'] + '<｜end▁of▁sentence｜>'}}{% set ns.is_tool = false %}{% else %}{% set content = message['content'] %}{{'<｜Assistant｜>' + content + '<｜end▁of▁sentence｜>'}}{% endif %}{% endif %}{% if message['role'] == 'tool' %}{% set ns.is_tool = true %}{% if ns.is_output_first %}{{'<｜tool▁outputs▁begin｜><｜tool▁output▁begin｜>' + message['content'] + '<｜tool▁output▁end｜>'}}{% set ns.is_output_first = false %}{% else %}{{'\\n<｜tool▁output▁begin｜>' + message['content'] + '<｜tool▁output▁end｜>'}}{% endif %}{% endif %}{% endfor %}{% if ns.is_tool %}{{'<｜tool▁outputs▁end｜>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<｜Assistant｜></think>\\n'}}{% endif %}"

In [None]:
def formatting_prompts_func(examples):
   convos = examples["messages"]
   texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
   return texts

# print(formatting_prompts_func(dataset))
# dataset = dataset.map(formatting_prompts_func, batched=True)


In [None]:
max_seq_length = 6144
output_dir = "<OUTPUT_DIR>"

from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    eval_dataset = None, # Can set up evaluation!
    max_seq_length = max_seq_length, 
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4, # Use GA to mimic batch size!
        warmup_steps = 5,
        num_train_epochs = 25, # Set this for 1 full training run.
        save_strategy="epoch", 
        # max_steps = 30,
        learning_rate = 2e-4, # Reduce to 2e-5 for long training runs
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        report_to = "wandb", # Use this for WandB etc
        dataset_num_proc=2,
        output_dir = output_dir,
    ),
    formatting_func = formatting_prompts_func, 
)

In [None]:
from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
    trainer,
    instruction_part = "<｜User｜>",
    response_part = "<｜Assistant｜>",
)

In [None]:
tokenizer.decode(trainer.train_dataset[100]["input_ids"])

In [None]:
tokenizer.decode([tokenizer.pad_token_id if x == -100 else x for x in trainer.train_dataset[100]["labels"]]).replace(tokenizer.pad_token, " ")

In [None]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

In [None]:
from contextlib import redirect_stdout, redirect_stderr
import os

with open(os.path.join(output_dir, "training.log"), "w") as out, open(os.path.join(output_dir, "training.err"),"w") as err:
    with redirect_stdout(out), redirect_stderr(err):
        # trainer.train(resume_from_checkpoint=checkpoint)
        trainer.train()

In [None]:
model.push_to_hub("") 
tokenizer.push_to_hub("")
model.save_lora("") 