Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consider making the decoding classes subclasses of torch.nn.Module and registering them via add_module() #8436

Closed
galv opened this issue Feb 15, 2024 · 9 comments
Assignees

Comments

@galv
Copy link
Collaborator

galv commented Feb 15, 2024

Is your feature request related to a problem? Please describe.

@artbataev pointed out to me an issue with calling this when the model is using the cuda graph rnn-t decoder:

from nemo.collections.asr.models import ASRModel
import torch

asr_model = ASRModel.from_pretrained(model_name, map_location=torch.device("cuda:0"))
asr_model = asr_model.to(torch.device("cuda:1"))

Basically, if I initialize the cuda decoder with instance variable buffers (or parameters) that are torch.Tensors on device cuda:0, the "to()" method won't move them over to device cuda:1, because to() recurses only into members that are also torch.nn.Modules: https://github.com/pytorch/pytorch/blob/c3b4d78e175920141de210f44d292971d7c52ff0/torch/nn/modules/module.py#L572

In order to make that behavior work, I would need to every class that transitively uses an instance of RNNTGreedyDecodeCudaGraph to inherit from torch.nn.Module, so that to() will act properly. This seems like a lot of work, but would allow this API call to work as expected. Otherwise, you get an error as Vladimir shows here: #8191 (comment)

.to() is being called here transitively

instance = instance.to(map_location)
by setup_model() here:
if cfg.cuda is None:
if torch.cuda.is_available():
device = [0] # use 0th CUDA device
accelerator = 'gpu'
map_location = torch.device('cuda:0')
elif cfg.allow_mps and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
logging.warning(
"MPS device (Apple Silicon M-series GPU) support is experimental."
" Env variable `PYTORCH_ENABLE_MPS_FALLBACK=1` should be set in most cases to avoid failures."
)
device = [0]
accelerator = 'mps'
map_location = torch.device('mps')
else:
device = 1
accelerator = 'cpu'
map_location = torch.device('cpu')
else:
device = [cfg.cuda]
accelerator = 'gpu'
map_location = torch.device(f'cuda:{cfg.cuda}')
logging.info(f"Inference will be done on device: {map_location}")
asr_model, model_name = setup_model(cfg, map_location)

in his transcribe_speech.py command line.

Basically any code in NeMo that allocates a torch.tensor that isn't tracked by pytorch's tracking of torch.Tensors (via wrapping it in either a parameter or a buffer and putting it inside a torch.nn.Module) will fail to be converted properly by to(). This could also cause subtle bugs in converting a model from float32 to bfloat16 as well. @erastorgueva-nv I don't think this is what you might be seeing in your Canary debugging, but FYI.

I ultimately don't think this is a good idea since it seems like a lot of work. I think a better way to fix this is to avoid calling to() at all, and instead have users set CUDA_VISIBLE_DEVICES environment variable appropriately before starting a process (and remove the cuda=1 config option in transcribe_speech.py). Note that using CUDA_VISIBLE_DEVICES to specify a single GPU and simply using torch.device("cuda") instead of specifying a device index in code via torch.device("cuda", my_index) is what's recommend: https://pytorch.org/docs/stable/generated/torch.cuda.set_device.html#torch.cuda.set_device

Do we have any use for multiple cuda devices in a single process in NeMo?

galv added a commit to galv/NeMo that referenced this issue Feb 15, 2024
Initialize cuda tensors lazily on first call of __call__ isntead of __init__.

We don't know what device is going to be used at construction time,
and we can't rely on torch.nn.Module.to() to work here. See here:
NVIDIA#8436

This fixes an error "Expected all tensors to be on the same device,
but found at least two devices" that happens when you call to() on your
torch.nn.Module after constructing it.

NVIDIA#8191 (comment)

Remove excess imports.

Check a few more error messages from nvrtc.

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
@titu1994
Copy link
Collaborator

There's a reason those decoding classes aren't modules. When rnnt was being developed it turned out there was a memory leak due to pt graph tracking and Autoregressive calls to the decoder joint. That's why .freeze(), .as_frozen() was developed in NeMos core to deal with that.

I dunno if that reason is still valid in pytorch after the inference_mode() decorator was added, maybe safe to make them modules again.

As you've mentioned, it is a lot of work to do this change, and I don't feel it is super high priority, but if someone wants to take a look that's fine.

To get device inside the decoding classes, the preferred mechanism is mostly next(module.parameters()).device and cache that value inside the inner forwards.

CUDA_VISIBLE_DEVICES environment variable appropriately before starting a process (and remove the cuda=1

This is not a good user experience and I want to avoid it. Users should not have to use any env flags for use Nemo ASR as a necessity.

Do we have any use for multiple cuda devices in a single process in NeMo?

Well if you're training, yes, but it's handled by ptl. During inference, we stick to single GPU for now, but with larger models we may consider doing multi GPU in the future.

@artbataev
Copy link
Collaborator

I frequently use in the notebooks model.to("cuda:1"), and expect this to work (I have multiple GPUs). It is inconvenient to use CUDA_VISIBLE_DEVICES with notebooks.

I think that there are two possible approaches:

  • if the decoder wants to store anything on the GPU (or MPS), it should be responsible for auto-transferring self-state to the device. This is due to Infer classes, which are not nn.Module instances (but have Joint/Prednet as submodules). In this case, everything will work, but the code's author should care about this.
  • redesign the RNN-T in the following way using nn.Modules
    • DecodingWrapper is an nn.Module instance and owns Joint and Prednet
    • decoding strategies inherit from DecodingWrapper (e.g., BatchedGreedyDecodingWrapper)
    • the core RNN-T model does not have joint/prednet members anymore, only through DecodingWrappers subclasses
    • so, everything inherits nn.Module, no duplication of ownership of joint/prednet, buffers are automatically moved to the appropriate device
    • changing the decoding strategy implies changing a subclass of DecodingWrapper (transferring Joint/Prednet from the current wrapper)

However, such a redesign will break the current checkpoints and will require a lot of work and testing.

@titu1994
Copy link
Collaborator

It's a few lines to manually manage device type from the encoder, decoder or joints param dtype, we shouldn't have such breaking changes for this

@titu1994
Copy link
Collaborator

On top of this, yes the fact that parameter count would be double counted for decoder and joint with such a redesign is also bad.

@artbataev
Copy link
Collaborator

I agree that, for now, such breaking changes are undesirable.

On top of this, yes the fact that parameter count would be double counted for decoder and joint with such a redesign is also bad
Nope, the concept aim is exactly to avoid parameter duplication.

Currently, rnnt model owns:

  • encoder
  • joint
  • prediction network
  • *Infer (not nn.Module), owns
    • joint (duplicate)
    • prediction network (duplicate)

Proposed, rnnt model owns:

  • encoder
  • *DecodingWrapper, owns
    • joint
    • prediction network

@titu1994
Copy link
Collaborator

titu1994 commented Feb 16, 2024

That's just bad design, why merge the transcription and prediction network when in all literature it's denoted as separate modules. Also, you don't always call the prednet and decoder network with the same set of inputs (train time prepends blank, eval time it starts with blank as first token for autoregressive decoding).

This is not a viable proposal in my opinion

@titu1994
Copy link
Collaborator

I don't really get what's the big issue with the decoding framework not being a neural module. It's responsibility is not to act as a parameter based operation on NN network as part of the forward, it's a agnostic layer that provides a stable interface (Hypothesis) to map the enc Dec Joint logprobs to text. It's a logical separation, not a module dependency.

I understand that certain issues can arise due to current design, that's cause we're using more advanced design. However very simple solution to this exists which Daniel has already implemented and is trivial to do (and also recommended by pytorch btw - to base new tensors on the device of the current active ones that it will interact with).

I don't see why such a thing requires a refactoring of the deciding framework. If there's a bug we fix it, we don't scrap the entire thing and do it all over cause of "pytorch design pattern" (which Nemo does not fully follow by design, we instead use PTL design pattern)

@galv
Copy link
Collaborator Author

galv commented Feb 16, 2024

My overall conclusion is that I have encountered an edge case, and the right approach is just to recreate the appropriate state tensors anytime that the device of the input tensors changes from what was being used before.

galv added a commit that referenced this issue Feb 26, 2024
* Speed up RNN-T greedy decoding with cuda graphs

This uses CUDA 12.3's conditional node support.

Initialize cuda tensors lazily on first call of __call__ instead of __init__.

We don't know what device is going to be used at construction time,
and we can't rely on torch.nn.Module.to() to work here. See here:
#8436

This fixes an error "Expected all tensors to be on the same device,
but found at least two devices" that happens when you call to() on your
torch.nn.Module after constructing it.

#8191 (comment)

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
yaoyu-33 pushed a commit that referenced this issue Feb 26, 2024
* Speed up RNN-T greedy decoding with cuda graphs

This uses CUDA 12.3's conditional node support.

Initialize cuda tensors lazily on first call of __call__ instead of __init__.

We don't know what device is going to be used at construction time,
and we can't rely on torch.nn.Module.to() to work here. See here:
#8436

This fixes an error "Expected all tensors to be on the same device,
but found at least two devices" that happens when you call to() on your
torch.nn.Module after constructing it.

#8191 (comment)

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
@galv
Copy link
Collaborator Author

galv commented Feb 28, 2024

Closing this. The better way is lazy initialization given how nemo currently is.

@galv galv closed this as completed Feb 28, 2024
zpx01 pushed a commit to zpx01/NeMo that referenced this issue Mar 8, 2024
* Speed up RNN-T greedy decoding with cuda graphs

This uses CUDA 12.3's conditional node support.

Initialize cuda tensors lazily on first call of __call__ instead of __init__.

We don't know what device is going to be used at construction time,
and we can't rely on torch.nn.Module.to() to work here. See here:
NVIDIA#8436

This fixes an error "Expected all tensors to be on the same device,
but found at least two devices" that happens when you call to() on your
torch.nn.Module after constructing it.

NVIDIA#8191 (comment)

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
Signed-off-by: Zeeshan Patel <zeeshanp@berkeley.edu>
JRD971000 pushed a commit that referenced this issue Mar 15, 2024
* Speed up RNN-T greedy decoding with cuda graphs

This uses CUDA 12.3's conditional node support.

Initialize cuda tensors lazily on first call of __call__ instead of __init__.

We don't know what device is going to be used at construction time,
and we can't rely on torch.nn.Module.to() to work here. See here:
#8436

This fixes an error "Expected all tensors to be on the same device,
but found at least two devices" that happens when you call to() on your
torch.nn.Module after constructing it.

#8191 (comment)

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
Signed-off-by: ataghibakhsh <ataghibakhsh@nvidia.com>
pablo-garay pushed a commit that referenced this issue Mar 19, 2024
* Speed up RNN-T greedy decoding with cuda graphs

This uses CUDA 12.3's conditional node support.

Initialize cuda tensors lazily on first call of __call__ instead of __init__.

We don't know what device is going to be used at construction time,
and we can't rely on torch.nn.Module.to() to work here. See here:
#8436

This fixes an error "Expected all tensors to be on the same device,
but found at least two devices" that happens when you call to() on your
torch.nn.Module after constructing it.

#8191 (comment)

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
Signed-off-by: Pablo Garay <pagaray@nvidia.com>
rohitrango pushed a commit to rohitrango/NeMo that referenced this issue Jun 25, 2024
* Speed up RNN-T greedy decoding with cuda graphs

This uses CUDA 12.3's conditional node support.

Initialize cuda tensors lazily on first call of __call__ instead of __init__.

We don't know what device is going to be used at construction time,
and we can't rely on torch.nn.Module.to() to work here. See here:
NVIDIA#8436

This fixes an error "Expected all tensors to be on the same device,
but found at least two devices" that happens when you call to() on your
torch.nn.Module after constructing it.

NVIDIA#8191 (comment)

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants