In [107]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [108]:
import torch.nn.functional as F

def dpo_loss(pi_logps, ref_logps, yw_idxs, yl_idxs, beta):
    """
    pi_logps: policy logprobs, shape (B,)
    ref_logps: reference model logprobs, shape (B,)
    yw_idxs: preferred completion indices in [0, B-1], shape (T,)
    yl_idxs: dispreferred completion indices in [0, B-1], shape (T,)
    beta: temperature controlling strength of KL penalty
    """
    pi_yw_logps, pi_yl_logps = pi_logps[yw_idxs], pi_logps[yl_idxs]
    ref_yw_logps, ref_yl_logps = ref_logps[yw_idxs], ref_logps[yl_idxs]
    pi_logratios = pi_yw_logps - pi_yl_logps
    ref_logratios = ref_yw_logps - ref_yl_logps
    losses = -F.logsigmoid(beta * (pi_logratios - ref_logratios))
    return losses

def dpo_loss_single(pi_logps, ref_logps, beta):
    """
    pi_logps: policy logprobs, shape (B,)
    ref_logps: reference model logprobs, shape (B,)
    yw_idxs: preferred completion indices in [0, B-1], shape (T,)
    yl_idxs: dispreferred completion indices in [0, B-1], shape (T,)
    beta: temperature controlling strength of KL penalty
    """
    pi_yw_logps, pi_yl_logps = pi_logps, pi_logps
    ref_yw_logps, ref_yl_logps = ref_logps, ref_logps
    pi_logratios = pi_yw_logps - pi_yl_logps
    ref_logratios = ref_yw_logps - ref_yl_logps
    losses = -F.logsigmoid(beta * (pi_logratios - ref_logratios))
    return losses

In [109]:
# my own code
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
fine_tuning_model = GPT2LMHeadModel.from_pretrained('gpt2')
reference_model = GPT2LMHeadModel.from_pretrained('gpt2')



In [110]:
# lets go ahead and freeze the weights of our reference model
for param in reference_model.parameters():
    param.requires_grad = False

In [111]:
# Lets make some example data
# make the batch size 2
prompts = [
    'Hello my name is',
    'The wheather is quite',
    "This sentence is going"
]

preffered_response = [
    'Adam',
    'humid',
    'to be really long'
]

rejected_response = [
    'Arnold',
    'sunny',
    'to be short'
]

In [112]:
prompts_lengths = [tokenizer(prompt, return_tensors='pt')['input_ids'].shape[-1] for prompt in prompts]

In [113]:
# lets combine the prompts and the answer.
prompts_preferred = [prompt + ' ' + pr for prompt, pr in zip(prompts, preffered_response)]
prompts_rejected = [prompt + ' ' + rj for prompt, rj in zip(prompts, rejected_response)]

In [114]:
prompts_preferred

['Hello my name is Adam',
 'The wheather is quite humid',
 'This sentence is going to be really long']

In [115]:
# Now we can tokenize each one
tokenized_pp = [tokenizer(pp, return_tensors='pt') for pp in prompts_preferred]
tokenized_pr = [tokenizer(pr, return_tensors='pt') for pr in prompts_rejected]

In [116]:
tokenized_pp

[{'input_ids': tensor([[15496,   616,  1438,   318,  7244]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])},
 {'input_ids': tensor([[  464,   483,  1032,   318,  2407, 35441]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])},
 {'input_ids': tensor([[1212, 6827,  318, 1016,  284,  307, 1107,  890]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}]

In [117]:
# calculate log probabilities
logits_pp = [fine_tuning_model(**pp).logits for pp in tokenized_pp]
logits_pr = [fine_tuning_model(**pr).logits for pr in tokenized_pr]
ref_logits_pp = [reference_model(**pp).logits for pp in tokenized_pp]
ref_logits_pr = [reference_model(**pr).logits for pr in tokenized_pr]

In [118]:
# calculate log probabilities
lgpr_pp = [torch.nn.functional.log_softmax(pp, dim=-1) for pp in logits_pp]
lgpr_pr = [torch.nn.functional.log_softmax(pr, dim=-1) for pr in logits_pr]
ref_lgpr_pp = [torch.nn.functional.log_softmax(pp, dim=-1) for pp in ref_logits_pp]
ref_lgpr_pr = [torch.nn.functional.log_softmax(pr, dim=-1) for pr in ref_logits_pr]

In [119]:
lgpr_pp_sums = []
lgpr_pr_sums = []
ref_lgpr_pp_sums = []
ref_lgpr_pr_sums = []
for i in range(len(lgpr_pp)):
    lgpr_pp_sums.append(torch.sum(torch.gather(lgpr_pp[i][:, prompts_lengths[i]-1:, :], dim=2, index=tokenizer(" " + preffered_response[i], return_tensors='pt')['input_ids'].unsqueeze(2)).squeeze(2)))
    lgpr_pr_sums.append(torch.sum(torch.gather(lgpr_pr[i][:, prompts_lengths[i]-1:, :], dim=2, index=tokenizer(" " + rejected_response[i], return_tensors='pt')['input_ids'].unsqueeze(2)).squeeze(2)))
    ref_lgpr_pp_sums.append(torch.sum(torch.gather(lgpr_pp[i][:, prompts_lengths[i]-1:, :], dim=2, index=tokenizer(" " + preffered_response[i], return_tensors='pt')['input_ids'].unsqueeze(2)).squeeze(2)))
    ref_lgpr_pr_sums.append(torch.sum(torch.gather(lgpr_pr[i][:, prompts_lengths[i]-1:, :], dim=2, index=tokenizer(" " + rejected_response[i], return_tensors='pt')['input_ids'].unsqueeze(2)).squeeze(2)))

In [122]:
torch.gather(lgpr_pp[-1][:, prompts_lengths[-1]-1:, :], dim=2, index=tokenizer(" " + preffered_response[-1], return_tensors='pt')['input_ids'].unsqueeze(2)).squeeze(2)

tensor([[-0.0769, -1.2717, -5.1082, -3.6530]], grad_fn=<SqueezeBackward1>)

In [121]:
lgpr_pr_sums

[tensor(-7.9237, grad_fn=<SumBackward0>),
 tensor(-11.0358, grad_fn=<SumBackward0>),
 tensor(-8.0154, grad_fn=<SumBackward0>)]