Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug for BertPrompt series code? #15

Closed
tangzhy opened this issue Dec 11, 2021 · 5 comments
Closed

Bug for BertPrompt series code? #15

tangzhy opened this issue Dec 11, 2021 · 5 comments

Comments

@tangzhy
Copy link

tangzhy commented Dec 11, 2021

Hi, I notice that the bert prompt model does not use the cls & linear head. I try to explain it in the following code with toy inputs, where say input_ids 's shape is [8, 32], and pre_seq_len is 3, then inputs_embeds's shall be [8, 35, 768]. I'll comment the shape of the key variables in the code and state my concern

class BertPromptForSequenceClassification(BertPreTrainedModel):
    def forward(*args):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        batch_size = input_ids.shape[0]
        raw_embedding = self.embeddings(
            input_ids=input_ids, 
            position_ids=position_ids,
            token_type_ids=token_type_ids,
        )
        prompts = self.get_prompt(batch_size=batch_size)
        inputs_embeds = torch.cat((prompts, raw_embedding), dim=1) # then inputs
        prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
        attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)

        # inputs_embeds's shape: [8, 35, 768]


        outputs = self.bert(
            # input_ids,
            attention_mask=attention_mask,
            # token_type_ids=token_type_ids,
            # position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            # past_key_values=past_key_values,
        ) 
# since bert encoder will take as inputs  the first token to the bert_pooler, \
# here the real token being used for classifier is the soft prompts' first token!
        
        pooled_output = outputs[1]

I wonder, is p-tuning v2 compared with soft prompt tuning?
But the token being used for the latter one in the head for classification is not the cls.

Is that expected?

@Xiao9905
Copy link
Member

Hi @tangzhy , are you asking the codes for P-tuning v2 or baseline v1?

@tangzhy
Copy link
Author

tangzhy commented Dec 11, 2021

Hi @Xiao9905 , I'm asking questions for the codes of P-tuning v2. I came across the codes of BertPromptForSequenceClassification and got confused with the question I just asked.

@Xiao9905
Copy link
Member

Xiao9905 commented Dec 11, 2021

@tangzhy , I think this piece of codes you are looking at is not for P-tuning v2, but for baseline method P-tuning v1. In P-tuning v2, we adopt an implementation trick (#9) of using past_key_values, which does not influence the position of [CLS] in the output hidden states tensor.

@tangzhy
Copy link
Author

tangzhy commented Dec 11, 2021

@Xiao9905 Yes, I think the trick used for BertPrefixForSequenceClassification is quite clever:) Thanks for this amazing work. I try to ask the code for BertPromptForSequenceClassification: is the problem I stated a real problem? or it is intended? Or maybe I was wrong for missing other details.

@Xiao9905 Xiao9905 mentioned this issue Dec 13, 2021
@Xiao9905
Copy link
Member

@tangzhy Hi, thank you for your reporting. We have fix the problem in our released codes.

@tangzhy tangzhy closed this as completed Dec 14, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants