In [None]:
# SDPO Training Script with Reference Model Configuration
# This script demonstrates how to configure SDPO with different reference model strategies

import time, os
# time.sleep(60*60*1)

# =============================================================================
# SDPO Configuration: Reference Model Selection
# =============================================================================

# Training configuration
node_name = 'model3b-dpo0117-distill-33b-sdpo-beta002-lambda10'
base_model_path = './models/model_3b'  # Policy model (3B parameters)

# Reference Model Selection (Critical for SDPO effectiveness)
# Choose one of the following strategies:

# Strategy 1: Same-size DPO-aligned reference (computationally efficient)
# ref_model_path = './models/model_3b-dpo-baseline'

# Strategy 2: Larger reference model (better performance, higher cost)
ref_model_path = './models/model_33b-dpo-baseline'  # 33B reference for 3B policy

# Strategy 3: Use base model as reference (baseline comparison)
# ref_model_path = base_model_path

# =============================================================================
# SDPO-specific Parameters
# =============================================================================

loss_type = 'sdpo'          # Enable Selective DPO
threshold = 0.6             # Keep top 40% important tokens (SDPO threshold)
lambda_sdpo = 10            # SDPO regularization parameter

# Standard DPO parameters
lr = 2e-7                   # Learning rate (slightly lower for SDPO)
beta = 0.01                 # DPO beta parameter

# Hardware configuration
work_num = 4
gpu_pool = ''

# Data configuration
train_file = '/data/Skywork-Reward-Preference-80K-v0.2'

# Output configuration
PROJ_PATH_BOLE = f'./'
OUTPUT_DIR = os.path.join('./DPO_output', node_name)
BATCH_SIZE = 1280 // (work_num * 8)

# =============================================================================
# Training Script Construction
# =============================================================================

train_script = f'examples/scripts/dpo.py \
    --deepspeed /code/dongzhijin/trl/trl-main/ds_config/ds_config_zero3_dzj.json \
    --dataset_name {train_file} \
    --model_name_or_path {base_model_path} \
    --ref_model_name_or_path {ref_model_path} \
    --learning_rate {lr} \
    --num_train_epochs 1 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps {BATCH_SIZE} \
    --gradient_checkpointing \
    --logging_steps 2 \
    --eval_strategy no \
    --eval_steps 4 \
    --save_steps 20 \
    --save_total_limit 100 \
    --output_dir {OUTPUT_DIR} \
    --no_remove_unused_columns \
    --max_length 4096 \
    --warmup_steps 30 \
    --loss_type {loss_type} \
    --num_train_epochs 2 \
    --beta {beta} \
    --bf16 True \
    --threshold {threshold} \
    --lambda_sdpo {lambda_sdpo} \
    '

print(f"""
=============================================================================
SDPO Training Configuration Summary
=============================================================================
Policy Model:     {base_model_path}
Reference Model:  {ref_model_path}
Loss Type:        {loss_type}
SDPO Threshold:   {threshold} (keeps top {int((1-threshold)*100)}% important tokens)
Beta:             {beta}
Lambda SDPO:      {lambda_sdpo}
Learning Rate:    {lr}
Output Dir:       {OUTPUT_DIR}
=============================================================================
""")

# Launch training
! deepspeed --num_gpus {work_num} {train_script}