In [None]:
import os
import torch
import sys, pathlib
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification,DebertaV2ForSequenceClassification, GPTNeoXForCausalLM
from llm_blender.pair_ranker.pairrm import DebertaV2PairRM
from transformers import DataCollatorWithPadding
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

from torch.utils.data import DataLoader
from datasets import load_dataset, concatenate_datasets, DatasetDict
import torch.nn.functional as F
from tqdm import tqdm
import re
import yaml

LOCAL_TRL_PARENT = "/workspace/Self_play_DRPO"
if LOCAL_TRL_PARENT not in sys.path:
    sys.path.insert(0, LOCAL_TRL_PARENT)

    
# now the import will use your local copy:
from trl import (
    DPOTrainer,
    DPOConfig,
    ModelConfig,
    DRPOTrainer,
    DRPOConfig,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
from trl.data_utils import apply_chat_template

data_cache_path = "/workspace/dataset"
model_cache_path = '/workspace/model_cache'
ref_policy_path = "Qwen/Qwen2.5-1.5B-Instruct" 
#target_policy_path = "Qwen/Qwen2.5-1.5B-Instruct" 
target_policy_500 = '/workspace/Self_play_DRPO/self_play_drpo_code/output/hh/0/checkpoint-600'
target_policy_1000 = '/workspace/Self_play_DRPO/self_play_drpo_code/output/hh/0/checkpoint-1000'
dpo_policy_path = 'august66/ultrafeedback_20k_qwen_1.5b_dpo_model'
#reward_model_path = 'cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr'
ds_path = 'august66/drpo_hh_qwen2.5_1.5b'


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_model(model_path, task = 'generation', model_type = 'decoder', model_cache_path =  '/workspace/model_cache'):

    model_args = ModelConfig(model_path)
    model_torch_dtype = (model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype))
    model_kwargs = dict(
    revision = model_args.model_revision,
    torch_dtype = model_torch_dtype, 
    trust_remote_code = model_args.trust_remote_code,
    )

    padding_side = 'left' if model_type == 'decoder' else 'right'
    truncation_side = 'left' if model_type == 'decoder' else 'right'

    if task == 'generation':
        model_instance = AutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            **model_kwargs,
            cache_dir = model_cache_path,
        )

    elif task == 'reward':
        model_instance = AutoModelForSequenceClassification.from_pretrained(
            model_args.model_name_or_path,
            **model_kwargs,
            cache_dir = model_cache_path,
        )
    

    model_tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path, 
        padding_side = padding_side, 
        truncation_side = truncation_side,
        use_fast = True,
        trust_remote_code = model_args.trust_remote_code,
        cache_dir = model_cache_path
    )

    if model_tokenizer.pad_token is None:
        model_tokenizer.pad_token = model_tokenizer.eos_token

    if getattr(model_instance.config, "pad_token_id", None) is None:
        model_instance.config.pad_token_id = model_tokenizer.pad_token_id

    if model_tokenizer.eos_token is None:
        model_tokenizer.eos_token = model_tokenizer.pad_token  

    if getattr(model_instance.config, "eos_token_id", None) is None:
        model_instance.config.eos_token_id = model_tokenizer.eos_token_id

    return model_instance, model_tokenizer


In [80]:
ref_model_instance, ref_model_tokenizer = load_model(ref_policy_path, task = 'generation')
target_model_instance, target_model_tokenizer = load_model(target_policy_path, task = 'generation')
dpo_policy_instance, dpo_policy_tokenizer = load_model(dpo_polict_path, task = 'generation')
#reward_model_instance, reward_model_tokenizer = load_model(reward_model_path, task ='reward')

ds = load_dataset(ds_path, cache_dir=data_cache_path, split = 'train')
prompts = ds['prompt']

Fetching 2 files: 100%|██████████| 2/2 [00:51<00:00, 25.88s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 40.94it/s]


In [49]:
def generate_response(model_instance, model_tokenizer, prompts, 
                      temperature=0.0, max_new_tokens=256, 
                      n_responses=2, batch_size=8, device='cuda'):
    generation_kwargs = {
        "top_k": 50,
        "top_p": 0.9,
        "temperature": temperature,
        "do_sample": True if temperature > 0 else False,
        "eos_token_id": model_tokenizer.eos_token_id
    }

    if isinstance(prompts, str):
        prompts = [prompts]

    model_instance.to(device)
    all_prompt_responses = []

    # loop over batches
    for start in tqdm(range(0, len(prompts), batch_size)):
        
        batch_prompts = prompts[start:start+batch_size]
        batch_prompts = [model_tokenizer.apply_chat_template(
            p,
            add_generation_prompt = True,
            tokenize = False
        ) for p in batch_prompts]
        print (batch_prompts)
        encoded_inputs = model_tokenizer(
            batch_prompts, 
            padding=True,
            truncation=True,
            max_length=1024,
            return_tensors="pt"
        ).to(device)


        prompt_lengths = encoded_inputs.attention_mask.sum(-1).tolist()

        outputs = model_instance.generate(
            **encoded_inputs,
            max_new_tokens=max_new_tokens,
            num_return_sequences=n_responses,
            **generation_kwargs
        )

        generated_tokens = []
        for i in range(len(batch_prompts)):
            prompt_length = prompt_lengths[i]
            for j in range(n_responses):
                idx = i * n_responses + j
                generated_tokens.append(outputs[idx][prompt_length:])

        decoded_responses = model_tokenizer.batch_decode(
            generated_tokens, skip_special_tokens=True
        )

        for i, prompt in enumerate(batch_prompts):
            for j in range(n_responses):
                idx = i * n_responses + j
                all_prompt_responses.append([prompt, decoded_responses[idx]])

        del encoded_inputs, outputs, generated_tokens, decoded_responses
        if device == 'cuda':
            torch.cuda.empty_cache()

    return all_prompt_responses



def get_reward_batch(
    model_instance, 
    model_tokenizer, 
    inputs,                    # list[(prompt, response)]
    device='cuda', 
    batch_size=32
):
    model_instance.to(device)
    all_rewards = []

    for start in tqdm(range(0, len(inputs), batch_size)):
        batch_inputs = inputs[start:start+batch_size]

        # concat prompt+response text
        responses = [prompt + response for (prompt, response) in batch_inputs]

        encoded_inputs = model_tokenizer(
            responses,
            padding='max_length',
            truncation=True,
            max_length=1024,
            return_tensors='pt'
        ).to(device)

        with torch.no_grad():
            outputs = model_instance(**encoded_inputs)

        # Typical reward heads return shape [B] or [B,1] in outputs.logits
        rewards = outputs.logits.squeeze(-1).detach().cpu().tolist()
        if isinstance(rewards, list):
            all_rewards.extend(rewards)
        else:
            all_rewards.append(rewards)

        del encoded_inputs, outputs
        if device == 'cuda':
            torch.cuda.empty_cache()

    return all_rewards





@torch.no_grad()
def get_expected_kl(
    ref_model,
    tokenizer,                 # same tokenizer for both models
    target_model,
    inputs,                    # list[[prompt, response], ...]
    device='cuda',
    batch_size=8,
    max_length=1024,
    response_only=False,        # if True, only score response tokens
):
    ref_model.to(device).eval()
    target_model.to(device).eval()

    seq_diff_all = []

    for start in tqdm(range(0, len(inputs), batch_size)):
        batch_pairs = inputs[start:start+batch_size]
        prompts  = [p for p, r in batch_pairs]
        full_txt = [p + r for p, r in batch_pairs]   # concatenate exactly as specified

        # Encode full sequences (prompt+response)
        enc_full = tokenizer(
            full_txt,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors='pt'
        ).to(device)

        # Encode prompts alone to get per-sample prompt token counts
        enc_prompt = tokenizer(
            prompts,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors='pt'
        ).to(device)

        input_ids = enc_full["input_ids"]        # [B, T]
        attn_mask = enc_full["attention_mask"]   # [B, T]
        prompt_len = enc_prompt["attention_mask"].sum(dim=1).to(torch.long)  # [B]

        # Forward both models
        ref_logits    = ref_model(input_ids=input_ids,    attention_mask=attn_mask).logits   # [B, T, V]
        target_logits = target_model(input_ids=input_ids, attention_mask=attn_mask).logits   # [B, T, V]

        # Teacher-forced next-token logprobs
        ref_logp = F.log_softmax(ref_logits[:, :-1, :], dim=-1)      # [B, T-1, V]
        tgt_logp = F.log_softmax(target_logits[:, :-1, :], dim=-1)   # [B, T-1, V]
        tgt_p = tgt_logp.exp()
        labels   = input_ids[:, 1:]                                   # [B, T-1]
        mask     = attn_mask[:, 1:].to(ref_logp.dtype).clone()        # [B, T-1]

        # If we only want the response, drop prompt contributions per sample.
        if response_only:
            # After shifting, tokens with position < (prompt_len - 1) belong to the prompt
            cuts = (prompt_len - 1).clamp_min(0)                      # [B]
            # Zero out up to 'cut' for each row
            # Note: sequences are padded; mask already 0 on pad positions
            for i, cut in enumerate(cuts.tolist()):
                if cut > 0:
                    cut = min(cut, mask.shape[1])  # guard against extreme lengths
                    mask[i, :cut] = 0.0


        per_token = (tgt_p * (tgt_logp - ref_logp)).sum(dim = -1)
        diff = (per_token*mask).sum(dim = -1)

        seq_diff_all.append(diff.detach().cpu())

        # cleanup
        del enc_full, enc_prompt, ref_logits, target_logits, ref_logp, tgt_logp, tgt_p, labels, mask
        if device == 'cuda':
            torch.cuda.empty_cache()

    diff_list = torch.cat(seq_diff_all).tolist()

    
    return diff_list



    

