Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recursion error in transformer module with NeMo Stable Diffusion #461

Open
athitten opened this issue May 26, 2024 · 6 comments 路 May be fixed by #626
Open

Recursion error in transformer module with NeMo Stable Diffusion #461

athitten opened this issue May 26, 2024 · 6 comments 路 May be fixed by #626
Assignees
Labels
bug Something isn't working high priority nemo Issues needed to support NVIDIA NeMo models. triage review

Comments

@athitten
Copy link

athitten commented May 26, 2024

馃悰 Bug

NeMo's Stable Diffusion uses CLIPTextModel from HuggingFace transformers. Using thunder.jit with the CLIPTextModel is causing a RecursionError.

To Reproduce

Steps to reproduce the behavior:

  1. Add the following lines to transformers/models/clip/modeling_clip.py here in the location where transformers is installed in your container
        ## thunder.jit
        self.embeddings = thunder.jit(self.embeddings)
        self.encoder = thunder.jit(self.encoder)
        self.final_layer_norm = thunder.jit(self.final_layer_norm)
  1. Run NeMo Stable Diffusion with the command below:
python examples/multimodal/text_to_image/stable_diffusion/sd_train.py trainer.precision=16 trainer.num_nodes=1 trainer.devices=1 ++exp_manager.max_time_per_run=00:00:03:00 trainer.max_steps=20 model.micro_batch_size=1 model.global_batch_size=1 model.data.synthetic_data=True exp_manager.exp_dir=/workspace/TestData/multimodal/stable_diffusion_train model.inductor=False model.cond_stage_config._target_=nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder ++model.cond_stage_config.version=openai/clip-vit-large-patch14 ++model.cond_stage_config.max_length=77 ~model.cond_stage_config.restore_from_path ~model.cond_stage_config.freeze ~model.cond_stage_config.layer model.unet_config.from_pretrained=null model.first_stage_config.from_pretrained=null model.unet_config.use_flash_attention=False model.unet_config.attention_resolutions=\[1\] model.unet_config.channel_mult=\[1\]

Partial stack trace below:

  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/software/lightning-thunder/thunder/core/module.py", line 49, in forward
    res = self._forward_fn(*args, **kwargs)
  File "/workspace/software/lightning-thunder/thunder/__init__.py", line 617, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "/workspace/software/lightning-thunder/thunder/__init__.py", line 202, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "/workspace/software/lightning-thunder/thunder/__init__.py", line 540, in get_computation_and_inputs
    autocast(computation_trc.python_callable(), dtype=autocast_thunder_dtype), *inps
  File "/workspace/software/lightning-thunder/thunder/core/trace.py", line 437, in python_callable
    python_str = self.python(**kwargs)
  File "/workspace/software/lightning-thunder/thunder/core/trace.py", line 318, in python
    import_ctx, call_ctx, object_ctx = self._gather_ctxs()
  File "/workspace/software/lightning-thunder/thunder/core/trace.py", line 281, in _gather_ctxs
    bsym_import_ctx, bsym_call_ctx, bsym_object_ctx = bsym.gather_ctxs()
  File "/workspace/software/lightning-thunder/thunder/core/symbol.py", line 580, in gather_ctxs
    return self.import_ctx(), self._get_call_ctx(), self.object_ctx()
  File "/workspace/software/lightning-thunder/thunder/core/symbol.py", line 520, in import_ctx
    self._out_printables, self._arg_printables, self._kwarg_printables  # type: ignore
  File "/workspace/software/lightning-thunder/thunder/core/symbol.py", line 472, in _out_printables
    return codeutils.to_printable(trace, self.output, import_ctx=self._import_ctx, object_ctx=self._object_ctx)
  File "/workspace/software/lightning-thunder/thunder/core/codeutils.py", line 128, in to_printable
    printables.append(to_printable(trace, f, import_ctx=import_ctx, object_ctx=object_ctx))
  File "/workspace/software/lightning-thunder/thunder/core/codeutils.py", line 128, in to_printable
    printables.append(to_printable(trace, f, import_ctx=import_ctx, object_ctx=object_ctx))
  File "/workspace/software/lightning-thunder/thunder/core/codeutils.py", line 128, in to_printable
    printables.append(to_printable(trace, f, import_ctx=import_ctx, object_ctx=object_ctx))
  [Previous line repeated 2899 more times]
  File "/workspace/software/lightning-thunder/thunder/core/codeutils.py", line 123, in to_printable
    if is_collection(x):
  File "/workspace/software/lightning-thunder/thunder/core/baseutils.py", line 153, in is_collection
    return isinstance(x, collections.abc.Collection) and not isinstance(x, (str, torch.Tensor, np.ndarray))
  File "/usr/lib/python3.10/abc.py", line 117, in __instancecheck__
    def __instancecheck__(cls, instance):
  File "/root/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_trace_dispatch_regular.py", line 469, in __call__
    return None if event == 'call' else NO_FTRACE
RecursionError: maximum recursion depth exceeded in comparison

CC: @tfogal

cc @apaz-cli @tfogal

@athitten athitten added the bug Something isn't working label May 26, 2024
@riccardofelluga riccardofelluga added the nemo Issues needed to support NVIDIA NeMo models. label May 27, 2024
@athitten
Copy link
Author

athitten commented May 30, 2024

FYI just figured that self.encoder consisted of a nn.ModuleList with a for loop (shown below) which probably caused the recursion error.
self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])

Adding thunder.jit to the individual modules of the for loop instead of the entire nn.ModuleList fixed the RecursionError.

@tfogal
Copy link
Collaborator

tfogal commented May 30, 2024

Thanks @athitten ! That's really helpful.

Tagging triage review. Triage team, beyond the obvious "add support for control flow", I'm curious what our options are here.

@t-vi
Copy link
Collaborator

t-vi commented May 30, 2024

Staring down the traceback (rather than running it myself) it does not look like the modules itself (litgpt also uses a for loop over ModuleList), but as if we do have a trace that fails to print itself because of some reference cycle (which might be caused by the interpreter erroneously inserting that into the trace).

@k223kim
Copy link
Contributor

k223kim commented Jun 11, 2024

Hi Team 鈿★笍, currently I am working on this issue and would like to share how I reproduced the same error (just as a reference to anyone else who is working on it). It is quite similar to the code shown above, but just smaller :)

  1. Clone and install from source from hugging face transformers.
git clone https://github.com/huggingface/transformers.git
cd transformers
git checkout tags/v.4.41.2
pip install -e .
  1. As mentioned above, add this to the CLIPTextTransformer class:
        self.embeddings = thunder.jit(self.embeddings)
        self.encoder = thunder.jit(self.encoder)
        self.final_layer_norm = thunder.jit(self.final_layer_norm)
  1. Run the following script: (This is actually from the huggingface repo)
from transformers import CLIPTokenizer, CLIPTextModel

model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")

outputs = model(**inputs)
last_hidden_state = outputs.last_hidden_state
pooled_output = outputs.pooler_output  # pooled (EOS token) states

(cc. @t-vi )

@t-vi
Copy link
Collaborator

t-vi commented Jun 18, 2024

@k223kim debugged this more and the infinite recursion is from to_printable assuming that tree_flatten will "simplify" the input when it in reality produces the original input as part of the flattened objects for BaseModelOutput (from transformers, a dataclass https://github.com/huggingface/transformers/blob/9ba9369a2557e53a01378199a9839ec6e82d8bc7/src/transformers/modeling_outputs.py#L24-L47.)

def to_printable(
    trace: Optional,
    x: Any,
    *,
    import_ctx: Optional[dict] = None,
    object_ctx: Optional[dict] = None,
) -> Printable:
    # Short-circuits if x is a Proxy
    if isinstance(x, ProxyInterface):
        return x

    if is_collection(x):
        flat, spec = tree_flatten(x)

        printables = []
        for f in flat:
            printables.append(to_printable(trace, f, import_ctx=import_ctx, object_ctx=object_ctx))

        printable = tree_unflatten(printables, spec)
        return printable

@t-vi
Copy link
Collaborator

t-vi commented Jun 19, 2024

More minimal repro to create a test in a fix:

import transformers
import torch
import thunder

def fn(x):
    return transformers.modeling_outputs.BaseModelOutput(x)

jfn = thunder.jit(fn)

x = torch.randn(5, 5)

print(jfn(x))

@k223kim k223kim linked a pull request Jun 20, 2024 that will close this issue
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working high priority nemo Issues needed to support NVIDIA NeMo models. triage review
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants