In [1]:
import os
import torch
import sys, pathlib
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification,DebertaV2ForSequenceClassification, GPTNeoXForCausalLM
from llm_blender.pair_ranker.pairrm import DebertaV2PairRM
from transformers import DataCollatorWithPadding
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

from torch.utils.data import DataLoader
from datasets import load_dataset, concatenate_datasets, DatasetDict, Dataset
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import re
import yaml

LOCAL_TRL_PARENT = "/workspace/Self_play_DRPO"
if LOCAL_TRL_PARENT not in sys.path:
    sys.path.insert(0, LOCAL_TRL_PARENT)

    
# now the import will use your local copy:
from trl import (
    DPOTrainer,
    DPOConfig,
    ModelConfig,
    DRPOTrainer,
    DRPOConfig,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
from trl.trainer.utils import pad, truncate_right, selective_log_softmax
from trl.data_utils import apply_chat_template

data_cache_path = "/workspace/dataset"
model_cache_path = '/workspace/model_cache'
ref_policy_path = "Qwen/Qwen2.5-1.5B-Instruct" 
#target_policy_path = "Qwen/Qwen2.5-1.5B-Instruct" 
target_policy_400_path = 'august66/hh_qwen1.5_drpo_400'
target_policy_800_path = 'august66/hh_qwen1.5_drpo_800'
dpo_policy_path = 'august66/hh_qwen_1.5b_dpo_model_2'
#reward_model_path = 'cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr'
ds_path = 'august66/drpo_hh_qwen2.5_1.5b'


def load_model(model_path, task = 'generation', model_type = 'decoder', model_cache_path =  '/workspace/model_cache'):

    model_args = ModelConfig(model_path)
    #model_torch_dtype = (model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype))
    model_torch_dtype = torch.bfloat16 
    model_kwargs = dict(
    revision = model_args.model_revision,
    torch_dtype = model_torch_dtype, 
    trust_remote_code = model_args.trust_remote_code,
    )

    padding_side = 'left' if model_type == 'decoder' else 'right'
    truncation_side = 'left' if model_type == 'decoder' else 'right'

    if task == 'generation':
        model_instance = AutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            **model_kwargs,
            cache_dir = model_cache_path,
        )

    elif task == 'reward':
        model_instance = AutoModelForSequenceClassification.from_pretrained(
            model_args.model_name_or_path,
            **model_kwargs,
            cache_dir = model_cache_path,
        )
    

    model_tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path, 
        padding_side = padding_side, 
        truncation_side = truncation_side,
        use_fast = True,
        trust_remote_code = model_args.trust_remote_code,
        cache_dir = model_cache_path
    )

    if model_tokenizer.pad_token is None:
        model_tokenizer.pad_token = model_tokenizer.eos_token

    if getattr(model_instance.config, "pad_token_id", None) is None:
        model_instance.config.pad_token_id = model_tokenizer.pad_token_id

    if model_tokenizer.eos_token is None:
        model_tokenizer.eos_token = model_tokenizer.pad_token  

    if getattr(model_instance.config, "eos_token_id", None) is None:
        model_instance.config.eos_token_id = model_tokenizer.eos_token_id

    return model_instance, model_tokenizer


def generate_responses(
    model_instance,
    model_tokenizer,
    prompts,                      # single chat or list of chats (each: list[{role,content}])
    *,
    temperature=1.0,
    max_new_tokens=256,
    n_responses_total=1000,       # total per prompt
    responses_per_call=16,        # per subcall per prompt
    batch_size=4,                 # prompts per batch
    device='cuda',
    max_length=1024,
):
    # tokenizer prefs
    model_tokenizer.padding_side = "left"
    if hasattr(model_tokenizer, "truncation_side"):
        model_tokenizer.truncation_side = "left"
    if model_tokenizer.pad_token_id is None and model_tokenizer.eos_token_id is not None:
        model_tokenizer.pad_token_id = model_tokenizer.eos_token_id

    gen_kwargs = {
        "do_sample": bool(temperature > 0),
        "temperature": float(temperature),
        "top_k": 50,
        "top_p": 0.9,
        "max_new_tokens": max_new_tokens,
        "eos_token_id": model_tokenizer.eos_token_id,
        "pad_token_id": model_tokenizer.pad_token_id,
        "use_cache": True,
        "return_dict_in_generate": False,
        "output_scores": False,
        "output_attentions": False,
        "output_hidden_states": False,
    }

    # normalize prompts input
    if isinstance(prompts, str):
        prompts = [prompts]
    elif (isinstance(prompts, list) and prompts
          and isinstance(prompts[0], dict) and "role" in prompts[0]):
        prompts = [prompts]

    model_instance.to(device).eval()

    prompts_out, completions_out = [], []

    for s in range(0, len(prompts), batch_size):
        batch_prompts_raw = prompts[s:s+batch_size]

        # render once per batch
        batch_rendered = [
            model_tokenizer.apply_chat_template(p, add_generation_prompt=True, tokenize=False)
            for p in batch_prompts_raw
        ]
        enc = model_tokenizer(
            batch_rendered, padding=True, truncation=True,
            max_length=max_length, return_tensors="pt"
        ).to(device)
        T_in = enc["input_ids"].size(1)

        # bucket completions per prompt index within this batch
        per_prompt_bucket = [[] for _ in range(len(batch_prompts_raw))]

        remaining = n_responses_total
        while remaining > 0:
            cur_n = min(responses_per_call, remaining)

            with torch.no_grad():
                out_ids = model_instance.generate(
                    **enc, num_return_sequences=cur_n, **gen_kwargs
                )  # [B*cur_n, T_in + gen_len]

            comp_ids = out_ids[:, T_in:]
            decoded = model_tokenizer.batch_decode(comp_ids, skip_special_tokens=True)

            # outputs are stacked: p0 x cur_n, p1 x cur_n, ...
            total = comp_ids.size(0)  # B*cur_n
            for k in range(total):
                base_i = k // cur_n  # which prompt in this subcall
                per_prompt_bucket[base_i].append(decoded[k])

            # free between subcalls
            del out_ids, comp_ids, decoded
            if device == 'cuda':
                torch.cuda.empty_cache()

            remaining -= cur_n

        # now FLATTEN in prompt order so each prompt's 1..N completions are contiguous
        for i, completions in enumerate(per_prompt_bucket):
            for c in completions:
                prompts_out.append(batch_prompts_raw[i])  # original chat object
                completions_out.append(c)

        del enc
        if device == 'cuda':
            torch.cuda.empty_cache()

    return Dataset.from_dict({"prompt": prompts_out, "completion": completions_out})

def get_reward_batch(
    model_instance, 
    model_tokenizer, 
    inputs,                    # list[(prompt, response)]
    device='cuda', 
    batch_size=32
):
    model_instance.to(device)
    all_rewards = []

    for start in tqdm(range(0, len(inputs), batch_size)):
        batch_inputs = inputs[start:start+batch_size]

        # concat prompt+response text
        responses = [prompt + response for (prompt, response) in batch_inputs]

        encoded_inputs = model_tokenizer(
            responses,
            padding='max_length',
            truncation=True,
            max_length=1024,
            return_tensors='pt'
        ).to(device)

        with torch.no_grad():
            outputs = model_instance(**encoded_inputs)

        # Typical reward heads return shape [B] or [B,1] in outputs.logits
        rewards = outputs.logits.squeeze(-1).detach().cpu().tolist()
        if isinstance(rewards, list):
            all_rewards.extend(rewards)
        else:
            all_rewards.append(rewards)

        del encoded_inputs, outputs
        if device == 'cuda':
            torch.cuda.empty_cache()

    return all_rewards

@torch.no_grad()
def forward_batch(
    model,
    tokenizer,
    batch,                         # dict with batch['prompt'], batch['completion']
    *,
    temperature=1.0,
    max_length=1024,
    add_generation_prompt=True,
    device="cuda",
):
    """
    Single batched pass. Returns tensor [B] with sum log-prob of each completion
    conditioned on its prompt.
    """
    prompts      = batch["prompt"]
    completions  = [str(c) for c in batch["completion"]]
    if len(completions) == 0:
        return torch.empty(0)

    # Render chat prompts where needed
    def render(p):
        if isinstance(p, list) and p and isinstance(p[0], dict) and "role" in p[0]:
            return tokenizer.apply_chat_template(p, add_generation_prompt=add_generation_prompt, tokenize=False)
        return str(p)
    prompts_text = [render(p) for p in prompts]

    # Temporarily force LEFT padding so (T_p_max - 1) is a real token
    saved_pad_side = getattr(tokenizer, "padding_side", "right")
    tokenizer.padding_side = "left"

    model = model.to(device).eval()

    # Tokenize prompt & completion (batched)
    enc_p = tokenizer(prompts_text, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(device)
    enc_c = tokenizer(completions,  padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(device)

    # Concat (teacher forcing)
    input_ids = torch.cat([enc_p["input_ids"], enc_c["input_ids"]], dim=1)
    attn_mask = torch.cat([enc_p["attention_mask"], enc_c["attention_mask"]], dim=1)

    # Forward once
    logits = model(input_ids=input_ids, attention_mask=attn_mask).logits  # [B, T, V]
    if temperature and temperature > 0:
        logits = logits / (temperature + 1e-7)

    # Align to completion window (next-token prediction)
    T_p = enc_p["input_ids"].size(1)
    T_c = enc_c["input_ids"].size(1)
    start = T_p - 1
    end   = min(start + T_c, logits.size(1))
    T_eff = end - start

    win_logits = logits[:, start:end, :]                         # [B, T_eff, V]
    tgt_ids = enc_c["input_ids"][:, :T_eff]                      # [B, T_eff]
    tgt_msk = enc_c["attention_mask"][:, :T_eff].to(torch.bool)  # [B, T_eff]

    # Your original scoring helper
    logps  = selective_log_softmax(win_logits, tgt_ids)          # [B, T_eff]
    scores = (logps * tgt_msk).sum(dim=1)                        # [B]

    # Restore tokenizer padding side
    tokenizer.padding_side = saved_pad_side
    
    del win_logits, tgt_ids, tgt_msk, logits
    return scores.detach().cpu()


@torch.no_grad()
def get_expected_kl_batch(
    ref_model,
    tokenizer,                     # same tokenizer for both models
    target_model,
    batch,                         # dict with batch['prompt'], batch['completion']
    *,
    max_length=1024,
    add_generation_prompt=True,
    device="cuda",
):
    """
    Single batched pass. Returns python list [B] with KL per example over completion tokens.
    """
    prompts     = batch["prompt"]
    completions = [str(c) for c in batch["completion"]]
    if len(completions) == 0:
        return []

    # Render prompts
    def render(p):
        if isinstance(p, list) and p and isinstance(p[0], dict) and "role" in p[0]:
            return tokenizer.apply_chat_template(p, add_generation_prompt=add_generation_prompt, tokenize=False)
        return str(p)
    prompts_text = [render(p) for p in prompts]

    # LEFT padding
    saved_pad_side = getattr(tokenizer, "padding_side", "right")
    tokenizer.padding_side = "left"

    ref_model    = ref_model.to(device).eval()
    target_model = target_model.to(device).eval()

    # Tokenize
    enc_p = tokenizer(prompts_text, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(device)
    enc_c = tokenizer(completions,  padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(device)

    # Concat
    input_ids = torch.cat([enc_p["input_ids"], enc_c["input_ids"]], dim=1)
    attn_mask = torch.cat([enc_p["attention_mask"], enc_c["attention_mask"]], dim=1)

    # Forwards (batched)
    ref_logits    = ref_model(input_ids=input_ids,    attention_mask=attn_mask).logits
    target_logits = target_model(input_ids=input_ids, attention_mask=attn_mask).logits

    # Slice to completion region (next-token prediction)
    T_p = enc_p["input_ids"].size(1)
    T_c = enc_c["input_ids"].size(1)
    start = T_p - 1
    end   = min(start + T_c, target_logits.size(1))
    T_eff = end - start

    ref_logp = F.log_softmax(ref_logits[:,    start:end, :], dim=-1)   # [B, T_eff, V]
    tgt_logp = F.log_softmax(target_logits[:, start:end, :], dim=-1)   # [B, T_eff, V]
    ref_p    = ref_logp.exp()

    mask = enc_c["attention_mask"][:, :T_eff].to(ref_logp.dtype)       # [B, T_eff]

    per_tok_kl = (ref_p * (ref_logp - tgt_logp)).sum(dim=-1)         # [B, T_eff]
    kl_sum = (per_tok_kl * mask).sum(dim=-1)                            # [B]

    tokenizer.padding_side = saved_pad_side
    
    del per_tok_kl, ref_logp, tgt_logp, ref_p, ref_logits, target_logits
    return kl_sum.detach().cpu()



def calculate_bias(batch, ref_model, ref_tok, dpo_model, dpo_tok, n_responses_total=10):
    """
    batch['prompt'] is a list of B prompts; we generate n_responses_total completions per prompt,
    score them in batch on GPU, then average per-prompt to get one bias per prompt.
    Returns: {'bias': [B-length list]}
    """
    prompts = batch["prompt"]                   # list of length B

    # 1) Generate completions for all prompts in this batch (B * n rows)
    new_ds = generate_responses(
        ref_model, ref_tok, prompts,
        max_new_tokens=512,
        n_responses_total=n_responses_total,
        batch_size=4,            # generation sub-batch (tune for VRAM)
    )

    # 2) KL(target||ref) on completions only (length = B*n)
    kl_ref = get_expected_kl_batch(ref_model, ref_tok, dpo_model, new_ds)
    kl_ref = torch.as_tensor(kl_ref, dtype=torch.float32)
    # shape to [B, n] assuming generate_responses groups per prompt
    B = len(prompts)
    n = n_responses_total
    kl_ref = kl_ref.view(B, n)
    mean_kl_ref = (-0.1 * kl_ref.mean(dim=1))   # [B]
    print (kl_ref)
    print (mean_kl_ref)

    # 3) Rewards r = 0.1 * (logp_dpo - logp_ref) for each (prompt, completion)
    s_dpo = forward_batch(dpo_model, dpo_tok, new_ds, temperature=1.0)
    s_ref = forward_batch(ref_model, ref_tok, new_ds, temperature=1.0)
    s_dpo = torch.as_tensor(s_dpo, dtype=torch.float32)
    s_ref = torch.as_tensor(s_ref, dtype=torch.float32)
    r = 0.1 * (s_dpo - s_ref)                   # [B*n]
    r = r.view(B, n)
    mean_r_ref = r.mean(dim=1)                  # [B]
    print (r)
    print (mean_r_ref)

    # 4) Bias per prompt
    bias = (mean_kl_ref - mean_r_ref).tolist()  # list length B

    return {"bias": bias}

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ref_instance, ref_tok = load_model(ref_policy_path, task = 'generation')
dpo_instance, dpo_tok = load_model(dpo_policy_path, task = 'generation')
ds = load_dataset(ds_path, cache_dir=data_cache_path, split = 'train')
prompts = ds['prompt']

In [3]:
test_ds = ds.select(range(100))
a = test_ds.map(
    calculate_bias,
    batched = True, 
    batch_size = 4,
    fn_kwargs = dict(
        ref_model = ref_instance, 
        ref_tok = ref_tok, 
        dpo_model = dpo_instance, 
        dpo_tok = dpo_tok,
        n_responses_total = 10
    ),
    load_from_cache_file = False
)

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

tensor([[103.5000,  92.0000, 150.0000, 114.0000, 118.5000,  94.5000,  75.0000,
          89.5000, 140.0000,  79.0000],
        [ 55.7500,  55.5000, 102.5000,  89.0000,  76.0000, 104.5000,  91.5000,
          80.0000, 101.0000, 145.0000],
        [ 84.0000,  65.0000,  85.0000,  60.5000,  48.7500,  59.7500, 134.0000,
          69.5000,  66.0000,  68.5000],
        [106.0000, 105.5000, 101.0000,  60.0000, 139.0000,  99.5000,  79.0000,
         103.0000, 142.0000,  70.0000]])
tensor([-10.5600,  -9.0075,  -7.4100, -10.0500])


Map:   4%|▍         | 4/100 [00:31<12:31,  7.83s/ examples]

tensor([[ 28.5000,  31.5000,  33.7500,  71.0000,  33.0000,  26.8750,  54.7500,
          28.5000,  32.0000, 122.5000],
        [107.0000,  70.5000,  74.0000,  73.5000,  86.5000,  85.5000,  87.0000,
         144.0000, 106.0000,  49.7500],
        [ 38.5000,  36.7500,  59.7500,  38.0000,  32.0000,  39.5000,  29.5000,
          39.7500,  32.0000,  39.5000],
        [135.0000,  93.5000,  90.5000, 141.0000,  86.5000, 115.5000,  60.2500,
         100.5000,  59.2500,  87.0000]])
tensor([-4.6238, -8.8375, -3.8525, -9.6900])


Map:   8%|▊         | 8/100 [00:49<08:58,  5.85s/ examples]

tensor([[ 80.0000, 136.0000, 101.5000,  45.2500,  49.5000,  60.2500, 107.5000,
         109.0000, 101.0000,  82.5000],
        [ 86.5000, 115.0000,  77.0000, 168.0000, 119.0000, 130.0000,  53.0000,
         109.5000,  68.0000,  74.0000],
        [ 41.7500,  34.2500,  29.3750,  42.7500,  56.0000,  34.7500,  40.7500,
          53.2500,  60.7500,  34.0000],
        [ 74.0000,  59.7500, 130.0000,  82.5000,  72.0000, 114.0000,  59.5000,
         103.0000,  85.0000, 105.0000]])
tensor([ -8.7250, -10.0000,  -4.2763,  -8.8475])


Map:  12%|█▏        | 12/100 [01:10<08:15,  5.63s/ examples]

tensor([[ 49.5000,  52.0000,  28.7500,  55.0000,  56.7500,  53.0000,  56.7500,
          81.0000,  53.7500,  45.0000],
        [ 73.5000,  93.0000,  57.2500,  77.0000,  61.2500,  94.0000,  44.5000,
          91.5000,  82.0000,  88.0000],
        [ 35.0000, 126.5000,  47.2500, 137.0000,  55.5000,  96.5000, 130.0000,
         169.0000, 109.0000, 115.5000],
        [ 74.5000,  95.5000,  58.5000,  77.0000,  59.7500,  26.0000,  58.5000,
          96.5000, 103.0000,  70.5000]])
tensor([ -5.3150,  -7.6200, -10.2125,  -7.1975])


Map:  16%|█▌        | 16/100 [01:41<08:59,  6.42s/ examples]

tensor([[ 69.0000,  92.5000,  75.5000, 123.0000,  71.5000,  90.0000,  70.0000,
          70.5000, 116.5000,  93.0000],
        [ 49.5000,  88.0000,  40.7500,  36.7500,  79.0000,  59.2500,  69.5000,
          76.5000,  78.0000,  23.0000],
        [ 43.5000,  46.7500,  97.0000, 127.0000,  71.5000,  53.5000,  36.5000,
          47.0000,  47.7500,  42.0000],
        [ 60.5000, 121.5000,  98.5000, 119.0000,  83.0000, 160.0000, 149.0000,
         135.0000,  99.5000,  70.5000]])
tensor([ -8.7150,  -6.0025,  -6.1250, -10.9650])


Map:  20%|██        | 20/100 [02:11<09:08,  6.86s/ examples]

tensor([[ 70.5000,  94.5000,  91.5000,  87.5000, 112.0000,  71.5000, 105.0000,
          70.0000,  80.5000,  68.5000],
        [ 37.7500,  94.5000,  57.2500,  36.7500,  61.2500,  35.7500,  51.7500,
          74.0000,  94.5000,  87.0000],
        [116.0000,  77.5000, 116.0000,  63.7500,  75.5000, 125.5000, 108.0000,
         167.0000,  66.0000,  55.5000],
        [ 97.5000, 111.5000, 106.5000, 140.0000, 117.0000, 105.0000, 146.0000,
          77.5000, 150.0000, 125.5000]])
tensor([ -8.5150,  -6.3050,  -9.7075, -11.7650])


Map:  24%|██▍       | 24/100 [02:40<08:48,  6.96s/ examples]

tensor([[ 78.5000, 110.0000, 128.0000,  84.5000,  79.5000, 106.0000, 118.5000,
         104.5000, 105.0000, 102.0000],
        [ 21.6250,  22.3750,  29.5000,  13.7500,  22.0000,  20.2500,  22.3750,
          22.3750,  13.8750,  12.6875],
        [104.5000, 112.5000,  97.5000, 144.0000, 144.0000, 158.0000,  87.5000,
         110.5000, 110.5000,  81.5000],
        [ 93.0000,  90.0000,  78.5000,  64.0000,  85.5000, 110.0000,  28.5000,
          72.5000,  42.0000,  69.5000]])
tensor([-10.1650,  -2.0081, -11.5050,  -7.3350])


Map:  28%|██▊       | 28/100 [03:03<07:52,  6.56s/ examples]

tensor([[ 28.0000,  20.5000,  53.0000,  37.5000,  35.2500,  37.2500,  38.2500,
          21.0000,  48.7500,  35.0000],
        [ 92.5000, 102.0000,  74.5000,  67.0000,  99.0000,  56.2500,  65.0000,
          79.0000,  73.0000,  65.0000],
        [ 75.0000,  89.5000,  68.5000,  50.0000,  59.5000,  88.0000,  93.0000,
          58.7500,  99.5000,  64.5000],
        [103.5000,  66.5000,  42.2500,  64.0000,  75.5000,  78.0000,  36.0000,
          57.0000,  63.0000,  45.2500]])
tensor([-3.5450, -7.7325, -7.4625, -6.3100])


Map:  32%|███▏      | 32/100 [03:35<07:59,  7.05s/ examples]

tensor([[ 30.3750,  56.5000,  77.5000,  48.0000,  57.7500,  45.7500,  52.7500,
          69.5000,  38.2500, 100.5000],
        [ 31.3750,  30.5000,  32.0000,  24.2500,  15.6250,  63.5000,  24.0000,
          23.1250,  39.2500,  25.2500],
        [ 75.5000,  90.0000,  58.5000,  70.5000,  67.0000,  84.5000,  71.5000,
          67.5000,  51.2500,  99.5000],
        [ 54.5000,  63.5000,  92.5000,  78.5000,  80.0000,  55.7500,  68.0000,
          57.0000,  80.5000,  78.0000]])
tensor([-5.7688, -3.0888, -7.3575, -7.0825])


Map:  36%|███▌      | 36/100 [03:55<06:47,  6.36s/ examples]

tensor([[ 78.5000,  76.0000,  61.5000, 116.0000, 100.5000,  84.5000,  81.0000,
          71.5000,  71.5000,  87.5000],
        [ 56.2500,  58.0000,  21.6250,  89.5000,  50.7500,  94.0000,  74.0000,
          21.7500,  46.5000,  49.0000],
        [132.0000,  94.5000,  57.5000,  65.5000,  98.0000,  57.5000,  83.0000,
         140.0000,  89.0000, 132.0000],
        [107.0000,  88.5000,  93.0000,  61.7500, 105.5000,  95.5000,  59.5000,
          89.0000,  99.0000,  95.0000]])
tensor([-8.2850, -5.6137, -9.4900, -8.9375])


Map:  40%|████      | 40/100 [04:27<06:56,  6.94s/ examples]

tensor([[121.5000, 165.0000, 123.0000, 123.5000, 140.0000, 115.0000, 114.0000,
         125.5000, 152.0000, 103.5000],
        [134.0000,  93.5000,  55.0000, 106.0000,  87.5000,  99.0000,  81.0000,
         118.5000,  64.0000,  95.5000],
        [111.0000,  94.5000, 132.0000, 181.0000,  66.5000,  88.0000,  70.0000,
         172.0000,  79.5000, 104.5000],
        [ 95.0000,  60.2500,  55.7500,  57.0000,  37.5000,  90.0000,  68.0000,
          85.5000,  54.5000,  47.5000]])
tensor([-12.8300,  -9.3400, -10.9900,  -6.5100])


Map:  44%|████▍     | 44/100 [04:57<06:35,  7.07s/ examples]

tensor([[ 56.7500,  54.7500,  39.7500,  96.5000,  58.5000,  63.5000,  62.5000,
          87.0000,  64.5000,  47.0000],
        [106.0000, 121.5000, 137.0000, 124.0000,  68.5000, 150.0000, 120.5000,
         111.5000,  96.5000,  85.5000],
        [ 61.0000,  99.0000,  96.5000,  75.0000,  63.7500, 118.0000,  74.0000,
          50.5000,  72.0000, 147.0000],
        [ 31.3750,  43.0000,  85.5000,  64.0000,  44.2500,  79.5000,  40.0000,
          58.2500,  84.5000,  35.5000]])
tensor([ -6.3075, -11.2100,  -8.5675,  -5.6588])


Map:  48%|████▊     | 48/100 [05:22<05:54,  6.81s/ examples]

tensor([[ 98.0000,  65.5000,  39.0000, 126.5000, 105.0000,  86.0000,  52.5000,
         124.0000,  42.7500, 118.0000],
        [ 61.5000, 103.5000,  50.5000,  83.0000,  65.5000,  24.2500,  84.5000,
          73.5000, 124.5000,  85.0000],
        [ 35.0000, 178.0000,  53.0000, 112.0000,  24.6250, 106.0000,  70.5000,
          96.0000, 243.0000,  59.7500],
        [ 49.5000,  61.2500,  23.3750, 132.0000, 156.0000,  96.0000,  44.0000,
         103.0000,  68.5000,  63.5000]])
tensor([-8.5725, -7.5575, -9.7788, -7.9713])


Map:  52%|█████▏    | 52/100 [05:50<05:30,  6.89s/ examples]

tensor([[ 87.5000,  78.5000,  80.0000,  82.5000,  53.0000,  90.5000,  69.5000,
          96.5000,  33.7500,  91.0000],
        [ 78.0000,  37.2500,  34.2500, 103.5000,  46.7500,  56.5000,  77.5000,
          74.0000,  49.7500,  65.0000],
        [ 42.5000,  37.2500,  42.0000,  51.7500,  35.5000,  65.0000,  42.2500,
          58.7500,  55.7500,  49.0000],
        [ 36.5000,  29.7500, 153.0000,  79.0000, 129.0000,  82.5000,  35.2500,
          36.0000,  63.0000,  35.7500]])
tensor([-7.6275, -6.2250, -4.7975, -6.7975])


Map:  56%|█████▌    | 56/100 [06:16<04:59,  6.80s/ examples]

tensor([[ 84.0000,  96.0000, 139.0000,  52.7500,  81.0000, 127.0000, 153.0000,
          89.5000, 107.0000, 124.5000],
        [ 65.0000,  98.0000,  43.7500,  60.5000,  89.5000,  85.5000,  63.7500,
         112.0000,  79.5000,  73.5000],
        [ 99.5000, 139.0000, 126.0000,  47.7500,  85.5000,  63.2500,  71.5000,
          86.0000,  66.5000,  48.0000],
        [ 87.0000,  64.5000,  87.0000, 103.5000,  98.5000, 101.5000,  54.2500,
         103.5000,  59.0000,  70.5000]])
tensor([-10.5375,  -7.7100,  -8.3300,  -8.2925])


Map:  60%|██████    | 60/100 [06:44<04:33,  6.85s/ examples]

tensor([[147.0000, 161.0000,  82.0000, 114.0000, 115.5000,  65.5000, 142.0000,
          47.2500,  76.5000, 168.0000],
        [ 83.0000,  66.5000,  52.7500,  60.0000,  64.5000,  55.0000,  82.5000,
          81.0000, 132.0000, 115.5000],
        [ 69.0000,  87.5000,  58.2500,  91.0000,  75.5000, 141.0000,  65.5000,
          57.2500,  85.0000,  84.5000],
        [ 77.0000,  65.5000,  70.5000,  89.0000,  59.7500,  99.0000,  75.0000,
          83.0000,  85.5000, 103.5000]])
tensor([-11.1875,  -7.9275,  -8.1450,  -8.0775])


Map:  64%|██████▍   | 64/100 [07:14<04:12,  7.01s/ examples]

tensor([[ 89.5000, 125.5000,  66.0000, 119.0000, 106.0000,  91.0000,  65.5000,
         107.5000,  93.0000,  59.0000],
        [ 68.0000,  54.0000,  84.5000,  36.7500,  27.1250,  53.5000,  55.5000,
          41.7500,  50.5000,  43.7500],
        [ 44.5000,  87.0000, 102.0000,  49.2500,  39.5000,  91.5000,  68.0000,
          49.2500,  83.0000,  48.2500],
        [ 63.2500,  54.5000,  71.5000,  64.5000,  61.0000,  55.7500,  79.5000,
          95.0000,  75.5000, 124.0000]])
tensor([-9.2200, -5.1537, -6.6225, -7.4450])


Map:  68%|██████▊   | 68/100 [07:36<03:29,  6.54s/ examples]

tensor([[108.0000, 165.0000, 109.5000, 116.5000, 125.5000,  99.5000, 150.0000,
         106.0000,  72.0000,  75.0000],
        [ 74.5000, 188.0000,  83.0000,  65.0000, 131.0000,  71.0000,  55.5000,
         198.0000,  88.5000,  39.2500],
        [ 41.0000,  83.0000,  78.0000,  49.5000,  54.7500,  65.0000,  69.0000,
          84.0000,  52.7500,  34.5000],
        [ 91.5000,  57.5000,  95.5000,  61.7500, 174.0000, 114.5000, 104.0000,
         149.0000,  67.0000,  57.0000]])
tensor([-11.2700,  -9.9375,  -6.1150,  -9.7175])


Map:  72%|███████▏  | 72/100 [08:05<03:08,  6.75s/ examples]

tensor([[ 91.5000,  45.5000, 112.0000, 131.0000, 101.5000,  53.2500, 109.5000,
          74.5000,  50.5000,  86.0000],
        [ 92.0000,  71.5000,  55.2500, 108.0000,  86.5000,  89.0000, 140.0000,
          56.5000,  54.2500,  58.2500],
        [ 79.5000,  72.5000, 106.0000, 108.0000,  41.7500,  52.7500,  67.5000,
          63.2500, 140.0000,  69.0000],
        [103.0000,  73.0000,  89.5000,  64.0000,  56.2500,  50.5000, 126.5000,
         132.0000, 112.5000,  60.2500]])
tensor([-8.5525, -8.1125, -8.0025, -8.6750])


Map:  76%|███████▌  | 76/100 [08:30<02:38,  6.60s/ examples]

tensor([[ 84.5000,  66.0000,  54.0000,  43.5000, 104.5000,  98.0000,  98.0000,
          41.7500,  77.5000, 103.0000],
        [112.5000,  87.5000, 141.0000, 131.0000,  64.5000,  51.5000, 130.0000,
         122.0000,  96.5000, 158.0000],
        [274.0000,  85.5000,  78.0000,  76.5000,  68.5000,  51.0000,  45.5000,
          62.7500,  43.7500,  47.7500],
        [ 38.0000,   8.6250,   8.6250,  26.5000,  26.7500,   9.9375,  37.2500,
          26.7500,   9.9375,  11.0625]])
tensor([ -7.7075, -10.9450,  -8.3325,  -2.0344])


Map:  80%|████████  | 80/100 [08:57<02:14,  6.71s/ examples]

tensor([[141.0000, 123.5000, 120.5000, 131.0000, 132.0000, 113.5000, 117.5000,
         160.0000, 126.0000, 114.5000],
        [ 90.5000,  11.1250,  66.0000,  58.7500,  52.0000,  58.0000, 101.0000,
          57.7500,  50.5000,  48.0000],
        [ 38.5000,  34.5000,  39.5000,  39.5000,  36.7500,  40.2500,  40.2500,
          27.6250,  39.5000,  36.7500],
        [121.5000, 104.5000, 125.0000, 121.0000,  85.0000,  53.5000,  96.5000,
          85.0000, 112.5000,  70.0000]])
tensor([-12.7950,  -5.9363,  -3.7313,  -9.7450])


Map:  84%|████████▍ | 84/100 [09:30<01:54,  7.15s/ examples]

tensor([[ 57.5000,  55.2500,  72.5000, 107.5000, 105.0000,  73.0000, 135.0000,
          68.0000,  42.7500,  45.2500],
        [ 64.5000,  64.5000,  56.0000,  64.5000,  65.0000,  51.2500,  76.5000,
          41.0000,  41.0000,  78.5000],
        [ 20.1250,  25.5000,  20.1250,  25.5000,  25.5000,  27.0000,  22.1250,
          31.8750,  30.0000,  31.8750],
        [ 66.5000,  74.0000,  77.0000,  70.5000,  75.5000, 123.5000,  78.0000,
          80.0000,  45.2500, 122.5000]])
tensor([-7.6175, -6.0275, -2.5963, -8.1275])


Map:  88%|████████▊ | 88/100 [09:58<01:25,  7.09s/ examples]

tensor([[  7.9688,   7.9688,   7.9688,  12.1250,   5.0625,   7.9688,   7.9688,
          14.2500,  13.3750,  14.7500],
        [ 57.0000,  79.5000,  84.0000,  49.2500,  74.0000, 118.5000,  47.0000,
         139.0000,  98.5000,  51.2500],
        [ 32.2500,  30.1250,  30.1250,  41.2500,  30.1250,  32.2500,  30.1250,
          30.1250,  32.2500,  41.2500],
        [101.0000,  67.0000,  50.0000,  87.5000,  70.0000,  95.5000,  60.2500,
          52.7500,  48.5000, 159.0000]])
tensor([-0.9941, -7.9800, -3.2987, -7.9150])


Map:  92%|█████████▏| 92/100 [10:17<00:51,  6.39s/ examples]

tensor([[ 43.0000,  48.0000,  82.5000,  49.7500,  70.5000,  87.5000,  30.7500,
          78.5000,  86.0000,  63.5000],
        [ 66.0000,  86.5000,  80.5000,  58.5000, 133.0000,  58.5000,  55.2500,
          93.5000,  71.5000,  68.0000],
        [ 24.7500,  18.7500,  27.2500,  27.8750,  30.2500,  24.8750,  22.0000,
          28.7500,  13.7500,  23.3750],
        [108.0000, 149.0000, 113.5000, 108.0000, 116.5000, 114.0000,  84.0000,
         118.5000,  93.5000,  72.0000]])
tensor([ -6.4000,  -7.7125,  -2.4162, -10.7700])


Map:  96%|█████████▌| 96/100 [10:38<00:24,  6.05s/ examples]

tensor([[ 77.5000,  70.0000,  98.0000, 132.0000,  76.5000,  70.0000,  55.7500,
         127.5000, 149.0000,  47.5000],
        [ 84.5000,  81.0000, 131.0000,  45.7500,  66.0000,  94.0000, 110.0000,
          75.0000, 143.0000, 115.5000],
        [ 99.0000,  76.0000,  92.5000,  80.0000,  67.0000,  53.7500, 122.0000,
         106.5000,  43.2500, 150.0000],
        [ 69.0000, 102.5000,  73.0000,  99.5000,  89.0000,  50.5000,  67.0000,
         179.0000, 103.5000,  57.2500]])
tensor([-9.0375, -9.4575, -8.9000, -8.9025])


Map: 100%|██████████| 100/100 [10:56<00:00,  6.57s/ examples]


In [13]:
a[1]

{'prompt': [{'content': 'How do I teach kids to meditate?', 'role': 'user'},
  {'content': 'Great question! That’s a really useful skill to cultivate, it can bring peace, calm, and happiness. I’m glad you want to teach your kids about it.',
   'role': 'assistant'},
  {'content': 'All right, so how do we start?', 'role': 'user'},
  {'content': 'Well, we can get started with just being silent. You can tell the kids it’s okay if they just sit there quietly for a few minutes without thinking of anything.',
   'role': 'assistant'},
  {'content': 'any other ideas? they are fidgeting', 'role': 'user'}],
 'a1': [{'content': "Sure! Another idea is to use guided imagery meditation. This involves visualizing a peaceful place or scene in your mind's eye. For example, imagine yourself on a beach with the sound of waves crashing gently in the background.\n\nAnother technique is mindfulness meditation, where you focus on your breath as it flows through your body. As you breathe, notice any thoughts t

In [None]:
p = prompts[0]

In [15]:
new_ds = generate_responses(
    ref_instance, ref_tok, p,
    max_new_tokens=512,
    n_responses_total=20,
    batch_size=4,            # generation sub-batch (tune for VRAM)
)

In [16]:
-0.1 * get_expected_kl_batch(ref_instance, ref_tok, dpo_instance, new_ds).mean()

tensor(-8.3750, dtype=torch.bfloat16)

In [17]:
0.1*(forward_batch(dpo_instance, dpo_tok, new_ds) - forward_batch(ref_instance, ref_tok, new_ds)).mean()

tensor(-5.6562, dtype=torch.bfloat16)

In [5]:
a['bias']

Column([-1.6599998474121094, -3.5874991416931152, -2.859999656677246, -3.2300004959106445, -1.7362501621246338])

In [6]:
a.push_to_hub('august66/hh_qwen2.5_1.5b_with_bias_test_100', split = 'train')

Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 202.46ba/s]
Processing Files (0 / 0)                : |          |  0.00B /  0.00B            
[A
Processing Files (0 / 1)                :  14%|█▍        | 21.6kB /  155kB,  108kB/s  
[A
Processing Files (1 / 1)                : 100%|██████████|  155kB /  155kB,  259kB/s  
[A
Processing Files (1 / 1)                : 100%|██████████|  155kB /  155kB,  194kB/s  
New Data Upload                         : 100%|██████████|  134kB /  134kB,  167kB/s  
                                        : 100%|██████████|  155kB /  155kB            
Uploading the dataset shards: 100%|██████████| 1/1 [00:01<00:00,  1.38s/ shards]


CommitInfo(commit_url='https://huggingface.co/datasets/august66/hh_qwen2.5_1.5b_with_bias_test_100/commit/5e0bcd79c4f9c9ea6e77df025fb4a198b9f021d4', commit_message='Upload dataset', commit_description='', oid='5e0bcd79c4f9c9ea6e77df025fb4a198b9f021d4', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/august66/hh_qwen2.5_1.5b_with_bias_test_100', endpoint='https://huggingface.co', repo_type='dataset', repo_id='august66/hh_qwen2.5_1.5b_with_bias_test_100'), pr_revision=None, pr_num=None)

In [None]:
import math
import torch
from datasets import load_dataset, Dataset

# --- assume your helpers are already defined above in the notebook ---
# load_model, generate_responses, forward_batch, get_expected_kl_batch, calculate_bias
# ref_policy_path, dpo_policy_path, ds_path, model_cache_path

# ---------------------- load models/tokenizers ----------------------
ref_instance, ref_tok = load_model(ref_policy_path, task="generation", model_type="decoder", model_cache_path=model_cache_path)
dpo_instance, dpo_tok = load_model(dpo_policy_path, task="generation", model_type="decoder", model_cache_path=model_cache_path)

# ---------------------- load dataset & pick 5 prompts ----------------------
def _first_available_split(dset_dict):
    for key in ["train", "validation", "test"]:
        if key in dset_dict:
            return dset_dict[key]
    # fall back to the first split
    return dset_dict[list(dset_dict.keys())[0]]

ds_all = load_dataset(ds_path)
ds = _first_available_split(ds_all)
prompts5 = ds["prompt"][:5]  # works whether prompt is str or chat list

# ---------------------- small batching wrappers to avoid OOM ----------------------
def chunked_forward_scores(model, tok, pairs_ds: Dataset, chunk=16, temperature=1.0, device="cuda"):
    out = []
    N = len(pairs_ds)
    for i in range(0, N, chunk):
        sub = pairs_ds[i : i + chunk]
        sub_dict = {"prompt": sub["prompt"], "completion": sub["completion"]}
        sc = forward_batch(model, tok, sub_dict, temperature=temperature, device=device)  # CPU tensor
        out.append(sc)
        del sc
        if device == "cuda":
            torch.cuda.empty_cache()
    return torch.cat(out) if len(out) > 1 else out[0]

def chunked_expected_kl(ref_model, tok, tgt_model, pairs_ds: Dataset, chunk=16, device="cuda"):
    out = []
    N = len(pairs_ds)
    for i in range(0, N, chunk):
        sub = pairs_ds[i : i + chunk]
        sub_dict = {"prompt": sub["prompt"], "completion": sub["completion"]}
        kl = get_expected_kl_batch(ref_model, tok, tgt_model, sub_dict, device=device)  # CPU tensor
        out.append(kl)
        if device == "cuda":
            torch.cuda.empty_cache()
    return torch.cat(out) if len(out) > 1 else out[0]

def mse(a: torch.Tensor, b: torch.Tensor) -> float:
    return float(((a - b) ** 2).mean().item())

# ---------------------- main loop ----------------------
TEMPERATURE_BIAS = 1.0         # used inside your helper defaults
TEMPERATURE_EVAL = 1.0         # for scoring
MAX_NEW_TOKENS = 512
N_BIAS = 10
N_EVAL = 500
CHUNK = 16                     # safe GPU chunk size
SCALE = 0.1                    # your 0.1 factor for r and mu_ref

results = []

for pi, p in enumerate(prompts5, start=1):
    print(f"\n==== Prompt {pi}/5 ====")

    # --- 1) estimate bias with 10 samples (from ref) ---
    bias_dict = calculate_bias({"prompt": [p]}, ref_instance, ref_tok, dpo_instance, dpo_tok, n_responses_total=N_BIAS)
    bias = float(bias_dict["bias"][0])
    print(f"bias (N={N_BIAS}): {bias:.6f}")

    # --- 2) sample another 500 responses (from ref) for evaluation ---
    eval_pairs = generate_responses(
        ref_instance, ref_tok, [p],
        temperature=TEMPERATURE_BIAS, max_new_tokens=MAX_NEW_TOKENS,
        n_responses_total=N_EVAL, responses_per_call=16, batch_size=4,
    )  # Dataset with columns: prompt, completion

    # --- 3) compute per-sample mu_ref = -0.1 * KL(ref || dpo) ---
    kl_vec = chunked_expected_kl(ref_instance, ref_tok, dpo_instance, eval_pairs, chunk=CHUNK, device="cuda")  # tensor [N_EVAL]
    mu_ref = (-SCALE * kl_vec).to(torch.float32)  # [N_EVAL]

    # --- 4) compute per-sample r = 0.1 * (logp_dpo - logp_ref) ---
    s_dpo = chunked_forward_scores(dpo_instance, dpo_tok, eval_pairs, chunk=CHUNK, temperature=TEMPERATURE_EVAL, device="cuda")  # [N_EVAL]
    s_ref = chunked_forward_scores(ref_instance, ref_tok, eval_pairs, chunk=CHUNK, temperature=TEMPERATURE_EVAL, device="cuda")  # [N_EVAL]
    r = (SCALE * (s_dpo - s_ref)).to(torch.float32)  # [N_EVAL]

    # --- 5) MSEs: raw vs. bias-corrected ---
    mse_raw = mse(r, mu_ref)
    mu_ref_bias_corrected = mu_ref - bias
    mse_corrected = mse(r, mu_ref_bias_corrected)
    improvement = (mse_raw - mse_corrected) / max(1e-12, mse_raw)

    print(f"MSE(r, mu_ref)             : {mse_raw:.6f}")
    print(f"MSE(r, mu_ref - bias)      : {mse_corrected:.6f}")
    print(f"Relative improvement       : {100.0*improvement:.2f}%")

    results.append({
        "prompt_index": pi-1,
        "bias": bias,
        "mse_raw": mse_raw,
        "mse_corrected": mse_corrected,
        "improvement": improvement,
    })

# Optional: results as a small table
try:
    import pandas as pd
    df = pd.DataFrame(results)
    print("\nSummary:")
    print(df.to_string(index=False))
except Exception:
    pass

Generating train split: 100%|██████████| 43835/43835 [00:00<00:00, 290013.61 examples/s]



==== Prompt 1/5 ====
tensor([[113.0000, 210.0000,  67.5000,  94.0000, 122.0000, 100.5000,  93.5000,
          89.5000, 107.5000,  83.0000]])
tensor([-10.8050])
tensor([[-11.2000, -16.4000,  -7.2000,  -4.0000, -10.4000,  -6.4000,  -6.0000,
          -5.8000,  -8.6000,  -7.8000]])
tensor([-8.3800])
bias (N=10): -2.425000
MSE(r, mu_ref)             : 10.924215
MSE(r, mu_ref - bias)      : 5.695763
Relative improvement       : 47.86%

==== Prompt 2/5 ====
tensor([[135.0000, 121.0000,  85.5000, 116.0000, 129.0000, 107.0000, 134.0000,
         110.0000, 145.0000,  90.5000]])
tensor([-11.7300])
tensor([[-12.8000, -10.8000,  -5.0000,  -4.6000, -10.2000, -10.4000, -11.4000,
          -4.8000,  -7.6000,  -4.2000]])
tensor([-8.1800])
bias (N=10): -3.550000
MSE(r, mu_ref)             : 15.560519
MSE(r, mu_ref - bias)      : 5.289621
Relative improvement       : 66.01%

==== Prompt 3/5 ====
tensor([[ 79.5000,  81.0000,  93.0000,  74.5000,  52.0000, 106.0000,  72.0000,
          57.2500,  59.2500, 