### Setup

In [None]:
from IPython.display import clear_output

!pip install transformers datasets trl torch huggingface-hub wandb scikit-learn bitsandbytes accelerate
clear_output(wait=False)

In [None]:
import random
import numpy as np
import torch
import gc

gc.collect()
torch.cuda.empty_cache()

SEED = 4242
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

In [None]:
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()

hf_token = user_secrets.get_secret("HF_TOKEN")
login(hf_token)

In [None]:
import wandb

wandb_api = user_secrets.get_secret("WANDB_API")
wandb.login(key=wandb_api)

run = wandb.init(
    project='test Deepseek-R1-Qwen-1.5b SFT on medical dataset', 
    job_type="training",
    anonymous="allow"
)

In [None]:
import torch

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
device

### Model loading

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import setup_chat_format


# bnb_config = BitsAndBytesConfig(load_in_8bit=True)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype="float16",
    bnb_4bit_use_double_quant=True,
)

model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    # torch_dtype=torch.bfloat16,
    quantization_config=bnb_config,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# tokenizer.pad_token = "<PAD>"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
model.config.pad_token_id = tokenizer.eos_token_id
model.config.use_cache = False
model.config.pretraining_tp = 1


finetune_name = "DeepSeek-R1-Distill-Qwen-1.5B-Medical"
finetune_tags = ["SFT", "MedChat"]

In [None]:
print(next(model.parameters()).dtype)

In [None]:
system_prompt = """
You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning. 
Please answer the following medical question.\n"""

In [None]:
import gc

gc.collect()
torch.cuda.empty_cache()

In [None]:
prompt = "A 3-year-old child presents with tall stature, developmental delay, joint hypermobility, hyperelastic skin, fair complexion, prominent sternum, and downward lens subluxation in the right eye. Considering these features, what complication is this child most likely to develop?"

# messages = [{"role": "user", "content": prompt}]
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}]
formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False)

inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
outputs = model.generate(
    input_ids=inputs.input_ids,
    attention_mask=inputs.attention_mask,
    max_new_tokens=1200,
    use_cache=True,
    pad_token_id=tokenizer.eos_token_id,
)

print("Output before training:")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

### Dataset loading and preparing

In [None]:
from datasets import load_dataset

ds = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", "en", split="train[0:500]", trust_remote_code=True)

In [None]:
ds

In [None]:
ds[0]

In [None]:
train_prompt_style = """
### Instruction:
You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning. 
Please answer the following medical question. 

### Question:
{}

### Response:
<think>
{}
</think>
{}
"""

In [None]:
EOS_TOKEN = tokenizer.eos_token

def formatting_prompts_func(examples):
    questions = examples["Question"]
    thoughts = examples["Complex_CoT"]
    responses = examples["Response"]
    texts = []
    for question, thought, response in zip(questions, thoughts, responses):
        text = train_prompt_style.format(question, thought, response) + EOS_TOKEN
        texts.append(text)
    return {"text": texts}

In [None]:
ds_formatted = ds.map(
    formatting_prompts_func,
    batched=True,
    remove_columns=["Question", "Complex_CoT", "Response"]
)

In [None]:
ds_formatted[0]["text"]

In [None]:
from datasets import *

ds_splitted = ds_formatted.train_test_split(test_size=0.1, seed=SEED)

In [None]:
ds_splitted

### Setup training config

In [None]:
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

In [None]:
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)
# model.gradient_checkpointing_enable()

In [None]:
from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="./results",
    eval_strategy="steps",
    eval_steps=20,
    # save_steps=1000,
    logging_steps=10,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    optim="paged_adamw_32bit",
    # optim="adamw_torch_fused", 
    lr_scheduler_type="cosine",
    warmup_steps=200,
    learning_rate=2e-5,
    num_train_epochs=3,
    weight_decay=0.01,
    seed=SEED,
    report_to="wandb",
    fp16=True,
    bf16=False,
    tf32=False,
    hub_model_id=finetune_name,
    gradient_checkpointing=True
)

In [None]:
from trl import SFTConfig, SFTTrainer

trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    train_dataset=ds_splitted["train"],
    eval_dataset=ds_splitted["test"],
    peft_config=peft_config,
    args=args,
)

### Train model

In [None]:
import gc

gc.collect()
torch.cuda.empty_cache()

In [None]:
trainer.train()