In [59]:
temperature = [1.0]
total = 10
step = 10
prompt = ds['prompt'][10000]
if not isinstance(prompt, list) or not isinstance(prompt[0], list):
    prompt = [prompt]
logp_diff_dict = {}
for t in temperature:
    outputs = []
    for i in tqdm(range(0, total, step)):
        output = generate_response(target_model_instance, target_model_tokenizer, prompt, n_responses=step, temperature = t)
        outputs.extend(output)
    logp_diff = get_expected_kl(
            ref_model_instance, ref_model_tokenizer, target_model_instance, outputs, response_only = True
    )
    logp_diff_dict[t] = logp_diff
    

  0%|          | 0/1 [00:00<?, ?it/s]

['<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhy does water become solid when it is boiling?<|im_end|>\n<|im_start|>assistant\n']


100%|██████████| 1/1 [00:11<00:00, 11.76s/it]
100%|██████████| 1/1 [00:11<00:00, 11.76s/it]
100%|██████████| 2/2 [00:01<00:00,  1.19it/s]


In [89]:
def _calculate_kl_divergence(target_model, ref_model, prompt_ids, prompt_attention_mask, completion_ids, completion_attention_mask, temperature=1.0):

    target_model.to('cuda')
    ref_model.to('cuda')
    max_length = 1024
    # Get the number of tokens to truncate from prompt
    num_tokens_to_truncate = max(prompt_ids.size(1) + completion_ids.size(1) - max_length, 0)

    # Truncate left to avoid oom
    prompt_ids = prompt_ids[:, num_tokens_to_truncate:]
    prompt_attention_mask = prompt_attention_mask[:, num_tokens_to_truncate:]

    # Concat the prompt and completion
    prompt_completion_ids = torch.cat((prompt_ids, completion_ids), dim=1)
    prompt_completion_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1)

    # Get the logps of the completions from the model
    logits_target = target_model(prompt_completion_ids, attention_mask=prompt_completion_mask).logits
    logits_ref = ref_model(prompt_completion_ids, attention_mask=prompt_completion_mask).logits

    prompt_length = prompt_ids.size(1)
    logits_completion_target = logits_target[:, max(0, prompt_length - 1) : -1, :]
    logits_completion_ref = logits_ref[:, max(0, prompt_length - 1) : -1, :]

    if completion_ids.size(1) > logits_target.size(1):
        completion_ids = completion_ids[:, : logits_target.size(1)]

    mask = completion_attention_mask[:, 1:].to(torch.float32).clone()
    logp_completion_target = F.log_softmax(logits_completion_target / (temperature + 1e-7), dim=-1)
    logp_completion_ref = F.log_softmax(logits_completion_ref / (temperature + 1e-7), dim=-1)
    p_target= logp_completion_target.exp()

    kl_divergence = (p_target * (logp_completion_target - logp_completion_ref) * mask.unsqueeze(-1)).sum(-1).sum(-1)

    return kl_divergence

In [63]:
prompt = outputs[0][0]
completion = outputs[0][1]

In [None]:
encoded_prompts = target_model_tokenizer(
            prompt,
            padding='max_length',
            max_length = 1024,
            return_tensors='pt'
).to('cuda')
encoded_completions = target_model_tokenizer(
            completion,
            padding='max_length',
            max_length = 1024, 
            return_tensors='pt'
).to('cuda')
prompt_id = encoded_prompts['input_ids']
prompt_attention_mask = encoded_prompts['attention_mask']
completion_id = encoded_completions['input_ids']    
completion_attention_mask = encoded_completions['attention_mask']

In [90]:
_calculate_kl_divergence(target_model_instance, dpo_policy_instance, prompt_id, prompt_attention_mask, completion_id, completion_attention_mask, temperature=1.0)

tensor([13.4464], device='cuda:0', grad_fn=<SumBackward1>)