In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from torch.nn.functional import log_softmax, logsigmoid
import tqdm
# 参数
model_name = "princeton-nlp/Llama-3-Base-8B-SFT"  # 可换成其他支持 causal LM 的模型
dataset_name = "argilla/dpo-mix-7k"
beta = 0.1
device = "cuda" if torch.cuda.is_available() else "cpu"

# 1. 加载 tokenizer 和模型
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16 if "cuda" in device else torch.float32).to(device)
model.eval()

# 2. 加载偏好数据集（仅用前几个样本做演示）
# dataset = load_dataset(dataset_name, split="train[:64]")  # 用全部数据可改成 "train"
dataset = load_dataset(dataset_name, split="train")  # 用全部数据可改成 "train"

# 3. 定义 log probability 计算函数
def compute_logps(texts):
    # padding to longest
    encodings = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
    input_ids = encodings.input_ids.to(device)
    attention_mask = encodings.attention_mask.to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # shape: (B, T, V)
        log_probs = log_softmax(logits, dim=-1)

    # shift tokens and gather correct token logprobs
    shift_input_ids = input_ids[:, 1:]
    shift_log_probs = log_probs[:, :-1, :]
    shift_mask = attention_mask[:, 1:]

    # 获取对应 token 的 log probability
    token_logprobs = torch.gather(shift_log_probs, dim=2, index=shift_input_ids.unsqueeze(-1)).squeeze(-1)
    seq_logprobs = (token_logprobs * shift_mask).sum(dim=1)  # 按 mask 取 sum

    return seq_logprobs  # shape: (B,)

# 4. 批量计算 DPO loss
batch_size = 8
all_losses = []



for i in tqdm.tqdm(range(0, len(dataset), batch_size), desc='handling'):
    batch = dataset[i:i + batch_size]
    chosen_texts = [sample[-1]['content'] for sample in batch['chosen']]
    rejected_texts = [sample[-1]['content'] for sample in batch['rejected']]
    # print(rejected_texts)
    chosen_logps = compute_logps(chosen_texts)
    rejected_logps = compute_logps(rejected_texts)

    logits_diff = beta * (chosen_logps - rejected_logps)
    losses = -logsigmoid(logits_diff)
    all_losses.append(losses.cpu())

# 5. 汇总 loss
all_losses = torch.cat(all_losses)
print("Mean DPO loss:", all_losses.mean().item())
print("Per-sample losses:", all_losses.tolist())

torch.save(all_losses, "agrilla_dpo_loss.pt")


  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 13.76it/s]
handling: 100%|██████████| 844/844 [13:32<00:00,  1.04it/s]

Mean DPO loss: 9.8125
Per-sample losses: [10.1875, 0.02978515625, 36.5, 2.066371962428093e-09, 2.481541837659083e-22, 3.625, 0.01214599609375, 21.5, 26.375, 15.125, 7.1875, 67.0, 1.996755599975586e-06, 27.75, 32.0, 6.139278411865234e-06, 20.25, 1.8775463104248047e-06, 8.754432201385498e-08, 4.125, 3.25, 50.5, 0.01507568359375, 0.00014019012451171875, 0.01007080078125, 12.0625, 7.188646122813225e-09, 4.7222086809427244e-20, 24.5, 0.000278472900390625, 49.5, 1.9563750449481093e-27, 3.46451997756958e-07, 22.375, 17.875, 0.0, 30.75, 0.059326171875, 8.0, 0.0037078857421875, 6.716706573930585e-22, 0.7578125, 0.048583984375, 6.6875, 29.0, 3.725290298461914e-06, 5.816113682013476e-24, 4.05634636990726e-10, 55.5, 16.625, 10.8125, 13.0, 20.25, 4.59375, 0.00299072265625, 6.03125, 36.5, 12.1875, 1.3589129821411916e-13, 26.375, 34.0, 1.2304311611726287e-23, 6.314393452555578e-16, 1.7139067942650854e-15, 0.01007080078125, 8.754432201385498e-08, 0.005584716796875, 0.0026397705078125, 13.8125, 0.37109




## Arigilla dataset

In [5]:
from datasets import load_dataset, Dataset, DatasetDict


raw_dataset = load_dataset("argilla/dpo-mix-7k")['train']
test_dataset = load_dataset("argilla/dpo-mix-7k")['test']

train_prompts = [sample[0]['content'] for sample in raw_dataset['chosen']]
test_prompts = [sample[0]['content'] for sample in test_dataset['chosen']]

raw_dataset = raw_dataset.add_column("prompt", train_prompts)
test_dataset = test_dataset.add_column("prompt", test_prompts)

# 重命名列
raw_dataset = raw_dataset.rename_column("chosen_rating", "score_chosen")
raw_dataset = raw_dataset.rename_column("rejected_rating", "score_rejected")
test_dataset = test_dataset.rename_column("chosen_rating", "score_chosen")
test_dataset = test_dataset.rename_column("rejected_rating", "score_rejected")

score_diff = [chosen - rejected for chosen, rejected in zip(raw_dataset['score_chosen'], raw_dataset['score_rejected'])]

raw_dataset = raw_dataset.add_column("score_diff", score_diff)

score_diff_sorted_dataset = raw_dataset.sort(f"score_diff", reverse=True)



score_diff_sorted_dataset_new = score_diff_sorted_dataset.remove_columns([column_name for column_name in score_diff_sorted_dataset.column_names if column_name not in test_dataset.column_names ])

score_diff_dataset = DatasetDict({
    'train': score_diff_sorted_dataset_new,
    'test': test_dataset
})
score_diff_dataset.push_to_hub("jlpang888/arigilla_sorted_score_diff")

Creating parquet from Arrow format: 100%|██████████| 7/7 [00:00<00:00, 23.58ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:02<00:00,  2.27s/it]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 30.24ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:01<00:00,  1.31s/it]


CommitInfo(commit_url='https://huggingface.co/datasets/jlpang888/arigilla_sorted_score_diff/commit/3a6cfe34bcaf03b43a18b37ca2d36203dca61cfd', commit_message='Upload dataset', commit_description='', oid='3a6cfe34bcaf03b43a18b37ca2d36203dca61cfd', pr_url=None, pr_revision=None, pr_num=None)

In [7]:
from datasets import load_dataset, Dataset, DatasetDict
import torch

raw_dataset = load_dataset("argilla/dpo-mix-7k")['train']
test_dataset = load_dataset("argilla/dpo-mix-7k")['test']

train_prompts = [sample[0]['content'] for sample in raw_dataset['chosen']]
test_prompts = [sample[0]['content'] for sample in test_dataset['chosen']]

raw_dataset = raw_dataset.add_column("prompt", train_prompts)
test_dataset = test_dataset.add_column("prompt", test_prompts)

dpo_losses = torch.load("agrilla_dpo_loss.pt").tolist()

raw_dataset = raw_dataset.add_column("llama_order", dpo_losses)

score_diff_sorted_dataset = raw_dataset.sort(f"llama_order", reverse=False)


score_diff_sorted_dataset_new = score_diff_sorted_dataset.remove_columns([column_name for column_name in score_diff_sorted_dataset.column_names if column_name not in test_dataset.column_names ])

score_diff_dataset = DatasetDict({
    'train': score_diff_sorted_dataset_new,
    'test': test_dataset
})
score_diff_dataset.push_to_hub("jlpang888/arigilla_sorted_llama")

Creating parquet from Arrow format: 100%|██████████| 7/7 [00:00<00:00, 23.80ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:02<00:00,  2.47s/it]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 73.98ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00,  1.07it/s]


CommitInfo(commit_url='https://huggingface.co/datasets/jlpang888/arigilla_sorted_llama/commit/44c9d76ac2d4a337b245406c668b888a0c627209', commit_message='Upload dataset', commit_description='', oid='44c9d76ac2d4a337b245406c668b888a0c627209', pr_url=None, pr_revision=None, pr_num=None)

In [8]:
from datasets import load_dataset, Dataset, DatasetDict
import torch

raw_dataset = load_dataset("argilla/dpo-mix-7k")['train']
test_dataset = load_dataset("argilla/dpo-mix-7k")['test']

train_prompts = [sample[0]['content'] for sample in raw_dataset['chosen']]
test_prompts = [sample[0]['content'] for sample in test_dataset['chosen']]

raw_dataset = raw_dataset.add_column("prompt", train_prompts)
test_dataset = test_dataset.add_column("prompt", test_prompts)


score_diff_sorted_dataset_new = score_diff_sorted_dataset.remove_columns([column_name for column_name in score_diff_sorted_dataset.column_names if column_name not in test_dataset.column_names ])

score_diff_dataset = DatasetDict({
    'train': score_diff_sorted_dataset_new,
    'test': test_dataset
})
score_diff_dataset.push_to_hub("jlpang888/arigilla_mix_7k")

Creating parquet from Arrow format: 100%|██████████| 7/7 [00:00<00:00, 23.75ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:02<00:00,  2.48s/it]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 31.14ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00,  1.11it/s]


CommitInfo(commit_url='https://huggingface.co/datasets/jlpang888/arigilla_mix_7k/commit/9dd8a6b0ab2695432a20451b5f8dcf9f8bec8c58', commit_message='Upload dataset', commit_description='', oid='9dd8a6b0ab2695432a20451b5f8dcf9f8bec8c58', pr_url=None, pr_revision=None, pr_num=None)