In [None]:
from datasets import load_dataset
from unsloth import FastLanguageModel
from trl import SFTTrainer, SFTConfig

In [None]:
## config 

MODEL = "unsloth/Qwen3-8B"
SFT_DATASET = "FreedomIntelligence/medical-o1-reasoning-SFT"
OUTPUT_DIR = "medQwen3-reasoning-8B"

In [None]:
### Load dataset
dataset = load_dataset(SFT_DATASET, "en")

In [None]:
train_dataset = dataset["train"]

In [None]:
train_dataset

In [None]:
def preprocess_dataset(example):
    assistant_response = f"""
<think>
{example["Complex_CoT"]}
</think>

<answer>
{example["Response"]}
</answer>"""
    return {
        "messages": [
            {"role": "user", "content": example["Question"]},
            {"role": "assistant", "content": assistant_response}
        ]
    }

In [None]:
train_dataset = train_dataset.map(preprocess_dataset)

In [None]:
train_dataset = train_dataset.remove_columns(["Question", "Complex_CoT", "Response"])

In [None]:
train_dataset

In [None]:
## Load Model and Lora Config

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = MODEL,
    max_seq_length = 2048,
    load_in_4bit = True, 
    fast_inference = True,
    max_lora_rank = 16,
    gpu_memory_utilization = 0.7,
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 16,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_alpha = 32, # *2 speeds up training
    use_gradient_checkpointing = "unsloth", 
    random_state = 42,
)

In [None]:
chat_template = """
{%- for message in messages %}
    {%- if message['role'] == 'user' %}
{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '\n\n' %}
{{ content }}
    {%- elif message['role'] == 'assistant' %}
{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + eos_token %}
{{ content }}
    {%- endif %}
{%- endfor %}
"""

In [None]:
tokenizer.chat_template = chat_template

In [None]:
rendered = tokenizer.apply_chat_template(train_dataset[0]["messages"], tokenize=False, add_generation_prompt=False)

In [None]:
print(rendered)

In [None]:
train_dataset = train_dataset.map(
    lambda example: {
        "text": tokenizer.apply_chat_template(example["messages"], tokenize=False)
    }
)

In [None]:
# Even though, we are formatting the dataset beforehand, but when packing=True, it needs to have formatting_func.

def formatting_func(example):
    return example["text"]

In [None]:
### SFFT Trainer

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_dataset,
    max_seq_length = 8192,
    formatting_func = formatting_func,
    args = SFTConfig(
        dataset_text_field="text",
        per_device_train_batch_size = 4,
        gradient_accumulation_steps = 16,
        num_train_epochs = 3, 
        warmup_steps = 100,
        learning_rate = 2e-5,
        logging_steps = 25,
        optim = "paged_adamw_8bit",
        weight_decay = 0.01,
        save_strategy="steps",
        save_steps=750,
        save_total_limit=2,
        lr_scheduler_type = "linear",
        seed = 42,
        output_dir = OUTPUT_DIR,
        report_to = "wandb",
        packing=False,
    )
)

In [None]:
trainer.train()