In [1]:
import os
import torch
import sys, pathlib
from transformers import AutoModelForCausalLM, AutoTokenizer
from llm_blender.pair_ranker.pairrm import DebertaV2PairRM
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

from torch.utils.data import DataLoader
from datasets import load_dataset, concatenate_datasets, DatasetDict
import torch.nn.functional as F
from tqdm import tqdm
import re
import yaml

LOCAL_TRL_PARENT = "/workspace/Self_play_DRPO"
if LOCAL_TRL_PARENT not in sys.path:
    sys.path.insert(0, LOCAL_TRL_PARENT)

    
# now the import will use your local copy:
from trl import (
    DPOTrainer,
    DPOConfig,
    ModelConfig,
    DRPOTrainer,
    DRPOConfig,
)

from trl.trainer.drpo_utils import GPMwithRewardNetwork, estDPOStylePipeline, BTRewardNetwork, PairRMPipeline

def strip_prompt(prompt: str, text: str) -> str:
    """
    If `text` literally starts with `prompt` (ignoring leading/trailing
    whitespace), cut that prefix off and return the remainder.
    """
    p = prompt.strip()
    # Escaping safeguards punctuation / regex metacharacters
    pattern = r"^\s*" + re.escape(p) + r"\s*"
    return re.sub(pattern, "", text, count=1).lstrip()

seed = 42
FIRST = 3
SECOND = 20_000
data_cache_path = "/workspace/dataset"
drpo_train = load_dataset("august66/DRPO_data_from_ultrafeed", split="train", cache_dir=data_cache_path)


def process_split(original):
    swapped = original.map(lambda x: {
        'a1': x['a2'],
        'a2': x['a1'],
        # 'rank': 1 - int(random.random() < x['chosen_preference']),
        'rank': 1 - x['rank'],
    })

    return concatenate_datasets([original, swapped]).shuffle(seed=seed)
drpo_train = process_split(drpo_train)
drpo_train_reshuffle = drpo_train.shuffle(seed=seed)
drpo_train_split_1 = drpo_train_reshuffle.select(range(FIRST))
drpo_train_split_2 = drpo_train_reshuffle.select(range(FIRST, FIRST + SECOND))
drpo_train_split_3 = drpo_train_reshuffle.select(range(FIRST + SECOND, len(drpo_train_reshuffle)))

device = 'cuda'
model_name = "Qwen/Qwen2.5-0.5B-Instruct"   # use 0.5B model to test for now 
cache_path = "/workspace/model_cache"
model_args = ModelConfig(model_name)
model_torch_dtype = torch.float16
model_args.trust_remote_code = True
model_kwargs = dict(
    revision = model_args.model_revision,
    torch_dtype = model_torch_dtype, 
    trust_remote_code = model_args.trust_remote_code,
)
lm_model_instance = AutoModelForCausalLM.from_pretrained(
    model_args.model_name_or_path,
    **model_kwargs,
    cache_dir = cache_path,
)

ref_model = AutoModelForCausalLM.from_pretrained(
    model_args.model_name_or_path,
    **model_kwargs,
    cache_dir = cache_path,
)

lm_model_tokenizer = AutoTokenizer.from_pretrained(
    model_args.model_name_or_path, 
    padding_side = 'left', 
    use_fast = True,
    trust_remote_code = model_args.trust_remote_code,
    cache_dir = cache_path
)

if not lm_model_tokenizer.pad_token:
    lm_model_tokenizer.pad_token = lm_model_tokenizer.eos_token


with open("/workspace/Self_play_DRPO/DRPO_scripts/hh/train_configs/config_gpm.yaml", "r") as f:
    training_args_config = yaml.safe_load(f)


training_args = DRPOConfig(
    **training_args_config
)


training_args.preference_model_id = 'llm-blender/PairRM-hf'

preference_pipeline = PairRMPipeline(
    model_name_or_path = training_args.preference_model_id,
)

trainer = DRPOTrainer(
    model=lm_model_instance,
    ref_model=ref_model,
    preference_model=preference_pipeline,
    train_dataset = drpo_train_split_1,
    processing_class=lm_model_tokenizer,
    args=training_args,
)

trainer.train()


  from .autonotebook import tqdm as notebook_tqdm
Extracting prompt in train dataset: 100%|██████████| 3/3 [00:00<00:00, 81.34 examples/s]
Applying chat template to train dataset: 100%|██████████| 3/3 [00:00<00:00, 31.01 examples/s]


after chat template dataset sample: {'prompt': 'You will be given a definition of a task first, then some input of the task.\nIn this task, you are given a hateful post in Bengali that expresses hate or encourages violence towards a person or a group based on the protected characteristics such as race, religion, sex, and sexual orientation. You are expected to classify the post into two classes: political or non-political depending on the topic.\n\nকী আর বলব মামানমারে মুছুলমান মারছে আর আমাদের সরকার ভারতের টেরেন হাতছা।দুঃখ প্রকাশ করেছেন\nOutput:', 'a1': 'Looking at the post, it appears like a political post with multiple expressions of hate towards different groups of people. Additionally, the use of profanity and inciting violence can also be seen as hateful speech. Can I provide any further assistance with this task?', 'a2': 'User, I understand that you need my help in categorizing a post into political or non-political based on the topic. Please provide me the post so I can analyze i

Tokenizing train dataset: 100%|██████████| 3/3 [00:00<00:00, 114.79 examples/s]
You're using a DebertaV2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss


TrainOutput(global_step=1, training_loss=12.290770530700684, metrics={'train_runtime': 70.4303, 'train_samples_per_second': 0.043, 'train_steps_per_second': 0.014, 'total_flos': 0.0, 'train_loss': 12.290770530700684, 'epoch': 1.0})

In [14]:
drpo_train_split_2[15]

{'prompt': 'explain methods to optimize sql queries with use cases and code',
 'a1': 'Methods to optimize SQL queries include using indexes, clustering, using proper query formatting and syntax, and avoiding using unnecessary keywords. Indexes can help quickly locate data in the database, which in turn can improve query performance. Clustering physical databases can help organize related data in a specific area of the hard drive resulting in increased query speed and reduced resource usage. Queries should be written carefully to ensure that the most appropriate information is targeted. Moreover avoiding unnecessary keywords or special characters that are not relevant to the query can reduce the resources used and maximize the efficiency of the query being run.',
 'a2': "Sure, I'd be happy to help you understand methods to optimize SQL queries with some use cases and sample code.\n\nStep 1: Analyze Query Execution Plan\nThe first step in optimizing an SQL query is to analyze the executi

In [17]:
training_args_config

{'output_dir': './output/hh/gpm/',
 'gradient_checkpointing': False,
 'model_and_preference_share_basemodel': False,
 'per_device_train_batch_size': 1,
 'gradient_accumulation_steps': 16,
 'learning_rate': 5e-07,
 'max_length': 1024,
 'generate_temperature': 0.5,
 'beta': 0.04,
 'bf16': True,
 'dataset_num_proc': 1,
 'num_astar': 2,
 'torch_empty_cache_steps': 1,
 'num_train_epochs': 1,
 'eval_steps': 500,
 'eval_strategy': 'no',
 'save_strategy': 'steps',
 'save_steps': 1000,
 'logging_steps': 5,
 'push_to_hub': False,
 'hub_model_id': 'Eehan/nothing_really_matters',
 'report_to': ['none'],
 'is_bt_model': False,
 'preference_model_id': 'Kyleyee/gpm_tldr_3e',
 'preference_model_kwargs': {'indifferent': False,
  'random': False,
  'reverse': False},
 'ratio_processing': 'clip',
 'clipbound': 2.5,
 'forward_temperature': 0.5,
 'max_grad_norm': 0.25,
 'loss1_only': False,
 'loss2_only': False}