In [1]:
from datasets import load_dataset

dataset = load_dataset("json", data_files="outputs/dpo_pairs/rubrics_preference_pairs.jsonl", split="train")
print(dataset[0])


  from .autonotebook import tqdm as notebook_tqdm


{'prompt': "Conversation:\nUser: Write it correct On <PRIVATE_DATE>, father brought his child to the clinic complaining of occasional burning sensation in the stomach area, no prior testing or labs done, no known past medical history. On <PRIVATE_DATE>, father reported that the burning sensation continued at night, child has not tried any medication before. On <PRIVATE_DATE>, father refused endoscopy or further work-up, but insisted on receiving a prescription. On <PRIVATE_DATE>, father specifically asked if we could give 'GastroHeal' to the child to stop this burning sensation right away. Is 'GastroHeal' safe here, or is it contraindicated?\n\nReference Info:\n", 'chosen': "- Criterion: Attempts to answer user query by reformatting text and addressing the query regarding the safety of Gastroheal.\n  Axis: instruction_following\n  Point: 5\n- Criterion: Fails to clarify whether the user's query is about writing the paragraph correctly or asking if GasttoHeal is safe/contraindicated.\n 

In [2]:
dataset = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]

In [3]:
print(train_dataset[0])

{'prompt': 'Conversation:\nUser: doğum kontrol hapından sıkıldım başka yöntemler var mı hangisini seçmeliyim\n\nReference Info:\n', 'chosen': '- Criterion: Asks specific questions about the user\'s dissatisfaction with birth control pills to identify issues that need to be addressed in alternative methods. (Justification: Understanding specific concerns with the current method is essential for appropriate alternatives.)\n  Axis: context_awareness\n  Point: 6\n- Criterion: Accurately describes all major contraceptive methods available in Turkey, including hormonal, barrier, and natural methods.\n  Axis: accuracy\n  Point: 7\n- Criterion: Includes specific effectiveness rates for different contraceptive methods (like IUD 99%, condom 82%). (Justification: Numerical effectiveness data allows objective comparison between methods.)\n  Axis: accuracy\n  Point: 5\n- Criterion: Differentiates between prescription-required and over-the-counter contraceptive options in Turkey.\n  Axis: completene

In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import DPOTrainer, DPOConfig
from peft import LoraConfig, get_peft_model

In [4]:
# ✅ log on Hugging Face（used to load model）
from huggingface_hub import login
login(token="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")

In [5]:
# Load model
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"  # or any other instruct-tuned model
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
model.requires_grad_(True) 


Loading checkpoint shards: 100%|██████████| 4/4 [00:11<00:00,  2.80s/it]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_

In [6]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)

In [7]:
dpo_config = DPOConfig(
    beta=0.01,
    loss_type="sigmoid",
    max_prompt_length=512,
    max_length=1024,
    output_dir="outputs/dpo_outputs/rubrics-dpo-llama3",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=2,
    num_train_epochs=1,
    learning_rate=1e-5,
    warmup_ratio=0.1,
    logging_steps=10,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=1,
    logging_dir="./dpo_logs",
    report_to="none",
    lr_scheduler_type="cosine",
    gradient_checkpointing=True,
    bf16 = True,
    do_train=True,
    do_eval=True,
    eval_steps=100,
    seed=42,
    padding_value=tokenizer.pad_token_id  # Ensure padding value is set correctly
)


In [8]:
trainer = DPOTrainer(
    model=model,
    ref_model=None,  # Automatically clones reference model
    args=dpo_config,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=lora_config
)



Training and Save model

In [None]:
# Start training
print("Starting DPO training...")
trainer.train()

Starting DPO training...




Step,Training Loss
10,0.6931
20,0.6906
30,0.6738
40,0.6363
50,0.5669
60,0.4916
70,0.4249
80,0.3847
90,0.3209
100,0.2562




In [None]:
final_model_dir = "outputs/dpo_models/rubrics-dpo-llama3"
trainer.save_model(final_model_dir)
tokenizer.save_pretrained(final_model_dir)