Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

error from generate_from_string method for BART #17

Open
timhartill opened this issue Mar 31, 2021 · 7 comments
Open

error from generate_from_string method for BART #17

timhartill opened this issue Mar 31, 2021 · 7 comments

Comments

@timhartill
Copy link

Hi, I'm attempting to run the BART model example as given in the readme:

import torch
from transformers import BartTokenizer
from bart import MyBart

base_model = "facebook/bart-large"
unifiedqa_path = "unifiedQA-uncased/best-model.pt" # path to the downloaded checkpoint

tokenizer = BartTokenizer.from_pretrained(base_model)
model = MyBart.from_pretrained(base_model, state_dict=torch.load(unifiedqa_path))
model.eval()

x = model.generate_from_string("Which is best conductor? \n (A) iron (B) feather", tokenizer=tokenizer)

The .from_pretrained line executes fine but the .generate_from_string(..) line errors out with the error:

TypeError: forward() got an unexpected keyword argument 'past_key_values'

I tried using the run_model(..) method from the main git page and it gives exactly the same error.

Any idea what might be causing this and how to fix it?

I am using python 3.85 with transformers 4.4.2 and pytorch 1.7.1

@timhartill timhartill changed the title error from generatefromstring method for BART error from generate_from_string method for BART Mar 31, 2021
@danyaljj
Copy link
Contributor

danyaljj commented Apr 1, 2021

Thoughts @shmsw25 ?

@timhartill
Copy link
Author

After a bit more digging it appears that the arguments to the forward method in modelling_bart.py in transformers 4.4.2 are rather different to the arguments passed to the forward method in the unifiedqa bart.py. I'm thinking I may need to update bart.py to match the latest modelling_bart.py to make this work. If I manage to do so would you like a copy of the updated version?

@danyaljj
Copy link
Contributor

danyaljj commented Apr 3, 2021

Sure, thank you! 🙏

@timhartill
Copy link
Author

I forked your code and updated bart.py and also run.py. I've run it a few times and it seems to work. Generally I've commented my changes with comments starting with #TJH..

You can access at: https://github.com/timhartill/unifiedqa-tjh

@danyaljj
Copy link
Contributor

danyaljj commented Apr 4, 2021

Appreciate it! Will look into your changes.

@tshrjn
Copy link

tshrjn commented Apr 6, 2021

transformers 4.x brought breaking changes & past_key_values were changed to past.

But it shouldn't be an issue if you use the HF's modelclass & not the derived class here.

Example of how generation would like:

import torch
from transformers import BartTokenizer, BartForConditionalGeneration

base_model = "facebook/bart-large"
unifiedqa_path = "unifiedQA-uncased/best-model.pt" # path to the downloaded checkpoint

tokenizer = BartTokenizer.from_pretrained(base_model)
model = BartForConditionalGeneration.from_pretrained(base_model, state_dict=torch.load(unifiedqa_path))
model.eval()


def generate_text(text, model, tokenizer):
    inputs = tokenizer([text], max_length=512, truncation=True, return_tensors='pt')

    output_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=50, early_stopping=True)
    return ' '.join([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in output_ids])


text = "Which is best conductor? \\n (A) iron (B) feather"
print('\n', text)
print("generated_text:", generate_text(text, model, tokenizer))


text = "What is the sum of 3 and 5? \\n (A) 8 (B) 3 (C) 5 (D) 10"
print('\n', text)
print("generated_text:", generate_text(text, model, tokenizer))


text = "What is 42 ? \\n 42 is the answer to life, the universe and everything"
print('\n', text)
print("generated_text:", generate_text(text, model, tokenizer))

Or one could also use HF's pipelines as follows:

# Using Pipeline
from transformers import pipeline

text = "What is 42 ? \\n 42 is the answer to life, the universe and everything"
print('\n', text)

text2text_generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
print(text2text_generator(text))

@shmsw25
Copy link
Contributor

shmsw25 commented Apr 7, 2021

Hi @timhartill and @tshrjn,
it looks like the error is coming from the discrepancy in HF versions. The code is written in an older version of HF; please see README. @tshrjn's solution looks like a good workaround to run inference in a newer version. However if you want to run finetuning, I recommend to follow the version in README, as finetuning using a newer version is not guaranteed to reproduce the result in the paper.

@danyaljj I was thinking keeping the version as it is in the repo is better since HF library will keep being updated and it would not easy to update the code every time with the guarantee of reproducing the numbers in the paper. Or we could update the inference code only and put a note that finetuning is only tested with the version in README. What do you think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants