In [1]:
import os
import torch
import sys, pathlib
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification,DebertaV2ForSequenceClassification, GPTNeoXForCausalLM
from llm_blender.pair_ranker.pairrm import DebertaV2PairRM
from transformers import DataCollatorWithPadding
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

from torch.utils.data import DataLoader
from datasets import load_dataset, concatenate_datasets, DatasetDict, Dataset
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import re
import yaml

LOCAL_TRL_PARENT = "/workspace/Self_play_DRPO"
if LOCAL_TRL_PARENT not in sys.path:
    sys.path.insert(0, LOCAL_TRL_PARENT)

    
# now the import will use your local copy:
from trl import (
    DPOTrainer,
    DPOConfig,
    ModelConfig,
    DRPOTrainer,
    DRPOConfig,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
from trl.trainer.utils import pad, truncate_right, selective_log_softmax
from trl.data_utils import apply_chat_template

data_cache_path = "/workspace/dataset"
model_cache_path = '/workspace/model_cache'
ref_policy_path = "Qwen/Qwen2.5-1.5B-Instruct" 
#target_policy_path = "Qwen/Qwen2.5-1.5B-Instruct" 
reward_model_path = 'Kyleyee/Qwen2.5-1.5B-reward-hh-retrain'
ds_path = 'august66/hh_qwen2.5_1.5b_with_bias'


def load_model(model_path, task = 'generation', model_type = 'decoder', model_cache_path =  '/workspace/model_cache'):

    model_args = ModelConfig(model_path)
    #model_torch_dtype = (model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype))
    model_torch_dtype = torch.bfloat16 
    model_kwargs = dict(
    revision = model_args.model_revision,
    torch_dtype = model_torch_dtype, 
    trust_remote_code = model_args.trust_remote_code,
    )

    padding_side = 'left' if model_type == 'decoder' else 'right'
    truncation_side = 'left' if model_type == 'decoder' else 'right'

    if task == 'generation':
        model_instance = AutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            **model_kwargs,
            cache_dir = model_cache_path,
        )

    elif task == 'reward':
        model_instance = AutoModelForSequenceClassification.from_pretrained(
            model_args.model_name_or_path,
            **model_kwargs,
            cache_dir = model_cache_path,
        )
    

    model_tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path, 
        padding_side = padding_side, 
        truncation_side = truncation_side,
        use_fast = True,
        trust_remote_code = model_args.trust_remote_code,
        cache_dir = model_cache_path
    )

    if model_tokenizer.pad_token is None:
        model_tokenizer.pad_token = model_tokenizer.eos_token

    if getattr(model_instance.config, "pad_token_id", None) is None:
        model_instance.config.pad_token_id = model_tokenizer.pad_token_id

    if model_tokenizer.eos_token is None:
        model_tokenizer.eos_token = model_tokenizer.pad_token  

    if getattr(model_instance.config, "eos_token_id", None) is None:
        model_instance.config.eos_token_id = model_tokenizer.eos_token_id

    return model_instance, model_tokenizer


def generate_responses(
    model_instance,
    model_tokenizer,
    prompts,                      # single chat or list of chats (each: list[{role,content}])
    *,
    temperature=1.0,
    max_new_tokens=256,
    n_responses_total=1000,       # total per prompt
    responses_per_call=16,        # per subcall per prompt
    batch_size=4,                 # prompts per batch
    device='cuda',
    max_length=1024,
):
    # tokenizer prefs
    model_tokenizer.padding_side = "left"
    if hasattr(model_tokenizer, "truncation_side"):
        model_tokenizer.truncation_side = "left"
    if model_tokenizer.pad_token_id is None and model_tokenizer.eos_token_id is not None:
        model_tokenizer.pad_token_id = model_tokenizer.eos_token_id

    gen_kwargs = {
        "do_sample": bool(temperature > 0),
        "temperature": float(temperature),
        "top_k": 50,
        "top_p": 0.9,
        "max_new_tokens": max_new_tokens,
        "eos_token_id": model_tokenizer.eos_token_id,
        "pad_token_id": model_tokenizer.pad_token_id,
        "use_cache": True,
        "return_dict_in_generate": False,
        "output_scores": False,
        "output_attentions": False,
        "output_hidden_states": False,
    }

    # normalize prompts input
    if isinstance(prompts, str):
        prompts = [prompts]
    elif (isinstance(prompts, list) and prompts
          and isinstance(prompts[0], dict) and "role" in prompts[0]):
        prompts = [prompts]

    model_instance.to(device).eval()

    prompts_out, completions_out = [], []

    for s in range(0, len(prompts), batch_size):
        batch_prompts_raw = prompts[s:s+batch_size]

        # render once per batch
        batch_rendered = [
            model_tokenizer.apply_chat_template(p, add_generation_prompt=True, tokenize=False)
            for p in batch_prompts_raw
        ]
        enc = model_tokenizer(
            batch_rendered, padding=True, truncation=True,
            max_length=max_length, return_tensors="pt"
        ).to(device)
        T_in = enc["input_ids"].size(1)

        # bucket completions per prompt index within this batch
        per_prompt_bucket = [[] for _ in range(len(batch_prompts_raw))]

        remaining = n_responses_total
        while remaining > 0:
            cur_n = min(responses_per_call, remaining)

            with torch.no_grad():
                out_ids = model_instance.generate(
                    **enc, num_return_sequences=cur_n, **gen_kwargs
                )  # [B*cur_n, T_in + gen_len]

            comp_ids = out_ids[:, T_in:]
            decoded = model_tokenizer.batch_decode(comp_ids, skip_special_tokens=True)

            # outputs are stacked: p0 x cur_n, p1 x cur_n, ...
            total = comp_ids.size(0)  # B*cur_n
            for k in range(total):
                base_i = k // cur_n  # which prompt in this subcall
                per_prompt_bucket[base_i].append(decoded[k])

            # free between subcalls
            del out_ids, comp_ids, decoded
            if device == 'cuda':
                torch.cuda.empty_cache()

            remaining -= cur_n

        # now FLATTEN in prompt order so each prompt's 1..N completions are contiguous
        for i, completions in enumerate(per_prompt_bucket):
            for c in completions:
                prompts_out.append(batch_prompts_raw[i])  # original chat object
                completions_out.append(c)

        del enc
        if device == 'cuda':
            torch.cuda.empty_cache()

    return Dataset.from_dict({"prompt": prompts_out, "completion": completions_out})

@torch.no_grad()
def compute_rewards(
    prompts,
    a1s,
    a2s,
    reward_tok,
    reward_model,
):
    """
    Returns (reward_a1, reward_a2) for each (prompt, a1/a2) pair.
    prompts/a1s/a2s can be chat-format (list of {"role","content"}) or strings.
    Chat rendering is done inline (no extra helper functions).
    """
    device = 'cuda'
    reward_model.to(device)
    n = len(prompts)
    texts = []

    for p, a in zip(prompts, a1s):
        # --- render (prompt + a1) ---
        if isinstance(p, list) and p and isinstance(p[0], dict) and "role" in p[0]:
            # prompt is chat
            if isinstance(a, list) and a and isinstance(a[0], dict) and "role" in a[0]:
                chat = p + a
            else:
                chat = p + [{"role": "assistant", "content": str(a)}]
            text = reward_tok.apply_chat_template(
                chat, tokenize=False, add_generation_prompt=False
            )
        else:
            # fallback: plain strings
            text = f"{str(p)}\n{str(a)}"
        texts.append(text)

    for p, a in zip(prompts, a2s):
        # --- render (prompt + a2) ---
        if isinstance(p, list) and p and isinstance(p[0], dict) and "role" in p[0]:
            if isinstance(a, list) and a and isinstance(a[0], dict) and "role" in a[0]:
                chat = p + a
            else:
                chat = p + [{"role": "assistant", "content": str(a)}]
            text = reward_tok.apply_chat_template(
                chat, tokenize=False, add_generation_prompt=False
            )
        else:
            text = f"{str(p)}\n{str(a)}"
        texts.append(text)

    inputs = reward_tok(
        texts, return_tensors="pt", padding=True, truncation=True, max_length=1024
    ).to(device)

    logits = reward_model(**inputs).logits  # [2n, 1] or [2n, K] or [2n]
    if logits.ndim == 2 and logits.size(-1) == 1:
        scores = logits.squeeze(-1)                  # [2n]
    elif logits.ndim == 2 and logits.size(-1) > 1:
        scores = logits[:, -1]                       # common: last logit as reward
    else:
        scores = logits
    scores = scores.detach().cpu().tolist()

    return scores[:n], scores[n:]  # reward_a1, reward_a2




@torch.no_grad()
def forward_batch(
    model,
    tokenizer,
    batch,                         # dict with batch['prompt'], batch['completion']
    *,
    temperature=1.0,
    max_length=1024,
    add_generation_prompt=True,
    device="cuda",
):
    """
    Single batched pass. Returns tensor [B] with sum log-prob of each completion
    conditioned on its prompt.
    """
    prompts      = batch["prompt"]
    completions  = [str(c) for c in batch["completion"]]
    if len(completions) == 0:
        return torch.empty(0)

    # Render chat prompts where needed
    def render(p):
        if isinstance(p, list) and p and isinstance(p[0], dict) and "role" in p[0]:
            return tokenizer.apply_chat_template(p, add_generation_prompt=add_generation_prompt, tokenize=False)
        return str(p)
    prompts_text = [render(p) for p in prompts]

    # Temporarily force LEFT padding so (T_p_max - 1) is a real token
    saved_pad_side = getattr(tokenizer, "padding_side", "right")
    tokenizer.padding_side = "left"

    model = model.to(device).eval()

    # Tokenize prompt & completion (batched)
    enc_p = tokenizer(prompts_text, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(device)
    enc_c = tokenizer(completions,  padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(device)

    # Concat (teacher forcing)
    input_ids = torch.cat([enc_p["input_ids"], enc_c["input_ids"]], dim=1)
    attn_mask = torch.cat([enc_p["attention_mask"], enc_c["attention_mask"]], dim=1)

    # Forward once
    logits = model(input_ids=input_ids, attention_mask=attn_mask).logits  # [B, T, V]
    if temperature and temperature > 0:
        logits = logits / (temperature + 1e-7)

    # Align to completion window (next-token prediction)
    T_p = enc_p["input_ids"].size(1)
    T_c = enc_c["input_ids"].size(1)
    start = T_p - 1
    end   = min(start + T_c, logits.size(1))
    T_eff = end - start

    win_logits = logits[:, start:end, :]                         # [B, T_eff, V]
    tgt_ids = enc_c["input_ids"][:, :T_eff]                      # [B, T_eff]
    tgt_msk = enc_c["attention_mask"][:, :T_eff].to(torch.bool)  # [B, T_eff]

    # Your original scoring helper
    logps  = selective_log_softmax(win_logits, tgt_ids)          # [B, T_eff]
    scores = (logps * tgt_msk).sum(dim=1)                        # [B]

    # Restore tokenizer padding side
    tokenizer.padding_side = saved_pad_side
    
    del win_logits, tgt_ids, tgt_msk, logits
    return scores.detach().cpu()


@torch.no_grad()
def get_expected_kl_batch(
    ref_model,
    tokenizer,                     # same tokenizer for both models
    target_model,
    batch,                         # dict with batch['prompt'], batch['completion']
    *,
    max_length=1024,
    add_generation_prompt=True,
    device="cuda",
):
    """
    Single batched pass. Returns python list [B] with KL per example over completion tokens.
    """
    prompts     = batch["prompt"]
    completions = [str(c) for c in batch["completion"]]
    if len(completions) == 0:
        return []

    # Render prompts
    def render(p):
        if isinstance(p, list) and p and isinstance(p[0], dict) and "role" in p[0]:
            return tokenizer.apply_chat_template(p, add_generation_prompt=add_generation_prompt, tokenize=False)
        return str(p)
    prompts_text = [render(p) for p in prompts]

    # LEFT padding
    saved_pad_side = getattr(tokenizer, "padding_side", "right")
    tokenizer.padding_side = "left"

    ref_model    = ref_model.to(device).eval()
    target_model = target_model.to(device).eval()

    # Tokenize
    enc_p = tokenizer(prompts_text, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(device)
    enc_c = tokenizer(completions,  padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(device)

    # Concat
    input_ids = torch.cat([enc_p["input_ids"], enc_c["input_ids"]], dim=1)
    attn_mask = torch.cat([enc_p["attention_mask"], enc_c["attention_mask"]], dim=1)

    # Forwards (batched)
    ref_logits    = ref_model(input_ids=input_ids,    attention_mask=attn_mask).logits
    target_logits = target_model(input_ids=input_ids, attention_mask=attn_mask).logits

    # Slice to completion region (next-token prediction)
    T_p = enc_p["input_ids"].size(1)
    T_c = enc_c["input_ids"].size(1)
    start = T_p - 1
    end   = min(start + T_c, target_logits.size(1))
    T_eff = end - start

    ref_logp = F.log_softmax(ref_logits[:,    start:end, :], dim=-1)   # [B, T_eff, V]
    tgt_logp = F.log_softmax(target_logits[:, start:end, :], dim=-1)   # [B, T_eff, V]
    ref_p    = ref_logp.exp()

    mask = enc_c["attention_mask"][:, :T_eff].to(ref_logp.dtype)       # [B, T_eff]

    per_tok_kl = (ref_p * (ref_logp - tgt_logp)).sum(dim=-1)         # [B, T_eff]
    kl_sum = (per_tok_kl * mask).sum(dim=-1)                            # [B]

    tokenizer.padding_side = saved_pad_side
    
    del per_tok_kl, ref_logp, tgt_logp, ref_p, ref_logits, target_logits
    return kl_sum.detach().cpu()



def calculate_bias(batch, ref_model, ref_tok, dpo_model, dpo_tok, n_responses_total=10):
    """
    batch['prompt'] is a list of B prompts; we generate n_responses_total completions per prompt,
    score them in batch on GPU, then average per-prompt to get one bias per prompt.
    Returns: {'bias': [B-length list]}
    """
    prompts = batch["prompt"]                   # list of length B

    # 1) Generate completions for all prompts in this batch (B * n rows)
    new_ds = generate_responses(
        ref_model, ref_tok, prompts,
        max_new_tokens=512,
        n_responses_total=n_responses_total,
        batch_size=4,            # generation sub-batch (tune for VRAM)
    )

    # 2) KL(target||ref) on completions only (length = B*n)
    kl_ref = get_expected_kl_batch(ref_model, ref_tok, dpo_model, new_ds)
    kl_ref = torch.as_tensor(kl_ref, dtype=torch.float32)
    # shape to [B, n] assuming generate_responses groups per prompt
    B = len(prompts)
    n = n_responses_total
    kl_ref = kl_ref.view(B, n)
    mean_kl_ref = (-0.1 * kl_ref.mean(dim=1))   # [B]
    print (kl_ref)
    print (mean_kl_ref)

    # 3) Rewards r = 0.1 * (logp_dpo - logp_ref) for each (prompt, completion)
    s_dpo = forward_batch(dpo_model, dpo_tok, new_ds, temperature=1.0)
    s_ref = forward_batch(ref_model, ref_tok, new_ds, temperature=1.0)
    s_dpo = torch.as_tensor(s_dpo, dtype=torch.float32)
    s_ref = torch.as_tensor(s_ref, dtype=torch.float32)
    r = 0.1 * (s_dpo - s_ref)                   # [B*n]
    r = r.view(B, n)
    mean_r_ref = r.mean(dim=1)                  # [B]
    print (r)
    print (mean_r_ref)

    # 4) Bias per prompt
    bias = (mean_kl_ref - mean_r_ref).tolist()  # list length B

    return {"bias": bias}

import math
def map_rank_and_reassign_chat(
    batch,
    reward_tok,
    reward_model
):
    """
    Single function for ds.map:
      - computes reward_a1/reward_a2 via compute_rewards(...)
      - BT preference
      - reassign preferred -> a1, dispreferred -> a2
    """
    prompts = batch["prompt"]
    a1s = batch["a1"]
    a2s = batch["a2"]

    reward_a1, reward_a2 = compute_rewards(
        prompts, a1s, a2s,
        reward_tok=reward_tok,
        reward_model=reward_model
    )

    bt_prob_a1_pref = [1.0 / (1.0 + math.exp(-(s1 - s2)))
                       for s1, s2 in zip(reward_a1, reward_a2)]

    new_a1, new_a2, swapped, pref_label, diff = [], [], [], [], []
    for a1, a2, s1, s2 in zip(a1s, a2s, reward_a1, reward_a2):
        d = s1 - s2
        diff.append(d)
        if d > 0:
            new_a1.append(a1); new_a2.append(a2)
            swapped.append(False); pref_label.append("a1")
        elif d < 0:
            new_a1.append(a2); new_a2.append(a1)
            swapped.append(True);  pref_label.append("a2")
        else:
            new_a1.append(a1); new_a2.append(a2)
            swapped.append(False); pref_label.append("tie")

    out = {
        "a1": new_a1,
        "a2": new_a2,
        "bt_prob_a1_pref": bt_prob_a1_pref
    }
    return out

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
reward_model, reward_model_instance = load_model(reward_model_path, task = 'reward')
ds_dict = load_dataset(ds_path)
ds = concatenate_datasets(list(ds_dict.values()))   

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.10it/s]


In [3]:
ds_bt_pref = ds.map(map_rank_and_reassign_chat,
                  fn_kwargs = {'reward_tok':reward_model_instance, 'reward_model':reward_model}, batched = True, batch_size = 8)

Map: 100%|██████████| 18000/18000 [08:30<00:00, 35.26 examples/s]


In [4]:
ds_bt_pref.push_to_hub('august66/hh_qwen2.5_1.5b_with_bias_bt_pref')

Creating parquet from Arrow format: 100%|██████████| 18/18 [00:00<00:00, 76.64ba/s]
Processing Files (0 / 0)                : |          |  0.00B /  0.00B            
[A
Processing Files (0 / 1)                :  14%|█▍        | 3.83MB / 26.7MB, 9.57MB/s  
Processing Files (0 / 1)                :  22%|██▏       | 5.93MB / 26.7MB, 9.87MB/s  
Processing Files (0 / 1)                :  38%|███▊      | 10.1MB / 26.7MB, 12.7MB/s  
Processing Files (0 / 1)                :  67%|██████▋   | 18.0MB / 26.7MB, 18.0MB/s  
Processing Files (0 / 1)                :  99%|█████████▉| 26.4MB / 26.7MB, 22.0MB/s  
[A
[A
Processing Files (1 / 1)                : 100%|██████████| 26.7MB / 26.7MB, 14.8MB/s  
[A
Processing Files (1 / 1)                : 100%|██████████| 26.7MB / 26.7MB, 13.3MB/s  
New Data Upload                         : 100%|██████████| 23.4MB / 23.4MB, 11.7MB/s  
                                        : 100%|██████████| 26.7MB / 26.7MB            
Uploading the dataset shards: 100%

CommitInfo(commit_url='https://huggingface.co/datasets/august66/hh_qwen2.5_1.5b_with_bias_bt_pref/commit/7cae78f267bc96bfa601b6d7bf092344a54d55dc', commit_message='Upload dataset', commit_description='', oid='7cae78f267bc96bfa601b6d7bf092344a54d55dc', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/august66/hh_qwen2.5_1.5b_with_bias_bt_pref', endpoint='https://huggingface.co', repo_type='dataset', repo_id='august66/hh_qwen2.5_1.5b_with_bias_bt_pref'), pr_revision=None, pr_num=None)