Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Make LLaMA torch.compile compatible #103

Closed
wants to merge 24 commits into from
Closed

Conversation

awaelchli
Copy link
Contributor

@awaelchli awaelchli commented Apr 6, 2023

Builds on top of #100
Attempt to address #62 by introducing a switch between the complex and non-complex variant of the rope implementation:

  • complex: used to be numerically in parity with Meta's model and checkpoints (inference)
  • non-complex: Used for training and finetuning

Drawback: The switch makes the code less readable and introduces indirection. The reader may struggle understanding the nuance here. If this is going too much against our minimalistic principles, we should drop this idea. Suggestions welcome!

Adapter
no compile + dynamic padding: between 90ms and 160ms
no compile + full padding: 352 ms
torch.compile + full padding: 118 ms

@lantiga
Copy link
Collaborator

lantiga commented Apr 6, 2023

I'm not that worried about the added conditionals, they look localized to calling rope.
Maybe we should just create a rope.py file that contains all things RoPE, so model.py is less cluttered.

@ipoletaev
Copy link

~1min vs. 120ms

Something is surely off. Compilation is almost never hurting. At least because it optimizes the graph size thereby making smaller which makes CUDA allocs faster, etc, etc.

Super curious what is not working?

For me in FSDP mode compilation (without activations checkpointing which doesn't work with compilation) torch.compile() decreases step time by ~30-40% for 7B model.

@lantiga
Copy link
Collaborator

lantiga commented Apr 20, 2023

Needs to be updated after #174 lands.
The slowdown seems to be related to the use of dynamic shapes. We'll resume this shortly.

@t-vi
Copy link
Contributor

t-vi commented Apr 20, 2023

For the dynamic shapes and quantization, I did ad-hoc padding in #173, if we fix this here, I should also update #173.

@lantiga
Copy link
Collaborator

lantiga commented May 1, 2023

What's the intention with this one?

@awaelchli
Copy link
Contributor Author

I'm closing the PR because adding torch.compile in our scripts is not worth it at the moment. In summary:

  • generate.py: The first iterations are slow and negate what we win with torch.compile.
  • train.py: torch.compile does not work with activation checkpointing in FSDP
  • finetune_adapter.py: Requires making the input shapes constant. By doing so (with padding) we increase the overall input length and iteration time over the baseline. Adding torch.compile on top negates this, but we win almost nothing compared to the current dynamic seq.-lenght version.
  • finetune_lora.py: Does not compile (I will investigate separately), but has the same issue as finetune_adapter.py

@awaelchli awaelchli closed this May 4, 2023
@ezyang
Copy link

ezyang commented May 9, 2023

@awaelchli did you try dynamic=True at any point? That would remove the need for padding. In our experiments with a different variant of llama, generation ends up speeding up 1.3x after compilation. (Perhaps the problem is we need to cache compilation products across process invocations, this is something we know about)

@awaelchli
Copy link
Contributor Author

@ezyang Thanks for your suggestion. Yes I did try it. For the generate.py inference script, it took a very long time to compile the iterations, even with dynamic=True (378.39 sec total, 0.13 tokens/sec), which is 100x slower compared to no compilation.

And for the finetune_adapter.py script with dynamic=True I got a strange error that I didn't understand:

Traceback (most recent call last):
  File "/home/adrian/repositories/lightning-llama/finetune_adapter.py", line 258, in <module>
    CLI(main)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/jsonargparse/cli.py", line 82, in CLI
    return _run_component(component, cfg_init)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/jsonargparse/cli.py", line 138, in _run_component
    return component(**cfg)
  File "/home/adrian/repositories/lightning-llama/finetune_adapter.py", line 101, in main
    train(fabric, model, optimizer, train_data, val_data, out_dir)
  File "/home/adrian/repositories/lightning-llama/finetune_adapter.py", line 132, in train
    logits = model(input_ids)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/lightning/fabric/wrappers.py", line 109, in forward
    output = self._forward_module(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 286, in _fn
    return fn(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 439, in catch_errors
    return callback(frame, cache_size, hooks, frame_state)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 519, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 122, in _fn
    return fn(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 355, in _convert_frame_assert
    return _compile(
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 177, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 425, in _compile
    out_code = transform_code_object(code, transform)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object
    transformations(instructions, code_options)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 410, in transform
    tracer.run()
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2010, in run
    super().run()
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 703, in run
    and self.step()
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 663, in step
    getattr(self, inst.opname)(inst)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2098, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 736, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 813, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 177, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 872, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 868, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 108, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/__init__.py", line 1531, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 590, in compile_fx
    return compile_fx(
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 700, in compile_fx
    return aot_autograd(
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 55, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 3334, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 177, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2975, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1911, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2082, in aot_wrapper_synthetic_base
    return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2485, in aot_dispatch_autograd
    fw_module, bw_module = aot_config.partition_fn(
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 671, in partition_fn
    return min_cut_rematerialization_partition(
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/torch/_functorch/partitioners.py", line 572, in min_cut_rematerialization_partition
    cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink")
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/networkx/algorithms/flow/maxflow.py", line 450, in minimum_cut
    R = flow_func(flowG, _s, _t, capacity=capacity, value_only=True, **kwargs)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/networkx/algorithms/flow/preflowpush.py", line 421, in preflow_push
    R = preflow_push_impl(G, s, t, capacity, residual, global_relabel_freq, value_only)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/networkx/algorithms/flow/preflowpush.py", line 41, in preflow_push_impl
    detect_unboundedness(R, s, t)
  File "/home/adrian/.conda/envs/lit-llama/lib/python3.10/site-packages/networkx/algorithms/flow/utils.py", line 166, in detect_unboundedness
    raise nx.NetworkXUnbounded(
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
NetworkXUnbounded: Infinite capacity path, flow unbounded above.


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

(using torch nightly)

If it's helpful, I could produce a minimized version for further debugging.

@ezyang
Copy link

ezyang commented May 16, 2023

I'm happy to run the lit-llama repo, but I was unable to get the OpenLLAMA weight download instructions working

│ /data/users/ezyang/a/pytorch/torch/utils/_contextlib.py:115 in decorate_context                  │         
│                                                                                                  │         
│   112 │   @functools.wraps(func)                                                                 │         
│   113 │   def decorate_context(*args, **kwargs):                                                 │         
│   114 │   │   with ctx_factory():                                                                │         
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │         
│   116 │                                                                                          │         
│   117 │   return decorate_context                                                                │         
│   118                                                                                            │         
│                                                                                                  │         
│ /data/users/ezyang/a/lit-llama/scripts/convert_hf_checkpoint.py:34 in convert_hf_checkpoint      │         
│                                                                                                  │
│    31 │   output_dir.mkdir(parents=True, exist_ok=True)                                          │         
│    32 │                                                                                          │         
│    33 │   # the tokenizer is the same for all model sizes, so we store it in the parent dir      │         
│ ❱  34 │   shutil.copy(ckpt_dir / "tokenizer.model", output_dir.parent)                           │         
│    35 │                                                                                          │
│    36 │   dt = getattr(torch, dtype, None)                                                       │         
│    37 │   if not isinstance(dt, torch.dtype):                                                    │         
│                                                                                                  │         
│ /home/ezyang/local/a/pytorch-env/lib/python3.9/shutil.py:427 in copy                             │         
│                                                                                                  │         
│    424 │   """                                                                                   │         
│    425 │   if os.path.isdir(dst):                                                                │       
│    426 │   │   dst = os.path.join(dst, os.path.basename(src))                                    │         
│ ❱  427 │   copyfile(src, dst, follow_symlinks=follow_symlinks)                                   │         
│    428 │   copymode(src, dst, follow_symlinks=follow_symlinks)                                   │         
│    429 │   return dst                                                                            │         
│    430                                                                                           │         
│                                                                                                  │
│ /home/ezyang/local/a/pytorch-env/lib/python3.9/shutil.py:264 in copyfile                         │
│                                                                                                  │
│    261 │   if not follow_symlinks and _islink(src):                                              │         
│    262 │   │   os.symlink(os.readlink(src), dst)                                                 │         
│    263 │   else:                                                                                 │         
│ ❱  264 │   │   with open(src, 'rb') as fsrc:                                                     │         
│    265 │   │   │   try:                                                                          │         
│    266 │   │   │   │   with open(dst, 'wb') as fdst:                                             │         
│    267 │   │   │   │   │   # macOS                                                               │         
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯         
FileNotFoundError: [Errno 2] No such file or directory:                                             
'checkpoints/open-llama/7B/open_llama_7b_preview_300bt_transformers_weights/7B/tokenizer.model'  

If you could try running TORCHDYNAMO_REPRO_AFTER="dynamo" this may or may not generated a minified repro; it's often hard to say. Alt, fix the model setup instructions (or tell me to go get the Meta weights, I probably can scrounge them up somewhere ;)

@ezyang
Copy link

ezyang commented Jun 6, 2023

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants