Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slow inference performance for large Llama models compared to naive MP #66

Open
sgsdxzy opened this issue Apr 6, 2023 · 26 comments
Open

Comments

@sgsdxzy
Copy link

sgsdxzy commented Apr 6, 2023

The inference speed of naive model parallel is much better than tensor parallel:

Setup: Llama-30b on 2080Ti 22G x4
Naive: 31.64s
4-way TP, main branch: 177.78s
4-way TP, llama branch: 102.22s

The code for naive inference

import torch
import time
import os
import json
import tensor_parallel
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import accelerate
from transformers.utils.bitsandbytes import replace_8bit_linear
from accelerate.hooks import remove_hook_from_module

model_name = 'models/llama-30b'

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.half, device_map="balanced")

torch.cuda.empty_cache()
model = model.eval()
with torch.no_grad():
    batch = tokenizer(
        "DeepSpeed first included offloading capabilities with ZeRO-Offload, a system ",
        return_tensors="pt"
    )
    batch = {k: v.cuda(0) for k, v in batch.items()}
    print("Start")
    t0 = time.time()
    generated = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], max_length=200)
    t1 = time.time()
    print(f"Output generated in {(t1-t0):.2f} seconds")
    print(tokenizer.decode(generated[0]))

The code for TP:

import torch
import time
import os
import json
import tensor_parallel
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import accelerate
from transformers.utils.bitsandbytes import replace_8bit_linear
from accelerate.hooks import remove_hook_from_module

model_name = 'models/llama-30b'

tokenizer = AutoTokenizer.from_pretrained(model_name)
with accelerate.init_empty_weights():
    model = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(model_name)).half()
    model = tensor_parallel.TensorParallelPreTrainedModel(model)

device_map = tensor_parallel.infer_sharded_device_map(model) # <- The model is on meta device but we can sill deduce
                                                #    the target devices for each weight using this helper function
# Get nums parts
with open(f"{model_name}/pytorch_model.bin.index.json", "r") as index_file:
    shard_filenames = set(json.load(index_file)["weight_map"].values())

for shard_filename in sorted(shard_filenames):
    # Download a shard
    shard_path = f"{model_name}/{shard_filename}"
    print(shard_path)
    
    # Convert model shard
    converted_state_dict = tensor_parallel.convert_state_dict( # <- tensor_parallel helper function. 
        torch.load(shard_path),                   #    Creates a tensor_parallel checkpoint form a normal one
        model.tensor_parallel_config,
        world_size=4,
        for_pretrained=True,
    )    
    torch.save(converted_state_dict, "/tmp/shard.bin")
    del converted_state_dict
        
    # Dispatch the shard
    accelerate.load_checkpoint_in_model(
        model,
        checkpoint="/tmp/shard.bin",
        device_map=device_map,
    )

torch.cuda.empty_cache()
model = model.eval()
with torch.no_grad():
    batch = tokenizer(
        "DeepSpeed first included offloading capabilities with ZeRO-Offload, a system ",
        return_tensors="pt"
    )
    batch = {k: v.cuda(0) for k, v in batch.items()}
    print("Start")
    t0 = time.time()
    generated = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], max_length=200)
    t1 = time.time()
    print(f"Output generated in {(t1-t0):.2f} seconds")
    print(tokenizer.decode(generated[0]))
@BlackSamorez
Copy link
Owner

BlackSamorez commented Apr 6, 2023

Hi @sgsdxzy!
I tried to reproduce this issue in a T4 x2 Kaggle notebook (sadly I don't own 2080Ti 22G x4) and here's what I got:

  • Naive: 43.13 seconds
  • 2-way TP, llama branch: 29.01 seconds

Which is not quite double the speed but it gets better on larger batches.

About your case: if you're sure that those numbers are valid, maybe It's somehow connected to the fact that you're using 4 cards. What's the data bandwidth between them? Are all 4 cards using enough PCI-E lanes?
In this case tensor_parallel is using raw from torch.cuda.nccl communication primitives so it's weird that they are that slow.

@sgsdxzy
Copy link
Author

sgsdxzy commented Apr 6, 2023

@BlackSamorez I can confirm using 2 cards TP provides a small speedup against 2 cards MP. The 4 cards are all running at pcie3.0x16 on an X99. Here's my P2P connectivity test (I have two nvlinks between [0,1] and [2,3])

P2P Connectivity Matrix                                                                                                                                                [7/32]
     D\D     0     1     2     3
     0       1     1     0     0
     1       1     1     0     0
     2       0     0     1     1
     3       0     0     1     1
Unidirectional P2P=Disabled Bandwidth Matrix (GB/s)
   D\D     0      1      2      3
     0 541.72   5.76   5.85   5.87
     1   5.76 542.96   5.82   5.87
     2   5.95   5.94 537.09   5.79
     3   5.89   5.93   5.81 533.16
Unidirectional P2P=Enabled Bandwidth (P2P Writes) Matrix (GB/s)
   D\D     0      1      2      3
     0 531.46  47.09   6.00   5.95
     1  47.11 536.05   5.97   5.95
     2   5.87   5.96 532.47  47.09
     3   5.92   5.90  47.10 532.53
Bidirectional P2P=Disabled Bandwidth Matrix (GB/s)
   D\D     0      1      2      3
     0 533.29   6.11   8.62   8.59
     1   6.12 535.29   8.58   8.57
     2   8.60   8.52 534.05   6.12
     3   8.56   8.57   6.10 534.13
Bidirectional P2P=Enabled Bandwidth Matrix (GB/s)
   D\D     0      1      2      3
     0 533.55  94.10   8.61   8.59
     1  94.13 534.78   8.56   8.59
     2   8.55   8.60 534.17  94.15
     3   8.62   8.59  94.16 533.62
P2P=Disabled Latency Matrix (us)
   GPU     0      1      2      3
     0   1.34  12.44  12.30  12.44
     1  12.44   1.38  21.21  12.68
     2  12.53  12.61   1.33  12.44
     3  12.38  12.30  12.68   1.33

   CPU     0      1      2      3
     0   2.05   5.85   5.74   5.82
     1   5.82   1.95   5.80   5.77
     2   5.63   5.66   1.99   5.58
     3   5.75   5.72   5.67   1.97
P2P=Enabled Latency (P2P Writes) Matrix (us)
   GPU     0      1      2      3
     0   1.33   1.88  12.30  12.45
     1   1.88   1.38  21.18  12.54
     2  12.53  12.53   1.33   1.85
     3  12.38  21.12   1.85   1.33

   CPU     0      1      2      3
     0   2.02   1.63   5.85   5.91
     1   1.64   1.99   5.75   5.91
     2   5.71   5.69   1.99   1.64
     3   6.01   5.80   1.74   2.12

I think Kaggle T4s are not using nvlinks so that's not the problem here, and I don't think 4-cards would suddenly hit a communication bottleneck and drastically reduce performance. I think it's more of a misconfigure or bug. Where would you suggest me to look?

@BlackSamorez
Copy link
Owner

@sgsdxzy Thanks!
Could you verify that correct communication functions are being used?
You should be hitting:

during forward passes.

Also could you please benchmark tensor_parallel on ["cuda:0", "cuda:1"] (nvlink) and ["cuda:0", "cuda:2"] (no nvlink)?

@sgsdxzy
Copy link
Author

sgsdxzy commented Apr 7, 2023

@BlackSamorez Here's the results:

Model setup llama-7b 1gpu llama-7b 8bit 1gpu llama-7b 2gpu+nvlink llama-7b 8bit 2gpu+nvlink llama-7b 2gpu w/o nvlink llama-7b 8bit 2gpu w/o nvlink
Naive time (s) 10.44 37.42 11.45 37.99 12.38 38.92
Naive memory per gpu (GB) 14 8.3 7.7 4.7 7.7 4.7
TP time (s) - - 27.85 28.23 27.66 27.66
TP memory per gpu (GB) - - 7.7 7.7 7.7 7.7

So the problem here:

  1. TP only provides a speed gain for 8bit, and drastically worse for fp16. And the fp16/int8 time for TP is the same, which is also suspicious.
  2. loading in 8bit is not saving VRAM for TP, which can be considered another bug.
  3. nvlink does not affect the result.
  4. I am using the main branch, as llama branch gives me the following error in 8bit (works fine for fp16, reducing 28s to 17s)
Traceback (most recent call last):
  File "/home/sgsdxzy/Programs/text-generation-webui/tp_test.py", line 68, in <module>
    generated = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], max_length=200)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/transformers/generation/utils.py", line 1437, in generate
    return self.greedy_search(
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/transformers/generation/utils.py", line 2248, in greedy_search
    outputs = self(
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/tensor_parallel/pretrained_model.py", line 88, in forward
    return self.wrapped_model(*args, **kwargs)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/tensor_parallel/tensor_parallel.py", line 130, in forward
    return parallel_apply(self.module_shards, inputs, kwargs_tup, self.devices)[self.output_device_index]
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 89, in parallel_apply
    output.reraise()
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/_utils.py", line 644, in reraise
    raise exception
AttributeError: Caught AttributeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 687, in forward
    outputs = self.model(
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 577, in forward
    layer_outputs = decoder_layer(
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 292, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/tensor_parallel/slicer_wrapper.py", line 390, in forward
    output = self.tp_wrapped_module(*args, **kwargs)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 196, in forward
    query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 242, in forward
    out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 488, in matmul
    return MatMul8bitLt.apply(A, B, out, bias, state)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 320, in forward
    state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/bitsandbytes/functional.py", line 1698, in transform
    prev_device = pre_call(A.device)
AttributeError: 'NoneType' object has no attribute 'device'

The updated script for reference

import torch
import time
import argparse
import json
import tensor_parallel
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, LlamaTokenizer
import accelerate
from transformers.utils.bitsandbytes import replace_8bit_linear
from accelerate.hooks import remove_hook_from_module

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str)
parser.add_argument('--int8', action='store_true')
parser.add_argument('--mp', type=int)
args = parser.parse_args()

tokenizer = LlamaTokenizer.from_pretrained(args.model)

if args.mp <= 1:
    model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.half, load_in_8bit=args.int8, device_map="balanced")
else:
    with accelerate.init_empty_weights():
        model = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(args.model)).half()
        model = tensor_parallel.TensorParallelPreTrainedModel(model)
        if args.int8:
            model = replace_8bit_linear(model)
            model.is_loaded_in_8bit = True

    device_map = tensor_parallel.infer_sharded_device_map(model) # <- The model is on meta device but we can sill deduce
                                                                 #    the target devices for each weight using this helper function

    # Get nums parts
    with open(f"{args.model}/pytorch_model.bin.index.json", "r") as index_file:
        shard_filenames = set(json.load(index_file)["weight_map"].values())

    for shard_filename in sorted(shard_filenames):
        # Download a shard
        shard_path = f"{args.model}/{shard_filename}"
        print(shard_path)
        
        # Convert model shard
        converted_state_dict = tensor_parallel.convert_state_dict( # <- tensor_parallel helper function. 
            torch.load(shard_path),                                #    Creates a tensor_parallel checkpoint form a normal one
            model.tensor_parallel_config,
            world_size=args.mp,
            for_pretrained=True,
        )    
        torch.save(converted_state_dict, "/tmp/shard.bin")
        del converted_state_dict
            
        # Dispatch the shard
        accelerate.load_checkpoint_in_model(
            model,
            checkpoint="/tmp/shard.bin",
            device_map=device_map,
        )

torch.cuda.empty_cache()
model = model.eval()
with torch.no_grad():
    batch = tokenizer(
        "DeepSpeed first included offloading capabilities with ZeRO-Offload, a system ",
        return_tensors="pt"
    )
    batch = {k: v.cuda(0) for k, v in batch.items()}
    print("Start")
    t0 = time.time()
    generated = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], max_length=200)
    t1 = time.time()
    print(f"Output generated in {(t1-t0):.2f} seconds")
    print(tokenizer.decode(generated[0]))

@sgsdxzy
Copy link
Author

sgsdxzy commented Apr 7, 2023

@BlackSamorez here's results for OPT-6.7B, almost same as llama-7b.

Model setup OPT-6.7B 1gpu OPT-6.7B 8bit 1gpu OPT-6.7B 2gpu+nvlink OPT-6.7B 8bit 2gpu+nvlink
Naive time (s) 10.16 39.86 9.94 40.08
Naive memory per gpu (GB) 13.6 7.6 7.6 4.6
TP time (s) - - 23.64 23.81
TP memory per gpu (GB) - - 7.6 7.6

Are you testing in int8 or fp16? Can you get any other cards than dual T4? And I don't think I am having a gpu communication problem as deepspeed-inference provided TP is boosting performance for me on OPT(llama is not well-supported yet), 2-card fp16 is 65% faster than 1-card fp16 oobabooga/text-generation-webui#561 (comment)

@sgsdxzy
Copy link
Author

sgsdxzy commented Apr 7, 2023

@sgsdxzy Thanks! Could you verify that correct communication functions are being used? You should be hitting:

* https://github.com/BlackSamorez/tensor_parallel/blob/main/src/tensor_parallel/cross_device_ops.py#L95

* https://github.com/BlackSamorez/tensor_parallel/blob/main/src/tensor_parallel/cross_device_ops.py#L77

during forward passes.

Also could you please benchmark tensor_parallel on ["cuda:0", "cuda:1"] (nvlink) and ["cuda:0", "cuda:2"] (no nvlink)?

I find NCCLAllGatherFunction is called, but not NCCLAllReduceFunction

@BlackSamorez
Copy link
Owner

@sgsdxzy Hi!
Firstly, about int8. You need the latest accelerate (like main branch from GitHub) to dispatch int8 models with load_checkpoint_in_model. Otherwise int8 layers are not quantized and behave exactly like fp16.
About everything else: I'll need some time to test it. It could be due a lot of reasons including bugs in communications or tensor cores suddenly not kicking-in for tensor_parallel.

@sgsdxzy
Copy link
Author

sgsdxzy commented Apr 7, 2023

@BlackSamorez I upragded accelerate to git+https://github.com/huggingface/accelerate , however the VRAM usage and speed is the same.

@BlackSamorez
Copy link
Owner

@sgsdxzy Now that's weird. This demo works which means that int8 should work fine since those model won't physically fit in VRAM in fp16.
Could you please attach the result of pip freeze in your environment.

@sgsdxzy
Copy link
Author

sgsdxzy commented Apr 7, 2023

@BlackSamorez it's here. This is conda envrionment, tell me if you suspect any specific package that doesn't have version listed by pip freeze

accelerate @ git+https://github.com/huggingface/accelerate@b757b6232516da4ece0fbcfec66855b37523f39a
aiofiles @ file:///home/conda/feedstock_root/build_artifacts/aiofiles_1664378549280/work
aiohttp==3.8.4
aiosignal==1.3.1
aiosqlite @ file:///home/conda/feedstock_root/build_artifacts/aiosqlite_1671461885930/work
altair==4.2.2
anyio @ file:///home/conda/feedstock_root/build_artifacts/anyio_1666191106763/work/dist
appdirs==1.4.4
argon2-cffi @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi_1640817743617/work
argon2-cffi-bindings @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi-bindings_1666850768662/work
astroid @ file:///home/conda/feedstock_root/build_artifacts/astroid_1679923748219/work
asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1670263926556/work
async-timeout==4.0.2
attrs @ file:///home/conda/feedstock_root/build_artifacts/attrs_1671632566681/work
autopep8 @ file:///home/conda/feedstock_root/build_artifacts/autopep8_1635267974115/work
Babel @ file:///home/conda/feedstock_root/build_artifacts/babel_1677767029043/work
backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work
backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1618230623929/work
beautifulsoup4 @ file:///home/conda/feedstock_root/build_artifacts/beautifulsoup4_1679322162244/work
bitsandbytes==0.37.2
black @ file:///home/conda/feedstock_root/build_artifacts/black-recipe_1675252854302/work
bleach @ file:///home/conda/feedstock_root/build_artifacts/bleach_1674535352125/work
brotlipy @ file:///home/conda/feedstock_root/build_artifacts/brotlipy_1666764671472/work
certifi==2022.12.7
cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1671179353105/work
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1661170624537/work
click @ file:///home/conda/feedstock_root/build_artifacts/click_1666798198223/work
cmake==3.26.1
colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work
comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1679481329611/work
contourpy @ file:///home/conda/feedstock_root/build_artifacts/contourpy_1673633665736/work
cryptography @ file:///home/conda/feedstock_root/build_artifacts/cryptography-split_1679811212387/work
cycler @ file:///home/conda/feedstock_root/build_artifacts/cycler_1635519461629/work
Cython @ file:///home/conda/feedstock_root/build_artifacts/cython_1673054058583/work
daal4py==2023.0.2
datasets==2.11.0
debugpy @ file:///home/conda/feedstock_root/build_artifacts/debugpy_1674522362098/work
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
deepspeed==0.8.3
defusedxml @ file:///home/conda/feedstock_root/build_artifacts/defusedxml_1615232257335/work
dill @ file:///home/conda/feedstock_root/build_artifacts/dill_1666603105584/work
docstring-to-markdown @ file:///home/conda/feedstock_root/build_artifacts/docstring-to-markdown_1679424273982/work
entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1667317341051/work
fastapi==0.95.0
fastjsonschema @ file:///home/conda/feedstock_root/build_artifacts/python-fastjsonschema_1677336799617/work/dist
ffmpy==0.3.0
filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1679932713187/work
fire==0.5.0
flake8 @ file:///home/conda/feedstock_root/build_artifacts/flake8_1669396691980/work
flexgen==0.1.7
flit_core @ file:///home/conda/feedstock_root/build_artifacts/flit-core_1667734568827/work/source/flit_core
fonttools @ file:///home/conda/feedstock_root/build_artifacts/fonttools_1680021152278/work
frozenlist==1.3.3
fsspec==2023.3.0
gmpy2 @ file:///home/conda/feedstock_root/build_artifacts/gmpy2_1666808654411/work
gradio==3.24.1
gradio_client==0.0.5
h11==0.14.0
hjson==3.1.0
httpcore==0.16.3
httpx==0.23.3
huggingface-hub==0.13.3
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1663625384323/work
importlib-metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1679167925176/work
importlib-resources @ file:///home/conda/feedstock_root/build_artifacts/importlib_resources_1676919000169/work
ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1679336319192/work
ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1677617093347/work
ipython-genutils==0.2.0
isort @ file:///home/conda/feedstock_root/build_artifacts/isort_1675033873689/work
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1669134318875/work
Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1654302431367/work
joblib @ file:///home/conda/feedstock_root/build_artifacts/joblib_1663332044897/work
json5 @ file:///home/conda/feedstock_root/build_artifacts/json5_1600692310011/work
jsonschema @ file:///home/conda/feedstock_root/build_artifacts/jsonschema-meta_1669810440410/work
jupyter-events @ file:///home/conda/feedstock_root/build_artifacts/jupyter_events_1673559782596/work
jupyter-ydoc @ file:///home/conda/feedstock_root/build_artifacts/jupyter_ydoc_1679325289144/work/dist
jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1679365123476/work
jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1678994169527/work
jupyter_server @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_1679073341944/work
jupyter_server_fileid @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_fileid_1677220209229/work
jupyter_server_terminals @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_terminals_1673491454549/work
jupyter_server_ydoc @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_ydoc_1678043727957/work
jupyterlab @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_1679327603632/work
jupyterlab-code-formatter @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_code_formatter_1679847042826/work
jupyterlab-pygments @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_pygments_1649936611996/work
jupyterlab_server @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_server_1679528718717/work
kiwisolver @ file:///home/conda/feedstock_root/build_artifacts/kiwisolver_1666805701884/work
lazy-object-proxy @ file:///home/conda/feedstock_root/build_artifacts/lazy-object-proxy_1672877787898/work
linkify-it-py==2.0.0
lit==16.0.0
loralib==0.1.1
Markdown==3.4.3
markdown-it-py==2.2.0
MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1674135787083/work
matplotlib @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-suite_1678135565516/work
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work
mccabe @ file:///home/conda/feedstock_root/build_artifacts/mccabe_1643049622439/work
mdit-py-plugins==0.3.3
mdurl==0.1.2
mistune @ file:///home/conda/feedstock_root/build_artifacts/mistune_1675771498296/work
mpmath @ file:///home/conda/feedstock_root/build_artifacts/mpmath_1678228039184/work
multidict==6.0.4
multiprocess==0.70.14
munkres==1.1.4
mypy-extensions @ file:///home/conda/feedstock_root/build_artifacts/mypy_extensions_1675543315189/work
nbclassic @ file:///home/conda/feedstock_root/build_artifacts/nbclassic_1678277563913/work
nbclient @ file:///home/conda/feedstock_root/build_artifacts/nbclient_1669795076334/work
nbconvert @ file:///home/conda/feedstock_root/build_artifacts/nbconvert-meta_1680034059411/work
nbformat @ file:///home/conda/feedstock_root/build_artifacts/nbformat_1679336765223/work
nest-asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1664684991461/work
networkx @ file:///home/conda/feedstock_root/build_artifacts/networkx_1673151334029/work
ninja==1.11.1
notebook @ file:///home/conda/feedstock_root/build_artifacts/notebook_1678109761260/work
notebook_shim @ file:///home/conda/feedstock_root/build_artifacts/notebook-shim_1667478401171/work
numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1675642512762/work
orjson==3.8.9
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1673482170163/work
pandas==1.5.3
pandocfilters @ file:///home/conda/feedstock_root/build_artifacts/pandocfilters_1631603243851/work
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work
pathspec @ file:///home/conda/feedstock_root/build_artifacts/pathspec_1678853982175/work
peft @ git+https://github.com/huggingface/peft.git@445940fb7b5d38390ffb6707e2a989e89fff03b5
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1667297516076/work
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
Pillow @ file:///home/conda/feedstock_root/build_artifacts/pillow_1675487172403/work
pkgutil_resolve_name @ file:///home/conda/feedstock_root/build_artifacts/pkgutil-resolve-name_1633981968097/work
platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1679871349196/work
pluggy @ file:///home/conda/feedstock_root/build_artifacts/pluggy_1667232663820/work
ply==3.11
pooch @ file:///home/conda/feedstock_root/build_artifacts/pooch_1679580333621/work
prometheus-client @ file:///home/conda/feedstock_root/build_artifacts/prometheus_client_1674535637125/work
prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1677600924538/work
psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1667885877572/work
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
PuLP==2.7.0
pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
py-cpuinfo==9.0.0
pyarrow==11.0.0
pybind11==2.10.4
pycodestyle @ file:///home/conda/feedstock_root/build_artifacts/pycodestyle_1669306857274/work
pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work
pydantic==1.10.7
pydocstyle @ file:///home/conda/feedstock_root/build_artifacts/pydocstyle_1673997095229/work
pydub==0.25.1
pyflakes @ file:///home/conda/feedstock_root/build_artifacts/pyflakes_1669319921641/work
Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1672682006896/work
pylint @ file:///home/conda/feedstock_root/build_artifacts/pylint_1679515272965/work
pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1680037383858/work
pyparsing @ file:///home/conda/feedstock_root/build_artifacts/pyparsing_1652235407899/work
PyQt5==5.15.7
PyQt5-sip==12.11.0
pyrsistent @ file:///home/conda/feedstock_root/build_artifacts/pyrsistent_1672681463845/work
PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1661604839144/work
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work
python-json-logger @ file:///home/conda/feedstock_root/build_artifacts/python-json-logger_1677079630776/work
python-lsp-jsonrpc @ file:///home/conda/feedstock_root/build_artifacts/python-lsp-jsonrpc_1618530352985/work
python-lsp-server @ file:///home/conda/feedstock_root/build_artifacts/python-lsp-server-meta_1674005136083/work
python-multipart==0.0.6
pytoolconfig @ file:///home/conda/feedstock_root/build_artifacts/pytoolconfig_1675124745143/work
pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1680088766131/work
PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1666772395347/work
pyzmq @ file:///home/conda/feedstock_root/build_artifacts/pyzmq_1679316826707/work
regex==2023.3.23
requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1673863902341/work
responses==0.18.0
rfc3339-validator @ file:///home/conda/feedstock_root/build_artifacts/rfc3339-validator_1638811747357/work
rfc3986==1.5.0
rfc3986-validator @ file:///home/conda/feedstock_root/build_artifacts/rfc3986-validator_1598024191506/work
rope @ file:///home/conda/feedstock_root/build_artifacts/rope_1674988456931/work
rwkv==0.7.3
safetensors==0.3.0
scikit-learn @ file:///home/conda/feedstock_root/build_artifacts/scikit-learn_1679675836718/work
scikit-learn-intelex==20230131.200059
scipy==1.10.1
semantic-version==2.10.0
Send2Trash @ file:///home/conda/feedstock_root/build_artifacts/send2trash_1628511208346/work
sentencepiece==0.1.97
sip @ file:///home/conda/feedstock_root/build_artifacts/sip_1675696581052/work
six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
sniffio @ file:///home/conda/feedstock_root/build_artifacts/sniffio_1662051266223/work
snowballstemmer @ file:///home/conda/feedstock_root/build_artifacts/snowballstemmer_1637143057757/work
soupsieve @ file:///home/conda/feedstock_root/build_artifacts/soupsieve_1658207591808/work
stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
starlette==0.26.1
sympy @ file:///home/conda/feedstock_root/build_artifacts/sympy_1679342590084/work
tensor-parallel @ file:///home/sgsdxzy/Programs/tensor_parallel
termcolor==2.2.0
terminado @ file:///home/conda/feedstock_root/build_artifacts/terminado_1670253674810/work
threadpoolctl @ file:///home/conda/feedstock_root/build_artifacts/threadpoolctl_1643647933166/work
tinycss2 @ file:///home/conda/feedstock_root/build_artifacts/tinycss2_1666100256010/work
tokenize-rt==5.0.0
tokenizers==0.13.3
toml @ file:///home/conda/feedstock_root/build_artifacts/toml_1604308577558/work
tomli @ file:///home/conda/feedstock_root/build_artifacts/tomli_1644342247877/work
tomlkit @ file:///home/conda/feedstock_root/build_artifacts/tomlkit_1679924068997/work
toolz==0.12.0
torch==2.0.0
torchaudio==2.0.0
torchvision==0.15.0
tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1666788589303/work
tqdm==4.65.0
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1675110562325/work
transformers @ git+https://github.com/huggingface/transformers.git@ee8e80a060d65ab349743ffcb5842365eb0e5606
triton==2.0.0
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1678559861143/work
uc-micro-py==1.0.1
ujson @ file:///home/conda/feedstock_root/build_artifacts/ujson_1675191915931/work
unicodedata2 @ file:///home/conda/feedstock_root/build_artifacts/unicodedata2_1667239886688/work
urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1678635778344/work
uvicorn==0.21.1
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1673864653149/work
webencodings==0.5.1
websocket-client @ file:///home/conda/feedstock_root/build_artifacts/websocket-client_1675567828044/work
websockets==10.4
whatthepatch @ file:///home/conda/feedstock_root/build_artifacts/whatthepatch_1675090462655/work
wrapt @ file:///home/conda/feedstock_root/build_artifacts/wrapt_1677485519705/work
xxhash==3.2.0
y-py @ file:///home/conda/feedstock_root/build_artifacts/y-py_1677231008299/work
yapf @ file:///home/conda/feedstock_root/build_artifacts/yapf_1641487982943/work
yarl==1.8.2
ypy-websocket @ file:///home/conda/feedstock_root/build_artifacts/ypy-websocket_1670333059911/work
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1677313463193/work

@BlackSamorez
Copy link
Owner

BlackSamorez commented Apr 7, 2023

@sgsdxzy By the way here's what I get on my setup with decapoda-research/llama-7b-hf:

  • GTX 1080 x 2 tensor_parallel: 25.87 seconds
  • GTX 1080 x 2 sequential: 16.11 seconds
  • GTX 1080 x 3 tensor_parallel: 22.40 seconds
  • GTX 1080 x 3 sequential: 18.14 seconds
  • GTX 1080 x 4 tensor_parallel: 75.03 seconds
  • GTX 1080 x 4 sequential: 19.41 seconds
  • RTX 3060 x 2 tensor_parallel: 15.25 seconds
  • RTX 3060 x 2 sequential: 19.90 seconds
  • RTX 3060 x 2 + GTX 1080 x 2 tensor_parallel: 73.69 seconds
  • RTX 3060 x 2 + GTX 1080 x 2 sequential: 31.91 seconds
  • RTX 3060 x 2 + GTX 1080 x 4 tensor_parallel: 123.15 seconds
  • RTX 3060 x 2 + GTX 1080 x 4 sequential: 29.55 seconds

Only RTX 3060 x 2 speeds things up. Something's definitely very wrong.

@BlackSamorez
Copy link
Owner

I've tested pure forward passes and it looks good:

  • Done 10 passes with batch_size=8 lenght=512 in 45.55 seconds with tensor_parallel
  • Done 10 passes with batch_size=8 lenght=512 in 194.08 seconds sequential

On the same GTX 1080 x 4. Maybe something's wrong with past_key_values processing which makes generation slow. Will look into it.

@sgsdxzy
Copy link
Author

sgsdxzy commented Apr 8, 2023

@BlackSamorez is that past_key_values are gathered to cuda:0 and redistributed to each rank every time?

@BlackSamorez
Copy link
Owner

@BlackSamorez is that past_key_values are gathered to cuda:0 and redistributed to each rank every time?

I'm not sure. There is a different data structure for ungathered tensors called PerDeviceTensors and it's used for past_key_values. They should not be gathered at all. I'll need to verify that it's working as expected.

@sgsdxzy
Copy link
Author

sgsdxzy commented Apr 11, 2023

Have you identified the issue?
With 1.2.1, load_in_8bit actually saves VRAM for me, but the performance is still bad.

@cxxz
Copy link

cxxz commented Apr 14, 2023

I also observed slowdown with tensor_parallel 1.2.1 compared to native performance on single GPU.

Setup

Llama-7b on 8 x A100 80GB (NVLink)

Prompt

"Count up from 100 to 130"

so the number of new generated tokens is a fixed value (155)

Inference Performance

1-GPU w/o TP: inference time 7.08s, GPU-util by nvidia-smi about 69%
2-way TP: inference time 10.24s, GPU-util by nvidia-smi only about 23%
the only code difference between the two tests are,

### 1-GPU w/o TP
model = LlamaForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16, device_map="sequential")

vs.

### 2-way TP
model = LlamaForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16)
model = TensorParallelPreTrainedModel(model, ["cuda:0", "cuda:1"])

any hints on what might have gone wrong?

@BlackSamorez
Copy link
Owner

I've measured the performance of LLaMA 13B on Kaggle 2x T4 and here's what I got:

Forward passes

image

Generation

image

It's definitely a .generate() problem. I'll look into it and, hopefully, release a fix soon.

@cxxz
Copy link

cxxz commented Apr 20, 2023

Thank you for sharing your findings on the performance of LLaMA 13B on Kaggle 2x T4. Good to know that you've identified the .generate() issue. I appreciate your efforts in looking into it and eagerly await the release of a fix. Keep up the good work!

@dmgcsilva
Copy link

dmgcsilva commented May 19, 2023

Hi @BlackSamorez , have you been able to identify and fix the issue? I am having similar issues, where using 2 way or even 4 way tp slows down inference times, while using 2xA100 40GB w/ NVLINK

@eric-mitchell
Copy link

eric-mitchell commented Jun 4, 2023

Would love to know if there is any update on this issue @BlackSamorez. tensor_parallel works great for us for training (nice job!), but the inability to actually sample from the model is a dealbreaker for us. We're seeing slow generation for non-llama models too (e.g., Pythia-6.9b).

@BlackSamorez
Copy link
Owner

@eric-mitchell @dmgcsilva Sadly, I have no time nor resources to properly test and benchmark this right now. I'll do it in a month or so.

@152334H
Copy link

152334H commented Jul 19, 2023

anyone find an alternative efficient TP solution yet?

@chujiezheng
Copy link

Also found that 4gpus tp is much slower than 2gpus tp, while the latter is still a bit faster than 2*gpus pp.

@dutsc
Copy link

dutsc commented Mar 15, 2024

This work is very meaningful. I followed @sgsdxzy and conducted the following test on 3090.

Model setup opt-6.7b 1gpu opt-6.7b 2gpu opt-1.3b1gpu opt-1.3b2gpu opt-13b 4gpu
Naive per token time (ms) 21.5 21.5(singal card) 12.5 12.5(singal card)  52.11
Naive memory per gpu (GB) 12.8 12.8 2.9 2.9 -
TP time (ms) - 76.89 - 62.1 373.71
TP memory per gpu (GB) - 6.5 - 1.6 6.7GB

But performance seems to be the same. Are there any other useful tensor parallel tools?

@sgsdxzy
Copy link
Author

sgsdxzy commented Mar 15, 2024

@dutsc I use Aphrodite-engine or vLLM for TP inference.

@dutsc
Copy link

dutsc commented Mar 15, 2024

Thank you for your answer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants