Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Seq2Seq Example with GPU Support #60

Closed
bigabig opened this issue Jun 25, 2019 · 9 comments
Closed

Seq2Seq Example with GPU Support #60

bigabig opened this issue Jun 25, 2019 · 9 comments
Assignees
Labels
bug Something isn't working priority: high High priority, should be address immediately

Comments

@bigabig
Copy link

bigabig commented Jun 25, 2019

Hello,

How can I run the Seq2Seq example with my GPU?

I already modified the training data to use the cuda device as well as the model:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

train_data = tx.data.PairedTextData(hparams=config_data.train, device=device)
val_data = tx.data.PairedTextData(hparams=config_data.val, device=device)
test_data = tx.data.PairedTextData(hparams=config_data.test, device=device)

model = Seq2SeqAttn(train_data)
model.to(device)
@huzecong
Copy link
Collaborator

Hi, thank you for your interest in texar-pytorch! Your training settings seems correct to me. Did you encounter any errors during training? If so, could you describe those errors (e.g. stacktrace or error messages)?

@huzecong huzecong added the question Further information is requested label Jun 25, 2019
@bigabig
Copy link
Author

bigabig commented Jun 25, 2019

Hey,

this is the error:

(env) C:\Development\Git\texar-pytorch\examples\seq2seq_attn>python seq2seq_attn.py --config_model config_model --config_data config_toy_copy
Traceback (most recent call last):
  File "seq2seq_attn.py", line 190, in <module>
    main()
  File "seq2seq_attn.py", line 176, in main
    _train_epoch()
  File "seq2seq_attn.py", line 142, in _train_epoch
    loss = model(batch, mode="train")
  File "C:\Development\pytorch\env\lib\site-packages\torch\nn\modules\module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "seq2seq_attn.py", line 94, in forward
    sequence_length=batch['target_length'] - 1)
  File "C:\Development\pytorch\env\lib\site-packages\torch\nn\modules\module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "c:\development\git\texar-pytorch\texar\modules\decoders\rnn_decoders.py", line 623, in forward
    max_decoding_length, impute_finished)
  File "c:\development\git\texar-pytorch\texar\modules\decoders\decoder_base.py", line 385, in dynamic_decode
    decoder_finished) = self.step(helper, time, step_inputs, state)
  File "c:\development\git\texar-pytorch\texar\modules\decoders\rnn_decoders.py", line 494, in step
    inputs, state, self.memory, self.memory_sequence_length)
  File "C:\Development\pytorch\env\lib\site-packages\torch\nn\modules\module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "c:\development\git\texar-pytorch\texar\core\cell_wrappers.py", line 706, in forward
    memory_sequence_length=memory_sequence_length)
  File "c:\development\git\texar-pytorch\texar\core\attention_mechanism.py", line 917, in compute_attention
    attention = attention_layer(torch.cat((cell_output, context), dim=1))
  File "C:\Development\pytorch\env\lib\site-packages\torch\nn\modules\module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "C:\Development\pytorch\env\lib\site-packages\torch\nn\modules\linear.py", line 92, in forward
    return F.linear(input, self.weight, self.bias)
  File "C:\Development\pytorch\env\lib\site-packages\torch\nn\functional.py", line 1408, in linear
    output = input.matmul(weight.t())
RuntimeError: Expected object of backend CUDA but got backend CPU for argument #2 'mat2'

So somewhere I am missing the ".cuda()". I looked a little bit into the code. The LinearLayer (torch.nn.Linear) initializes a weight matrix, that is not a cuda matrix. But I am not so sure if this is the problem, as i just started with pytorch.

self.weight = Parameter(torch.Tensor(out_features, in_features))

@bigabig
Copy link
Author

bigabig commented Jun 25, 2019

When changing the calculation of the weight like this (in torch.nn.Linear) (but i dont think thats a good idea anyways)

self.weight = Parameter(torch.Tensor(out_features, in_features).cuda())

i get the following error after 312 steps:

Traceback (most recent call last):
  File "seq2seq_attn.py", line 190, in <module>
    main()
  File "seq2seq_attn.py", line 178, in main
    val_bleu = _eval_epoch('val')
  File "seq2seq_attn.py", line 159, in _eval_epoch
    infer_outputs = model(batch, mode="infer")
  File "C:\Development\pytorch\env\lib\site-packages\torch\nn\modules\module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "seq2seq_attn.py", line 115, in forward
    max_decoding_length=60)
  File "C:\Development\pytorch\env\lib\site-packages\torch\nn\modules\module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "c:\development\git\texar-pytorch\texar\modules\decoders\rnn_decoders.py", line 623, in forward
    max_decoding_length, impute_finished)
  File "c:\development\git\texar-pytorch\texar\modules\decoders\decoder_base.py", line 385, in dynamic_decode
    decoder_finished) = self.step(helper, time, step_inputs, state)
  File "c:\development\git\texar-pytorch\texar\modules\decoders\rnn_decoders.py", line 502, in step
    sample_ids=sample_ids)
  File "c:\development\git\texar-pytorch\texar\modules\decoders\decoder_helpers.py", line 324, in next_inputs
    finished = (sample_ids == self._end_token)
RuntimeError: Expected object of backend CUDA but got backend CPU for argument #2 'other'

@huzecong
Copy link
Collaborator

I see. This is indeed a bug in texar-pytorch. In lines 539-543 of texar.core.cell_wrappers, the created nn.Linear layers are stored in a tuple, so they're not recognized by PyTorch. The correct implementation should store the layers in an nn.ModuleList.

We're going to fix this ASAP. In the meantime, you can change these lines into:

self._attention_layers = nn.ModuleList(
    nn.Linear(attention_mechanisms[i].encoder_output_size +
              cell.hidden_size,
              attention_layer_sizes[i],
              False) for i in range(len(attention_layer_sizes)))

to work around the issue.

@gpengzhi Can you fix this issue soon?

@huzecong huzecong added bug Something isn't working and removed question Further information is requested labels Jun 25, 2019
@huzecong
Copy link
Collaborator

... and the error after 312 steps is another problem. That was during validation, where the model created an "infer_greedy" helper. The end_token passed into the helper should be put on the same device (i.e. GPU) as the rest of the tensor, but it wasn't.

To fix this, change line 107 of seq2seq_attn.py to:

end_token=start_token.new_full((1,), self.eos_token_id))

@bigabig
Copy link
Author

bigabig commented Jun 25, 2019

hey, so i tried

            helper_infer = self.decoder.create_helper(
                decoding_strategy="infer_greedy",
                embedding=self.target_embedder,
                start_tokens=start_tokens,
                end_token=start_tokens.new_full((1,), self.eos_token_id))

and i got this error:

step=312, loss=0.6623
Traceback (most recent call last):
  File "seq2seq_attn.py", line 190, in <module>
    main()
  File "seq2seq_attn.py", line 178, in main
    val_bleu = _eval_epoch('val')
  File "seq2seq_attn.py", line 159, in _eval_epoch
    infer_outputs = model(batch, mode="infer")
  File "C:\Development\pytorch\env\lib\site-packages\torch\nn\modules\module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "seq2seq_attn.py", line 109, in forward
    end_token=start_tokens.new_full((1,), self.eos_token_id))
  File "c:\development\git\texar-pytorch\texar\modules\decoders\decoder_base.py", line 292, in create_helper
    embedding, start_tokens, end_token)
  File "c:\development\git\texar-pytorch\texar\modules\decoders\decoder_helpers.py", line 300, in __init__
    super().__init__(start_tokens, end_token)
  File "c:\development\git\texar-pytorch\texar\modules\decoders\decoder_helpers.py", line 248, in __init__
    raise ValueError("end_token must be a scalar")
ValueError: end_token must be a scalar

@gpengzhi
Copy link
Collaborator

gpengzhi commented Jun 25, 2019

Thank you for your feedback. You could try end_token=start_tokens.new_tensor(self.eos_token_id). start_tokens.new_full((1,), self.eos_token_id) returns a 1-dimensional tensor instead of a 0-dimensional (scalar) tensor.

@bigabig
Copy link
Author

bigabig commented Jun 25, 2019

Thanks, this seems to fix this issue. However, now I get the following error:

step=312, loss=2.8005
Traceback (most recent call last):
  File "seq2seq_attn.py", line 190, in <module>
    main()
  File "seq2seq_attn.py", line 178, in main
    val_bleu = _eval_epoch('val')
  File "seq2seq_attn.py", line 165, in _eval_epoch
    ids=output_ids, vocab=val_data.target_vocab)
  File "c:\development\git\texar-pytorch\texar\data\vocabulary.py", line 312, in map_ids_to_strs
    tokens = vocab.map_ids_to_tokens_py(ids)
  File "c:\development\git\texar-pytorch\texar\data\vocabulary.py", line 166, in map_ids_to_tokens_py
    return dict_lookup(self.id_to_token_map_py, ids, self.unk_token)
  File "c:\development\git\texar-pytorch\texar\utils\utils.py", line 654, in dict_lookup
    return np.vectorize(lambda x: dict_.get(x, default))(keys)  # type: ignore
  File "C:\Development\pytorch\env\lib\site-packages\numpy\lib\function_base.py", line 2091, in __call__
    return self._vectorize_call(func=func, args=vargs)
  File "C:\Development\pytorch\env\lib\site-packages\numpy\lib\function_base.py", line 2161, in _vectorize_call
    ufunc, otypes = self._get_ufunc_and_otypes(func=func, args=args)
  File "C:\Development\pytorch\env\lib\site-packages\numpy\lib\function_base.py", line 2115, in _get_ufunc_and_otypes
    args = [asarray(arg) for arg in args]
  File "C:\Development\pytorch\env\lib\site-packages\numpy\lib\function_base.py", line 2115, in <listcomp>
    args = [asarray(arg) for arg in args]
  File "C:\Development\pytorch\env\lib\site-packages\numpy\core\numeric.py", line 538, in asarray
    return array(a, dtype, copy=False, order=order)
  File "C:\Development\pytorch\env\lib\site-packages\torch\tensor.py", line 458, in __array__
    return self.numpy()
TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

@gpengzhi
Copy link
Collaborator

Thank you for the feedback. This issue happens when one tries to convert a tensor value stored in the GPU (cuda) to numpy. Use output_ids = infer_outputs.sample_id.cpu() to bring back the tensor to the cpu at first.

@huzecong huzecong added the priority: high High priority, should be address immediately label Jun 25, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working priority: high High priority, should be address immediately
Projects
None yet
Development

No branches or pull requests

3 participants