Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support memory_format on to() #144

Closed
Tracked by #179
tfogal opened this issue Apr 9, 2024 · 1 comment · Fixed by #157
Closed
Tracked by #179

Support memory_format on to() #144

tfogal opened this issue Apr 9, 2024 · 1 comment · Fixed by #157
Assignees
Labels
enhancement New feature or request MegatronImagen Needed to support NeMo's MegatronImagen model (text to image generation) nemo Issues needed to support NVIDIA NeMo models.

Comments

@tfogal
Copy link
Collaborator

tfogal commented Apr 9, 2024

🚀 Feature

to(memory_format=something) is part of the MegatronImagen model in NeMo.

Ideally, this would work:

$ git diff .
diff --git a/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py b/nemo/collections/multimodal/models/text_to_image/imen/imagen.py
index 4fa6cd230..2cf7a8ffa 100644
--- a/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py
+++ b/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py
@@ -31,6 +31,7 @@ from nemo.collections.nlp.modules.common.megatron.module import Float16Module
 from nemo.collections.nlp.parts.utils_funcs import get_last_rank
 from nemo.core.classes.common import Serialization
 from nemo.utils import logging
+import thunder
 
 try:
     from apex import amp
@@ -190,6 +191,7 @@ class MegatronImagen(MegatronBaseModel):
         self.megatron_amp_O2 = cfg.get('megatron_amp_O2', False)
 
         self.model = self.model_provider_func()
+        self.model = thunder.jit(self.model)
 
         if self.trainer.precision in ['bf16', 'bf16-mixed']:
             self.autocast_dtype = torch.bfloat16

Motivation

Trying to evaluate NeMo models in thunder and expand our model support there. Megatron-based models appear to be widely used.

Alternatives

I wonder if we could temporarily just accept the keyword without actually doing anything about it. I imagine that would be very slow, but it might allow us to get models like this one into thunder more easily.

I'll start trying to convert smaller parts of the model next.

Additional context

Model in question:

https://github.com/NVIDIA/NeMo/blob/23baa48e441ecb6cc6b49c23bf8cfc076db38bdc/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py#L175

I think the to that is failing for me
is actually this line:
https://github.com/NVIDIA/NeMo/blob/23baa48e441ecb6cc6b49c23bf8cfc076db38bdc/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py#L135

Model test:
log.txt

@tfogal tfogal added enhancement New feature or request help wanted Extra attention is needed triage review labels Apr 9, 2024
@jjsjann123 jjsjann123 self-assigned this Apr 10, 2024
@jjsjann123
Copy link
Collaborator

I think we just didn't add the memory_format arg in our thunder/torch/__init__.py

we should be easily mapping it to stride_order prim.

I'll take a stab on this one.

@tfogal tfogal added nemo Issues needed to support NVIDIA NeMo models. MegatronImagen Needed to support NeMo's MegatronImagen model (text to image generation) and removed triage review help wanted Extra attention is needed labels Apr 12, 2024
@t-vi t-vi closed this as completed in #157 Apr 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request MegatronImagen Needed to support NeMo's MegatronImagen model (text to image generation) nemo Issues needed to support NVIDIA NeMo models.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants