#### 测试DPO的流程以及损失计算等

In [1]:
import torch

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

name_or_path = "../../train-result/2024-06-19/checkpoint-5000"
tokenizer_name_or_path = name_or_path
local_dirs = ".cache"
n_epochs = 1
batch_size = 1
rank = 0
trainer = "BasicTrainer"
policy_dtype = "float16"

In [2]:
import transformers
from tools.common_utils import get_local_dir

tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name_or_path, cache_dir=get_local_dir(local_dirs), trust_remote_code=True)

  from .autonotebook import tqdm as notebook_tqdm
Setting eos_token is not supported, use the default one.
Setting pad_token is not supported, use the default one.
Setting unk_token is not supported, use the default one.


In [None]:
from dpo_trainers import get_batch_iterator, concatenated_inputs

data_iterator_kwargs = dict(
    data_dir="data",
    data_files="train_err.csv",
    tokenizer=tokenizer,
    shuffle=True,
    max_length=2048,
    max_prompt_length=1024,
)
train_iterator = get_batch_iterator(**data_iterator_kwargs, split='train', n_epochs=n_epochs, batch_size=batch_size, silent=rank != 0)

idx = 0
for batch in train_iterator:
    concatenated_batch = concatenated_inputs(batch)
    print(concatenated_batch)
    idx += 1
    if idx == 2:
        break

In [3]:
from tools.common_utils import disable_dropout

model_kwargs = {'device_map': 'balanced'} if trainer == 'BasicTrainer' else {}
policy_dtype = getattr(torch, policy_dtype)

policy = transformers.AutoModelForCausalLM.from_pretrained(
    name_or_path, 
    cache_dir=get_local_dir(local_dirs), 
    low_cpu_mem_usage=True, 
    torch_dtype=policy_dtype,
    trust_remote_code=True,
    **model_kwargs)
policy.to(device)

disable_dropout(policy)

Loading checkpoint shards: 100%|██████████| 6/6 [05:54<00:00, 59.16s/it] 


In [11]:
idx = 0

for batch in train_iterator:
    print(batch)
    idx += 1
    if idx == 2:
        break

{'prompt': ['\n\nHuman: 你现在是一个金融领域专家，你需要通过编排api来得到用户query的答案，输出json格式。query是：我想知道如果我用所有的流动资产买了蓝英股票，然后在今日的最高价时全部卖出，我能赚多少钱 \n query中提到的产品标准名可能是：嘉实物流产业股票型证券投资基金C类、嘉实物流产业股票型证券投资基金A类、民生加银核心资产股票型证券投资基金C类、民生加银核心资产股票型证券投资基金A类、工银瑞信物流产业股票型证券投资基金C类、工银瑞信物流产业股票型证券投资基金A类、鑫元核心资产股票型发起式证券投资基金C类、鑫元核心资产股票型发起式证券投资基金A类、招商移动互联网产业股票型证券投资基金C类、招商移动互联网产业股票型证券投资基金A类、国联金如意双利一年持有期债券型集合资产管理计划C类、国联金如意双利一年持有期债券型集合资产管理计划A类、国联金如意3个月滚动持有债券型集合资产管理计划C类、国联金如意3个月滚动持有债券型集合资产管理计划A类、鹏华环保产业股票型证券投资基金、鹏华养老产业股票型证券投资基金、长信创新驱动股票型证券投资基金、融通产业趋势股票型证券投资基金、招商全球资源股票型证券投资基金、建信优势动力股票型证券投资基金、广发科技动力股票型证券投资基金、嘉实新兴产业股票型证券投资基金、嘉实事件驱动股票型证券投资基金、南方国策动力股票型证券投资基金、南方产业活力股票型证券投资基金、南方产业智选股票型证券投资基金、信澳蓝筹精选股票型证券投资基金、新华精选低波动股票型证券投资基金、国联安科技动力股票型证券投资基金、南方天元新产业股票型证券投资基金、信澳新能源产业股票型证券投资基金、中银新动力股票型证券投资基金C类、中银新动力股票型证券投资基金A类、金鹰信息产业股票型证券投资基金C类、金鹰信息产业股票型证券投资基金A类、西部利得事件驱动股票型证券投资基金、融通产业趋势先锋股票型证券投资基金、英大国企改革主题股票型证券投资基金、汇添富民营新动力股票型证券投资基金、汇安趋势动力股票型证券投资基金C类、汇安趋势动力股票型证券投资基金A类、摩根研究驱动股票型证券投资基金C类、摩根研究驱动股票型证券投资基金A类、摩根大盘蓝筹股票型证券投资基金C类、摩根大盘蓝筹股票型证券投资基金A类、招商蓝筹精选股票型证券投资基金C类、招商蓝筹精选股票型证

In [5]:
all_logits = policy(concatenated_batch['concatenated_input_ids'].to(device), attention_mask=concatenated_batch['concatenated_attention_mask'].to(device)).logits.to(torch.float32)

In [6]:
from dpo_trainers import _get_batch_logps

all_logps = _get_batch_logps(all_logits, concatenated_batch['concatenated_labels'].to(device), average_log_prob=False)
chosen_logps = all_logps[:batch['chosen_input_ids'].shape[0]]
rejected_logps = all_logps[batch['chosen_input_ids'].shape[0]:]

print("--------chosen_logps--------")
print(chosen_logps)
print("--------rejected_logps--------")
print(rejected_logps)

--------chosen_logps--------
tensor([-6.8947], device='cuda:0', grad_fn=<SliceBackward0>)
--------rejected_logps--------
tensor([-4.4050], device='cuda:0', grad_fn=<SliceBackward0>)


In [8]:
reference_model_dtype = "float16"
reference_model_dtype = getattr(torch, reference_model_dtype)

reference_model = transformers.AutoModelForCausalLM.from_pretrained(
    name_or_path, 
    cache_dir=get_local_dir(local_dirs), 
    low_cpu_mem_usage=True, 
    torch_dtype=reference_model_dtype,
    trust_remote_code=True,
    offload_buffers=True,
    **model_kwargs)
reference_model.to(device)

disable_dropout(reference_model)

Loading checkpoint shards: 100%|██████████| 6/6 [07:41<00:00, 76.96s/it] 


In [9]:
with torch.no_grad():
    ref_logits = reference_model(
        concatenated_batch['concatenated_input_ids'].to(device), 
        attention_mask=concatenated_batch['concatenated_attention_mask'].to(device)
        ).logits.to(torch.float32)
    ref_logps = _get_batch_logps(ref_logits, concatenated_batch['concatenated_labels'].to(device), average_log_prob=False)
    ref_chosen_logps = ref_logps[:batch['chosen_input_ids'].shape[0]]
    ref_rejected_logps = ref_logps[batch['chosen_input_ids'].shape[0]:]

    print("--------ref_chosen_logps--------")
    print(ref_chosen_logps)
    print("--------ref_rejected_logps--------")
    print(ref_rejected_logps)

--------ref_chosen_logps--------
tensor([-6.8947], device='cuda:0')
--------ref_rejected_logps--------
tensor([-4.4050], device='cuda:0')


In [10]:
from dpo_trainers import preference_loss

beta = 0.9
reference_free = False
label_smoothing = 0

loss_kwargs = {'beta': beta, 'reference_free': reference_free, 'label_smoothing': label_smoothing, 'ipo': False}

losses, chosen_rewards, rejected_rewards = preference_loss(chosen_logps, rejected_logps, ref_chosen_logps, ref_rejected_logps, **loss_kwargs)

print("--------losses--------")
print(losses)
print("--------chosen_rewards--------")
print(chosen_rewards)
print("--------rejected_rewards--------")
print(rejected_rewards)

--------losses--------
tensor([0.6931], device='cuda:0', grad_fn=<SubBackward0>)
--------chosen_rewards--------
tensor([0.], device='cuda:0')
--------rejected_rewards--------
tensor([0.], device='cuda:0')
