Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

position_ids related PPO bug #217

Closed
tianhao-nexusflow opened this issue Feb 21, 2024 · 2 comments
Closed

position_ids related PPO bug #217

tianhao-nexusflow opened this issue Feb 21, 2024 · 2 comments

Comments

@tianhao-nexusflow
Copy link

tianhao-nexusflow commented Feb 21, 2024

I'm working to reproduce results and align them with trlx. After careful code review, I've identified this issue:

Concern: link, Precomputing log probabilities using only token_ids and attention_mask may not fully account for the model's use of positional embeddings. As for llama model, the same token_ids with or without left padding would have different result even it explicitly pass in the attention_mask. A common strategy to correct this is to explicitly pass in the position_ids as in link

Potential Impact: The same padding related issue could also impact the generate quality, while I need more time to carefully look into the code. Will this also impact the critic and reward modules?

@hijkzzz
Copy link
Collaborator

hijkzzz commented Feb 21, 2024

This issue is very interesting and I will confirm it as soon as possible. If so, the bug is caused by huggingface and we will fix it.

@tianhao-nexusflow btw, huggingface TRL also does not use the position_ids
https://github.com/huggingface/trl/blob/a46cd84a6405312837f0d0e56fd1cf4d45585770/trl/trainer/ppo_trainer.py#L920

hijkzzz added a commit that referenced this issue Feb 21, 2024
hijkzzz added a commit that referenced this issue Feb 21, 2024
@hijkzzz
Copy link
Collaborator

hijkzzz commented Feb 21, 2024

fixed position_ids
But there are still small precision issues between left pad and right pad

Llama-7B

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

modelname = "OpenLLMAI/Llama-2-7b-sft-model-ocra-500k"
model = AutoModelForCausalLM.from_pretrained(modelname).cuda()

inputs={'input_ids': torch.tensor([[    1,  7251,   727, 29901, 29871],
        [    2,     2,     1, 29871, 29896]]).cuda(), 'attention_mask': torch.tensor([[1, 1, 1, 1, 1],
        [0, 0, 1, 1, 1]]).cuda()}
inputs2={'input_ids': torch.tensor([[    1,  7251,   727, 29901, 29871],
        [    1, 29871, 29896,     2,     2]]).cuda(), 'attention_mask': torch.tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 0, 0]]).cuda()}

# baseline
output = model(**inputs)
output2 = model(**inputs2)

output2.logits[1][:3] - output.logits[1][-3:]
tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0020, -0.0010, -0.0001,  ...,  0.0006,  0.0013,  0.0007],
        [ 0.0025,  0.0040, -0.0005,  ...,  0.0025,  0.0015,  0.0008]],
       device='cuda:0', grad_fn=<SubBackward0>)

# fixed positions
position_ids = inputs['attention_mask'].long().cumsum(-1) - 1
position_ids2 = inputs2['attention_mask'].long().cumsum(-1) - 1

output = model(**inputs, position_ids=position_ids)
output2 = model(**inputs2, position_ids=position_ids2)

output2.logits[1][:3] - output.logits[1][-3:]
tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 9.3555e-04, -8.1062e-06, -8.5831e-05,  ...,  7.9441e-04,
          5.7936e-04,  4.6229e-04]], device='cuda:0', grad_fn=<SubBackward0>)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants