Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues using BART model for inference #6

Open
AADeLucia opened this issue Mar 10, 2022 · 1 comment
Open

Issues using BART model for inference #6

AADeLucia opened this issue Mar 10, 2022 · 1 comment

Comments

@AADeLucia
Copy link

I am trying to use scripts/prep.sh and scripts/inference.py to load /reddit_vanilla_actual/checkpoint_best.pt BART for inference. I have been having many issues, mostly related to package versions and the extended 2048 source positions.

Environment:

pytorch                   1.7.1           py3.8_cuda10.2.89_cudnn7.6.5_0    pytorch

And I tried installing fairseq from source to access the examples module, but then I saw you had your own copy of fairseq in this repo so I installed your version according to the instructions here

cd fairseq
pip install --editable ./
python setup.py build develop

I binarized val.source and val.target from and am running inference as such:

python scripts/inference.py /home/aadelucia/ConvoSumm/checkpoints/reddit_vanilla_actual checkpoint_best.pt /home/aadelucia/ConvoSumm/alexandra_test/data_processed /home/aadelucia/ConvoSumm/alexandra_test/data/val.source /home/aadelucia/ConvoSumm/alexandra_test/inference_output.txt 4 1 80 120 1 2048 ./misc/encoder.json ./misc/vocab.bpe

And I get the following error:

Traceback (most recent call last):
  File "scripts/inference.py", line 42, in <module>
    hypotheses_batch = bart.sample(slines, beam=beam, lenpen=lenpen, min_len=min_len, no_repeat_ngram_size=3)
  File "/home/aadelucia/ConvoSumm/code/fairseq/fairseq/hub_utils.py", line 132, in sample
    batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
  File "/home/aadelucia/ConvoSumm/code/fairseq/fairseq/models/bart/hub_interface.py", line 108, in generate
    return super().generate(
  File "/home/aadelucia/ConvoSumm/code/fairseq/fairseq/hub_utils.py", line 171, in generate
    for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):
  File "/home/aadelucia/ConvoSumm/code/fairseq/fairseq/hub_utils.py", line 258, in _build_batches
    batch_iterator = self.task.get_batch_iterator(
  File "/home/aadelucia/ConvoSumm/code/fairseq/fairseq/tasks/fairseq_task.py", line 244, in get_batch_iterator
    batch_sampler = dataset.batch_by_size(
  File "/home/aadelucia/ConvoSumm/code/fairseq/fairseq/data/fairseq_dataset.py", line 145, in batch_by_size
    return data_utils.batch_by_size(
  File "/home/aadelucia/ConvoSumm/code/fairseq/fairseq/data/data_utils.py", line 337, in batch_by_size
    return batch_by_size_vec(
  File "fairseq/data/data_utils_fast.pyx", line 20, in fairseq.data.data_utils_fast.batch_by_size_vec
  File "fairseq/data/data_utils_fast.pyx", line 27, in fairseq.data.data_utils_fast.batch_by_size_vec
AssertionError: Sentences lengths should not exceed max_tokens=1024

Am I using the wrong version of a package? Is there something extra needed for this to work?

@AADeLucia
Copy link
Author

Nevermind, seems to be working when I pass in max_tokens=max_source_positions in scripts/inference.py

bart = BARTModel.from_pretrained(
    model_dir,
    checkpoint_file=model_file,
    data_name_or_path=bin_folder,
    gpt2_encoder_json=encoder_file,
    gpt2_vocab_bpe=vocab_file,
    max_source_positions=max_source_positions,
    max_tokens=max_source_positions
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant