In [1]:
from datasets import load_dataset

dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", "en")

In [2]:
dataset['train'][0]

{'Question': 'Given the symptoms of sudden weakness in the left arm and leg, recent long-distance travel, and the presence of swollen and tender right lower leg, what specific cardiac abnormality is most likely to be found upon further evaluation that could explain these findings?',
 'Complex_CoT': "Okay, let's see what's going on here. We've got sudden weakness in the person's left arm and leg - and that screams something neuro-related, maybe a stroke?\n\nBut wait, there's more. The right lower leg is swollen and tender, which is like waving a big flag for deep vein thrombosis, especially after a long flight or sitting around a lot.\n\nSo, now I'm thinking, how could a clot in the leg end up causing issues like weakness or stroke symptoms?\n\nOh, right! There's this thing called a paradoxical embolism. It can happen if there's some kind of short circuit in the heart - like a hole that shouldn't be there.\n\nLet's put this together: if a blood clot from the leg somehow travels to the l

In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM,AutoConfig
import torch
from peft import LoraConfig, get_peft_model, TaskType

config = AutoConfig.from_pretrained("Qwen/Qwen3-1.7B")
config._attn_implementation = "eager"  
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-1.7B", config=config)

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-1.7B", use_fast=True)

target_modules = ["q_proj","v_proj","k_proj","o_proj","gate_proj","up_proj","down_proj"]

peft_config = LoraConfig(
    target_modules=target_modules,
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

model = get_peft_model(model, peft_config)

model.print_trainable_parameters()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

trainable params: 17,432,576 || all params: 1,738,007,552 || trainable%: 1.0030


In [4]:
def preprocess_fn(example):
    return {
        "messages": [
            {"role": "user", "content": example["Question"]},
            {
                "role": "assistant",
                "content": f"<think>{example['Complex_CoT']}</think>{example['Response']}",
            },
        ]
    }

def tokenize_fn(examples):
    convos = examples["messages"]
    texts = [tokenizer.apply_chat_template(convo,tokenize=False,add_generation_prompt=False) for convo in convos]
    return {
        "text": texts,
    }

dataset = dataset.map(
    preprocess_fn,
    remove_columns=dataset["train"].column_names,
    num_proc=4,
)

tokenized_dataset = dataset.map(
    tokenize_fn,
    batched=True,
    num_proc=4,
    remove_columns=["messages"]
)

Map (num_proc=4):   0%|          | 0/19704 [00:00<?, ? examples/s]

In [5]:
print(tokenized_dataset['train'][100]['text'])

<|im_start|>user
A 25-year-old woman presents to the ED with a diffuse, erythematous rash, nausea, vomiting, and fever for 2 days. Physical examination reveals a soaked tampon in her vagina, and blood cultures are negative, suggesting toxic shock syndrome. Which specific molecule on T cells does the toxin most likely bind to?<|im_end|>
<|im_start|>assistant
<think>
Alright, here's a situation with a 25-year-old woman who showed up in the emergency department. She's got this widespread red rash, feeling nauseous, she's vomiting, and running a fever for two days. Something's not quite right here, and it all starts connecting to the idea of toxic shock syndrome. Oh, and there's a crucial detail: they found a soaked tampon during her exam.

Okay, let's dig into what's happening in toxic shock syndrome. It's a bit of a nightmare because it's associated with these things called superantigens. These are basically like the rogue agents of the bacterial world, and they're mostly coming from bug

In [7]:
from trl import SFTConfig, SFTTrainer

sft_config = SFTConfig(
    output_dir="./qwen3-medical-sft",
    dataset_text_field="text",
    max_seq_length=512,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    num_train_epochs=1,
    warmup_ratio=0.03,  
    optim="adamw_torch",
    lr_scheduler_type="cosine",  
    fp16=True,
    save_steps=200,
    save_total_limit=3,
    logging_steps=50,
    report_to=None,  
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=None,  
    formatting_func=None,  
    data_collator=None,  
    args=sft_config,
)

trainer.train()

Error in callback <bound method _WandbInit._pre_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x7f62f738a440>> (for pre_run_cell), with arguments args (<ExecutionInfo object at 7f62bb3b21d0, raw_cell="from trl import SFTConfig, SFTTrainer

sft_config .." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://ssh-remote%2Blearn-g5-xl/home/ubuntu/oracle/finetune_learn/medical-reasoning.ipynb#W4sdnNjb2RlLXJlbW90ZQ%3D%3D>,),kwargs {}:


BrokenPipeError: [Errno 32] Broken pipe

Map:   0%|          | 0/19704 [00:00<?, ? examples/s]

  super().__init__(


BrokenPipeError: [Errno 32] Broken pipe

Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x7f62f738a440>> (for post_run_cell), with arguments args (<ExecutionResult object at 7f62bb3b2ec0, execution_count=7 error_before_exec=None error_in_exec=[Errno 32] Broken pipe info=<ExecutionInfo object at 7f62bb3b21d0, raw_cell="from trl import SFTConfig, SFTTrainer

sft_config .." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://ssh-remote%2Blearn-g5-xl/home/ubuntu/oracle/finetune_learn/medical-reasoning.ipynb#W4sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:


BrokenPipeError: [Errno 32] Broken pipe