Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: gather(): Expected dtype int64 for index, in beam_search/beam_search.py, line 26, in fn #82

Closed
linhuixiao opened this issue Jul 22, 2022 · 1 comment

Comments

@linhuixiao
Copy link

Meshed-Memory Transformer Evaluation
Evaluation: 0%|
Evaluation: 0%| | 0/500 [00:00<?, ?it/s]

Traceback (most recent call last):
File "test.py", line 78, in
scores = predict_captions(model, dict_dataloader_test, text_field)
File "test.py", line 26, in predict_captions
out, _ = model.beam_search(images, 20, text_field.vocab.stoi[''], 5, out_size=1)
File "/home/lhxiao/pcl_experiment_202203/meshed-memory-transformer/models/captioning_model.py", line 70, in beam_search
return bs.apply(visual, out_size, return_probs, **kwargs)
File "/home/lhxiao/pcl_experiment_202203/meshed-memory-transformer/models/beam_search/beam_search.py", line 71, in apply
visual, outputs = self.iter(t, visual, outputs, return_probs, **kwargs)
File "/home/lhxiao/pcl_experiment_202203/meshed-memory-transformer/models/beam_search/beam_search.py", line 121, in iter
self.model.apply_to_states(self._expand_state(selected_beam, cur_beam_size))
File "/home/lhxiao/pcl_experiment_202203/meshed-memory-transformer/models/containers.py", line 30, in apply_to_states
self._buffers[name] = fn(self._buffers[name])
File "/home/lhxiao/pcl_experiment_202203/meshed-memory-transformer/models/beam_search/beam_search.py", line 26, in fn
s = torch.gather(s.view(*([self.b_s, cur_beam_size] + shape[1:])), 1,
RuntimeError: gather(): Expected dtype int64 for index

@linhuixiao
Copy link
Author

linhuixiao commented Jul 22, 2022

I have solved this bug:

this is a bug, please fix the code in models/beam_search.py line 118:

        # selected_beam = selected_idx / candidate_logprob.shape[-1]
        selected_beam = torch.div(selected_idx, candidate_logprob.shape[-1], rounding_mode="floor")  

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants