# **A Simple Notebook Example to Train TinyLlama on UltraMedical Preference Dataset using DPO **

In [None]:
import torch, platform, sys, os, textwrap

print("Python:", sys.version)
print("PyTorch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))


In [None]:
!pip install -q "transformers>=4.43.0" "datasets>=2.20.0" "accelerate>=0.31.0" "trl>=0.9.4" peft



In [None]:
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
PREF_DATASET = "TsinghuaC3I/UltraMedical-Preference"

OUTPUT_DIR = "./tinyllama-ultramed-dpo"
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("Base model:", BASE_MODEL)
print("Preference dataset:", PREF_DATASET)
print("Output dir:", OUTPUT_DIR)



In [None]:
from datasets import load_dataset

ds = load_dataset(PREF_DATASET, split="train")
# print(ds)
sample = ds[0]
sample


In [None]:
def extract_prompt_chosen_rejected(ex):
    prompt = ex["prompt"]

    def get_last_assistant(turns):
        assistants = [t["content"] for t in turns if t["role"] == "assistant"]
        return assistants[-1] if len(assistants) > 0 else ""

    chosen_answer = get_last_assistant(ex["chosen"])
    rejected_answer = get_last_assistant(ex["rejected"])

    return {
        "prompt": prompt,
        "chosen": chosen_answer,
        "rejected": rejected_answer,
    }

processed_ds = ds.map(extract_prompt_chosen_rejected)
processed_ds = processed_ds.remove_columns(
    [col for col in processed_ds.column_names if col not in ["prompt", "chosen", "rejected"]]
)
processed_ds[0]


In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.bfloat16,
)
model.config.pad_token_id = tokenizer.pad_token_id

print("Model dtype:", next(model.parameters()).dtype)



In [None]:
!huggingface-cli login

In [None]:
from trl import DPOTrainer, DPOConfig

OUTPUT_DIR = "./tinyllama-ultramed-dpo-a100"

training_args = DPOConfig(
    output_dir=OUTPUT_DIR,

    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    learning_rate=1e-6,
    num_train_epochs=1,

    logging_steps=50,
    save_strategy="epoch",
    save_total_limit=3,
    report_to="none",

    bf16=True,
    fp16=False,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},

    beta=0.1,
    max_length=512,
    max_prompt_length=512,

    remove_unused_columns=False,
    seed=42,
)

dpo_trainer = DPOTrainer(
    model=model,
    ref_model=None,
    args=training_args,
    train_dataset=processed_ds,
    processing_class=tokenizer,
)





Here is just an example; we show actual training result in the report.

In [None]:
dpo_trainer.train()



In [None]:
dpo_trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

print("DPO-tuned model saved to:", OUTPUT_DIR)


In [None]:
from transformers import pipeline

base_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.float16,
    device_map="auto",
)
base_pipe = pipeline("text-generation", model=base_model, tokenizer=base_tokenizer)

dpo_tokenizer = AutoTokenizer.from_pretrained(OUTPUT_DIR)
dpo_model = AutoModelForCausalLM.from_pretrained(
    OUTPUT_DIR,
    torch_dtype=torch.float16,
    device_map="auto",
)
dpo_pipe = pipeline("text-generation", model=dpo_model, tokenizer=dpo_tokenizer)

example = processed_ds[0]
question = example["prompt"]

prompt = (
    "You are a helpful and precise medical assistant.\n\n"
    f"Question: {question}\n\nAnswer:"
)

print("=== Base TinyLlama ===")
out_base = base_pipe(prompt, max_new_tokens=256, do_sample=False)
print(out_base[0]["generated_text"])

print("\n=== DPO-tuned TinyLlama ===")
out_dpo = dpo_pipe(prompt, max_new_tokens=256, do_sample=False)
print(out_dpo[0]["generated_text"])
