Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

H100 Transformer Engine implementation #249

Open
SinanAkkoyun opened this issue May 9, 2023 · 25 comments
Open

H100 Transformer Engine implementation #249

SinanAkkoyun opened this issue May 9, 2023 · 25 comments

Comments

@SinanAkkoyun
Copy link

Hello!
As I asked on the Discord, here is the issue on implementing NVIDIA's Transformer Engine with compute capability 9 (H100 GPU).

I would really love to see and help with implementing that!
Thank you very much 😊

@carmocca
Copy link
Contributor

carmocca commented May 9, 2023

Linked issue in the lightning repo: Lightning-AI/pytorch-lightning#17172

@carmocca
Copy link
Contributor

carmocca commented May 9, 2023

@SinanAkkoyun Do you have access to H100s? If so, would you like to try out the PR Lightning-AI/pytorch-lightning#17597?
It adds support to Fabric by passing L.Fabric(precision="8-mixed")

You can install it by running

pip install -U https://github.com/Lightning-AI/lightning/archive/refs/heads/carmocca/transformer-engine.zip

We would love to get your feedback.

@vgoklani
Copy link

vgoklani commented May 10, 2023

Wouldn't you need to replace all your nn.Modules with te.Module equivalents if you switch to TransformersEngine?

In particular, the te.Modules also have fused components (i.e. LayerNormLinear, LayerNormMLP)

https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html

You would also need to update your training code:

https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/advanced_optimizations.html#Multi-GPU-training

The speed-up is definitely there, we see 2-3x using FP8...

@carmocca
Copy link
Contributor

Yes. Automatic replacement of the layers is missing but it's something that we want to do too.

The parallelism customization would be left to the user to do though.

@SinanAkkoyun
Copy link
Author

@carmocca Hello! Thank you very much for the info!

I currently have access to a H100 cloud GPU, although if I shut it down I might not get my hands on it again, so any quick help would be very appreciated :)

pip install -U https://github.com/Lightning-AI/lightning/archive/refs/heads/carmocca/transformer-engine.zip
I ran this installation, but I still only get around 40 tokens/second. What else do I need to change?

@carmocca
Copy link
Contributor

@SinanAkkoyun Did you run generate.py? How many tokens/sec do you get without it? How does the generation look?

Can you print(fabric.strategy.precision) to make sure it's using fp8 precision?
We might also need to change this

dtype = torch.bfloat16 if fabric.device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float32

@SinanAkkoyun
Copy link
Author

SinanAkkoyun commented May 22, 2023

@carmocca Sure! Thank you so much for the reply?
Sadly my mail just notified me now of your comment.

I just ran generate.py, only get 30 tokens/second with the H100

@SinanAkkoyun
Copy link
Author

@carmocca
So I ran the print (sorry I somehow managed to forget to paste it in earlier):
<lightning.fabric.plugins.precision.precision.Precision object at 0x7f5a4025b850>

Does this mean its not automatically set?

These are some stats:

nvcc--version:
Build cuda_11.8.r11.8/compiler.31833905_0

nvidia-smi:
NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0

@carmocca
Copy link
Contributor

It is not automatically set, try this:

    fabric = L.Fabric(devices=1, precision="8-mixed")
    dtype = None

@SinanAkkoyun
Copy link
Author

@carmocca
I did the suggested modifications

(base) ubuntu@209-20-158-170:~/lit-llama$ python generate.py --prompt "Hello, my name is"
Traceback (most recent call last):
  File "/home/ubuntu/lit-llama/generate.py", line 179, in <module>
    CLI(main)
  File "/home/ubuntu/miniconda3/lib/python3.10/site-packages/jsonargparse/cli.py", line 85, in CLI
    return _run_component(component, cfg_init)
  File "/home/ubuntu/miniconda3/lib/python3.10/site-packages/jsonargparse/cli.py", line 147, in _run_component
    return component(**cfg)
  File "/home/ubuntu/lit-llama/generate.py", line 119, in main
    fabric = L.Fabric(devices=1, precision="8-mixed")
  File "/home/ubuntu/miniconda3/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 103, in __init__
    self._connector = _Connector(
  File "/home/ubuntu/miniconda3/lib/python3.10/site-packages/lightning/fabric/connector.py", line 163, in __init__
    self.precision = self._check_and_init_precision()
  File "/home/ubuntu/miniconda3/lib/python3.10/site-packages/lightning/fabric/connector.py", line 453, in _check_and_init_precision
    return Fp8TransformerEnginePrecision()
  File "/home/ubuntu/miniconda3/lib/python3.10/site-packages/lightning/fabric/plugins/precision/fp8_transformer_engine.py", line 57, in __init__
    raise ModuleNotFoundError(str(_TRANSFORMER_ENGINE_AVAILABLE))
ModuleNotFoundError: DistributionNotFound: The 'transformer_engine' distribution was not found and is required by the application. HINT: Try running `pip install -U 'transformer_engine'`

I received the error above... What do I need to import? I already installed your zip, does it have anything to do with that? I am sorry for all the minor questions but I want to be safe in the implementation

@carmocca
Copy link
Contributor

@SinanAkkoyun
Copy link
Author

Thank you very much, I am in the process of installing it

@carmocca
Copy link
Contributor

Actually, based on NVIDIA/TransformerEngine#242 (comment) it seems like we can keep the weights in fp16 or bf16 during inference. Meaning not doing dtype = None

@SinanAkkoyun
Copy link
Author

Thank you so much! I had/have trouble reinstalling the right cuda and cudnn versions, it's done in a couple of minutes

Thanks, I will leave the original dtype code

@SinanAkkoyun
Copy link
Author

SinanAkkoyun commented May 22, 2023

It is still installing TransformerEngine (dependency hell, wrong cuda pytorc version etc, now compiling pytorch with cuda 12.1 from scratch)
As soon as I get it up and running I will comment

@SinanAkkoyun
Copy link
Author

SinanAkkoyun commented May 23, 2023

@carmocca

NVRM version: NVIDIA UNIX x86_64 Kernel Module 525.105.17

Normal generate.py:

root@e7bbd97bc0d2:/app/lit-llama# python3 generate.py --prompt "Hello, my name is"
Loading model ...
Time to load model: 19.90 seconds.
Global seed set to 1234
Hello, my name is TJ.
I am a stay at home dad with 3 kids. I work part time for my church as a Custodian and I also tutor online. I have a Masters degree in Human Resource Management and I’
Time for inference 1: 1.25 sec total, 40.02 tokens/sec
Memory used: 13.57 GB

So, after lots of trial and error, I finally set up an nvidia dev docker with cuda 12.1, installed pytorch (cuda 12.1) and ran the modified generation.py, received this output:

root@e7bbd97bc0d2:/app/lit-llama# python3 gen.py --prompt "Hello, my name is"
<lightning.fabric.plugins.precision.fp8_transformer_engine.Fp8TransformerEnginePrecision object at 0x7efbb27bfe80>
cuda
Loading model ...
Time to load model: 24.39 seconds.
Traceback (most recent call last):
  File "gen.py", line 179, in <module>
    CLI(main)
  File "/usr/local/lib/python3.8/dist-packages/jsonargparse/cli.py", line 85, in CLI
    return _run_component(component, cfg_init)
  File "/usr/local/lib/python3.8/dist-packages/jsonargparse/cli.py", line 147, in _run_component
    return component(**cfg)
  File "gen.py", line 139, in main
    model = fabric.setup_module(model)
  File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/fabric.py", line 254, in setup_module
    module = self._precision.convert_module(module)
  File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/plugins/precision/fp8_transformer_engine.py", line 82, in convert_module
    _convert_layers(module)
  File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/plugins/precision/fp8_transformer_engine.py", line 116, in _convert_layers
    _convert_layers(module)
  File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/plugins/precision/fp8_transformer_engine.py", line 116, in _convert_layers
    _convert_layers(module)
  File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/plugins/precision/fp8_transformer_engine.py", line 116, in _convert_layers
    _convert_layers(module)
  [Previous line repeated 989 more times]
  File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/plugins/precision/fp8_transformer_engine.py", line 90, in _convert_layers
    for name, child in module.named_children():
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 2206, in named_children
    memo = set()
RecursionError: maximum recursion depth exceeded while calling a Python object

I am trying to fix this but I am not achieving any results with that at the moment...

@SinanAkkoyun
Copy link
Author

SinanAkkoyun commented May 23, 2023

I am stiill running the server, how could I debug this?

@SinanAkkoyun
Copy link
Author

This is what GPT-4 told me, I don't know if it makes sense:
fp8_transformer_engine.py:

def _convert_layers(module: torch.nn.Module) -> None:
    import transformer_engine.pytorch as te

    for name, child in module.named_children():
        if isinstance(child, torch.nn.Linear):
            if child.in_features % 16 != 0 or child.out_features % 16 != 0:
                # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#FP8-autocasting
                rank_zero_warn(
                    "Support for FP8 in the linear layers with `precision='8-mixed'` is currently limited to tensors"
                    f" with shapes where both dimensions are divisible by 16. The layer {name!r} does not fit this"
                    " criteria. You might want to add padding to your inputs."
                )
                continue
            has_bias = child.bias is not None
            replacement = te.Linear(child.in_features, child.out_features, bias=has_bias)
            replacement.weight.data = child.weight.data.clone()
            if has_bias:
                replacement.bias.data = child.bias.data.clone()
            log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent")
            module.__setattr__(name, replacement)
        elif isinstance(child, torch.nn.LayerNorm):
            replacement = te.LayerNorm(child.normalized_shape[0], eps=child.eps)
            replacement.weight.data = child.weight.data.clone()
            replacement.bias.data = child.bias.data.clone()
            log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent")
            module.__setattr__(name, replacement)
        else:
            _convert_layers(child)  # Recurse on the child, not the parent

After doing this mod, the following error occurs:

root@d395b61d47d7:/app/lit-llama# python3 gen.py --prompt "Hello, my name is"
Loading model ...
<lightning.fabric.plugins.precision.fp8_transformer_engine.Fp8TransformerEnginePrecision object at 0x7f2df85656a0>
cuda
Time to load model: 22.48 seconds.
Global seed set to 1234
Traceback (most recent call last):
  File "gen.py", line 179, in <module>
    CLI(main)
  File "/usr/local/lib/python3.8/dist-packages/jsonargparse/cli.py", line 85, in CLI
    return _run_component(component, cfg_init)
  File "/usr/local/lib/python3.8/dist-packages/jsonargparse/cli.py", line 147, in _run_component
    return component(**cfg)
  File "gen.py", line 154, in main
    y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "gen.py", line 67, in generate
    logits = model(x, max_seq_length, input_pos)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/wrappers.py", line 109, in forward
    output = self._forward_module(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/lit-llama/lit_llama/model.py", line 114, in forward
    x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i])
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/lit-llama/lit_llama/model.py", line 159, in forward
    h, new_kv_cache = self.attn(self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/lit-llama/lit_llama/model.py", line 191, in forward
    q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformer_engine/pytorch/module.py", line 2267, in forward
    with self.prepare_forward(inp, is_first_microbatch) as inp:
  File "/usr/lib/python3.8/contextlib.py", line 113, in __enter__
    return next(self.gen)
  File "/usr/local/lib/python3.8/dist-packages/transformer_engine/pytorch/module.py", line 593, in prepare_forward
    self.set_activation_dtype(inp)
  File "/usr/local/lib/python3.8/dist-packages/transformer_engine/pytorch/module.py", line 484, in set_activation_dtype
    assert all(
AssertionError: Data type for activations and buffers must match when outside of autocasted region

@SinanAkkoyun
Copy link
Author

I tried to resolve this, this is my attempt and findings:
When changing the dtype to float32, it throws an error that "AssertionError: Input and weight dimensions are not compatible for FP8 execution."
However, when setting the dtype to float16 it still throws me this error:

File "/usr/local/lib/python3.8/dist-packages/transformer_engine/pytorch/module.py", line 484, in set_activation_dtype
    assert all(
AssertionError: Data type for activations and buffers must match when outside of autocasted region

When autocasting to f32, I get this:

File "/usr/local/lib/python3.8/dist-packages/transformer_engine/pytorch/module.py", line 1674, in forward
    assert (
AssertionError: Input and weight dimensions are not compatible for FP8 execution.

@carmocca
Copy link
Contributor

Your fp16 vs fp32 issues might be caused because of this https://github.com/NVIDIA/TransformerEngine/blob/stable/transformer_engine/pytorch/module.py#L3464-L3467
There might be a conflict with that logic and this logic: https://github.com/Lightning-AI/lit-llama/blob/main/generate.py#L129-L131

if (
self.dtype is not None
and func in torch.utils._device._device_constructors()
and kwargs.get("dtype") is None
):
kwargs["dtype"] = self.dtype

So maybe the easiest thing to start with is to remove the EmptyInitOnDevice context manager for now. Even if it takes longer to load

BTW thank you for your efforts, this is really useful

@SinanAkkoyun
Copy link
Author

Thank you very much for helping me out!!!

I tried to remove the EmptyInitOnDevice context manager, but I think I did not succeed, here is the code:

    """
          ``"gptq.int4"``: GPTQ 4-bit mode.
    """
    assert checkpoint_path.is_file(), checkpoint_path
    assert tokenizer_path.is_file(), tokenizer_path

    fabric = L.Fabric(devices=1, precision="8-mixed")
    dtype = torch.bfloat16 if fabric.device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float32

    print("Loading model ...", file=sys.stderr)

    print(fabric.strategy.precision)
    print(fabric.device.type)

    t0 = time.time()
    with lazy_load(checkpoint_path) as checkpoint:
        name = llama_model_lookup(checkpoint)

        # Removed the context manager here
        model = LLaMA.from_name(name)
        model.load_state_dict(checkpoint)
    print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)

    model.eval()
    model = fabric.setup_module(model)

    tokenizer = Tokenizer(tokenizer_path)
    encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
    prompt_length = encoded.size(0)

    L.seed_everything(1234)
    for i in range(num_samples):
        t0 = time.perf_counter()
        y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
        t = time.perf_counter() - t0

        model.reset_cache()
        print(tokenizer.decode(y))
        tokens_generated = y.size(0) - prompt_length
        print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)

    if fabric.device.type == "cuda":
        print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)

if __name__ == "__main__":

I still get the error:

root@d395b61d47d7:/app/lit-llama# python3 gen.py --prompt "Hello, my name is"
Loading model ...
<lightning.fabric.plugins.precision.fp8_transformer_engine.Fp8TransformerEnginePrecision object at 0x7f075c468fa0>
cuda
Time to load model: 18.31 seconds.
Global seed set to 1234
DEBUG: IDX TYPE:
torch.int32
DEBUG: x dtype:
torch.float32
Traceback (most recent call last):
  File "gen.py", line 179, in <module>
    CLI(main)
  File "/usr/local/lib/python3.8/dist-packages/jsonargparse/cli.py", line 85, in CLI
    return _run_component(component, cfg_init)
  File "/usr/local/lib/python3.8/dist-packages/jsonargparse/cli.py", line 147, in _run_component
    return component(**cfg)
  File "gen.py", line 154, in main
    y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "gen.py", line 67, in generate
    logits = model(x, max_seq_length, input_pos)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/wrappers.py", line 109, in forward
    output = self._forward_module(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/lit-llama/lit_llama/model.py", line 119, in forward
    x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i])
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/lit-llama/lit_llama/model.py", line 164, in forward
    h, new_kv_cache = self.attn(self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/lit-llama/lit_llama/model.py", line 200, in forward
    q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformer_engine/pytorch/module.py", line 2306, in forward
    out = linear_fn(*args)
  File "/usr/local/lib/python3.8/dist-packages/transformer_engine/pytorch/module.py", line 1671, in forward
    assert (
AssertionError: Input and weight dimensions are not compatible for FP8 execution.

@carmocca
Copy link
Contributor

The assertion you are hitting gets raised under two conditions: https://github.com/NVIDIA/TransformerEngine/blob/stable/transformer_engine/pytorch/module.py#L1675

The 7B config will have c_attn.shape == (4096, 3*4096) (ref) and that's divisible by (8, 16) as required by https://github.com/NVIDIA/TransformerEngine/blob/stable/transformer_engine/pytorch/utils.py#L186

So it must be inputmat that has the issue: https://github.com/NVIDIA/TransformerEngine/blob/stable/transformer_engine/pytorch/module.py#L1673. The input of c_attn would be a B x T x 4096 tensor, and that line converts it to T x 4096 since B == 1 during generation. This suggests that B x T needs to be a multiple of 8. So we would need to pad the input.

@SinanAkkoyun
Copy link
Author

SinanAkkoyun commented May 23, 2023

Thank you so so much!!!

I first tried to pad it like this:

# generate max_new_tokens tokens
    for _ in range(max_new_tokens):
        x = idx.index_select(0, input_pos).view(1, -1)

        
        # new: padding code
        original_length = x.size(1)
        new_length = ((original_length - 1) // 8 + 1) * 8  # Round up to nearest multiple of 8
        if original_length != new_length:
            padding = torch.zeros(x.size(0), new_length - original_length, dtype=x.dtype, device=x.device)
            x = torch.cat([x, padding], dim=1)

        # forward
        logits = model(x, max_seq_length, input_pos)
        logits = logits[0, -1] / temperature

Which resulted in:

Traceback (most recent call last):
  File "gen.py", line 185, in <module>
    CLI(main)
  File "/usr/local/lib/python3.8/dist-packages/jsonargparse/cli.py", line 85, in CLI
    return _run_component(component, cfg_init)
  File "/usr/local/lib/python3.8/dist-packages/jsonargparse/cli.py", line 147, in _run_component
    return component(**cfg)
  File "gen.py", line 160, in main
    y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "gen.py", line 75, in generate
    logits = model(x, max_seq_length, input_pos)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/wrappers.py", line 109, in forward
    output = self._forward_module(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/lit-llama/lit_llama/model.py", line 120, in forward
    x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i])
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/lit-llama/lit_llama/model.py", line 165, in forward
    h, new_kv_cache = self.attn(self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/lit-llama/lit_llama/model.py", line 208, in forward
    q = apply_rope(q, rope)
  File "/app/lit-llama/lit_llama/model.py", line 317, in apply_rope
    rope_cache = rope_cache.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
RuntimeError: shape '[1, 8, 1, 64, 2]' is invalid for input of size 768

Then, I tried supplying only 8 tokens to see if it matches (because I failed at the padding). It now worked for the first CasualAttention forward passes but crashes after the second LLaMA forward pass:

root@d395b61d47d7:/app/lit-llama# python3 gen.py --prompt "Hello Hello Hello Hello Hello Hello Hello"
Loading model ...
<lightning.fabric.plugins.precision.fp8_transformer_engine.Fp8TransformerEnginePrecision object at 0x7f21b0f43a90>
cuda
Time to load model: 62.42 seconds.
Length of encoded prompt: 8
Size of encoded prompt: 8
Global seed set to 1234
DEBUG: IDX TYPE: torch.int32
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: IDX TYPE: torch.int32
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 1, 4096])
Traceback (most recent call last):
  File "gen.py", line 189, in <module>
    CLI(main)
  File "/usr/local/lib/python3.8/dist-packages/jsonargparse/cli.py", line 85, in CLI
    return _run_component(component, cfg_init)
  File "/usr/local/lib/python3.8/dist-packages/jsonargparse/cli.py", line 147, in _run_component
    return component(**cfg)
  File "gen.py", line 164, in main
    y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "gen.py", line 67, in generate
    logits = model(x, max_seq_length, input_pos)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/wrappers.py", line 109, in forward
    output = self._forward_module(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/lit-llama/lit_llama/model.py", line 118, in forward
    x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i])
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/lit-llama/lit_llama/model.py", line 163, in forward
    h, new_kv_cache = self.attn(self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/lit-llama/lit_llama/model.py", line 199, in forward
    q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformer_engine/pytorch/module.py", line 2306, in forward
    out = linear_fn(*args)
  File "/usr/local/lib/python3.8/dist-packages/transformer_engine/pytorch/module.py", line 1671, in forward
    assert (
AssertionError: Input and weight dimensions are not compatible for FP8 execution.

I really do not know what to do now, I tried many things over the last hours, I would greatly appreciate if you could implement the padding just like you imagined it and I would be happy to build upon that and test it out (I will keep the cloud GPU running, so if you find the time I would be very glad to test it soon)

@carmocca
Copy link
Contributor

@28Smiles That seems like a completely separate issue to H100 support. Can you open a different issue?

@carmocca
Copy link
Contributor

@28Smiles Our inference scripts do not support batch size > 1 at the moment. #188 tracks this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants