diff --git a/.gitignore b/.gitignore index b531b7918c30..5f41478c08e2 100644 --- a/.gitignore +++ b/.gitignore @@ -177,3 +177,12 @@ _build/ # vim swap files *.swo *.swp + +# Dataset files +*.json +*.json.[0-9] + +# Hipified files +*.hip +*_hip.h +*_hip.cuh diff --git a/Dockerfile b/Dockerfile index 72f0249490d9..85e92a229bb3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,72 +1,64 @@ -FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 AS dev - -RUN apt-get update -y \ - && apt-get install -y python3-pip - -WORKDIR /workspace - -# install build and runtime dependencies -COPY requirements.txt requirements.txt -RUN --mount=type=cache,target=/root/.cache/pip \ - pip install -r requirements.txt - -# install development dependencies -COPY requirements-dev.txt requirements-dev.txt -RUN --mount=type=cache,target=/root/.cache/pip \ - pip install -r requirements-dev.txt - -# image to build pytorch extensions -FROM dev AS build - -# copy input files -COPY csrc csrc -COPY setup.py setup.py -COPY requirements.txt requirements.txt -COPY pyproject.toml pyproject.toml -COPY vllm/__init__.py vllm/__init__.py - -# max jobs used by Ninja to build extensions -ENV MAX_JOBS=$max_jobs -RUN python3 setup.py build_ext --inplace - -# image to run unit testing suite -FROM dev AS test - -# copy pytorch extensions separately to avoid having to rebuild -# when python code changes -COPY --from=build /workspace/vllm/*.so /workspace/vllm/ -COPY tests tests -COPY vllm vllm - -ENTRYPOINT ["python3", "-m", "pytest", "tests"] - -# use CUDA base as CUDA runtime dependencies are already installed via pip -FROM nvidia/cuda:11.8.0-base-ubuntu22.04 AS vllm-base - -# libnccl required for ray -RUN apt-get update -y \ - && apt-get install -y python3-pip - -WORKDIR /workspace -COPY requirements.txt requirements.txt -RUN --mount=type=cache,target=/root/.cache/pip \ - pip install -r requirements.txt - -FROM vllm-base AS vllm -COPY --from=build /workspace/vllm/*.so /workspace/vllm/ -COPY vllm vllm - -EXPOSE 8000 -ENTRYPOINT ["python3", "-m", "vllm.entrypoints.api_server"] - -# openai api server alternative -FROM vllm-base AS vllm-openai -# install additional dependencies for openai api server -RUN --mount=type=cache,target=/root/.cache/pip \ - pip install accelerate fschat - -COPY --from=build /workspace/vllm/*.so /workspace/vllm/ -COPY vllm vllm - -ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] - +FROM rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1 + +# Install some basic utilities +RUN apt-get update && apt-get install python3 python3-pip -y + +# Install some basic utilities +RUN apt-get update && apt-get install -y \ + curl \ + ca-certificates \ + sudo \ + git \ + bzip2 \ + libx11-6 \ + build-essential \ + wget \ + unzip \ + nvidia-cuda-toolkit \ + tmux \ + && rm -rf /var/lib/apt/lists/* + +### Mount Point ### +# When launching the container, mount the code directory to /app +ARG APP_MOUNT=/app +VOLUME [ ${APP_MOUNT} ] +WORKDIR ${APP_MOUNT} + +RUN python3 -m pip install --upgrade pip +RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers + +ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer +ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin: +ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib: +ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/: +ENV PYTORCH_ROCM_ARCH=gfx900;gfx906;gfx908;gfx90a;gfx1030;gfx1101 + +# Install ROCm flash-attention +RUN mkdir libs \ + && cd libs \ + && git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \ + && cd flash-attention \ + && git submodule update --init \ + && sed -i -e "s/--offload-arch=native/--offload-arch=$(/opt/rocm/llvm/bin/amdgpu-offload-arch)/g" setup.py \ + && patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch \ + && python3 setup.py install \ + && cd .. + +COPY ./ /app/vllm-rocm/ + +RUN cd /app \ + && cd vllm-rocm \ + && git checkout v0.2.1.post1-rocm \ + && python3 setup.py install \ + && cd .. + +RUN cd /app \ + && mkdir dataset \ + && cd .. + +COPY ./benchmark_throughput.sh /app/benchmark_throughput.sh + +RUN python3 -m pip install --upgrade pip +RUN python3 -m pip install --no-cache-dir ray[all] + +CMD ["/bin/bash"] diff --git a/README.md b/README.md index 5ddb2a48ca83..6232d3157cb6 100644 --- a/README.md +++ b/README.md @@ -1,92 +1,147 @@ +

- - vLLM + + vLLM + + + ROCm

-

-Easy, fast, and cheap LLM serving for everyone -

+

+vLLM ROCm port +

+ +This version of vLLM 0.2.x supports model inferencing and serving on AMD GPUs with ROCm. This ROCm port was adapted from [vLLM](https://github.com/vllm-project/vllm), a ROCm [community port](https://github.com/pcmoritz/vllm-public/tree/port-to-rocm) and [xformers](https://github.com/facebookresearch/xformers), replacing the attention forward method employed in xformers by the ROCm realization of [flash attention](https://github.com/ROCmSoftwarePlatform/flash-attention). Currently this port does not support AWQ quantization yet, but SqueezeLLM has been incorporated. + +This port is an extension of our previous [vLLM v0.1.4 ROCm port](https://github.com/EmbeddedLLM/vllm-rocm/tree/v0.1.4-rocm). Compared with our previous port, vLLM v0.2.x achieves speedup of > 2x for LLaMA-70B model, and > 3x for LLaMA-7B/13B on MI210 thanks to the introduction of [efficient de-tokenization, vectorized sampling](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit#slide=id.g24c1f26d37c_10_117) and [paged attention v2](https://github.com/vllm-project/vllm/pull/1348).

-| Documentation | Blog | Paper | Discord | + + throughput_tokens + +

+

+ + throughput_requests +

--- -*Latest News* 🔥 -- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing). -- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there. -- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv! -- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM. -- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command! -- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds. -- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai). +*Latest News* +- [2023/11] We have updated our ROCm port for vLLM v0.2.x. +- [2023/10] LLaMA-2 models are now supported. 7B/13B/70B models can be run and served on AMD GPUs! --- -vLLM is a fast and easy-to-use library for LLM inference and serving. +## Getting Started -vLLM is fast with: +The following sections describes the installation of this ROCm port. If you intend to use our provided container, please skip to the [using docker](#using-docker) section. -- State-of-the-art serving throughput -- Efficient management of attention key and value memory with **PagedAttention** -- Continuous batching of incoming requests -- Optimized CUDA kernels +## Dependencies -vLLM is flexible and easy to use with: +To build this project, the following pre-requisites must be met: -- Seamless integration with popular Hugging Face models -- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more -- Tensor parallelism support for distributed inference -- Streaming outputs -- OpenAI-compatible API server +- [PyTorch](https://pytorch.org/) with ROCm (5.7.0 or later) support -vLLM seamlessly supports many Hugging Face models, including the following architectures: +- Install ROCm [flash-attention](https://github.com/ROCmSoftwarePlatform/flash-attention) following the instructions in [AMD ROCm Support](https://github.com/ROCmSoftwarePlatform/flash-attention#amd-gpurocm-support) -- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.) -- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.) -- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.) -- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.) -- GPT-2 (`gpt2`, `gpt2-xl`, etc.) -- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.) -- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.) -- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) -- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.) -- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) -- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) -- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) -- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) -- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.) +## Installation -Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): +Build the repository ```bash -pip install vllm +git clone https://github.com/EmbeddedLLM/vllm-rocm.git +cd vllm-rocm/ +python3 setup.py install ``` -## Getting Started +## Using Docker -Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started. -- [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html) -- [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html) -- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html) +A base docker image can be built from this repository: -## Contributing +```bash +docker build -t vllm-rocm . +``` + +Run a docker container with -We welcome and value any contributions and collaborations. -Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved. +```bash +docker run -it \ + --network=host \ + --group-add=video \ + --ipc=host \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + --shm-size 8G \ + --device /dev/kfd \ + --device /dev/dri \ + vllm-rocm \ + bash +``` -## Citation +Alternatively, you can pull from our pre-built docker image: -If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180): -```bibtex -@inproceedings{kwon2023efficient, - title={Efficient Memory Management for Large Language Model Serving with PagedAttention}, - author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica}, - booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles}, - year={2023} -} +```bash +docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.1.post1 + +docker run -it \ + --network=host \ + --group-add=video \ + --ipc=host \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + --shm-size 8G \ + --device /dev/kfd \ + --device /dev/dri \ + embeddedllminfo/vllm-rocm \ + bash ``` + +## Serving + +The project supports native vLLM serving + +```bash +python -m vllm.entrypoints.api_server \ + --model lmsys/vicuna-7b-v1.5 \ + --tensor-parallel-size 2 +``` + +## Benchmarking + +The benchmark results were obtained by running the vLLM benchmark scripts under the *benchmark* directory. + +If your vLLM is installed using the provided [docker environment](#using-docker), you can benchmark the inferencing throughput following the steps below: +- Download the model you would like to evaluate to a directory of your choice (say a vicuna-7b model is downloaded to /path/to/your/model/vicuna-7b-v1.5) +- Run the docker and mount the model to /app/model + +```bash +docker run -it \ + --network=host \ + --group-add=video \ + --ipc=host \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + --shm-size 8G \ + --device /dev/kfd \ + --device /dev/dri \ + -v /path/to/your/model/vicuna-7b-v1.5:/app/model \ + vllm-rocm \ + bash +``` +Inside the container, run +```bash +bash /app/benchmark_throughput.sh +``` + +## Acknowledgement + +This ROCm port was built upon the following amazing projects: + +- [vLLM](https://github.com/vllm-project/vllm) and [pcmoritz's ROCm fork](https://github.com/pcmoritz/vllm-public/tree/port-to-rocm) +- [flash-attention](https://github.com/ROCmSoftwarePlatform/flash-attention) +- [xformers](https://github.com/facebookresearch/xformers) diff --git a/benchmark_throughput.sh b/benchmark_throughput.sh new file mode 100644 index 000000000000..c0cb42989a42 --- /dev/null +++ b/benchmark_throughput.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +main () { + + cd /app + + positional_args=() + + model_path="/app/model" + dataset_path="/app/dataset/ShareGPT_V3_unfiltered_cleaned_split.json" + dataset_path_modified=0 + + while [[ $# -gt 0 ]]; do + case $1 in + "-h"|"--help") + python3 /app/vllm-rocm/benchmarks/benchmark_throughput.py --help + return 0 + ;; + "--dataset") + dataset_path="$2" + dataset_path_modified=1 + shift + shift + ;; + "--model") + model_path="$2" + shift + shift + ;; + *) + positional_args+=("$1") + shift + ;; + esac + done + + if [ ! -f "$dataset_path" ]; then + if [[ $dataset_path_modified -lt 1 ]]; then + cd /app/dataset + wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + cd /app + fi + fi + + python3 /app/vllm-rocm/benchmarks/benchmark_throughput.py --dataset "$dataset_path" --model "$model_path" "${positional_args[@]}" + return $? + +} + +main "$@" +exit $? diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index e560cb1fbfc0..eb9497204630 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -70,7 +70,8 @@ def run_to_completion(profile: bool = False): parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--quantization', '-q', - choices=['awq', 'squeezellm', None], + # choices=['awq', 'squeezellm', None], + choices=['squeezellm', None], default=None) parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--input-len', type=int, default=32) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 3a80e679191e..3d8f863770a7 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -104,6 +104,7 @@ async def send_request( output_len: int, best_of: int, use_beam_search: bool, + repetition_penalty: float = 1.0, ) -> None: request_start_time = time.perf_counter() @@ -119,6 +120,7 @@ async def send_request( "max_tokens": output_len, "ignore_eos": True, "stream": False, + "repetition_penalty": repetition_penalty, } elif backend == "tgi": assert not use_beam_search @@ -160,13 +162,15 @@ async def benchmark( best_of: int, use_beam_search: bool, request_rate: float, + repetition_penalty: float = 1.0, ) -> None: tasks: List[asyncio.Task] = [] async for request in get_request(input_requests, request_rate): prompt, prompt_len, output_len = request task = asyncio.create_task(send_request(backend, api_url, prompt, prompt_len, output_len, - best_of, use_beam_search)) + best_of, use_beam_search, + repetition_penalty=repetition_penalty,)) tasks.append(task) await asyncio.gather(*tasks) @@ -182,7 +186,7 @@ def main(args: argparse.Namespace): benchmark_start_time = time.perf_counter() asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of, - args.use_beam_search, args.request_rate)) + args.use_beam_search, args.request_rate, args.repetition_penalty)) benchmark_end_time = time.perf_counter() benchmark_time = benchmark_end_time - benchmark_start_time print(f"Total time: {benchmark_time:.2f} s") @@ -229,5 +233,7 @@ def main(args: argparse.Namespace): parser.add_argument("--seed", type=int, default=0) parser.add_argument('--trust-remote-code', action='store_true', help='trust remote code from huggingface') + parser.add_argument("--repetition-penalty", type=float, default=1.0, + help="Set >1 to penalize repetition and <1 to reward repetition") args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index fc578b497286..62951eedcb33 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -65,6 +65,9 @@ def run_vllm( use_beam_search: bool, trust_remote_code: bool, dtype: str, + block_size: int = 16, + max_num_seqs: int = 256, + repetition_penalty = 1.0, ) -> float: llm = LLM( model=model, @@ -74,8 +77,12 @@ def run_vllm( seed=seed, trust_remote_code=trust_remote_code, dtype=dtype, + block_size=block_size, + max_num_seqs=max_num_seqs ) + print("Using repetition penalty: {}".format(repetition_penalty)) + # Add the requests to the engine. for prompt, _, output_len in requests: sampling_params = SamplingParams( @@ -85,6 +92,7 @@ def run_vllm( use_beam_search=use_beam_search, ignore_eos=True, max_tokens=output_len, + repetition_penalty=repetition_penalty, ) # FIXME(woosuk): Do not use internal method. llm._add_request( @@ -170,10 +178,13 @@ def main(args: argparse.Namespace): requests = sample_requests(args.dataset, args.num_prompts, tokenizer) if args.backend == "vllm": + print("Using block_size={}, batch_size={}".format(args.block_size, args.max_num_seqs)) elapsed_time = run_vllm(requests, args.model, args.tokenizer, args.quantization, args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, - args.trust_remote_code, args.dtype) + args.trust_remote_code, args.dtype, + block_size=args.block_size, max_num_seqs=args.max_num_seqs, + repetition_penalty=args.repetition_penalty) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -201,7 +212,8 @@ def main(args: argparse.Namespace): parser.add_argument("--tokenizer", type=str, default=None) parser.add_argument('--quantization', '-q', - choices=['awq', 'squeezellm', None], + # choices=['awq', 'squeezellm', None], + choices=['squeezellm', None], default=None) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--n", @@ -230,6 +242,16 @@ def main(args: argparse.Namespace): 'The "auto" option will use FP16 precision ' 'for FP32 and FP16 models, and BF16 precision ' 'for BF16 models.') + parser.add_argument('--block-size', + type=int, + default=16, + choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], + help='token block size') + parser.add_argument("--max-num-seqs", + type=int, + default=256) + parser.add_argument("--repetition-penalty", type=float, default=1.0, + help="Set >1 to penalize repetition and <1 to reward repetition") args = parser.parse_args() if args.backend == "vllm": diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 89d1ba2d37dd..1cca2c5fccc1 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -1,6 +1,7 @@ #include #include +#include "cuda_compat.h" #include "dispatch_utils.h" namespace vllm { @@ -18,8 +19,8 @@ __global__ void silu_and_mul_kernel( const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { - const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]); - const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]); + const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); + const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); out[token_idx * d + idx] = silu(x) * y; } } @@ -57,7 +58,7 @@ __global__ void activation_kernel( const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { - const scalar_t x = __ldg(&input[token_idx * d + idx]); + const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); out[token_idx * d + idx] = ACT_FN(x); } } diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 78e8d8ecd6d4..9228de194bf3 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -23,7 +23,11 @@ #include +#ifndef USE_ROCM #define WARP_SIZE 32 +#else +#define WARP_SIZE 64 +#endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) @@ -40,7 +44,7 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Compute the sum per warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + sum += VLLM_SHFL_XOR_SYNC(sum, mask); } // Warp leaders store the data to shared memory. @@ -59,11 +63,11 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Parallel reduction inside the warp. #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + sum += VLLM_SHFL_XOR_SYNC(sum, mask); } // Broadcast to other threads. - return __shfl_sync(uint32_t(-1), sum, 0); + return VLLM_SHFL_SYNC(sum, 0); } // TODO(woosuk): Merge the last two dimensions of the grid. @@ -223,7 +227,7 @@ __device__ void paged_attention_kernel( // The 0-th thread of each thread group already has its max qk value. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); } if (lane == 0) { red_smem[warp_idx] = qk_max; @@ -235,10 +239,10 @@ __device__ void paged_attention_kernel( qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); } // Broadcast the max qk value to all threads. - qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + qk_max = VLLM_SHFL_SYNC(qk_max, 0); // Get the sum of the exp values. float exp_sum = 0.f; @@ -326,7 +330,7 @@ __device__ void paged_attention_kernel( float acc = accs[i]; #pragma unroll for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { - acc += __shfl_xor_sync(uint32_t(-1), acc, mask); + acc += VLLM_SHFL_XOR_SYNC(acc, mask); } accs[i] = acc; } @@ -492,7 +496,7 @@ __global__ void paged_attention_v2_reduce_kernel( // Reduce within the warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); } if (lane == 0) { red_smem[warp_idx] = max_logit; @@ -502,10 +506,10 @@ __global__ void paged_attention_v2_reduce_kernel( max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); } // Broadcast the max value to all threads. - max_logit = __shfl_sync(uint32_t(-1), max_logit, 0); + max_logit = VLLM_SHFL_SYNC(max_logit, 0); // Load rescaled exp sums to shared memory. float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); @@ -538,9 +542,10 @@ __global__ void paged_attention_v2_reduce_kernel( } // namespace vllm +#ifndef USE_ROCM #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ cudaFuncSetAttribute( \ - vllm::paged_attention_v1_kernel, \ + (void*)vllm::paged_attention_v1_kernel, \ cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ vllm::paged_attention_v1_kernel \ <<>>( \ @@ -557,6 +562,27 @@ __global__ void paged_attention_v2_reduce_kernel( q_stride, \ kv_block_stride, \ kv_head_stride); +#else +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + hipFuncSetAttribute( \ + (void*)vllm::paged_attention_v1_kernel, \ + hipFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ + vllm::paged_attention_v1_kernel \ + <<>>( \ + out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + head_mapping_ptr, \ + scale, \ + block_tables_ptr, \ + context_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride); +#endif // TODO(woosuk): Tune NUM_THREADS. template< @@ -654,6 +680,15 @@ void paged_attention_v1_launcher( // 1, 2, 4, 64, 128, 256. #define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \ switch (block_size) { \ + /*case 1: */ \ + /* CALL_V1_LAUNCHER(T, 1); */ \ + /* break; */ \ + /*case 2: */ \ + /* CALL_V1_LAUNCHER(T, 2); */ \ + /* break; */ \ + /*case 4: */ \ + /* CALL_V1_LAUNCHER(T, 4); */ \ + /* break; */ \ case 8: \ CALL_V1_LAUNCHER(T, 8); \ break; \ @@ -663,6 +698,15 @@ void paged_attention_v1_launcher( case 32: \ CALL_V1_LAUNCHER(T, 32); \ break; \ + /*case 64: */ \ + /* CALL_V1_LAUNCHER(T, 64); */ \ + /* break; */ \ + /*case 128: */ \ + /* CALL_V1_LAUNCHER(T, 128); */ \ + /* break; */ \ + /*case 256: */ \ + /* CALL_V1_LAUNCHER(T, 256); */ \ + /* break; */ \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ break; \ @@ -826,6 +870,15 @@ void paged_attention_v2_launcher( // 1, 2, 4, 64, 128, 256. #define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \ switch (block_size) { \ + /*case 1: */ \ + /* CALL_V2_LAUNCHER(T, 1); */ \ + /* break; */ \ + /*case 2: */ \ + /* CALL_V2_LAUNCHER(T, 2); */ \ + /* break; */ \ + /*case 4: */ \ + /* CALL_V2_LAUNCHER(T, 4); */ \ + /* break; */ \ case 8: \ CALL_V2_LAUNCHER(T, 8); \ break; \ @@ -835,6 +888,15 @@ void paged_attention_v2_launcher( case 32: \ CALL_V2_LAUNCHER(T, 32); \ break; \ + /*case 64: */ \ + /* CALL_V2_LAUNCHER(T, 64); */ \ + /* break; */ \ + /*case 128: */ \ + /* CALL_V2_LAUNCHER(T, 128); */ \ + /* break; */ \ + /*case 256: */ \ + /* CALL_V2_LAUNCHER(T, 256); */ \ + /* break; */ \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ break; \ diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index bb7df25b14f0..ff64c4bd8f80 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -17,6 +17,7 @@ */ #pragma once +#include "../cuda_compat.h" #include "attention_dtypes.h" #include @@ -39,7 +40,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { float qk = sum(qk_vec); #pragma unroll for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + qk += VLLM_SHFL_XOR_SYNC(qk, mask); } return qk; } diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 5786f77f7bca..31e0cee01d2e 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -21,8 +21,17 @@ #include "attention_generic.cuh" #include "dtype_float32.cuh" -#include -#include +#ifndef USE_ROCM + #include + #include +#else + #include + #include + + typedef __hip_bfloat162 __nv_bfloat162; + typedef __hip_bfloat16 __nv_bfloat16; +#endif + #include namespace vllm { @@ -98,7 +107,11 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else - return a + b; + #ifndef USE_ROCM + return a + b; + #else + return __hadd(a, b); + #endif #endif } diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index e67921128d52..b5e128ce8ee0 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -21,6 +21,10 @@ #include "attention_generic.cuh" #include "dtype_float32.cuh" +#ifdef USE_ROCM + #include +#endif + #include namespace vllm { @@ -63,58 +67,114 @@ struct FloatVec { // Utility functions for type conversions. inline __device__ uint32_t h0_h0(uint16_t a) { +#ifndef USE_ROCM uint32_t b; asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); return b; +#else + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u16[0] = a; + tmp.u16[1] = a; + return tmp.u32; +#endif } inline __device__ float half_to_float(uint16_t h) { +#ifndef USE_ROCM float f; asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); return f; +#else + float f; + asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h)); + return f; +#endif } inline __device__ float2 half2_to_float2(uint32_t v) { +#ifndef USE_ROCM uint16_t lo, hi; asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); return make_float2(half_to_float(lo), half_to_float(hi)); +#else + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u32 = v; + float2 ret; + ret.x = half_to_float(tmp.u16[0]); + ret.y = half_to_float(tmp.u16[1]); + return ret; +#endif } inline __device__ uint16_t float_to_half(float f) { +#ifndef USE_ROCM union { uint32_t u32; uint16_t u16[2]; } tmp; asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); return tmp.u16[0]; +#else + uint16_t ret; + asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(ret) : "v"(f)); + return ret; +#endif } inline __device__ uint32_t float2_to_half2(float2 f) { +#ifndef USE_ROCM union { uint32_t u32; uint16_t u16[2]; } tmp; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); + #else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); + #endif + return tmp.u32; #else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); -#endif + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u16[0] = float_to_half(f.x); + tmp.u16[1] = float_to_half(f.y); return tmp.u32; +#endif } // Vector addition. inline __device__ uint16_t add(uint16_t a, uint16_t b) { +#ifndef USE_ROCM uint16_t c; asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); return c; +#else + uint16_t c; + asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); + return c; +#endif } inline __device__ uint32_t add(uint32_t a, uint32_t b) { +#ifndef USE_ROCM uint32_t c; asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; +#else + uint32_t c; + asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); + return c; +#endif } inline __device__ uint2 add(uint2 a, uint2 b) { @@ -157,16 +217,28 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { // Vector multiplication. template<> inline __device__ uint16_t mul(uint16_t a, uint16_t b) { +#ifndef USE_ROCM uint16_t c; asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); return c; +#else + uint16_t c; + asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); + return c; +#endif } template<> inline __device__ uint32_t mul(uint32_t a, uint32_t b) { +#ifndef USE_ROCM uint32_t c; asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; +#else + uint32_t c; + asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); + return c; +#endif } template<> @@ -271,9 +343,15 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) { // Vector fused multiply-add. inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { +#ifndef USE_ROCM uint32_t d; asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); return d; +#else + uint32_t d; + asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); + return d; +#endif } inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 3ad52b1681c0..59bacffdf464 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,6 +1,7 @@ #include #include +#include "cuda_compat.h" #include "dispatch_utils.h" #include @@ -28,8 +29,8 @@ void swap_blocks( TORCH_CHECK(false, "Invalid device combination"); } - void *src_ptr = src.data_ptr(); - void *dst_ptr = dst.data_ptr(); + char *src_ptr = static_cast(src.data_ptr()); + char *dst_ptr = static_cast(dst.data_ptr()); const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -267,8 +268,8 @@ __global__ void gather_cached_kv_kernel( + head_offset * block_size + block_offset; - key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]); - value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]); + key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]); + value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]); } } @@ -333,8 +334,8 @@ __global__ void gather_cached_kv_kernel_optimized( src_key_indices[j] = src_key_idx; src_value_indices[j] = src_value_idx; - keys_to_store[j] = __ldg(&key_cache[src_key_idx]); - values_to_store[j] = __ldg(&value_cache[src_value_idx]); + keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]); + values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]); } #pragma unroll diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h new file mode 100644 index 000000000000..8991462a862e --- /dev/null +++ b/csrc/cuda_compat.h @@ -0,0 +1,19 @@ +#pragma once + +#ifndef USE_ROCM + #define VLLM_LDG(arg) __ldg(arg) +#else + #define VLLM_LDG(arg) *(arg) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) +#else + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane); +#else + #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) +#endif diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu index f1c30fe7ea99..2439f5922a3f 100644 --- a/csrc/cuda_utils_kernels.cu +++ b/csrc/cuda_utils_kernels.cu @@ -1,3 +1,7 @@ +#ifdef USE_ROCM + #include +#endif + int get_device_attribute( int attribute, int device_id) diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 0a5ec95f8c0d..e1dc711778ff 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -1,6 +1,7 @@ #include #include +#include "cuda_compat.h" #include "dispatch_utils.h" namespace vllm { @@ -19,14 +20,14 @@ inline __device__ void apply_rotary_embedding( // GPT-NeoX style rotary embedding. x_index = rot_offset; y_index = embed_dim + rot_offset; - cos = __ldg(cos_ptr + x_index); - sin = __ldg(sin_ptr + x_index); + cos = VLLM_LDG(cos_ptr + x_index); + sin = VLLM_LDG(sin_ptr + x_index); } else { // GPT-J style rotary embedding. x_index = 2 * rot_offset; y_index = 2 * rot_offset + 1; - cos = __ldg(cos_ptr + x_index / 2); - sin = __ldg(sin_ptr + x_index / 2); + cos = VLLM_LDG(cos_ptr + x_index / 2); + sin = VLLM_LDG(sin_ptr + x_index / 2); } const scalar_t x = arr[x_index]; diff --git a/csrc/quantization.cpp b/csrc/quantization.cpp index dfe17a496c78..6ebcc35e4227 100644 --- a/csrc/quantization.cpp +++ b/csrc/quantization.cpp @@ -1,11 +1,11 @@ #include -torch::Tensor awq_gemm( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters); +// torch::Tensor awq_gemm( +// torch::Tensor _in_feats, +// torch::Tensor _kernel, +// torch::Tensor _scaling_factors, +// torch::Tensor _zeros, +// int split_k_iters); void squeezellm_gemm( torch::Tensor vec, @@ -14,6 +14,6 @@ void squeezellm_gemm( torch::Tensor lookup_table); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); + // m.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); m.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); } diff --git a/csrc/quantization/squeezellm/quant_cuda_kernel.cu b/csrc/quantization/squeezellm/quant_cuda_kernel.cu index 1392b877397b..2c37d01e0ae5 100644 --- a/csrc/quantization/squeezellm/quant_cuda_kernel.cu +++ b/csrc/quantization/squeezellm/quant_cuda_kernel.cu @@ -20,9 +20,17 @@ __device__ inline unsigned int as_unsigned(int i) { // 4-bit matvec kernel (LUT-based) __global__ void NUQ4MatMulKernel( +#ifndef USE_ROCM const half2* __restrict__ vec, +#else + const __half2* __restrict__ vec, +#endif const int* __restrict__ mat, +#ifndef USE_ROCM half2* __restrict__ mul, +#else + float2* __restrict__ mul, +#endif const __half* __restrict__ lookup_table, int height, int width, @@ -35,7 +43,11 @@ __global__ void NUQ4MatMulKernel( int row = BLOCKHEIGHT4 * blockIdx.x; int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; +#ifndef USE_ROCM __shared__ half2 blockvec[blockwidth2]; +#else + __shared__ __half2 blockvec[blockwidth2]; +#endif __shared__ __half deq2[16][BLOCKWIDTH]; int off = threadIdx.x; @@ -46,8 +58,13 @@ __global__ void NUQ4MatMulKernel( } __half res; +#ifndef USE_ROCM half2 res2; half2 tmp2; +#else + __half2 res2; + __half2 tmp2; +#endif int i; int k; @@ -68,48 +85,96 @@ __global__ void NUQ4MatMulKernel( while (k < blockwidth2) { tmp1 = as_unsigned(mat[i]); +#ifndef USE_ROCM res2 = {}; tmp2 = {}; +#else + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); + tmp2.x = __half_as_ushort(__float2half(0)); + tmp2.y = __half_as_ushort(__float2half(0)); +#endif lut_index1 = tmp1 & 0xF; lut_index2 = (tmp1 >> 4) & 0xF; +#ifndef USE_ROCM tmp2.x = deq2[lut_index1][off]; tmp2.y = deq2[lut_index2][off]; +#else + tmp2.x = __half_as_ushort(deq2[lut_index1][off]); + tmp2.y = __half_as_ushort(deq2[lut_index2][off]); +#endif res2 = __hfma2(tmp2, blockvec[k + 0], res2); lut_index1 = (tmp1 >> 8) & 0xF; lut_index2 = (tmp1 >> 12) & 0xF; +#ifndef USE_ROCM tmp2.x = deq2[lut_index1][off]; tmp2.y = deq2[lut_index2][off]; +#else + tmp2.x = __half_as_ushort(deq2[lut_index1][off]); + tmp2.y = __half_as_ushort(deq2[lut_index2][off]); +#endif res2 = __hfma2(tmp2, blockvec[k + 1], res2); lut_index1 = (tmp1 >> 16) & 0xF; lut_index2 = (tmp1 >> 20) & 0xF; +#ifndef USE_ROCM tmp2.x = deq2[lut_index1][off]; tmp2.y = deq2[lut_index2][off]; +#else + tmp2.x = __half_as_ushort(deq2[lut_index1][off]); + tmp2.y = __half_as_ushort(deq2[lut_index2][off]); +#endif res2 = __hfma2(tmp2, blockvec[k + 2], res2); lut_index1 = (tmp1 >> 24) & 0xF; lut_index2 = (tmp1 >> 28) & 0xF; +#ifndef USE_ROCM tmp2.x = deq2[lut_index1][off]; tmp2.y = deq2[lut_index2][off]; +#else + tmp2.x = __half_as_ushort(deq2[lut_index1][off]); + tmp2.y = __half_as_ushort(deq2[lut_index2][off]); +#endif res2 = __hfma2(tmp2, blockvec[k + 3], res2); +#ifndef USE_ROCM res = __hadd(__hadd(res2.x, res2.y), res); +#else + res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res); +#endif i += width; k += 4; } // col%2 -> only set one of the two values +#ifndef USE_ROCM half2 res3 = {}; if (col % 2 == 0) { res3.x = res; } else { res3.y = res; } +#else + __half2 res3; + res3.x = __half_as_ushort(__float2half(0)); + res3.y = __half_as_ushort(__float2half(0)); + if (col % 2 == 0) { + res3.x = __half_as_ushort(res); + } else { + res3.y = __half_as_ushort(res); + } +#endif +#ifndef USE_ROCM atomicAdd(&mul[b * width / 2 + col / 2], res3); +#else + int tmp_addr = b * width / 2 + col / 2; + atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x))); + atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y))); +#endif } } @@ -136,10 +201,19 @@ void squeezellm_gemm( dim3 threads(BLOCKWIDTH); vllm::squeezellm::NUQ4MatMulKernel<<>>( +#ifndef USE_ROCM (half2*) vec.data(), +#else + (__half2*) vec.data_ptr(), +#endif mat.data_ptr(), +#ifndef USE_ROCM (half2*) mul.data(), (__half*) lookup_table.data(), +#else + (float2*) mul.data_ptr(), + (__half*) lookup_table.data_ptr(), +#endif height, width, batch, vec_height ); } diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index bc35aa0424b5..b95ccef16207 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -17,13 +17,15 @@ */ #pragma once +#include "cuda_compat.h" + namespace vllm { template __inline__ __device__ T warpReduceSum(T val) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) - val += __shfl_xor_sync(0xffffffff, val, mask, 32); + val += VLLM_SHFL_XOR_SYNC(val, mask); return val; } diff --git a/docs/source/assets/benchmarks/throughput_requests.png b/docs/source/assets/benchmarks/throughput_requests.png new file mode 100644 index 000000000000..9b4b541393ca Binary files /dev/null and b/docs/source/assets/benchmarks/throughput_requests.png differ diff --git a/docs/source/assets/benchmarks/throughput_requests.svg b/docs/source/assets/benchmarks/throughput_requests.svg new file mode 100644 index 000000000000..f87c9fa9defc --- /dev/null +++ b/docs/source/assets/benchmarks/throughput_requests.svg @@ -0,0 +1,434 @@ + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 2.2x 3.2x 3.5x + + + diff --git a/docs/source/assets/benchmarks/throughput_tokens.png b/docs/source/assets/benchmarks/throughput_tokens.png new file mode 100644 index 000000000000..59bf778cd08c Binary files /dev/null and b/docs/source/assets/benchmarks/throughput_tokens.png differ diff --git a/docs/source/assets/benchmarks/throughput_tokens.svg b/docs/source/assets/benchmarks/throughput_tokens.svg new file mode 100644 index 000000000000..c1643afc4e3d --- /dev/null +++ b/docs/source/assets/benchmarks/throughput_tokens.svg @@ -0,0 +1,535 @@ + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 2.2x 3.2x 3.5x + + + diff --git a/docs/source/assets/logos/16900649.png b/docs/source/assets/logos/16900649.png new file mode 100644 index 000000000000..095beb65b139 Binary files /dev/null and b/docs/source/assets/logos/16900649.png differ diff --git a/examples/api_client.py b/examples/api_client.py index 70ec8c549212..6f84973e0318 100644 --- a/examples/api_client.py +++ b/examples/api_client.py @@ -17,15 +17,17 @@ def clear_line(n: int = 1) -> None: def post_http_request(prompt: str, api_url: str, n: int = 1, - stream: bool = False) -> requests.Response: + stream: bool = False, + repetition_penalty = 1.0) -> requests.Response: headers = {"User-Agent": "Test Client"} pload = { "prompt": prompt, "n": n, "use_beam_search": True, "temperature": 0.0, - "max_tokens": 16, + "max_tokens": 1024, "stream": stream, + "repetition_penalty": repetition_penalty, } response = requests.post(api_url, headers=headers, json=pload, stream=True) return response @@ -54,6 +56,8 @@ def get_response(response: requests.Response) -> List[str]: parser.add_argument("--n", type=int, default=4) parser.add_argument("--prompt", type=str, default="San Francisco is a") parser.add_argument("--stream", action="store_true") + parser.add_argument("--repetition-penalty", type=float, default=1.0, + help="Set >1 to penalize repetition and <1 to reward repetition") args = parser.parse_args() prompt = args.prompt api_url = f"http://{args.host}:{args.port}/generate" @@ -61,7 +65,7 @@ def get_response(response: requests.Response) -> List[str]: stream = args.stream print(f"Prompt: {prompt!r}\n", flush=True) - response = post_http_request(prompt, api_url, n, stream) + response = post_http_request(prompt, api_url, n, stream, repetition_penalty=args.repetition_penalty) if stream: num_printed_lines = 0 diff --git a/requirements.txt b/requirements.txt index d8597b3ec554..7d4fe2972815 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,10 +4,10 @@ ray >= 2.5.1 pandas # Required for Ray data. pyarrow # Required for Ray data. sentencepiece # Required for LLaMA tokenizer. -numpy +numpy >= 1.22.4 torch == 2.0.1 transformers >= 4.34.0 # Required for Mistral. -xformers == 0.0.22 # Required for Mistral. +#xformers == 0.0.22 # Required for Mistral. fastapi uvicorn[standard] pydantic == 1.10.13 # Required for OpenAI server. diff --git a/setup.py b/setup.py index 660b5196cfd9..784de6734da4 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ from packaging.version import parse, Version import setuptools import torch -from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME +from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME ROOT_DIR = os.path.dirname(__file__) @@ -24,10 +24,14 @@ CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] -if CUDA_HOME is None: - raise RuntimeError( - "Cannot find CUDA_HOME. CUDA must be available to build the package.") +if torch.version.hip: + if ROCM_HOME is not None: + NVCC_FLAGS += [f"-DUSE_ROCM"] +if not torch.version.hip: + if CUDA_HOME is None: + raise RuntimeError( + "Cannot find CUDA_HOME. CUDA must be available to build the package.") def get_nvcc_cuda_version(cuda_dir: str) -> Version: """Get the CUDA version from nvcc. @@ -76,66 +80,72 @@ def get_torch_arch_list() -> Set[str]: f"{valid_archs}.") return arch_list - -# First, check the TORCH_CUDA_ARCH_LIST environment variable. -compute_capabilities = get_torch_arch_list() -if not compute_capabilities: - # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available - # GPUs on the current machine. - device_count = torch.cuda.device_count() - for i in range(device_count): - major, minor = torch.cuda.get_device_capability(i) - if major < 7: - raise RuntimeError( - "GPUs with compute capability below 7.0 are not supported.") - compute_capabilities.add(f"{major}.{minor}") - -nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) -if not compute_capabilities: - # If no GPU is specified nor available, add all supported architectures - # based on the NVCC CUDA version. - compute_capabilities = SUPPORTED_ARCHS.copy() +def get_cuda_compute_capabilities(nvcc_cuda_version): + # First, check the TORCH_CUDA_ARCH_LIST environment variable. + compute_capabilities = get_torch_arch_list() + if not compute_capabilities: + # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available + # GPUs on the current machine. + device_count = torch.cuda.device_count() + for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 7: + raise RuntimeError( + "GPUs with compute capability below 7.0 are not supported.") + compute_capabilities.add(f"{major}.{minor}") + + if not compute_capabilities: + # If no GPU is specified nor available, add all supported architectures + # based on the NVCC CUDA version. + compute_capabilities = SUPPORTED_ARCHS.copy() + if nvcc_cuda_version < Version("11.1"): + compute_capabilities.remove("8.6") + if nvcc_cuda_version < Version("11.8"): + compute_capabilities.remove("8.9") + compute_capabilities.remove("9.0") + + return compute_capabilities + +def validate_nvcc_cuda_version(nvcc_cuda_version, compute_capabilities): + if nvcc_cuda_version < Version("11.0"): + raise RuntimeError("CUDA 11.0 or higher is required to build the package.") if nvcc_cuda_version < Version("11.1"): - compute_capabilities.remove("8.6") + if any(cc.startswith("8.6") for cc in compute_capabilities): + raise RuntimeError( + "CUDA 11.1 or higher is required for compute capability 8.6.") if nvcc_cuda_version < Version("11.8"): - compute_capabilities.remove("8.9") - compute_capabilities.remove("9.0") - -# Validate the NVCC CUDA version. -if nvcc_cuda_version < Version("11.0"): - raise RuntimeError("CUDA 11.0 or higher is required to build the package.") -if nvcc_cuda_version < Version("11.1"): - if any(cc.startswith("8.6") for cc in compute_capabilities): - raise RuntimeError( - "CUDA 11.1 or higher is required for compute capability 8.6.") -if nvcc_cuda_version < Version("11.8"): - if any(cc.startswith("8.9") for cc in compute_capabilities): - # CUDA 11.8 is required to generate the code targeting compute capability 8.9. - # However, GPUs with compute capability 8.9 can also run the code generated by - # the previous versions of CUDA 11 and targeting compute capability 8.0. - # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 - # instead of 8.9. - warnings.warn( - "CUDA 11.8 or higher is required for compute capability 8.9. " - "Targeting compute capability 8.0 instead.") - compute_capabilities = set(cc for cc in compute_capabilities - if not cc.startswith("8.9")) - compute_capabilities.add("8.0+PTX") - if any(cc.startswith("9.0") for cc in compute_capabilities): - raise RuntimeError( - "CUDA 11.8 or higher is required for compute capability 9.0.") + if any(cc.startswith("8.9") for cc in compute_capabilities): + # CUDA 11.8 is required to generate the code targeting compute capability 8.9. + # However, GPUs with compute capability 8.9 can also run the code generated by + # the previous versions of CUDA 11 and targeting compute capability 8.0. + # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 + # instead of 8.9. + warnings.warn( + "CUDA 11.8 or higher is required for compute capability 8.9. " + "Targeting compute capability 8.0 instead.") + compute_capabilities = set(cc for cc in compute_capabilities + if not cc.startswith("8.9")) + compute_capabilities.add("8.0+PTX") + if any(cc.startswith("9.0") for cc in compute_capabilities): + raise RuntimeError( + "CUDA 11.8 or higher is required for compute capability 9.0.") + +if not torch.version.hip: + nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) + compute_capabilities = get_cuda_compute_capabilities(nvcc_cuda_version) + validate_nvcc_cuda_version(nvcc_cuda_version, compute_capabilities) -# Add target compute capabilities to NVCC flags. -for capability in compute_capabilities: - num = capability[0] + capability[2] - NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] - if capability.endswith("+PTX"): - NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] + # Add target compute capabilities to NVCC flags. + for capability in compute_capabilities: + num = capability[0] + capability[2] + NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] + if capability.endswith("+PTX"): + NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] -# Use NVCC threads to parallelize the build. -if nvcc_cuda_version >= Version("11.2"): - num_threads = min(os.cpu_count(), 8) - NVCC_FLAGS += ["--threads", str(num_threads)] + # Use NVCC threads to parallelize the build. + if nvcc_cuda_version >= Version("11.2"): + num_threads = min(os.cpu_count(), 8) + NVCC_FLAGS += ["--threads", str(num_threads)] ext_modules = [] @@ -199,7 +209,7 @@ def get_torch_arch_list() -> Set[str]: name="vllm.quantization_ops", sources=[ "csrc/quantization.cpp", - "csrc/quantization/awq/gemm_kernels.cu", + # "csrc/quantization/awq/gemm_kernels.cu", "csrc/quantization/squeezellm/quant_cuda_kernel.cu", ], extra_compile_args={ diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 7c4a84d4c7d8..35c6a6ab2fa1 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -3,8 +3,8 @@ import pytest import torch -from xformers import ops as xops -from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask +from vllm.xformers.ops.fmha import memory_efficient_attention_forward +from vllm.xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from vllm import attention_ops from vllm.utils import get_max_shared_memory_bytes @@ -308,7 +308,7 @@ def test_multi_query_kv_attention( key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) - output = xops.memory_efficient_attention_forward( + output = memory_efficient_attention_forward( query.unsqueeze(0), key.unsqueeze(0), value.unsqueeze(0), diff --git a/vllm/config.py b/vllm/config.py index a9e86c24b273..13db4d8a65b0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -103,7 +103,8 @@ def _verify_tokenizer_mode(self) -> None: self.tokenizer_mode = tokenizer_mode def _verify_quantization(self) -> None: - supported_quantization = ["awq", "squeezellm"] + # supported_quantization = ["awq", "squeezellm"] + supported_quantization = ["squeezellm"] if self.quantization is None: return quantization = self.quantization.lower() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index cc425a2c079e..dc00da04c183 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -132,7 +132,7 @@ def add_cli_args( parser.add_argument('--block-size', type=int, default=EngineArgs.block_size, - choices=[8, 16, 32], + choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size') # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). parser.add_argument('--seed', @@ -168,7 +168,8 @@ def add_cli_args( parser.add_argument('--quantization', '-q', type=str, - choices=['awq', 'squeezellm', None], + # choices=['awq', 'squeezellm', None], + choices=['squeezellm', None], default=None, help='Method used to quantize the weights') return parser diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index ed7f1ec45e32..93cebb288174 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -74,7 +74,7 @@ def initialize_cluster( "Ray is not installed. Please install Ray to use distributed " "serving.") # Connect to a ray cluster. - ray.init(address=ray_address, ignore_reinit_error=True) + ray.init(address=ray_address, ignore_reinit_error=True, num_gpus=parallel_config.world_size) if not parallel_config.worker_use_ray: # Initialize cluster locally. diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index b3b5852e4876..6df6d529f13f 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -1,7 +1,7 @@ from typing import Dict, List, Optional, Tuple import torch -from xformers.ops import AttentionBias +from vllm.xformers.ops.fmha.attn_bias import AttentionBias from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index c1259a1b11ea..bea78c95c5c2 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -3,8 +3,8 @@ import torch import torch.nn as nn -from xformers import ops as xops -from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, +from vllm.xformers.ops.fmha import memory_efficient_attention_forward +from vllm.xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, LowerTriangularMaskWithTensorBias) from vllm import attention_ops @@ -104,7 +104,7 @@ def multi_query_kv_attention( dim=1) # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize. - out = xops.memory_efficient_attention_forward( + out = memory_efficient_attention_forward( query.unsqueeze(0), key.unsqueeze(0), value.unsqueeze(0), @@ -465,7 +465,7 @@ def multi_query_kv_attention( batch_size = input_metadata.num_prompts seq_len = input_metadata.max_prompt_len - out = xops.memory_efficient_attention_forward( + out = memory_efficient_attention_forward( query.view(batch_size, seq_len, self.num_heads, self.head_size), key.view(batch_size, seq_len, self.num_heads, self.head_size), value.view(batch_size, seq_len, self.num_heads, self.head_size), diff --git a/vllm/model_executor/layers/quantized_linear/__init__.py b/vllm/model_executor/layers/quantized_linear/__init__.py index b09358261d5d..0feb3b243b2a 100644 --- a/vllm/model_executor/layers/quantized_linear/__init__.py +++ b/vllm/model_executor/layers/quantized_linear/__init__.py @@ -1,12 +1,12 @@ -from vllm.model_executor.layers.quantized_linear.awq import ( - AWQColumnParallelLinear, AWQRowParallelLinear) +# from vllm.model_executor.layers.quantized_linear.awq import ( +# AWQColumnParallelLinear, AWQRowParallelLinear) from vllm.model_executor.layers.quantized_linear.squeezellm import ( SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear) from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear, RowParallelLinear) _QUANTIZED_LINEAR_REGISTRY = { - "awq": (AWQColumnParallelLinear, AWQRowParallelLinear), + # "awq": (AWQColumnParallelLinear, AWQRowParallelLinear), "squeezellm": (SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear), } diff --git a/vllm/model_executor/layers/quantized_linear/squeezellm.py b/vllm/model_executor/layers/quantized_linear/squeezellm.py index 3ccbc4e579dc..c365483b58d8 100644 --- a/vllm/model_executor/layers/quantized_linear/squeezellm.py +++ b/vllm/model_executor/layers/quantized_linear/squeezellm.py @@ -39,9 +39,15 @@ def apply_weights( out_shape = x.shape[:-1] + (self.qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) # NOTE: The output tensor should be zero-initialized. - out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) - quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out, - self.lookup_table) + if torch.version.hip: + out_float = torch.zeros(out_shape, device="cuda", dtype=torch.float) + quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out_float, + self.lookup_table) + out = out_float.to(dtype=torch.float16) + else: + out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) + quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out, + self.lookup_table) if bias is not None: out = out + bias @@ -78,7 +84,13 @@ def apply_weights(self, x: torch.Tensor) -> torch.Tensor: reshaped_x = x.reshape(-1, x.shape[-1]) out_shape = x.shape[:-1] + (self.qweight.shape[-1], ) # NOTE: The output tensor should be zero-initialized. - out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) - quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out, - self.lookup_table) - return out.reshape(out_shape) + if torch.version.hip: + out = torch.zeros(out_shape, device="cuda", dtype=torch.float) + quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out, + self.lookup_table) + return out.to(dtype=torch.float16).reshape(out_shape) + else: + out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) + quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out, + self.lookup_table) + return out.reshape(out_shape) \ No newline at end of file diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 735e4ad17218..b957aab3f576 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -36,10 +36,11 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.quantized_linear import ParallelLinear from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding +from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, + ColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.quantized_linear import ParallelLinear from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import ( convert_pyslice_to_tensor, hf_model_weights_iterator, @@ -350,6 +351,8 @@ def load_weights(self, ] state_dict = self.state_dict() + load_format = load_format if load_format != "auto" else "pt" + for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index 8b09276e6f91..675e54ed675b 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -36,10 +36,10 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.quantized_linear import ParallelLinear from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding +from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, + ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import ( convert_pyslice_to_tensor, hf_model_weights_iterator, @@ -59,12 +59,12 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() - self.gate_up_proj = ParallelLinear.column(hidden_size, + self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size, bias=False, gather_output=False, quant_config=quant_config) - self.down_proj = ParallelLinear.row(intermediate_size, + self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, input_is_parallel=True, @@ -107,7 +107,7 @@ def __init__(self, self.rope_theta = rope_theta self.sliding_window = sliding_window - self.qkv_proj = ParallelLinear.column( + self.qkv_proj = ColumnParallelLinear( hidden_size, (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim, @@ -115,7 +115,7 @@ def __init__(self, gather_output=False, quant_config=quant_config, ) - self.o_proj = ParallelLinear.row( + self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, @@ -268,7 +268,7 @@ def __init__( self.model = MistralModel(config, quant_config) vocab_size = ((config.vocab_size + 63) // 64) * 64 # NOTE: The LM head is not quantized. - self.lm_head = ParallelLinear.column(config.hidden_size, + self.lm_head = ColumnParallelLinear(config.hidden_size, vocab_size, bias=False, gather_output=False, diff --git a/vllm/model_executor/quantization_utils/__init__.py b/vllm/model_executor/quantization_utils/__init__.py index 345f6494bf83..d3ee39131e49 100644 --- a/vllm/model_executor/quantization_utils/__init__.py +++ b/vllm/model_executor/quantization_utils/__init__.py @@ -1,11 +1,11 @@ from typing import Type -from vllm.model_executor.quantization_utils.awq import AWQConfig +# from vllm.model_executor.quantization_utils.awq import AWQConfig from vllm.model_executor.quantization_utils.base import QuantizationConfig from vllm.model_executor.quantization_utils.squeezellm import SqueezeLLMConfig _QUANTIZATION_REGISTRY = { - "awq": AWQConfig, + # "awq": AWQConfig, "squeezellm": SqueezeLLMConfig, } diff --git a/vllm/utils.py b/vllm/utils.py index 0e17e9070489..8ec33d2a21ac 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -30,7 +30,10 @@ def reset(self) -> None: def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html - cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 # pylint: disable=invalid-name + if torch.version.hip: + cudaDevAttrMaxSharedMemoryPerBlockOptin = 74 + else: + cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 # pylint: disable=invalid-name max_shared_mem = cuda_utils.get_device_attribute( cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu) return int(max_shared_mem) diff --git a/vllm/xformers/LICENSE b/vllm/xformers/LICENSE new file mode 100644 index 000000000000..f42f840d1027 --- /dev/null +++ b/vllm/xformers/LICENSE @@ -0,0 +1,35 @@ +From xFormers: + +Copyright (c) Facebook, Inc. and its affiliates + + +=== + +BSD 3-Clause License + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America + and IDIAP Research Institute nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/vllm/xformers/__init__.py b/vllm/xformers/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/xformers/_cpp_lib.py b/vllm/xformers/_cpp_lib.py new file mode 100644 index 000000000000..4eb6fd98140e --- /dev/null +++ b/vllm/xformers/_cpp_lib.py @@ -0,0 +1,144 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import dataclasses +import json +import logging +import os +import platform +from typing import Any, Dict, Optional + +import torch + +logger = logging.getLogger("xformers") + +UNAVAILABLE_FEATURES_MSG = ( + " Memory-efficient attention, SwiGLU, sparse and more won't be available." +) + + +@dataclasses.dataclass +class _BuildInfo: + metadata: Dict[str, Any] + + @property + def cuda_version(self) -> Optional[int]: + return self.metadata["version"]["cuda"] + + @property + def torch_version(self) -> str: + return self.metadata["version"]["torch"] + + @property + def python_version(self) -> str: + return self.metadata["version"]["python"] + + @property + def flash_version(self) -> str: + return self.metadata["version"].get("flash", "0.0.0") + + @property + def build_env(self) -> Dict[str, Any]: + return self.metadata["env"] + + +class xFormersWasNotBuiltException(Exception): + def __str__(self) -> str: + return ( + "Need to compile C++ extensions to use all xFormers features.\n" + " Please install xformers properly " + "(see https://github.com/facebookresearch/xformers#installing-xformers)\n" + + UNAVAILABLE_FEATURES_MSG + ) + + +class xFormersInvalidLibException(Exception): + def __init__(self, build_info: Optional[_BuildInfo]) -> None: + self.build_info = build_info + + def __str__(self) -> str: + if self.build_info is None: + msg = "xFormers was built for a different version of PyTorch or Python." + else: + msg = f"""xFormers was built for: + PyTorch {self.build_info.torch_version} with CUDA {self.build_info.cuda_version} (you have {torch.__version__}) + Python {self.build_info.python_version} (you have {platform.python_version()})""" + return ( + "xFormers can't load C++/CUDA extensions. " + + msg + + "\n Please reinstall xformers " + "(see https://github.com/facebookresearch/xformers#installing-xformers)\n" + + UNAVAILABLE_FEATURES_MSG + ) + + +def _register_extensions(): + import importlib + import os + + import torch + + # load the custom_op_library and register the custom ops + lib_dir = os.path.dirname(__file__) + if os.name == "nt": + # Register the main torchvision library location on the default DLL path + import ctypes + import sys + + kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) + with_load_library_flags = hasattr(kernel32, "AddDllDirectory") + prev_error_mode = kernel32.SetErrorMode(0x0001) + + if with_load_library_flags: + kernel32.AddDllDirectory.restype = ctypes.c_void_p + + if sys.version_info >= (3, 8): + os.add_dll_directory(lib_dir) + elif with_load_library_flags: + res = kernel32.AddDllDirectory(lib_dir) + if res is None: + err = ctypes.WinError(ctypes.get_last_error()) + err.strerror += f' Error adding "{lib_dir}" to the DLL directories.' + raise err + + kernel32.SetErrorMode(prev_error_mode) + + loader_details = ( + importlib.machinery.ExtensionFileLoader, + importlib.machinery.EXTENSION_SUFFIXES, + ) + + extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) + ext_specs = extfinder.find_spec("_C") + if ext_specs is None: + raise xFormersWasNotBuiltException() + cpp_lib_json = os.path.join(lib_dir, "cpp_lib.json") + with open(cpp_lib_json, "r") as fp: + build_metadata = _BuildInfo(json.load(fp)) + try: + torch.ops.load_library(ext_specs.origin) + except OSError as exc: + raise xFormersInvalidLibException(build_metadata) from exc + return build_metadata + + +_cpp_library_load_exception = None +_build_metadata: Optional[_BuildInfo] = None + +try: + _build_metadata = _register_extensions() +except (xFormersInvalidLibException, xFormersWasNotBuiltException) as e: + ENV_VAR_FOR_DETAILS = "XFORMERS_MORE_DETAILS" + if os.environ.get(ENV_VAR_FOR_DETAILS, False): + logger.warning(f"WARNING[XFORMERS]: {e}", exc_info=e) + else: + logger.warning( + f"WARNING[XFORMERS]: {e}\n Set {ENV_VAR_FOR_DETAILS}=1 for more details" + ) + _cpp_library_load_exception = e + +_built_with_cuda = ( + _build_metadata is not None and _build_metadata.cuda_version is not None +) diff --git a/vllm/xformers/ops/__init__.py b/vllm/xformers/ops/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/xformers/ops/common.py b/vllm/xformers/ops/common.py new file mode 100644 index 000000000000..e20cedabfab6 --- /dev/null +++ b/vllm/xformers/ops/common.py @@ -0,0 +1,23 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any + +class BaseOperator: + OPERATOR: Any + NAME: str + OPERATOR_CATEGORY: str + + @classmethod + def is_available(cls) -> bool: + if cls.OPERATOR is None or cls.OPERATOR.__name__ == "no_such_operator": + return False + return True + + @classmethod + def operator_flop(cls, *inputs) -> int: + """Calculate number of FLOP given inputs to `OPERATOR`""" + return -1 + \ No newline at end of file diff --git a/vllm/xformers/ops/fmha/__init__.py b/vllm/xformers/ops/fmha/__init__.py new file mode 100644 index 000000000000..449b026ce571 --- /dev/null +++ b/vllm/xformers/ops/fmha/__init__.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Optional, Tuple, Type, Union + +import torch + +#from . import cutlass, decoder, flash, small_k, triton +from .attn_bias import AttentionBias +from .common import ( + AttentionFwOpBase, + Inputs, +) +from .dispatch import _dispatch_fw, _ensure_op_supports_or_raise + + +def memory_efficient_attention_forward( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None, + p: float = 0.0, + scale: Optional[float] = None, + *, + op: Optional[Type[AttentionFwOpBase]] = None, +) -> torch.Tensor: + """ + Calculates the forward pass of :attr:`xformers.ops.memory_efficient_attention`. + """ + return _memory_efficient_attention_forward( + Inputs( + query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale + ), + op=op, + ) + +def _memory_efficient_attention_forward( + inp: Inputs, op: Optional[Type[AttentionFwOpBase]] +) -> torch.Tensor: + inp.validate_inputs() + output_shape = inp.normalize_bmhk() + if op is None: + op = _dispatch_fw(inp, False) + else: + _ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp) + + out, *_ = op.apply(inp, needs_gradient=False) + return out.reshape(output_shape) + diff --git a/vllm/xformers/ops/fmha/attn_bias.py b/vllm/xformers/ops/fmha/attn_bias.py new file mode 100644 index 000000000000..4af444b92a90 --- /dev/null +++ b/vllm/xformers/ops/fmha/attn_bias.py @@ -0,0 +1,778 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import math +from dataclasses import dataclass +from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union + +import torch + +class AttentionBias: + """Base class for a custom bias that can be applied \ + as the attn_bias argument in + :attr:`xformers.ops.memory_efficient_attention`. + + That function has the ability to add a tensor, the + attention bias, to the QK^T matrix before it is used + in the softmax part of the attention calculation. + The attention bias tensor with shape + (B or 1, n_queries, number of keys) + can be given as the attn_bias input. + The most common use case is for an attention bias is + to contain only zeros and negative infinities, which forms + a mask so that some queries only attend to some keys. + + Children of this class define alternative things which can + be used as the attn_bias input to define an attention bias which + forms such a mask, for some common cases. + + When using an :attr:`xformers.ops.AttentionBias` + instead of a :attr:`torch.Tensor`, the mask matrix does + not need to be materialized, and can be + hardcoded into some kernels for better performance. + + See: + + - :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMask` + - :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias` + - :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask` + - :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask` + + """ + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + """ + Materializes the bias as a `torch.Tensor`. This is very slow + and we don't attempt to make it fast. Only use for debugging/testing. + + Shape should be like `[*, q_seqlen, k_seqlen]` + """ + raise NotImplementedError() + + +class LowerTriangularMask(AttentionBias): + """ + A lower-triangular (aka causal) mask + + A query Q cannot attend to a key which is farther from the + initial key than Q is from the initial query. + """ + + def __init__(self, *tensor_args, **tensor_kwargs) -> None: + # NOTE: Unused arguments, we keep them for backward compatibility + super().__init__() + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + create_as = dtype if dtype is not torch.bfloat16 else torch.float32 + tensor = torch.full( # type: ignore + shape, + dtype=create_as, + fill_value=float("-inf"), + device=device, + ) + return torch.triu(tensor, diagonal=1).to(dtype) # type: ignore + + def add_bias(self, bias: torch.Tensor) -> "LowerTriangularMaskWithTensorBias": + return LowerTriangularMaskWithTensorBias(bias) + + +class LowerTriangularMaskWithTensorBias(LowerTriangularMask): + """A lower-triangular (aka causal) mask with an additive bias""" + + def __init__(self, bias: torch.Tensor) -> None: + self._bias = bias + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return super().materialize(shape, dtype=dtype, device=device) + self._bias + + +@dataclass +class _SeqLenInfo: + """ + (Internal) Represents the division of a dimension into blocks. + + For example, to represents a dimension of length 7 divided into + three blocks of lengths 2, 3 and 2, use `from_seqlength([2, 3, 2])`. + The members will be: + max_seqlen: 3 + min_seqlen: 2 + seqstart_py: [0, 2, 5, 7] + seqstart: torch.IntTensor([0, 2, 5, 7]) + """ + + seqstart: torch.Tensor + max_seqlen: int + min_seqlen: int + seqstart_py: List[int] + + def to(self, device: torch.device) -> None: + self.seqstart = self.seqstart.to(device, non_blocking=True) + + def intervals(self) -> Iterable[Tuple[int, int]]: + yield from zip(self.seqstart_py, self.seqstart_py[1:]) + + @classmethod + def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo": + """ + Input tensors are assumed to be in shape [B, M, *] + """ + assert not isinstance(seqlens, torch.Tensor) + seqstart_py = [0] + max_seqlen = -1 + min_seqlen = -1 + for seqlen in seqlens: + min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen + max_seqlen = max(max_seqlen, seqlen) + seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen) + seqstart = torch.tensor(seqstart_py, dtype=torch.int32) + return cls( + max_seqlen=max_seqlen, + min_seqlen=min_seqlen, + seqstart=seqstart, + seqstart_py=seqstart_py, + ) + + def split( + self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None + ) -> List[torch.Tensor]: + if self.seqstart_py[-1] != x.shape[1] or x.shape[0] != 1: + raise ValueError( + f"Invalid `torch.Tensor` of shape {x.shape}, expected format " + f"(B, M, *) with B=1 and M={self.seqstart_py[-1]}\n" + f" seqstart: {self.seqstart_py}" + ) + if batch_sizes is None: + batch_sizes = [1] * (len(self.seqstart_py) - 1) + split_chunks = [] + it = 0 + for batch_size in batch_sizes: + split_chunks.append( + self.seqstart_py[it + batch_size] - self.seqstart_py[it] + ) + it += batch_size + return [ + tensor.reshape([bs, -1, *tensor.shape[2:]]) + for bs, tensor in zip(batch_sizes, x.split(split_chunks, dim=1)) + ] + + +@dataclass +class _PaddedSeqLenInfo(_SeqLenInfo): + """ + (Internal) Represents the division of a dimension into blocks which are + padded out to the same total length. + + For example, to represent a dimension of length 12 with space for + three blocks of length 4, but where the occupied lengths are + 2, 3 and 2, use `from_seqlens_padded([2, 3, 2], 4)`. + + The layout along the dimension is + + 0 ─► block 0 + block 0 + + + 4 ─► block 1 + block 1 + block 1 + + 8 ─► block 2 + block 2 + + + 12 ─► + + The members will be: + max_seqlen: 3 + min_seqlen: 2 + seqstart_py: [0, 4, 8, 12] + seqstart: torch.IntTensor([0, 4, 8, 12]) + seqlen_py: [2, 3, 2] + seqlen: torch.IntTensor([2, 3, 2]) + padding: 4 + """ + + seqlen: torch.Tensor + seqlen_py: Sequence[int] + padding: int + # From parent: seqstart[i] contains the start position + # of the i-th sequence + # seqstart: torch.Tensor + + def __post_init__(self) -> None: + assert len(self.seqstart_py) == len(self.seqlen_py) + 1 + + def to(self, device: torch.device) -> None: + self.seqlen = self.seqlen.to(device, non_blocking=True) + super().to(device) + + def intervals(self) -> Iterable[Tuple[int, int]]: + for (start, _), length in zip(super().intervals(), self.seqlen_py): + yield start, start + length + + @classmethod + def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo": + raise RuntimeError( + "Use either `_SeqLenInfo.from_seqlens` or `_PaddedSeqLenInfo.from_seqlens_padded`" + ) + + @classmethod + def from_seqlens_padded( + cls, seqlens: Sequence[int], padding: int + ) -> "_PaddedSeqLenInfo": + """ + Input tensors are assumed to be in shape [B, M, *] + seqstart = padding * torch.arange(batch_size) + """ + assert not isinstance(seqlens, torch.Tensor) + assert all(seqlen <= padding for seqlen in seqlens) + seqstart_py = list(range(0, len(seqlens) * padding + 1, padding)) + return cls( + seqlen=torch.tensor(seqlens, dtype=torch.int32), + seqlen_py=seqlens, + max_seqlen=max(seqlens), + min_seqlen=min(seqlens), + seqstart=torch.tensor(seqstart_py, dtype=torch.int32), + seqstart_py=seqstart_py, + padding=padding, + ) + + def split( + self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None + ) -> List[torch.Tensor]: + raise NotImplementedError("_PaddedSeqLenInfo.split") + + +@dataclass +class BlockDiagonalMask(AttentionBias): + """ + A block-diagonal mask that can be passed as ``attn_bias`` + argument to :attr:`xformers.ops.memory_efficient_attention`. + + Queries and Keys are each divided into the same number of blocks. + Queries in block i only attend to keys in block i. + + .. figure:: /_static/block_diag_bias.png + + This bias can be used to handle a batch of sequences of + different lengths, via :attr:`BlockDiagonalMask.from_tensor_list` + + :Example: + + .. code-block:: python + + import torch + from xformers.ops import fmha + + K = 16 + dtype = torch.float16 + device = "cuda" + list_x = [ + torch.randn([1, 3, 1, K], dtype=dtype, device=device), + torch.randn([1, 6, 1, K], dtype=dtype, device=device), + torch.randn([1, 2, 1, K], dtype=dtype, device=device), + ] + attn_bias, x = fmha.BlockDiagonalMask.from_tensor_list(list_x) + linear = torch.nn.Linear(K, K * 3).to(device=device, dtype=dtype) + + q, k, v = linear(x).reshape([1, -1, 1, 3, K]).unbind(-2) + out = fmha.memory_efficient_attention(q, k, v, attn_bias=attn_bias) + list_out = attn_bias.split(out) + print(list_out[0].shape) # [1, 3, 1, K] + assert tuple(list_out[0].shape) == (1, 3, 1, K) + + """ + + q_seqinfo: _SeqLenInfo + k_seqinfo: _SeqLenInfo + _batch_sizes: Optional[Sequence[int]] = None + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return torch.zeros( + shape, + dtype=dtype, + device=device, + ) + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + """Materialize the attention bias - for debugging & testing""" + assert shape[-1] == self.k_seqinfo.seqstart_py[-1], ( + shape[-1], + self.k_seqinfo.seqstart_py[-1], + ) + assert shape[-2] == self.q_seqinfo.seqstart_py[-1], ( + shape[-2], + self.q_seqinfo.seqstart_py[-1], + ) + mask = torch.empty(shape[-2:], dtype=dtype, device=device) + mask.fill_(-math.inf) + for i, ((q_start, q_end), (k_start, k_end)) in enumerate( + zip( + self.q_seqinfo.intervals(), + self.k_seqinfo.intervals(), + ) + ): + mask[q_start:q_end, k_start:k_end] = self._create_block_mask( + (q_end - q_start, k_end - k_start), + dtype=dtype, + device=device, + ) + for _ in range(len(shape) - 2): + mask = mask.unsqueeze(0) + return mask.expand(shape) + + @classmethod + def from_seqlens( + cls, + q_seqlen: Sequence[int], + kv_seqlen: Optional[Sequence[int]] = None, + ) -> "BlockDiagonalMask": + """Creates a :attr:`BlockDiagonalMask` from a list of tensors lengths for query and key/value. + + Args: + q_seqlen (Union[Sequence[int], torch.Tensor]): List or tensor of sequence lengths for query tensors + kv_seqlen (Union[Sequence[int], torch.Tensor], optional): List or tensor of sequence lengths for key/value. + (Defaults to ``q_seqlen``.) + Returns: + BlockDiagonalMask + """ + assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen) + q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen) + if kv_seqlen is None or q_seqlen == kv_seqlen: + k_seqinfo = q_seqinfo + else: + k_seqinfo = _SeqLenInfo.from_seqlens(kv_seqlen) + return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo) + + @classmethod + def from_tensor_list( + cls, + tensors: Sequence[torch.Tensor], + ) -> Tuple["BlockDiagonalMask", torch.Tensor]: + """Creates a :attr:`BlockDiagonalMask` from a list of tensors, and returns the tensors + concatenated on the sequence length dimension + + .. figure:: /_static/block_diag_cat_split.png + + See also :attr:`BlockDiagonalMask.split` to split the returned + :attr:`torch.Tensor` back to a list of tensors of varying sequence length + + Args: + tensors (Sequence[torch.Tensor]): A list of tensors of shape ``[B, M_i, *]``. + All tensors should have the same dimension and the same batch size ``B``, but + they can have different sequence length ``M``. + + Returns: + Tuple[BlockDiagonalMask, torch.Tensor]: The corresponding bias for the attention + along with `tensors` concatenated on the sequence length dimension, with shape ``[1, sum_i{M_i}, *]`` + """ + batch_sizes = [tensor.shape[0] for tensor in tensors] + seqlens = [] + for x in tensors: + for _ in range(x.shape[0]): + seqlens.append(x.shape[1]) + block_diag = cls.from_seqlens(seqlens) + block_diag._batch_sizes = batch_sizes + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in tensors) + concat_tensors = torch.cat(tensors_bs1, dim=1) + return block_diag, concat_tensors + + @classmethod + def from_tensor_lists_qkv( + cls, + tensors_q: Sequence[torch.Tensor], + tensors_k: Sequence[torch.Tensor], + tensors_v: Optional[Sequence[torch.Tensor]] = None, + ) -> Tuple["BlockDiagonalMask", torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + assert len(tensors_q) == len(tensors_k) + assert tensors_v is None or len(tensors_v) == len(tensors_q) + batch_sizes = [tensor.shape[0] for tensor in tensors_q] + q_seqlens, kv_seqlens = [], [] + for i, (q, k) in enumerate(zip(tensors_q, tensors_k)): + assert q.shape[0] == k.shape[0] + q_seqlens += [q.shape[1]] * q.shape[0] + kv_seqlens += [k.shape[1]] * k.shape[0] + assert tensors_v is None or tensors_v[i].shape[:2] == k.shape[:2] + block_diag = cls.from_seqlens(q_seqlens, kv_seqlens) + block_diag._batch_sizes = batch_sizes + return ( + block_diag, + torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_q], dim=1), + torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_k], dim=1), + torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_v], dim=1) + if tensors_v is not None + else None, + ) + + def split_queries(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]: + return self.q_seqinfo.split(tensor, self._batch_sizes) + + def split_kv(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]: + return self.k_seqinfo.split(tensor, self._batch_sizes) + + def split(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]: + """The inverse operation of :attr:`BlockDiagonalCausalMask.from_tensor_list` + + Args: + tensor (torch.Tensor): Tensor of tokens of shape ``[1, sum_i{M_i}, *]`` + + Returns: + Sequence[torch.Tensor]: A list of tokens with possibly different sequence lengths + """ + assert self.q_seqinfo is self.k_seqinfo + return self.q_seqinfo.split(tensor, self._batch_sizes) + + def make_causal(self) -> "BlockDiagonalCausalMask": + """Makes each block causal""" + return BlockDiagonalCausalMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + ) + + def make_causal_from_bottomright(self) -> "BlockDiagonalCausalFromBottomRightMask": + """Makes each block causal with a possible non-causal prefix""" + return BlockDiagonalCausalFromBottomRightMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + ) + + def make_local_attention( + self, window_size: int + ) -> "BlockDiagonalCausalLocalAttentionMask": + """Experimental: Makes each block causal with local attention""" + return BlockDiagonalCausalLocalAttentionMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + _window_size=window_size, + ) + + def make_local_attention_from_bottomright( + self, window_size: int + ) -> "BlockDiagonalCausalLocalAttentionFromBottomRightMask": + """Experimental: Makes each block causal with local attention, start from bottom right""" + return BlockDiagonalCausalLocalAttentionFromBottomRightMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + _window_size=window_size, + ) + + +@dataclass +class BlockDiagonalCausalMask(BlockDiagonalMask): + """ + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal. + + Queries and Keys are each divided into the same number of blocks. + A query Q in block i cannot attend to a key which is not in block i, + nor one which is farther from the initial key in block i than Q + is from the initial query in block i. + """ + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return LowerTriangularMask().materialize( + shape, + dtype=dtype, + device=device, + ) + + +@dataclass +class BlockDiagonalCausalFromBottomRightMask(BlockDiagonalMask): + """ + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal. + This mask allows for a non-causal prefix + NOTE: Each block should have `num_keys >= num_queries` otherwise the forward pass is not + defined (softmax of vector of `-inf` in the attention) + + Queries and keys are each divided into the same number of blocks. + A query Q in block i cannot attend to a key which is not in block i, + nor one which nearer the final key in block i than Q is to the + final query in block i. + """ + + def __post_init__(self) -> None: + for i, ((q_start, q_end), (k_start, k_end)) in enumerate( + zip( + self.q_seqinfo.intervals(), + self.k_seqinfo.intervals(), + ) + ): + num_queries = q_end - q_start + num_keys = k_end - k_start + if num_keys < num_queries: + raise ValueError( + f"Block #{i} has num_keys={num_keys} and num_queries={num_queries}." + " Expected `num_keys >= num_queries`" + ) + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + create_as = dtype if dtype is not torch.bfloat16 else torch.float32 + tensor = torch.full( # type: ignore + shape, + dtype=create_as, + fill_value=float("-inf"), + device=device, + ) + num_queries, num_keys = shape[-2:] + return torch.triu(tensor, diagonal=num_keys - num_queries + 1).to(dtype) # type: ignore + + +@dataclass +class BlockDiagonalCausalWithOffsetPaddedKeysMask(AttentionBias): + """ + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`, + except an offset on causality is allowed for each block and we support padding for k/v + + The keys and values are divided into blocks which are padded out to + the same total length. + For example, if there is space for 12 keys, for three blocks of + max length 4, but we only want to use the first 2, 3 and 2 + of each block, use `kv_padding=4` and `kv_seqlens=[2, 3, 2]`. + The queries are divided into blocks, without padding, of lengths given by + q_seqlen. + + A query Q in block i cannot attend to a key which is not in block i, + nor one which is not in use (i.e. in the padded area), + nor one which is nearer to the final key in block i + than Q is to the final query in block i. + """ + + q_seqinfo: _SeqLenInfo + k_seqinfo: _PaddedSeqLenInfo + causal_diagonal: Any = None # unused. Exists for BC only. + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + create_as = dtype if dtype is not torch.bfloat16 else torch.float32 + tensor = torch.full( # type: ignore + shape, + dtype=create_as, + fill_value=float("-inf"), + device=device, + ) + num_queries, num_keys = shape[-2:] + return torch.triu(tensor, diagonal=1 + num_keys - num_queries).to(dtype) # type: ignore + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + """Materialize the attention bias - for debugging & testing""" + if shape[-1] != self.k_seqinfo.seqstart_py[-1]: + raise ValueError("k shapes wrong") + if shape[-2] != self.q_seqinfo.seqstart_py[-1]: + raise ValueError("q shapes wrong") + mask = torch.empty(shape[-2:], dtype=dtype, device=device) + mask.fill_(-math.inf) + for i, ((q_start, q_end), (k_start, k_end)) in enumerate( + zip( + self.q_seqinfo.intervals(), + self.k_seqinfo.intervals(), + ) + ): + mask[q_start:q_end, k_start:k_end] = self._create_block_mask( + (q_end - q_start, k_end - k_start), + dtype=dtype, + device=device, + ) + for _ in range(len(shape) - 2): + mask = mask.unsqueeze(0) + return mask.expand(shape) + + @classmethod + def from_seqlens( + cls, + q_seqlen: Sequence[int], + kv_padding: int, + kv_seqlen: Sequence[int], + causal_diagonal: Any = None, + ) -> "BlockDiagonalCausalWithOffsetPaddedKeysMask": + """Creates a :attr:`BlockDiagonalCausalWithOffsetPaddedKeysMask` from a list of tensor + lengths for query and key/value. + + Args: + q_seqlen (Sequence[int]): List or tensor of sequence lengths for query tensors + kv_padding (int): Padding for k/v - also an upperbound on each individual key length + kv_seqlen (Sequence[int]): List or tensor of sequence lengths for key/value. + causal_diagonal: unused, for BC only + Returns: + BlockDiagonalCausalWithOffsetPaddedKeysMask + """ + assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen), ( + q_seqlen, + kv_seqlen, + ) + q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen) + k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded(kv_seqlen, kv_padding) + return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo) + + +@dataclass +class BlockDiagonalCausalLocalAttentionMask(BlockDiagonalCausalMask): + """ + (Experimental feature) + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`. + This makes the mask "local" and the attention pattern banded. + + Query i only attends to keys in its block and cannot attend keys further than "window_size" + from it. + """ + + _window_size: int = 0 # forced due to inheritance and default arguments + + def __post_init__(self): + if self._window_size <= 0: + raise ValueError( + f"Expected `window_size > 0`, but window_size={self._window_size}" + ) + q_seqlen = [ + y - x + for x, y in zip( + self.q_seqinfo.seqstart_py[:-1], self.q_seqinfo.seqstart_py[1:] + ) + ] + kv_seqlen = [ + y - x + for x, y in zip( + self.k_seqinfo.seqstart_py[:-1], self.k_seqinfo.seqstart_py[1:] + ) + ] + for q, k in zip(q_seqlen, kv_seqlen): + if q - self._window_size >= k: + raise RuntimeError( + f"No keys are attended in q_seqlen {q} k_seqlen {k} with sliding window {self._window_size}" + ) + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + create_as = dtype if dtype is not torch.bfloat16 else torch.float32 + tensor = torch.full( # type: ignore + shape, + dtype=create_as, + fill_value=1, + device=device, + ) + + num_queries, num_keys = shape[-2:] + mask = torch.tril(tensor, diagonal=0).to(dtype) # type: ignore + if self._window_size is not None and self._window_size > 0: + mask = torch.triu(mask, diagonal=-self._window_size + 1) + mask = torch.log(mask) + return mask.to(dtype) + + +@dataclass +class BlockDiagonalCausalLocalAttentionFromBottomRightMask( + BlockDiagonalCausalFromBottomRightMask +): + """ + (Experimental feature) + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`. + This makes the mask "local" and the attention pattern banded. + + Query i only attends to keys in its block and cannot attend keys further than "window_size" + from it. + """ + + _window_size: int = 0 # forced due to inheritance and default arguments + + def __post_init__(self): + super().__post_init__() + if self._window_size <= 0: + raise ValueError( + f"Expected `window_size > 0`, but window_size={self._window_size}" + ) + q_seqlen = [ + y - x + for x, y in zip( + self.q_seqinfo.seqstart_py[:-1], self.q_seqinfo.seqstart_py[1:] + ) + ] + kv_seqlen = [ + y - x + for x, y in zip( + self.k_seqinfo.seqstart_py[:-1], self.k_seqinfo.seqstart_py[1:] + ) + ] + for q, k in zip(q_seqlen, kv_seqlen): + if q + (q - k) - self._window_size >= k: + raise RuntimeError( + f"No keys are attended in q_seqlen {q} k_seqlen {k} with sliding window {self._window_size}" + ) + materialized = self.materialize((sum(q_seqlen), sum(kv_seqlen))) + if torch.max(materialized, dim=1).values.min() == -float("inf"): + raise RuntimeError("FUCKING FUCK FUCK") + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + create_as = dtype if dtype is not torch.bfloat16 else torch.float32 + tensor = torch.full( # type: ignore + shape, + dtype=create_as, + fill_value=1, + device=device, + ) + num_queries, num_keys = shape[-2:] + mask = torch.tril(tensor, diagonal=num_keys - num_queries).to(dtype) # type: ignore + if self._window_size is not None: + mask = torch.triu( + mask, diagonal=num_keys - num_queries - self._window_size + 1 + ) + mask = torch.log(mask) + return mask.to(dtype) diff --git a/vllm/xformers/ops/fmha/common.py b/vllm/xformers/ops/fmha/common.py new file mode 100644 index 000000000000..c87cdd1ba8ec --- /dev/null +++ b/vllm/xformers/ops/fmha/common.py @@ -0,0 +1,527 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass +from typing import Any, List, Mapping, Optional, Set, Tuple, Type, Union + +import torch + +from ..._cpp_lib import _built_with_cuda +from ..common import BaseOperator +from .attn_bias import ( + AttentionBias, + BlockDiagonalMask, + LowerTriangularMask, + LowerTriangularMaskWithTensorBias, +) + + +def _is_bias_type_supported_in_BMK(attn_bias_type: Any) -> bool: + # NoneType + if isinstance(None, attn_bias_type): + return True + if attn_bias_type in [LowerTriangularMask, torch.Tensor]: + return True + return False + + +@dataclass +class Inputs: + """ + Stores inputs to the `memory_efficient_attention` operators + """ + + query: torch.Tensor + key: torch.Tensor + value: torch.Tensor + attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None + p: float = 0.0 + scale: Optional[float] = None + + @property + def device(self) -> torch.device: + return self.query.device + + @property + def scale_float(self) -> float: + return self.query.shape[-1] ** (-0.5) if self.scale is None else self.scale + + def get_qkv_in_bmghk(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if self.query.ndim == 5: + return self.query, self.key, self.value + if self.query.ndim == 4: + return ( + self.query.unsqueeze(2), + self.key.unsqueeze(2), + self.value.unsqueeze(2), + ) + if self.value.ndim == 3: + return ( + self.query[:, :, None, None], + self.key[:, :, None, None], + self.value[:, :, None, None], + ) + assert False + + def normalize_bmhk(self) -> Tuple[int, ...]: + if self.query.ndim not in [3, 4, 5]: + raise ValueError( + f"Invalid shape for query: {self.query.shape}. " + "Expected shape [batch, seqlen, head_groups, num_heads_per_group, K]" + ", [batch, seqlen, num_heads, K], or [batch, seqlen, K]." + ) + if self.value.dtype == torch.int32: + # Quantized K/V case, in which the last dims of Q and K are different. + # NB we currently don't have any implementations for quantized KV with + # SUPPORTS_DIFFERENT_VALUE_EMBED. + output_shape = tuple(self.query.shape) + else: + output_shape = (self.query.shape[:-1]) + (self.value.shape[-1],) + # Convert from legacy format + if self.query.ndim == 3: + self.query = self.query.unsqueeze(2) + self.key = self.key.unsqueeze(2) + self.value = self.value.unsqueeze(2) + if isinstance(self.attn_bias, torch.Tensor): + if self.attn_bias.ndim != 3: + raise ValueError( + f"Expected BMK format for attn_bias, but got {self.attn_bias.shape}" + ) + self.attn_bias = self.attn_bias.unsqueeze(1) + return output_shape + + def validate_inputs(self) -> None: + qkv = (self.query, self.key, self.value) + if self.query.ndim not in (3, 4, 5) or any( + x.ndim != self.query.ndim for x in qkv + ): + raise ValueError( + f"Query/Key/Value should all have BMGHK, BMHK, or BMK shape.\n" + f" query.shape: {self.query.shape}\n" + f" key.shape : {self.key.shape}\n" + f" value.shape: {self.value.shape}" + ) + if any(x.device != self.query.device for x in qkv): + raise ValueError("Query/Key/Value should all be on the same device") + quantized_dtypes = self.key.dtype == self.value.dtype == torch.int32 + non_quantized_dtypes = all(x.dtype == self.query.dtype for x in qkv) + if not (quantized_dtypes or non_quantized_dtypes): + raise ValueError( + "Query/Key/Value should either all have the same dtype, or " + "(in the quantized case) Key/Value should have dtype torch.int32\n" + f" query.dtype: {self.query.dtype}\n" + f" key.dtype : {self.key.dtype}\n" + f" value.dtype: {self.value.dtype}" + ) + # Biases with tensors attached are meant to be in BMHK format + # This would require to permute biases/gradients which can be expensive, + # so let's just forbid it - BMK is a legacy format anyway + if self.query.ndim == 3 and not _is_bias_type_supported_in_BMK( + type(self.attn_bias) + ): + raise ValueError( + f"Please provide inputs in BMHK format rather " + f"than BMK when using bias type `{type(self.attn_bias).__name__}`" + ) + attn_bias_t: Optional[torch.Tensor] = None + if isinstance(self.attn_bias, torch.Tensor): + attn_bias_t = self.attn_bias + if isinstance(self.attn_bias, LowerTriangularMaskWithTensorBias): + attn_bias_t = self.attn_bias._bias + if self.query.ndim == 4 and attn_bias_t is not None: + expected_shape = ( + self.query.shape[0], + self.query.shape[2], + self.query.shape[1], + self.key.shape[1], + ) + if attn_bias_t.shape != expected_shape: + raise ValueError( + f"Invalid shape for attention bias: {attn_bias_t.shape} (expected {expected_shape})\n" + f" query.shape: {self.query.shape}\n" + f" key.shape : {self.key.shape}\n" + f" value.shape: {self.value.shape}" + ) + if isinstance(self.attn_bias, BlockDiagonalMask): + if any(x.shape[0] != 1 for x in qkv): + raise ValueError( + f"Expected batch_size=1 when using block-diagonal bias\n" + f" query.shape: {self.query.shape}\n" + f" key.shape : {self.key.shape}\n" + f" value.shape: {self.value.shape}" + ) + if self.p < 0.0 or self.p > 1.0: + raise ValueError(f"Invalid dropout probability: p={self.p}") + # Check that shapes match between inputs + B, Mq = self.query.shape[:2] + K = self.query.shape[-1] + B, Mkv = self.key.shape[:2] + Kv = self.value.shape[-1] + + valid_shapes = True + if self.query.ndim == 3: # BMK + valid_shapes = ( + self.query.shape == (B, Mq, K) + and self.key.shape == (B, Mkv, K) + and self.value.shape == (B, Mkv, Kv) + ) + H = self.query.shape[-2] + if self.query.ndim == 4: # BMHK + quantized_kv_cache = self.value.dtype == torch.int32 + key_embed_dim = Kv if quantized_kv_cache else K + valid_shapes = ( + self.query.shape == (B, Mq, H, K) + and self.key.shape == (B, Mkv, H, key_embed_dim) + and self.value.shape == (B, Mkv, H, Kv) + ) + G = self.query.shape[2] + if self.query.ndim == 5: # BMNHK + valid_shapes = ( + self.query.shape == (B, Mq, G, H, K) + and self.key.shape == (B, Mkv, G, H, K) + and self.value.shape == (B, Mkv, G, H, Kv) + ) + if not valid_shapes: + raise ValueError( + f"Incompatible shapes for attention inputs:\n" + f" query.shape: {self.query.shape}\n" + f" key.shape : {self.key.shape}\n" + f" value.shape: {self.value.shape}\n" + "HINT: We don't support broadcasting, please use `expand` " + "yourself before calling `memory_efficient_attention` if you need to" + ) + + +@dataclass +class Context: + lse: torch.Tensor + out: torch.Tensor + op_bw: Optional[Type["AttentionBwOpBase"]] = None + rng_state: Optional[torch.Tensor] = None + + def get_padded_lse(self, pad_to: int, force_pad_inf: bool = False) -> torch.Tensor: + pad_amount = (pad_to - (self.lse.shape[2] % pad_to)) % pad_to + lse = self.lse + if pad_amount > 0: + if force_pad_inf: + lse = lse[:, :, : self.out.shape[1]] + pad_amount = (pad_to - (lse.shape[2] % pad_to)) % pad_to + lse = torch.nn.functional.pad(lse, [0, pad_amount], value=math.inf) + elif force_pad_inf and self.out.shape[1] != lse.shape[2]: + lse[:, :, self.out.shape[1] :].fill_(math.inf) + return lse + + +@dataclass +class Gradients: + dq: torch.Tensor + dk: torch.Tensor + dv: torch.Tensor + # bias gradient. None if there is no tensor bias or if it doesn't require grad + db: Optional[torch.Tensor] = None + + +class AttentionOpBase(BaseOperator): + """Base class for any attention operator in xFormers + + See: + + - :attr:`xformers.ops.fmha.cutlass.FwOp` + - :attr:`xformers.ops.fmha.cutlass.BwOp` + - :attr:`xformers.ops.fmha.flash.FwOp` + - :attr:`xformers.ops.fmha.flash.BwOp` + - :attr:`xformers.ops.fmha.triton.FwOp` + - :attr:`xformers.ops.fmha.triton.BwOp` + - :attr:`xformers.ops.fmha.small_k.FwOp` + - :attr:`xformers.ops.fmha.small_k.BwOp` + """ + + OPERATOR: Any + SUPPORTED_DEVICES: Set[str] + CUDA_MINIMUM_COMPUTE_CAPABILITY: Tuple[int, int] = (5, 0) + SUPPORTED_DTYPES: Set[torch.dtype] + SUPPORTED_MAX_K: float + SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {type(None)} + SUPPORTS_DROPOUT: bool + SUPPORTS_CUSTOM_SCALE: bool = False + SUPPORTS_DIFFERENT_VALUE_EMBED: bool = False + IS_DETERMINISTIC: bool = True + SUPPORTS_BMGHK: bool = False + NAME: str + OPERATOR_CATEGORY = "memory_efficient_attention" + + _TEST_BATCH_SIZES: List[int] = [1, 300] + _TEST_K: List[int] = [32, 128] + + @classmethod + def supports(cls, d: Inputs) -> bool: + return not cls.not_supported_reasons(d) + + @classmethod + def shape_not_supported_reasons( + cls, Mq: int, Mkv: int, K: int, Kv: int + ) -> List[str]: + reasons = [] + if not cls.SUPPORTS_DIFFERENT_VALUE_EMBED and K != Kv: + reasons.append("query.shape[-1] != value.shape[-1]") + if max(K, Kv) > cls.SUPPORTED_MAX_K: + reasons.append( + f"max(query.shape[-1] != value.shape[-1]) > {cls.SUPPORTED_MAX_K}" + ) + return reasons + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + """ + Returns a list of reasons why this is not supported. + The kernel can run these inputs only if the returned list is empty + """ + reasons = cls.shape_not_supported_reasons( + Mq=d.query.shape[1], + Mkv=d.key.shape[1], + K=d.query.shape[-1], + Kv=d.query.shape[-1], + ) + device_type = d.query.device.type + dtype = d.query.dtype + if device_type not in cls.SUPPORTED_DEVICES: + reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})") + if dtype not in cls.SUPPORTED_DTYPES: + reasons.append(f"dtype={dtype} (supported: {cls.SUPPORTED_DTYPES})") + if type(d.attn_bias) not in cls.SUPPORTED_ATTN_BIAS_TYPES: + reasons.append(f"attn_bias type is {type(d.attn_bias)}") + if (d.p != 0.0) and not cls.SUPPORTS_DROPOUT: + reasons.append("dropout > 0.0") + if d.scale is not None and not cls.SUPPORTS_CUSTOM_SCALE: + reasons.append("has custom scale") + # bfloat16 is only supported on A100+ + # ... although the kernels can still run and give the + # correct result + if dtype is torch.bfloat16 and ( + not device_type.startswith("cuda") + or torch.cuda.get_device_capability(d.query.device)[0] < 8 + ): + reasons.append("bf16 is only supported on A100+ GPUs") + if not cls.is_available(): + reasons.append( + "operator wasn't built - see `python -m xformers.info` for more info" + ) + if not cls.IS_DETERMINISTIC and torch.are_deterministic_algorithms_enabled(): + reasons.append( + "operator is non-deterministic, but `torch.use_deterministic_algorithms` is set" + ) + if not cls.SUPPORTS_BMGHK and d.query.ndim == 5: + reasons.append("operator does not support BMGHK format") + return reasons + + +class AttentionFwOpBase(AttentionOpBase): + ERROR_ATOL: Mapping[torch.dtype, float] = { + torch.float: 3e-4, + torch.half: 4e-3, + torch.bfloat16: 2e-2, + } + ERROR_RTOL: Mapping[torch.dtype, float] = { + torch.float: 2e-5, + torch.half: 4e-4, + torch.bfloat16: 5e-3, + } + + @classmethod + def apply( + cls, inp: Inputs, needs_gradient: bool + ) -> Tuple[torch.Tensor, Optional[Context]]: + raise NotImplementedError() + + @classmethod + def attn_operator_flop( + cls, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + causal: bool = False, + seqstart_k: Optional[torch.Tensor] = None, + seqstart_q: Optional[torch.Tensor] = None, + ) -> int: + """ + Computes total flops for the attention + Assumes inputs in format BMHK + """ + assert query.ndim == 4 + + if seqstart_q is not None: + seqstart_q_py = seqstart_q.tolist() + else: + seqstart_q_py = [0, query.shape[1]] + if seqstart_k is not None: + seqstart_k_py = seqstart_k.tolist() + else: + seqstart_k_py = [0, key.shape[1]] + + total_flop = 0 + for q_start, q_end, k_start, k_end in zip( + seqstart_q_py, seqstart_q_py[1:], seqstart_k_py, seqstart_k_py[1:] + ): + num_q = q_end - q_start + num_kv = k_end - k_start + # (M,K) @ (K,N) GEMM needs M*N*K*2 flop + # Q @ K.transpose + total_flop += num_q * num_kv * query.shape[-1] * 2 + # (ignore softmax) + # attn @ V + total_flop += num_q * key.shape[-1] * num_kv * 2 + # Multiply by num_heads and batches + total_flop = total_flop * value.shape[2] * value.shape[0] + if causal: + total_flop //= 2 + return total_flop + + +class AttentionBwOpBase(AttentionOpBase): + ERROR_ATOL: Mapping[torch.dtype, float] = { + torch.float: 5e-4, + torch.half: 9e-2, + torch.bfloat16: 0.7, + } + ERROR_RTOL: Mapping[torch.dtype, float] = { + torch.float: 1e-4, + torch.half: 2e-2, + torch.bfloat16: 0.1, + } + SUPPORTS_ATTN_BIAS_GRAD = False + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(AttentionBwOpBase, cls).not_supported_reasons(d) + if ( + isinstance(d.attn_bias, torch.Tensor) + and d.attn_bias.requires_grad + and not cls.SUPPORTS_ATTN_BIAS_GRAD + ): + reasons.append( + "Computing the bias gradient is not supported (attn_bias.requires_grad = True)" + ) + + return reasons + + @classmethod + def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: + raise NotImplementedError() + + @classmethod + def attn_operator_flop( + cls, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + causal: bool = False, + seqstart_k: Optional[torch.Tensor] = None, + seqstart_q: Optional[torch.Tensor] = None, + ) -> int: + """ + Computes total flops for the attention + Assumes inputs in format BMHK + """ + assert query.ndim == 4 + + if seqstart_q is not None: + seqstart_q_py = seqstart_q.tolist() + else: + seqstart_q_py = [0, query.shape[1]] + if seqstart_k is not None: + seqstart_k_py = seqstart_k.tolist() + else: + seqstart_k_py = [0, key.shape[1]] + + total_flop = 0 + for q_start, q_end, k_start, k_end in zip( + seqstart_q_py, seqstart_q_py[1:], seqstart_k_py, seqstart_k_py[1:] + ): + num_q = q_end - q_start + num_kv = k_end - k_start + Kqk = query.shape[-1] + Kv = value.shape[-1] + # (M,K) @ (K,N) GEMM needs M*N*K*2 flop + # att = Q @ K.transpose + total_flop += num_q * num_kv * Kqk * 2 + # att @ dO + total_flop += num_kv * num_q * Kv * 2 + # dov = dO @ V + total_flop += num_q * Kv * num_kv * 2 + # dov @ K + total_flop += num_q * Kqk * num_kv * 2 + # dov @ Q + total_flop += num_q * Kqk * num_kv * 2 + # Multiply by num_heads and batches + total_flop = total_flop * value.shape[2] * value.shape[0] + if causal: + total_flop //= 2 + return total_flop + + +AttentionOp = Tuple[ + Optional[Type[AttentionFwOpBase]], Optional[Type[AttentionBwOpBase]] +] + + +@dataclass +class AttentionOpDispatch: + """Dispatcher to automatically select + the best operator to run memory-efficient attention. + + :Deprecated: + + This class is deprecated and will be removed in a later version + """ + + op: AttentionOp + + @classmethod + def from_arguments( + cls, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None, + p: float = 0.0, + scale: Optional[float] = None, + ) -> "AttentionOpDispatch": + """Here for backward compatibility""" + from .dispatch import _dispatch_fw + + inp = Inputs( + query=query, + key=key, + value=value, + attn_bias=attn_bias, + p=p, + scale=scale, + ) + return AttentionOpDispatch(op=(_dispatch_fw(inp, True), )) + + +def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor: + if tensor.ndim == 4: + return tensor + return tensor.reshape([-1, num_heads, tensor.shape[1], tensor.shape[2]]).permute( + (0, 2, 1, 3) + ) + + +def check_lastdim_alignment_stride1( + reasons: List[str], name: str, x: torch.Tensor, alignment: int +) -> None: + if x.shape[-1] % alignment != 0: + reasons.append(f"{name}.shape[-1] % {alignment} != 0") + elif x.stride(-2) % alignment != 0: + reasons.append( + f"{name}.stride(-2) % {alignment} != 0 ({name}.stride() = {x.stride()})" + ) + # We can have stride=0 sometimes if dimension=1 + if x.stride(-1) > 1: + reasons.append( + f"{name}.stride(-1) > 1 ({name}.stride() = {x.stride()}) - you should call `.contiguous()` on the input" + ) diff --git a/vllm/xformers/ops/fmha/dispatch.py b/vllm/xformers/ops/fmha/dispatch.py new file mode 100644 index 000000000000..d9120e2dd5ff --- /dev/null +++ b/vllm/xformers/ops/fmha/dispatch.py @@ -0,0 +1,81 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import textwrap +from collections import deque +from typing import List, Sequence, Type, TypeVar + +from . import flash +from .common import AttentionBwOpBase, AttentionFwOpBase, Inputs + + +T = TypeVar("T", Type[AttentionFwOpBase], Type[AttentionBwOpBase]) + + +def _format_inputs_description(inp: Inputs) -> str: + return f"""query : shape={tuple(inp.query.shape)} ({inp.query.dtype}) +key : shape={tuple(inp.key.shape)} ({inp.key.dtype}) +value : shape={tuple(inp.value.shape)} ({inp.value.dtype}) +attn_bias : {type(inp.attn_bias)} +p : {inp.p}""" + + +def _ensure_op_supports_or_raise(exc_type, name: str, op, inp: Inputs) -> None: + reasons = op.not_supported_reasons(inp) + if not reasons: + return + raise exc_type( + f"""Operator `{name}` does not support inputs: +{textwrap.indent(_format_inputs_description(inp), ' ')} +{_format_not_supported_reasons(op, reasons)}""" + ) + + +def _format_not_supported_reasons(op, reasons: List[str]) -> str: + return f"`{op.NAME}` is not supported because:\n " + "\n ".join(reasons) + + +def _run_priority_list(name: str, priority_list: Sequence[T], inp: Inputs) -> T: + not_supported_reasons: List[List[str]] = [] + for op in priority_list: + not_supported = op.not_supported_reasons(inp) + if not not_supported: + return op + not_supported_reasons.append(not_supported) + + # Let's write a nice message explaining what we tried and why it's not supported + msg = f"""No operator found for `{name}` with inputs: +{textwrap.indent(_format_inputs_description(inp), ' ')}""" + for op, not_supported in zip(priority_list, not_supported_reasons): + msg += "\n" + _format_not_supported_reasons(op, not_supported) + raise NotImplementedError(msg) + + +def _dispatch_fw_priority_list( + inp: Inputs, needs_gradient: bool +) -> Sequence[Type[AttentionFwOpBase]]: + priority_list_ops = deque( + [ + flash.FwOp, + ] + ) + return priority_list_ops + + +def _dispatch_fw(inp: Inputs, needs_gradient: bool) -> Type[AttentionFwOpBase]: + """Computes the best operator for forward + + Raises: + NotImplementedError: if not operator was found + + Returns: + AttentionOp: The best operator for the configuration + """ + return _run_priority_list( + "memory_efficient_attention_forward", + _dispatch_fw_priority_list(inp, needs_gradient), + inp, + ) diff --git a/vllm/xformers/ops/fmha/flash.py b/vllm/xformers/ops/fmha/flash.py new file mode 100644 index 000000000000..3c2520c6dceb --- /dev/null +++ b/vllm/xformers/ops/fmha/flash.py @@ -0,0 +1,426 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import replace +from itertools import zip_longest +from typing import Any, List, Optional, Set, Tuple, Union + +import torch + +from .attn_bias import ( + AttentionBias, + BlockDiagonalCausalFromBottomRightMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalMask, + BlockDiagonalMask, + LowerTriangularMask, + LowerTriangularMaskWithTensorBias, +) +from .common import ( + AttentionFwOpBase, + Context, + Inputs, + check_lastdim_alignment_stride1, +) + +FLASH_VERSION = "0.0.0" +try: + import flash_attn + from flash_attn.flash_attn_interface import (_flash_attn_forward, + _flash_attn_backward, + _flash_attn_varlen_forward, + _flash_attn_varlen_backward + ) + + FLASH_VERSION = flash_attn.__version__ + + # create library so that flash-attn goes through the PyTorch Dispatcher + _flash_lib = torch.library.Library("xformers_flash", "DEF") + + _flash_lib.define( + "flash_fwd(Tensor query, Tensor key, Tensor value, " + "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, " + "int max_seqlen_q, int max_seqlen_k, " + "float p, float softmax_scale, " + "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)" + ) + + _flash_lib.define( + "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, " + "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " + "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " + "int max_seqlen_q, int max_seqlen_k, " + "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)" + ) + + def _flash_fwd( + query, + key, + value, + cu_seq_lens_q, + cu_seq_lens_k, + max_seq_len_q, + max_seq_len_k, + p, + softmax_scale, + is_causal, + window_size, + return_softmax, + ): + if cu_seq_lens_q is None: + assert cu_seq_lens_k is None + ( + out, + q_padded, + k_padded, + v_padded, + out_padded, + softmax_lse, + S_dmask, + rng_state, + ) = _flash_attn_forward( + query, + key, + value, + p, + softmax_scale, + is_causal, + return_softmax + ) + else: + out = query.new_empty(query.shape[0], query.shape[1], value.shape[2]) + ( + out, + q_padded, + k_padded, + v_padded, + out_padded, + softmax_lse, + S_dmask, + rng_state, + ) = _flash_attn_varlen_forward( + query, + key, + value, + cu_seq_lens_q, + cu_seq_lens_k, + max_seq_len_q, + max_seq_len_k, + p, + softmax_scale, + is_causal, + return_softmax, + ) + return out, softmax_lse, rng_state + + def _flash_bwd( + grad, + query, + key, + value, + out, + lse, + dq, + dk, + dv, + cu_seq_lens_q, + cu_seq_lens_k, + max_seq_len_q, + max_seq_len_k, + p, + softmax_scale, + is_causal, + window_size, + rng_state, + ): + if cu_seq_lens_k is None: + assert cu_seq_lens_q is None + ( + dq, dk, dv, softmax_d + ) = _flash_attn_backward( + grad, + query, + key, + value, + out, + lse, + dq, + dk, + dv, + p, + softmax_scale, + is_causal, + rng_state, + ) + else: + ( + dq, dk, dv, softmax_d + ) = _flash_attn_varlen_backward( + grad, + query, + key, + value, + out, + lse, + dq, + dk, + dv, + cu_seq_lens_q, + cu_seq_lens_k, + max_seq_len_q, + max_seq_len_k, + p, + softmax_scale, + is_causal, + rng_state, + ) + return dq, dk, dv + + _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA") + _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA") +except ImportError: + pass + + +def _convert_input_format( + inp: Inputs, +) -> Tuple[Inputs, Optional[torch.Tensor], int, Optional[torch.Tensor], int]: + assert inp.query.ndim in [4, 5] + query, key, value = inp.query, inp.key, inp.value + batch = query.shape[0] + seqlen_q = query.shape[1] + seqlen_kv = key.shape[1] + head_dim_q = query.shape[-1] + head_dim_v = value.shape[-1] + + attn_bias = inp.attn_bias + if isinstance(attn_bias, BlockDiagonalMask): + attn_bias.k_seqinfo.seqstart = attn_bias.k_seqinfo.seqstart.to( + inp.query.device, non_blocking=True + ) + attn_bias.q_seqinfo.seqstart = attn_bias.q_seqinfo.seqstart.to( + inp.query.device, non_blocking=True + ) + + cu_seqlen_k = attn_bias.k_seqinfo.seqstart + cu_seqlen_q = attn_bias.q_seqinfo.seqstart + max_seqlen_q = attn_bias.q_seqinfo.max_seqlen + max_seqlen_k = attn_bias.k_seqinfo.max_seqlen + else: + cu_seqlen_k = None + cu_seqlen_q = None + max_seqlen_q = inp.query.shape[1] + max_seqlen_k = inp.key.shape[1] + + if query.ndim == 5: # QGA + # Fold the group/head_in_group dimensions together + def fold(x): + # Either the head is replicated + if x.stride(3) == 0: + return x[:, :, :, 0] + # Or we reshape + return x.reshape( + [ + x.shape[0], + x.shape[1], + -1, + x.shape[4], + ] + ) + + query = fold(query) + key = fold(key) + value = fold(value) + # Optimize for MHA + if key.ndim == 4 and key.stride(2) == 0 and value.stride(2) == 0: + key = key[:, :, :1] + value = value[:, :, :1] + # Initially we have `query.shape = [batch, seqlen, head_dim_q]` + # We want format `[batch * seqlen, num_heads, head_dim_q]` + if cu_seqlen_k is not None: + query = query.reshape([batch * seqlen_q, -1, head_dim_q]) + key = key.reshape([batch * seqlen_kv, -1, head_dim_q]) + value = value.reshape([batch * seqlen_kv, -1, head_dim_v]) + if query.is_contiguous() or key.is_contiguous() or value.is_contiguous(): + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + new_inp = replace( + inp, + query=query, + key=key, + value=value, + ) + return new_inp, cu_seqlen_q, max_seqlen_q, cu_seqlen_k, max_seqlen_k + + +def _is_causal(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> bool: + return isinstance( + attn_bias, + ( + LowerTriangularMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalFromBottomRightMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + ), + ) + + +def _window_size(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int: + if isinstance( + attn_bias, + (BlockDiagonalCausalLocalAttentionMask,), + ): + return attn_bias._window_size or 0 + if isinstance(attn_bias, BlockDiagonalCausalLocalAttentionFromBottomRightMask): + return attn_bias._window_size + return 0 + + +def _check_needs_no_topleft(d: Inputs, reasons: List[str]) -> None: + # Flash does not support TopLeft, so only allow causal masks with TopLeft + # if each batch element has equal number of queries and keys. + if isinstance(d.attn_bias, BlockDiagonalCausalMask): + # Flash does not support TopLeft, so only allow BlockDiagonalCausalMask + # if each batch element has equal number of queries and keys. + for k_start, q_start in zip_longest( + d.attn_bias.k_seqinfo.seqstart_py, d.attn_bias.q_seqinfo.seqstart_py + ): + if k_start != q_start: + reasons.append( + "Only support BlockDiagonalCausalMask if equal" + " numbers of keys and queries" + ) + break + elif isinstance(d.attn_bias, LowerTriangularMask): + if d.query.shape[1] != d.key.shape[1]: + reasons.append( + "Only support LowerTriangularMask if equal number of" "keys and queries" + ) + + +def _check_strides_for_bmghk(x: torch.Tensor, name: str, reasons: List[str]) -> None: + """ + We want to be able to collapse the G/H dimensions together + """ + if x.ndim == 5: + stride_g, stride_h = x.stride(2), x.stride(3) + if x.shape[2] == 1: + return + if x.shape[3] == 1 or stride_h == 0: + return + if stride_g != stride_h * x.shape[-2]: + reasons.append( + f"GQA is only supported when the G/H dimensions are contiguous\n" + f" {name}.stride: {x.stride()}\n" + f" {name}.shape : {list(x.shape)}" + ) + + +class FwOp(AttentionFwOpBase): + """Operator that computes memory-efficient attention using \ + `Flash-Attention `_ \ + implementation. + """ + + OPERATOR = _flash_fwd + SUPPORTED_DEVICES: Set[str] = {"cuda"} + #CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) + CUDA_MINIMUM_COMPUTE_CAPABILITY = (7, 5) + SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16} + SUPPORTED_MAX_K = 256 + SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = { + type(None), + LowerTriangularMask, + BlockDiagonalMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + BlockDiagonalCausalFromBottomRightMask, + LowerTriangularMaskWithTensorBias, + } + SUPPORTS_DROPOUT = True + SUPPORTS_CUSTOM_SCALE = True + SUPPORTS_DIFFERENT_VALUE_EMBED = False + SUPPORTS_BMGHK = True + NAME = f"flshattF@{FLASH_VERSION}" + VERSION = FLASH_VERSION + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(FwOp, cls).not_supported_reasons(d) + check_lastdim_alignment_stride1(reasons, "query", d.query, 8) + _check_needs_no_topleft(d, reasons) + _check_strides_for_bmghk(d.query, "query", reasons) + _check_strides_for_bmghk(d.key, "key", reasons) + _check_strides_for_bmghk(d.value, "value", reasons) + return reasons + + @classmethod + def apply( + cls, inp: Inputs, needs_gradient: bool + ) -> Tuple[torch.Tensor, Optional[Context]]: + return_softmax = False + out_shape = [ + *inp.query.shape[:-1], + inp.value.shape[-1], + ] + # no cumulative seqlen + ( + inp, + cu_seqlens_q, + max_seqlen_q, + cu_seqlens_k, + max_seqlen_k, + ) = _convert_input_format(inp) + + softmax_scale = inp.query.shape[-1] ** (-0.5) if inp.scale is None else inp.scale + + ret = cls.OPERATOR( + inp.query, + inp.key, + inp.value, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + inp.p, + softmax_scale, + _is_causal(inp.attn_bias), + _window_size(inp.attn_bias), + return_softmax, + ) + + out = ret[0].reshape(out_shape) + ctx = Context(out=out, lse=ret[1]) + return (out, ctx) + + @classmethod + # type: ignore + def operator_flop( + cls, + query, + key, + value, + cu_seq_lens_q, + cu_seq_lens_k, + max_seq_len_q, + max_seq_len_k, + p, + softmax_scale, + causal, + return_softmax, + ) -> int: + return cls.attn_operator_flop( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + causal=causal, + seqstart_k=cu_seq_lens_k, + seqstart_q=cu_seq_lens_q, + )