diff --git a/.gitignore b/.gitignore
index b531b7918c30..5f41478c08e2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -177,3 +177,12 @@ _build/
# vim swap files
*.swo
*.swp
+
+# Dataset files
+*.json
+*.json.[0-9]
+
+# Hipified files
+*.hip
+*_hip.h
+*_hip.cuh
diff --git a/Dockerfile b/Dockerfile
index 72f0249490d9..85e92a229bb3 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,72 +1,64 @@
-FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 AS dev
-
-RUN apt-get update -y \
- && apt-get install -y python3-pip
-
-WORKDIR /workspace
-
-# install build and runtime dependencies
-COPY requirements.txt requirements.txt
-RUN --mount=type=cache,target=/root/.cache/pip \
- pip install -r requirements.txt
-
-# install development dependencies
-COPY requirements-dev.txt requirements-dev.txt
-RUN --mount=type=cache,target=/root/.cache/pip \
- pip install -r requirements-dev.txt
-
-# image to build pytorch extensions
-FROM dev AS build
-
-# copy input files
-COPY csrc csrc
-COPY setup.py setup.py
-COPY requirements.txt requirements.txt
-COPY pyproject.toml pyproject.toml
-COPY vllm/__init__.py vllm/__init__.py
-
-# max jobs used by Ninja to build extensions
-ENV MAX_JOBS=$max_jobs
-RUN python3 setup.py build_ext --inplace
-
-# image to run unit testing suite
-FROM dev AS test
-
-# copy pytorch extensions separately to avoid having to rebuild
-# when python code changes
-COPY --from=build /workspace/vllm/*.so /workspace/vllm/
-COPY tests tests
-COPY vllm vllm
-
-ENTRYPOINT ["python3", "-m", "pytest", "tests"]
-
-# use CUDA base as CUDA runtime dependencies are already installed via pip
-FROM nvidia/cuda:11.8.0-base-ubuntu22.04 AS vllm-base
-
-# libnccl required for ray
-RUN apt-get update -y \
- && apt-get install -y python3-pip
-
-WORKDIR /workspace
-COPY requirements.txt requirements.txt
-RUN --mount=type=cache,target=/root/.cache/pip \
- pip install -r requirements.txt
-
-FROM vllm-base AS vllm
-COPY --from=build /workspace/vllm/*.so /workspace/vllm/
-COPY vllm vllm
-
-EXPOSE 8000
-ENTRYPOINT ["python3", "-m", "vllm.entrypoints.api_server"]
-
-# openai api server alternative
-FROM vllm-base AS vllm-openai
-# install additional dependencies for openai api server
-RUN --mount=type=cache,target=/root/.cache/pip \
- pip install accelerate fschat
-
-COPY --from=build /workspace/vllm/*.so /workspace/vllm/
-COPY vllm vllm
-
-ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
-
+FROM rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1
+
+# Install some basic utilities
+RUN apt-get update && apt-get install python3 python3-pip -y
+
+# Install some basic utilities
+RUN apt-get update && apt-get install -y \
+ curl \
+ ca-certificates \
+ sudo \
+ git \
+ bzip2 \
+ libx11-6 \
+ build-essential \
+ wget \
+ unzip \
+ nvidia-cuda-toolkit \
+ tmux \
+ && rm -rf /var/lib/apt/lists/*
+
+### Mount Point ###
+# When launching the container, mount the code directory to /app
+ARG APP_MOUNT=/app
+VOLUME [ ${APP_MOUNT} ]
+WORKDIR ${APP_MOUNT}
+
+RUN python3 -m pip install --upgrade pip
+RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers
+
+ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
+ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
+ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
+ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
+ENV PYTORCH_ROCM_ARCH=gfx900;gfx906;gfx908;gfx90a;gfx1030;gfx1101
+
+# Install ROCm flash-attention
+RUN mkdir libs \
+ && cd libs \
+ && git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \
+ && cd flash-attention \
+ && git submodule update --init \
+ && sed -i -e "s/--offload-arch=native/--offload-arch=$(/opt/rocm/llvm/bin/amdgpu-offload-arch)/g" setup.py \
+ && patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch \
+ && python3 setup.py install \
+ && cd ..
+
+COPY ./ /app/vllm-rocm/
+
+RUN cd /app \
+ && cd vllm-rocm \
+ && git checkout v0.2.1.post1-rocm \
+ && python3 setup.py install \
+ && cd ..
+
+RUN cd /app \
+ && mkdir dataset \
+ && cd ..
+
+COPY ./benchmark_throughput.sh /app/benchmark_throughput.sh
+
+RUN python3 -m pip install --upgrade pip
+RUN python3 -m pip install --no-cache-dir ray[all]
+
+CMD ["/bin/bash"]
diff --git a/README.md b/README.md
index 5ddb2a48ca83..6232d3157cb6 100644
--- a/README.md
+++ b/README.md
@@ -1,92 +1,147 @@
+
-
-
+
+
+
+
+
-
-Easy, fast, and cheap LLM serving for everyone
-
+
+vLLM ROCm port
+
+
+This version of vLLM 0.2.x supports model inferencing and serving on AMD GPUs with ROCm. This ROCm port was adapted from [vLLM](https://github.com/vllm-project/vllm), a ROCm [community port](https://github.com/pcmoritz/vllm-public/tree/port-to-rocm) and [xformers](https://github.com/facebookresearch/xformers), replacing the attention forward method employed in xformers by the ROCm realization of [flash attention](https://github.com/ROCmSoftwarePlatform/flash-attention). Currently this port does not support AWQ quantization yet, but SqueezeLLM has been incorporated.
+
+This port is an extension of our previous [vLLM v0.1.4 ROCm port](https://github.com/EmbeddedLLM/vllm-rocm/tree/v0.1.4-rocm). Compared with our previous port, vLLM v0.2.x achieves speedup of > 2x for LLaMA-70B model, and > 3x for LLaMA-7B/13B on MI210 thanks to the introduction of [efficient de-tokenization, vectorized sampling](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit#slide=id.g24c1f26d37c_10_117) and [paged attention v2](https://github.com/vllm-project/vllm/pull/1348).
-| Documentation | Blog | Paper | Discord |
+
+
+
+
+
+
+
+
---
-*Latest News* 🔥
-- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
-- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
-- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
-- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
-- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
-- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
-- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
+*Latest News*
+- [2023/11] We have updated our ROCm port for vLLM v0.2.x.
+- [2023/10] LLaMA-2 models are now supported. 7B/13B/70B models can be run and served on AMD GPUs!
---
-vLLM is a fast and easy-to-use library for LLM inference and serving.
+## Getting Started
-vLLM is fast with:
+The following sections describes the installation of this ROCm port. If you intend to use our provided container, please skip to the [using docker](#using-docker) section.
-- State-of-the-art serving throughput
-- Efficient management of attention key and value memory with **PagedAttention**
-- Continuous batching of incoming requests
-- Optimized CUDA kernels
+## Dependencies
-vLLM is flexible and easy to use with:
+To build this project, the following pre-requisites must be met:
-- Seamless integration with popular Hugging Face models
-- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
-- Tensor parallelism support for distributed inference
-- Streaming outputs
-- OpenAI-compatible API server
+- [PyTorch](https://pytorch.org/) with ROCm (5.7.0 or later) support
-vLLM seamlessly supports many Hugging Face models, including the following architectures:
+- Install ROCm [flash-attention](https://github.com/ROCmSoftwarePlatform/flash-attention) following the instructions in [AMD ROCm Support](https://github.com/ROCmSoftwarePlatform/flash-attention#amd-gpurocm-support)
-- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
-- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.)
-- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
-- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
-- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
-- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
-- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
-- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
-- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
-- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
-- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
-- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
-- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
-- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
+## Installation
-Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
+Build the repository
```bash
-pip install vllm
+git clone https://github.com/EmbeddedLLM/vllm-rocm.git
+cd vllm-rocm/
+python3 setup.py install
```
-## Getting Started
+## Using Docker
-Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started.
-- [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html)
-- [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html)
-- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html)
+A base docker image can be built from this repository:
-## Contributing
+```bash
+docker build -t vllm-rocm .
+```
+
+Run a docker container with
-We welcome and value any contributions and collaborations.
-Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
+```bash
+docker run -it \
+ --network=host \
+ --group-add=video \
+ --ipc=host \
+ --cap-add=SYS_PTRACE \
+ --security-opt seccomp=unconfined \
+ --shm-size 8G \
+ --device /dev/kfd \
+ --device /dev/dri \
+ vllm-rocm \
+ bash
+```
-## Citation
+Alternatively, you can pull from our pre-built docker image:
-If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
-```bibtex
-@inproceedings{kwon2023efficient,
- title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
- author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica},
- booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles},
- year={2023}
-}
+```bash
+docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.1.post1
+
+docker run -it \
+ --network=host \
+ --group-add=video \
+ --ipc=host \
+ --cap-add=SYS_PTRACE \
+ --security-opt seccomp=unconfined \
+ --shm-size 8G \
+ --device /dev/kfd \
+ --device /dev/dri \
+ embeddedllminfo/vllm-rocm \
+ bash
```
+
+## Serving
+
+The project supports native vLLM serving
+
+```bash
+python -m vllm.entrypoints.api_server \
+ --model lmsys/vicuna-7b-v1.5 \
+ --tensor-parallel-size 2
+```
+
+## Benchmarking
+
+The benchmark results were obtained by running the vLLM benchmark scripts under the *benchmark* directory.
+
+If your vLLM is installed using the provided [docker environment](#using-docker), you can benchmark the inferencing throughput following the steps below:
+- Download the model you would like to evaluate to a directory of your choice (say a vicuna-7b model is downloaded to /path/to/your/model/vicuna-7b-v1.5)
+- Run the docker and mount the model to /app/model
+
+```bash
+docker run -it \
+ --network=host \
+ --group-add=video \
+ --ipc=host \
+ --cap-add=SYS_PTRACE \
+ --security-opt seccomp=unconfined \
+ --shm-size 8G \
+ --device /dev/kfd \
+ --device /dev/dri \
+ -v /path/to/your/model/vicuna-7b-v1.5:/app/model \
+ vllm-rocm \
+ bash
+```
+Inside the container, run
+```bash
+bash /app/benchmark_throughput.sh
+```
+
+## Acknowledgement
+
+This ROCm port was built upon the following amazing projects:
+
+- [vLLM](https://github.com/vllm-project/vllm) and [pcmoritz's ROCm fork](https://github.com/pcmoritz/vllm-public/tree/port-to-rocm)
+- [flash-attention](https://github.com/ROCmSoftwarePlatform/flash-attention)
+- [xformers](https://github.com/facebookresearch/xformers)
diff --git a/benchmark_throughput.sh b/benchmark_throughput.sh
new file mode 100644
index 000000000000..c0cb42989a42
--- /dev/null
+++ b/benchmark_throughput.sh
@@ -0,0 +1,51 @@
+#!/bin/bash
+
+main () {
+
+ cd /app
+
+ positional_args=()
+
+ model_path="/app/model"
+ dataset_path="/app/dataset/ShareGPT_V3_unfiltered_cleaned_split.json"
+ dataset_path_modified=0
+
+ while [[ $# -gt 0 ]]; do
+ case $1 in
+ "-h"|"--help")
+ python3 /app/vllm-rocm/benchmarks/benchmark_throughput.py --help
+ return 0
+ ;;
+ "--dataset")
+ dataset_path="$2"
+ dataset_path_modified=1
+ shift
+ shift
+ ;;
+ "--model")
+ model_path="$2"
+ shift
+ shift
+ ;;
+ *)
+ positional_args+=("$1")
+ shift
+ ;;
+ esac
+ done
+
+ if [ ! -f "$dataset_path" ]; then
+ if [[ $dataset_path_modified -lt 1 ]]; then
+ cd /app/dataset
+ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
+ cd /app
+ fi
+ fi
+
+ python3 /app/vllm-rocm/benchmarks/benchmark_throughput.py --dataset "$dataset_path" --model "$model_path" "${positional_args[@]}"
+ return $?
+
+}
+
+main "$@"
+exit $?
diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py
index e560cb1fbfc0..eb9497204630 100644
--- a/benchmarks/benchmark_latency.py
+++ b/benchmarks/benchmark_latency.py
@@ -70,7 +70,8 @@ def run_to_completion(profile: bool = False):
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
- choices=['awq', 'squeezellm', None],
+ # choices=['awq', 'squeezellm', None],
+ choices=['squeezellm', None],
default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32)
diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py
index 3a80e679191e..3d8f863770a7 100644
--- a/benchmarks/benchmark_serving.py
+++ b/benchmarks/benchmark_serving.py
@@ -104,6 +104,7 @@ async def send_request(
output_len: int,
best_of: int,
use_beam_search: bool,
+ repetition_penalty: float = 1.0,
) -> None:
request_start_time = time.perf_counter()
@@ -119,6 +120,7 @@ async def send_request(
"max_tokens": output_len,
"ignore_eos": True,
"stream": False,
+ "repetition_penalty": repetition_penalty,
}
elif backend == "tgi":
assert not use_beam_search
@@ -160,13 +162,15 @@ async def benchmark(
best_of: int,
use_beam_search: bool,
request_rate: float,
+ repetition_penalty: float = 1.0,
) -> None:
tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, request_rate):
prompt, prompt_len, output_len = request
task = asyncio.create_task(send_request(backend, api_url, prompt,
prompt_len, output_len,
- best_of, use_beam_search))
+ best_of, use_beam_search,
+ repetition_penalty=repetition_penalty,))
tasks.append(task)
await asyncio.gather(*tasks)
@@ -182,7 +186,7 @@ def main(args: argparse.Namespace):
benchmark_start_time = time.perf_counter()
asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of,
- args.use_beam_search, args.request_rate))
+ args.use_beam_search, args.request_rate, args.repetition_penalty))
benchmark_end_time = time.perf_counter()
benchmark_time = benchmark_end_time - benchmark_start_time
print(f"Total time: {benchmark_time:.2f} s")
@@ -229,5 +233,7 @@ def main(args: argparse.Namespace):
parser.add_argument("--seed", type=int, default=0)
parser.add_argument('--trust-remote-code', action='store_true',
help='trust remote code from huggingface')
+ parser.add_argument("--repetition-penalty", type=float, default=1.0,
+ help="Set >1 to penalize repetition and <1 to reward repetition")
args = parser.parse_args()
main(args)
diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py
index fc578b497286..62951eedcb33 100644
--- a/benchmarks/benchmark_throughput.py
+++ b/benchmarks/benchmark_throughput.py
@@ -65,6 +65,9 @@ def run_vllm(
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
+ block_size: int = 16,
+ max_num_seqs: int = 256,
+ repetition_penalty = 1.0,
) -> float:
llm = LLM(
model=model,
@@ -74,8 +77,12 @@ def run_vllm(
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
+ block_size=block_size,
+ max_num_seqs=max_num_seqs
)
+ print("Using repetition penalty: {}".format(repetition_penalty))
+
# Add the requests to the engine.
for prompt, _, output_len in requests:
sampling_params = SamplingParams(
@@ -85,6 +92,7 @@ def run_vllm(
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
+ repetition_penalty=repetition_penalty,
)
# FIXME(woosuk): Do not use internal method.
llm._add_request(
@@ -170,10 +178,13 @@ def main(args: argparse.Namespace):
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
if args.backend == "vllm":
+ print("Using block_size={}, batch_size={}".format(args.block_size, args.max_num_seqs))
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
args.quantization, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search,
- args.trust_remote_code, args.dtype)
+ args.trust_remote_code, args.dtype,
+ block_size=args.block_size, max_num_seqs=args.max_num_seqs,
+ repetition_penalty=args.repetition_penalty)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@@ -201,7 +212,8 @@ def main(args: argparse.Namespace):
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
- choices=['awq', 'squeezellm', None],
+ # choices=['awq', 'squeezellm', None],
+ choices=['squeezellm', None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n",
@@ -230,6 +242,16 @@ def main(args: argparse.Namespace):
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
+ parser.add_argument('--block-size',
+ type=int,
+ default=16,
+ choices=[1, 2, 4, 8, 16, 32, 64, 128, 256],
+ help='token block size')
+ parser.add_argument("--max-num-seqs",
+ type=int,
+ default=256)
+ parser.add_argument("--repetition-penalty", type=float, default=1.0,
+ help="Set >1 to penalize repetition and <1 to reward repetition")
args = parser.parse_args()
if args.backend == "vllm":
diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu
index 89d1ba2d37dd..1cca2c5fccc1 100644
--- a/csrc/activation_kernels.cu
+++ b/csrc/activation_kernels.cu
@@ -1,6 +1,7 @@
#include
#include
+#include "cuda_compat.h"
#include "dispatch_utils.h"
namespace vllm {
@@ -18,8 +19,8 @@ __global__ void silu_and_mul_kernel(
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
- const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
- const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
+ const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
+ const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = silu(x) * y;
}
}
@@ -57,7 +58,7 @@ __global__ void activation_kernel(
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
- const scalar_t x = __ldg(&input[token_idx * d + idx]);
+ const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
out[token_idx * d + idx] = ACT_FN(x);
}
}
diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu
index 78e8d8ecd6d4..9228de194bf3 100644
--- a/csrc/attention/attention_kernels.cu
+++ b/csrc/attention/attention_kernels.cu
@@ -23,7 +23,11 @@
#include
+#ifndef USE_ROCM
#define WARP_SIZE 32
+#else
+#define WARP_SIZE 64
+#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
@@ -40,7 +44,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
- sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
+ sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}
// Warp leaders store the data to shared memory.
@@ -59,11 +63,11 @@ inline __device__ float block_sum(float* red_smem, float sum) {
// Parallel reduction inside the warp.
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
- sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
+ sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}
// Broadcast to other threads.
- return __shfl_sync(uint32_t(-1), sum, 0);
+ return VLLM_SHFL_SYNC(sum, 0);
}
// TODO(woosuk): Merge the last two dimensions of the grid.
@@ -223,7 +227,7 @@ __device__ void paged_attention_kernel(
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
- qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
+ qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
if (lane == 0) {
red_smem[warp_idx] = qk_max;
@@ -235,10 +239,10 @@ __device__ void paged_attention_kernel(
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
- qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
+ qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
// Broadcast the max qk value to all threads.
- qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
+ qk_max = VLLM_SHFL_SYNC(qk_max, 0);
// Get the sum of the exp values.
float exp_sum = 0.f;
@@ -326,7 +330,7 @@ __device__ void paged_attention_kernel(
float acc = accs[i];
#pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
- acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
+ acc += VLLM_SHFL_XOR_SYNC(acc, mask);
}
accs[i] = acc;
}
@@ -492,7 +496,7 @@ __global__ void paged_attention_v2_reduce_kernel(
// Reduce within the warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
- max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
+ max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
if (lane == 0) {
red_smem[warp_idx] = max_logit;
@@ -502,10 +506,10 @@ __global__ void paged_attention_v2_reduce_kernel(
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
- max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
+ max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
// Broadcast the max value to all threads.
- max_logit = __shfl_sync(uint32_t(-1), max_logit, 0);
+ max_logit = VLLM_SHFL_SYNC(max_logit, 0);
// Load rescaled exp sums to shared memory.
float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions);
@@ -538,9 +542,10 @@ __global__ void paged_attention_v2_reduce_kernel(
} // namespace vllm
+#ifndef USE_ROCM
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
cudaFuncSetAttribute( \
- vllm::paged_attention_v1_kernel, \
+ (void*)vllm::paged_attention_v1_kernel, \
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
vllm::paged_attention_v1_kernel \
<<>>( \
@@ -557,6 +562,27 @@ __global__ void paged_attention_v2_reduce_kernel(
q_stride, \
kv_block_stride, \
kv_head_stride);
+#else
+#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
+ hipFuncSetAttribute( \
+ (void*)vllm::paged_attention_v1_kernel, \
+ hipFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
+ vllm::paged_attention_v1_kernel \
+ <<>>( \
+ out_ptr, \
+ query_ptr, \
+ key_cache_ptr, \
+ value_cache_ptr, \
+ head_mapping_ptr, \
+ scale, \
+ block_tables_ptr, \
+ context_lens_ptr, \
+ max_num_blocks_per_seq, \
+ alibi_slopes_ptr, \
+ q_stride, \
+ kv_block_stride, \
+ kv_head_stride);
+#endif
// TODO(woosuk): Tune NUM_THREADS.
template<
@@ -654,6 +680,15 @@ void paged_attention_v1_launcher(
// 1, 2, 4, 64, 128, 256.
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
+ /*case 1: */ \
+ /* CALL_V1_LAUNCHER(T, 1); */ \
+ /* break; */ \
+ /*case 2: */ \
+ /* CALL_V1_LAUNCHER(T, 2); */ \
+ /* break; */ \
+ /*case 4: */ \
+ /* CALL_V1_LAUNCHER(T, 4); */ \
+ /* break; */ \
case 8: \
CALL_V1_LAUNCHER(T, 8); \
break; \
@@ -663,6 +698,15 @@ void paged_attention_v1_launcher(
case 32: \
CALL_V1_LAUNCHER(T, 32); \
break; \
+ /*case 64: */ \
+ /* CALL_V1_LAUNCHER(T, 64); */ \
+ /* break; */ \
+ /*case 128: */ \
+ /* CALL_V1_LAUNCHER(T, 128); */ \
+ /* break; */ \
+ /*case 256: */ \
+ /* CALL_V1_LAUNCHER(T, 256); */ \
+ /* break; */ \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
@@ -826,6 +870,15 @@ void paged_attention_v2_launcher(
// 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
+ /*case 1: */ \
+ /* CALL_V2_LAUNCHER(T, 1); */ \
+ /* break; */ \
+ /*case 2: */ \
+ /* CALL_V2_LAUNCHER(T, 2); */ \
+ /* break; */ \
+ /*case 4: */ \
+ /* CALL_V2_LAUNCHER(T, 4); */ \
+ /* break; */ \
case 8: \
CALL_V2_LAUNCHER(T, 8); \
break; \
@@ -835,6 +888,15 @@ void paged_attention_v2_launcher(
case 32: \
CALL_V2_LAUNCHER(T, 32); \
break; \
+ /*case 64: */ \
+ /* CALL_V2_LAUNCHER(T, 64); */ \
+ /* break; */ \
+ /*case 128: */ \
+ /* CALL_V2_LAUNCHER(T, 128); */ \
+ /* break; */ \
+ /*case 256: */ \
+ /* CALL_V2_LAUNCHER(T, 256); */ \
+ /* break; */ \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh
index bb7df25b14f0..ff64c4bd8f80 100644
--- a/csrc/attention/attention_utils.cuh
+++ b/csrc/attention/attention_utils.cuh
@@ -17,6 +17,7 @@
*/
#pragma once
+#include "../cuda_compat.h"
#include "attention_dtypes.h"
#include
@@ -39,7 +40,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
float qk = sum(qk_vec);
#pragma unroll
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
- qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
+ qk += VLLM_SHFL_XOR_SYNC(qk, mask);
}
return qk;
}
diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh
index 5786f77f7bca..31e0cee01d2e 100644
--- a/csrc/attention/dtype_bfloat16.cuh
+++ b/csrc/attention/dtype_bfloat16.cuh
@@ -21,8 +21,17 @@
#include "attention_generic.cuh"
#include "dtype_float32.cuh"
-#include
-#include
+#ifndef USE_ROCM
+ #include
+ #include
+#else
+ #include
+ #include
+
+ typedef __hip_bfloat162 __nv_bfloat162;
+ typedef __hip_bfloat16 __nv_bfloat16;
+#endif
+
#include
namespace vllm {
@@ -98,7 +107,11 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
- return a + b;
+ #ifndef USE_ROCM
+ return a + b;
+ #else
+ return __hadd(a, b);
+ #endif
#endif
}
diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh
index e67921128d52..b5e128ce8ee0 100644
--- a/csrc/attention/dtype_float16.cuh
+++ b/csrc/attention/dtype_float16.cuh
@@ -21,6 +21,10 @@
#include "attention_generic.cuh"
#include "dtype_float32.cuh"
+#ifdef USE_ROCM
+ #include
+#endif
+
#include
namespace vllm {
@@ -63,58 +67,114 @@ struct FloatVec {
// Utility functions for type conversions.
inline __device__ uint32_t h0_h0(uint16_t a) {
+#ifndef USE_ROCM
uint32_t b;
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
return b;
+#else
+ union {
+ uint32_t u32;
+ uint16_t u16[2];
+ } tmp;
+ tmp.u16[0] = a;
+ tmp.u16[1] = a;
+ return tmp.u32;
+#endif
}
inline __device__ float half_to_float(uint16_t h) {
+#ifndef USE_ROCM
float f;
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
return f;
+#else
+ float f;
+ asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
+ return f;
+#endif
}
inline __device__ float2 half2_to_float2(uint32_t v) {
+#ifndef USE_ROCM
uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
return make_float2(half_to_float(lo), half_to_float(hi));
+#else
+ union {
+ uint32_t u32;
+ uint16_t u16[2];
+ } tmp;
+ tmp.u32 = v;
+ float2 ret;
+ ret.x = half_to_float(tmp.u16[0]);
+ ret.y = half_to_float(tmp.u16[1]);
+ return ret;
+#endif
}
inline __device__ uint16_t float_to_half(float f) {
+#ifndef USE_ROCM
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
return tmp.u16[0];
+#else
+ uint16_t ret;
+ asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(ret) : "v"(f));
+ return ret;
+#endif
}
inline __device__ uint32_t float2_to_half2(float2 f) {
+#ifndef USE_ROCM
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
- asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+ asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
+ #else
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
+ #endif
+ return tmp.u32;
#else
- asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
- asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
-#endif
+ union {
+ uint32_t u32;
+ uint16_t u16[2];
+ } tmp;
+ tmp.u16[0] = float_to_half(f.x);
+ tmp.u16[1] = float_to_half(f.y);
return tmp.u32;
+#endif
}
// Vector addition.
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
+#ifndef USE_ROCM
uint16_t c;
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
return c;
+#else
+ uint16_t c;
+ asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
+ return c;
+#endif
}
inline __device__ uint32_t add(uint32_t a, uint32_t b) {
+#ifndef USE_ROCM
uint32_t c;
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
+#else
+ uint32_t c;
+ asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
+ return c;
+#endif
}
inline __device__ uint2 add(uint2 a, uint2 b) {
@@ -157,16 +217,28 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
// Vector multiplication.
template<>
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
+#ifndef USE_ROCM
uint16_t c;
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
return c;
+#else
+ uint16_t c;
+ asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
+ return c;
+#endif
}
template<>
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
+#ifndef USE_ROCM
uint32_t c;
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
+#else
+ uint32_t c;
+ asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
+ return c;
+#endif
}
template<>
@@ -271,9 +343,15 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
// Vector fused multiply-add.
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
+#ifndef USE_ROCM
uint32_t d;
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
return d;
+#else
+ uint32_t d;
+ asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
+ return d;
+#endif
}
inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu
index 3ad52b1681c0..59bacffdf464 100644
--- a/csrc/cache_kernels.cu
+++ b/csrc/cache_kernels.cu
@@ -1,6 +1,7 @@
#include
#include
+#include "cuda_compat.h"
#include "dispatch_utils.h"
#include
@@ -28,8 +29,8 @@ void swap_blocks(
TORCH_CHECK(false, "Invalid device combination");
}
- void *src_ptr = src.data_ptr();
- void *dst_ptr = dst.data_ptr();
+ char *src_ptr = static_cast(src.data_ptr());
+ char *dst_ptr = static_cast(dst.data_ptr());
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@@ -267,8 +268,8 @@ __global__ void gather_cached_kv_kernel(
+ head_offset * block_size
+ block_offset;
- key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]);
- value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]);
+ key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]);
+ value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]);
}
}
@@ -333,8 +334,8 @@ __global__ void gather_cached_kv_kernel_optimized(
src_key_indices[j] = src_key_idx;
src_value_indices[j] = src_value_idx;
- keys_to_store[j] = __ldg(&key_cache[src_key_idx]);
- values_to_store[j] = __ldg(&value_cache[src_value_idx]);
+ keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]);
+ values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]);
}
#pragma unroll
diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h
new file mode 100644
index 000000000000..8991462a862e
--- /dev/null
+++ b/csrc/cuda_compat.h
@@ -0,0 +1,19 @@
+#pragma once
+
+#ifndef USE_ROCM
+ #define VLLM_LDG(arg) __ldg(arg)
+#else
+ #define VLLM_LDG(arg) *(arg)
+#endif
+
+#ifndef USE_ROCM
+ #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
+#else
+ #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
+#endif
+
+#ifndef USE_ROCM
+ #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane);
+#else
+ #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
+#endif
diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu
index f1c30fe7ea99..2439f5922a3f 100644
--- a/csrc/cuda_utils_kernels.cu
+++ b/csrc/cuda_utils_kernels.cu
@@ -1,3 +1,7 @@
+#ifdef USE_ROCM
+ #include
+#endif
+
int get_device_attribute(
int attribute,
int device_id)
diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu
index 0a5ec95f8c0d..e1dc711778ff 100644
--- a/csrc/pos_encoding_kernels.cu
+++ b/csrc/pos_encoding_kernels.cu
@@ -1,6 +1,7 @@
#include
#include
+#include "cuda_compat.h"
#include "dispatch_utils.h"
namespace vllm {
@@ -19,14 +20,14 @@ inline __device__ void apply_rotary_embedding(
// GPT-NeoX style rotary embedding.
x_index = rot_offset;
y_index = embed_dim + rot_offset;
- cos = __ldg(cos_ptr + x_index);
- sin = __ldg(sin_ptr + x_index);
+ cos = VLLM_LDG(cos_ptr + x_index);
+ sin = VLLM_LDG(sin_ptr + x_index);
} else {
// GPT-J style rotary embedding.
x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1;
- cos = __ldg(cos_ptr + x_index / 2);
- sin = __ldg(sin_ptr + x_index / 2);
+ cos = VLLM_LDG(cos_ptr + x_index / 2);
+ sin = VLLM_LDG(sin_ptr + x_index / 2);
}
const scalar_t x = arr[x_index];
diff --git a/csrc/quantization.cpp b/csrc/quantization.cpp
index dfe17a496c78..6ebcc35e4227 100644
--- a/csrc/quantization.cpp
+++ b/csrc/quantization.cpp
@@ -1,11 +1,11 @@
#include
-torch::Tensor awq_gemm(
- torch::Tensor _in_feats,
- torch::Tensor _kernel,
- torch::Tensor _scaling_factors,
- torch::Tensor _zeros,
- int split_k_iters);
+// torch::Tensor awq_gemm(
+// torch::Tensor _in_feats,
+// torch::Tensor _kernel,
+// torch::Tensor _scaling_factors,
+// torch::Tensor _zeros,
+// int split_k_iters);
void squeezellm_gemm(
torch::Tensor vec,
@@ -14,6 +14,6 @@ void squeezellm_gemm(
torch::Tensor lookup_table);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
+ // m.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
m.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
}
diff --git a/csrc/quantization/squeezellm/quant_cuda_kernel.cu b/csrc/quantization/squeezellm/quant_cuda_kernel.cu
index 1392b877397b..2c37d01e0ae5 100644
--- a/csrc/quantization/squeezellm/quant_cuda_kernel.cu
+++ b/csrc/quantization/squeezellm/quant_cuda_kernel.cu
@@ -20,9 +20,17 @@ __device__ inline unsigned int as_unsigned(int i) {
// 4-bit matvec kernel (LUT-based)
__global__ void NUQ4MatMulKernel(
+#ifndef USE_ROCM
const half2* __restrict__ vec,
+#else
+ const __half2* __restrict__ vec,
+#endif
const int* __restrict__ mat,
+#ifndef USE_ROCM
half2* __restrict__ mul,
+#else
+ float2* __restrict__ mul,
+#endif
const __half* __restrict__ lookup_table,
int height,
int width,
@@ -35,7 +43,11 @@ __global__ void NUQ4MatMulKernel(
int row = BLOCKHEIGHT4 * blockIdx.x;
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
+#ifndef USE_ROCM
__shared__ half2 blockvec[blockwidth2];
+#else
+ __shared__ __half2 blockvec[blockwidth2];
+#endif
__shared__ __half deq2[16][BLOCKWIDTH];
int off = threadIdx.x;
@@ -46,8 +58,13 @@ __global__ void NUQ4MatMulKernel(
}
__half res;
+#ifndef USE_ROCM
half2 res2;
half2 tmp2;
+#else
+ __half2 res2;
+ __half2 tmp2;
+#endif
int i;
int k;
@@ -68,48 +85,96 @@ __global__ void NUQ4MatMulKernel(
while (k < blockwidth2) {
tmp1 = as_unsigned(mat[i]);
+#ifndef USE_ROCM
res2 = {};
tmp2 = {};
+#else
+ res2.x = __half_as_ushort(__float2half(0));
+ res2.y = __half_as_ushort(__float2half(0));
+ tmp2.x = __half_as_ushort(__float2half(0));
+ tmp2.y = __half_as_ushort(__float2half(0));
+#endif
lut_index1 = tmp1 & 0xF;
lut_index2 = (tmp1 >> 4) & 0xF;
+#ifndef USE_ROCM
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
+#else
+ tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
+ tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
+#endif
res2 = __hfma2(tmp2, blockvec[k + 0], res2);
lut_index1 = (tmp1 >> 8) & 0xF;
lut_index2 = (tmp1 >> 12) & 0xF;
+#ifndef USE_ROCM
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
+#else
+ tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
+ tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
+#endif
res2 = __hfma2(tmp2, blockvec[k + 1], res2);
lut_index1 = (tmp1 >> 16) & 0xF;
lut_index2 = (tmp1 >> 20) & 0xF;
+#ifndef USE_ROCM
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
+#else
+ tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
+ tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
+#endif
res2 = __hfma2(tmp2, blockvec[k + 2], res2);
lut_index1 = (tmp1 >> 24) & 0xF;
lut_index2 = (tmp1 >> 28) & 0xF;
+#ifndef USE_ROCM
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
+#else
+ tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
+ tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
+#endif
res2 = __hfma2(tmp2, blockvec[k + 3], res2);
+#ifndef USE_ROCM
res = __hadd(__hadd(res2.x, res2.y), res);
+#else
+ res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res);
+#endif
i += width;
k += 4;
}
// col%2 -> only set one of the two values
+#ifndef USE_ROCM
half2 res3 = {};
if (col % 2 == 0) {
res3.x = res;
} else {
res3.y = res;
}
+#else
+ __half2 res3;
+ res3.x = __half_as_ushort(__float2half(0));
+ res3.y = __half_as_ushort(__float2half(0));
+ if (col % 2 == 0) {
+ res3.x = __half_as_ushort(res);
+ } else {
+ res3.y = __half_as_ushort(res);
+ }
+#endif
+#ifndef USE_ROCM
atomicAdd(&mul[b * width / 2 + col / 2], res3);
+#else
+ int tmp_addr = b * width / 2 + col / 2;
+ atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x)));
+ atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y)));
+#endif
}
}
@@ -136,10 +201,19 @@ void squeezellm_gemm(
dim3 threads(BLOCKWIDTH);
vllm::squeezellm::NUQ4MatMulKernel<<>>(
+#ifndef USE_ROCM
(half2*) vec.data(),
+#else
+ (__half2*) vec.data_ptr(),
+#endif
mat.data_ptr(),
+#ifndef USE_ROCM
(half2*) mul.data(),
(__half*) lookup_table.data(),
+#else
+ (float2*) mul.data_ptr(),
+ (__half*) lookup_table.data_ptr(),
+#endif
height, width, batch, vec_height
);
}
diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh
index bc35aa0424b5..b95ccef16207 100644
--- a/csrc/reduction_utils.cuh
+++ b/csrc/reduction_utils.cuh
@@ -17,13 +17,15 @@
*/
#pragma once
+#include "cuda_compat.h"
+
namespace vllm {
template
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
- val += __shfl_xor_sync(0xffffffff, val, mask, 32);
+ val += VLLM_SHFL_XOR_SYNC(val, mask);
return val;
}
diff --git a/docs/source/assets/benchmarks/throughput_requests.png b/docs/source/assets/benchmarks/throughput_requests.png
new file mode 100644
index 000000000000..9b4b541393ca
Binary files /dev/null and b/docs/source/assets/benchmarks/throughput_requests.png differ
diff --git a/docs/source/assets/benchmarks/throughput_requests.svg b/docs/source/assets/benchmarks/throughput_requests.svg
new file mode 100644
index 000000000000..f87c9fa9defc
--- /dev/null
+++ b/docs/source/assets/benchmarks/throughput_requests.svg
@@ -0,0 +1,434 @@
+
+
+
+
+
+ image/svg+xml
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 2.2x 3.2x 3.5x
+
+
+
diff --git a/docs/source/assets/benchmarks/throughput_tokens.png b/docs/source/assets/benchmarks/throughput_tokens.png
new file mode 100644
index 000000000000..59bf778cd08c
Binary files /dev/null and b/docs/source/assets/benchmarks/throughput_tokens.png differ
diff --git a/docs/source/assets/benchmarks/throughput_tokens.svg b/docs/source/assets/benchmarks/throughput_tokens.svg
new file mode 100644
index 000000000000..c1643afc4e3d
--- /dev/null
+++ b/docs/source/assets/benchmarks/throughput_tokens.svg
@@ -0,0 +1,535 @@
+
+
+
+
+
+ image/svg+xml
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 2.2x 3.2x 3.5x
+
+
+
diff --git a/docs/source/assets/logos/16900649.png b/docs/source/assets/logos/16900649.png
new file mode 100644
index 000000000000..095beb65b139
Binary files /dev/null and b/docs/source/assets/logos/16900649.png differ
diff --git a/examples/api_client.py b/examples/api_client.py
index 70ec8c549212..6f84973e0318 100644
--- a/examples/api_client.py
+++ b/examples/api_client.py
@@ -17,15 +17,17 @@ def clear_line(n: int = 1) -> None:
def post_http_request(prompt: str,
api_url: str,
n: int = 1,
- stream: bool = False) -> requests.Response:
+ stream: bool = False,
+ repetition_penalty = 1.0) -> requests.Response:
headers = {"User-Agent": "Test Client"}
pload = {
"prompt": prompt,
"n": n,
"use_beam_search": True,
"temperature": 0.0,
- "max_tokens": 16,
+ "max_tokens": 1024,
"stream": stream,
+ "repetition_penalty": repetition_penalty,
}
response = requests.post(api_url, headers=headers, json=pload, stream=True)
return response
@@ -54,6 +56,8 @@ def get_response(response: requests.Response) -> List[str]:
parser.add_argument("--n", type=int, default=4)
parser.add_argument("--prompt", type=str, default="San Francisco is a")
parser.add_argument("--stream", action="store_true")
+ parser.add_argument("--repetition-penalty", type=float, default=1.0,
+ help="Set >1 to penalize repetition and <1 to reward repetition")
args = parser.parse_args()
prompt = args.prompt
api_url = f"http://{args.host}:{args.port}/generate"
@@ -61,7 +65,7 @@ def get_response(response: requests.Response) -> List[str]:
stream = args.stream
print(f"Prompt: {prompt!r}\n", flush=True)
- response = post_http_request(prompt, api_url, n, stream)
+ response = post_http_request(prompt, api_url, n, stream, repetition_penalty=args.repetition_penalty)
if stream:
num_printed_lines = 0
diff --git a/requirements.txt b/requirements.txt
index d8597b3ec554..7d4fe2972815 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,10 +4,10 @@ ray >= 2.5.1
pandas # Required for Ray data.
pyarrow # Required for Ray data.
sentencepiece # Required for LLaMA tokenizer.
-numpy
+numpy >= 1.22.4
torch == 2.0.1
transformers >= 4.34.0 # Required for Mistral.
-xformers == 0.0.22 # Required for Mistral.
+#xformers == 0.0.22 # Required for Mistral.
fastapi
uvicorn[standard]
pydantic == 1.10.13 # Required for OpenAI server.
diff --git a/setup.py b/setup.py
index 660b5196cfd9..784de6734da4 100644
--- a/setup.py
+++ b/setup.py
@@ -8,7 +8,7 @@
from packaging.version import parse, Version
import setuptools
import torch
-from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
ROOT_DIR = os.path.dirname(__file__)
@@ -24,10 +24,14 @@
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
-if CUDA_HOME is None:
- raise RuntimeError(
- "Cannot find CUDA_HOME. CUDA must be available to build the package.")
+if torch.version.hip:
+ if ROCM_HOME is not None:
+ NVCC_FLAGS += [f"-DUSE_ROCM"]
+if not torch.version.hip:
+ if CUDA_HOME is None:
+ raise RuntimeError(
+ "Cannot find CUDA_HOME. CUDA must be available to build the package.")
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
"""Get the CUDA version from nvcc.
@@ -76,66 +80,72 @@ def get_torch_arch_list() -> Set[str]:
f"{valid_archs}.")
return arch_list
-
-# First, check the TORCH_CUDA_ARCH_LIST environment variable.
-compute_capabilities = get_torch_arch_list()
-if not compute_capabilities:
- # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
- # GPUs on the current machine.
- device_count = torch.cuda.device_count()
- for i in range(device_count):
- major, minor = torch.cuda.get_device_capability(i)
- if major < 7:
- raise RuntimeError(
- "GPUs with compute capability below 7.0 are not supported.")
- compute_capabilities.add(f"{major}.{minor}")
-
-nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
-if not compute_capabilities:
- # If no GPU is specified nor available, add all supported architectures
- # based on the NVCC CUDA version.
- compute_capabilities = SUPPORTED_ARCHS.copy()
+def get_cuda_compute_capabilities(nvcc_cuda_version):
+ # First, check the TORCH_CUDA_ARCH_LIST environment variable.
+ compute_capabilities = get_torch_arch_list()
+ if not compute_capabilities:
+ # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
+ # GPUs on the current machine.
+ device_count = torch.cuda.device_count()
+ for i in range(device_count):
+ major, minor = torch.cuda.get_device_capability(i)
+ if major < 7:
+ raise RuntimeError(
+ "GPUs with compute capability below 7.0 are not supported.")
+ compute_capabilities.add(f"{major}.{minor}")
+
+ if not compute_capabilities:
+ # If no GPU is specified nor available, add all supported architectures
+ # based on the NVCC CUDA version.
+ compute_capabilities = SUPPORTED_ARCHS.copy()
+ if nvcc_cuda_version < Version("11.1"):
+ compute_capabilities.remove("8.6")
+ if nvcc_cuda_version < Version("11.8"):
+ compute_capabilities.remove("8.9")
+ compute_capabilities.remove("9.0")
+
+ return compute_capabilities
+
+def validate_nvcc_cuda_version(nvcc_cuda_version, compute_capabilities):
+ if nvcc_cuda_version < Version("11.0"):
+ raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
if nvcc_cuda_version < Version("11.1"):
- compute_capabilities.remove("8.6")
+ if any(cc.startswith("8.6") for cc in compute_capabilities):
+ raise RuntimeError(
+ "CUDA 11.1 or higher is required for compute capability 8.6.")
if nvcc_cuda_version < Version("11.8"):
- compute_capabilities.remove("8.9")
- compute_capabilities.remove("9.0")
-
-# Validate the NVCC CUDA version.
-if nvcc_cuda_version < Version("11.0"):
- raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
-if nvcc_cuda_version < Version("11.1"):
- if any(cc.startswith("8.6") for cc in compute_capabilities):
- raise RuntimeError(
- "CUDA 11.1 or higher is required for compute capability 8.6.")
-if nvcc_cuda_version < Version("11.8"):
- if any(cc.startswith("8.9") for cc in compute_capabilities):
- # CUDA 11.8 is required to generate the code targeting compute capability 8.9.
- # However, GPUs with compute capability 8.9 can also run the code generated by
- # the previous versions of CUDA 11 and targeting compute capability 8.0.
- # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
- # instead of 8.9.
- warnings.warn(
- "CUDA 11.8 or higher is required for compute capability 8.9. "
- "Targeting compute capability 8.0 instead.")
- compute_capabilities = set(cc for cc in compute_capabilities
- if not cc.startswith("8.9"))
- compute_capabilities.add("8.0+PTX")
- if any(cc.startswith("9.0") for cc in compute_capabilities):
- raise RuntimeError(
- "CUDA 11.8 or higher is required for compute capability 9.0.")
+ if any(cc.startswith("8.9") for cc in compute_capabilities):
+ # CUDA 11.8 is required to generate the code targeting compute capability 8.9.
+ # However, GPUs with compute capability 8.9 can also run the code generated by
+ # the previous versions of CUDA 11 and targeting compute capability 8.0.
+ # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
+ # instead of 8.9.
+ warnings.warn(
+ "CUDA 11.8 or higher is required for compute capability 8.9. "
+ "Targeting compute capability 8.0 instead.")
+ compute_capabilities = set(cc for cc in compute_capabilities
+ if not cc.startswith("8.9"))
+ compute_capabilities.add("8.0+PTX")
+ if any(cc.startswith("9.0") for cc in compute_capabilities):
+ raise RuntimeError(
+ "CUDA 11.8 or higher is required for compute capability 9.0.")
+
+if not torch.version.hip:
+ nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
+ compute_capabilities = get_cuda_compute_capabilities(nvcc_cuda_version)
+ validate_nvcc_cuda_version(nvcc_cuda_version, compute_capabilities)
-# Add target compute capabilities to NVCC flags.
-for capability in compute_capabilities:
- num = capability[0] + capability[2]
- NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
- if capability.endswith("+PTX"):
- NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
+ # Add target compute capabilities to NVCC flags.
+ for capability in compute_capabilities:
+ num = capability[0] + capability[2]
+ NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
+ if capability.endswith("+PTX"):
+ NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
-# Use NVCC threads to parallelize the build.
-if nvcc_cuda_version >= Version("11.2"):
- num_threads = min(os.cpu_count(), 8)
- NVCC_FLAGS += ["--threads", str(num_threads)]
+ # Use NVCC threads to parallelize the build.
+ if nvcc_cuda_version >= Version("11.2"):
+ num_threads = min(os.cpu_count(), 8)
+ NVCC_FLAGS += ["--threads", str(num_threads)]
ext_modules = []
@@ -199,7 +209,7 @@ def get_torch_arch_list() -> Set[str]:
name="vllm.quantization_ops",
sources=[
"csrc/quantization.cpp",
- "csrc/quantization/awq/gemm_kernels.cu",
+ # "csrc/quantization/awq/gemm_kernels.cu",
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
],
extra_compile_args={
diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py
index 7c4a84d4c7d8..35c6a6ab2fa1 100644
--- a/tests/kernels/test_attention.py
+++ b/tests/kernels/test_attention.py
@@ -3,8 +3,8 @@
import pytest
import torch
-from xformers import ops as xops
-from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
+from vllm.xformers.ops.fmha import memory_efficient_attention_forward
+from vllm.xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from vllm import attention_ops
from vllm.utils import get_max_shared_memory_bytes
@@ -308,7 +308,7 @@ def test_multi_query_kv_attention(
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
- output = xops.memory_efficient_attention_forward(
+ output = memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
diff --git a/vllm/config.py b/vllm/config.py
index a9e86c24b273..13db4d8a65b0 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -103,7 +103,8 @@ def _verify_tokenizer_mode(self) -> None:
self.tokenizer_mode = tokenizer_mode
def _verify_quantization(self) -> None:
- supported_quantization = ["awq", "squeezellm"]
+ # supported_quantization = ["awq", "squeezellm"]
+ supported_quantization = ["squeezellm"]
if self.quantization is None:
return
quantization = self.quantization.lower()
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index cc425a2c079e..dc00da04c183 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -132,7 +132,7 @@ def add_cli_args(
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
- choices=[8, 16, 32],
+ choices=[1, 2, 4, 8, 16, 32, 64, 128, 256],
help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed',
@@ -168,7 +168,8 @@ def add_cli_args(
parser.add_argument('--quantization',
'-q',
type=str,
- choices=['awq', 'squeezellm', None],
+ # choices=['awq', 'squeezellm', None],
+ choices=['squeezellm', None],
default=None,
help='Method used to quantize the weights')
return parser
diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py
index ed7f1ec45e32..93cebb288174 100644
--- a/vllm/engine/ray_utils.py
+++ b/vllm/engine/ray_utils.py
@@ -74,7 +74,7 @@ def initialize_cluster(
"Ray is not installed. Please install Ray to use distributed "
"serving.")
# Connect to a ray cluster.
- ray.init(address=ray_address, ignore_reinit_error=True)
+ ray.init(address=ray_address, ignore_reinit_error=True, num_gpus=parallel_config.world_size)
if not parallel_config.worker_use_ray:
# Initialize cluster locally.
diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py
index b3b5852e4876..6df6d529f13f 100644
--- a/vllm/model_executor/input_metadata.py
+++ b/vllm/model_executor/input_metadata.py
@@ -1,7 +1,7 @@
from typing import Dict, List, Optional, Tuple
import torch
-from xformers.ops import AttentionBias
+from vllm.xformers.ops.fmha.attn_bias import AttentionBias
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData
diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py
index c1259a1b11ea..bea78c95c5c2 100644
--- a/vllm/model_executor/layers/attention.py
+++ b/vllm/model_executor/layers/attention.py
@@ -3,8 +3,8 @@
import torch
import torch.nn as nn
-from xformers import ops as xops
-from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
+from vllm.xformers.ops.fmha import memory_efficient_attention_forward
+from vllm.xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
LowerTriangularMaskWithTensorBias)
from vllm import attention_ops
@@ -104,7 +104,7 @@ def multi_query_kv_attention(
dim=1)
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
- out = xops.memory_efficient_attention_forward(
+ out = memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
@@ -465,7 +465,7 @@ def multi_query_kv_attention(
batch_size = input_metadata.num_prompts
seq_len = input_metadata.max_prompt_len
- out = xops.memory_efficient_attention_forward(
+ out = memory_efficient_attention_forward(
query.view(batch_size, seq_len, self.num_heads, self.head_size),
key.view(batch_size, seq_len, self.num_heads, self.head_size),
value.view(batch_size, seq_len, self.num_heads, self.head_size),
diff --git a/vllm/model_executor/layers/quantized_linear/__init__.py b/vllm/model_executor/layers/quantized_linear/__init__.py
index b09358261d5d..0feb3b243b2a 100644
--- a/vllm/model_executor/layers/quantized_linear/__init__.py
+++ b/vllm/model_executor/layers/quantized_linear/__init__.py
@@ -1,12 +1,12 @@
-from vllm.model_executor.layers.quantized_linear.awq import (
- AWQColumnParallelLinear, AWQRowParallelLinear)
+# from vllm.model_executor.layers.quantized_linear.awq import (
+# AWQColumnParallelLinear, AWQRowParallelLinear)
from vllm.model_executor.layers.quantized_linear.squeezellm import (
SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
RowParallelLinear)
_QUANTIZED_LINEAR_REGISTRY = {
- "awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
+ # "awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
"squeezellm":
(SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear),
}
diff --git a/vllm/model_executor/layers/quantized_linear/squeezellm.py b/vllm/model_executor/layers/quantized_linear/squeezellm.py
index 3ccbc4e579dc..c365483b58d8 100644
--- a/vllm/model_executor/layers/quantized_linear/squeezellm.py
+++ b/vllm/model_executor/layers/quantized_linear/squeezellm.py
@@ -39,9 +39,15 @@ def apply_weights(
out_shape = x.shape[:-1] + (self.qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
# NOTE: The output tensor should be zero-initialized.
- out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
- quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out,
- self.lookup_table)
+ if torch.version.hip:
+ out_float = torch.zeros(out_shape, device="cuda", dtype=torch.float)
+ quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out_float,
+ self.lookup_table)
+ out = out_float.to(dtype=torch.float16)
+ else:
+ out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
+ quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out,
+ self.lookup_table)
if bias is not None:
out = out + bias
@@ -78,7 +84,13 @@ def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (self.qweight.shape[-1], )
# NOTE: The output tensor should be zero-initialized.
- out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
- quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out,
- self.lookup_table)
- return out.reshape(out_shape)
+ if torch.version.hip:
+ out = torch.zeros(out_shape, device="cuda", dtype=torch.float)
+ quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out,
+ self.lookup_table)
+ return out.to(dtype=torch.float16).reshape(out_shape)
+ else:
+ out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
+ quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out,
+ self.lookup_table)
+ return out.reshape(out_shape)
\ No newline at end of file
diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py
index 735e4ad17218..b957aab3f576 100644
--- a/vllm/model_executor/models/llama.py
+++ b/vllm/model_executor/models/llama.py
@@ -36,10 +36,11 @@
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
-from vllm.model_executor.layers.quantized_linear import ParallelLinear
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
-from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding
+from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
+ ColumnParallelLinear, RowParallelLinear)
+from vllm.model_executor.layers.quantized_linear import ParallelLinear
from vllm.model_executor.quantization_utils import QuantizationConfig
from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor, hf_model_weights_iterator,
@@ -350,6 +351,8 @@ def load_weights(self,
]
state_dict = self.state_dict()
+ load_format = load_format if load_format != "auto" else "pt"
+
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py
index 8b09276e6f91..675e54ed675b 100644
--- a/vllm/model_executor/models/mistral.py
+++ b/vllm/model_executor/models/mistral.py
@@ -36,10 +36,10 @@
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
-from vllm.model_executor.layers.quantized_linear import ParallelLinear
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
-from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding
+from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
+ ColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.quantization_utils import QuantizationConfig
from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor, hf_model_weights_iterator,
@@ -59,12 +59,12 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
- self.gate_up_proj = ParallelLinear.column(hidden_size,
+ self.gate_up_proj = ColumnParallelLinear(hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
quant_config=quant_config)
- self.down_proj = ParallelLinear.row(intermediate_size,
+ self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
@@ -107,7 +107,7 @@ def __init__(self,
self.rope_theta = rope_theta
self.sliding_window = sliding_window
- self.qkv_proj = ParallelLinear.column(
+ self.qkv_proj = ColumnParallelLinear(
hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim,
@@ -115,7 +115,7 @@ def __init__(self,
gather_output=False,
quant_config=quant_config,
)
- self.o_proj = ParallelLinear.row(
+ self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
@@ -268,7 +268,7 @@ def __init__(
self.model = MistralModel(config, quant_config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
# NOTE: The LM head is not quantized.
- self.lm_head = ParallelLinear.column(config.hidden_size,
+ self.lm_head = ColumnParallelLinear(config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
diff --git a/vllm/model_executor/quantization_utils/__init__.py b/vllm/model_executor/quantization_utils/__init__.py
index 345f6494bf83..d3ee39131e49 100644
--- a/vllm/model_executor/quantization_utils/__init__.py
+++ b/vllm/model_executor/quantization_utils/__init__.py
@@ -1,11 +1,11 @@
from typing import Type
-from vllm.model_executor.quantization_utils.awq import AWQConfig
+# from vllm.model_executor.quantization_utils.awq import AWQConfig
from vllm.model_executor.quantization_utils.base import QuantizationConfig
from vllm.model_executor.quantization_utils.squeezellm import SqueezeLLMConfig
_QUANTIZATION_REGISTRY = {
- "awq": AWQConfig,
+ # "awq": AWQConfig,
"squeezellm": SqueezeLLMConfig,
}
diff --git a/vllm/utils.py b/vllm/utils.py
index 0e17e9070489..8ec33d2a21ac 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -30,7 +30,10 @@ def reset(self) -> None:
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes."""
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
- cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 # pylint: disable=invalid-name
+ if torch.version.hip:
+ cudaDevAttrMaxSharedMemoryPerBlockOptin = 74
+ else:
+ cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 # pylint: disable=invalid-name
max_shared_mem = cuda_utils.get_device_attribute(
cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu)
return int(max_shared_mem)
diff --git a/vllm/xformers/LICENSE b/vllm/xformers/LICENSE
new file mode 100644
index 000000000000..f42f840d1027
--- /dev/null
+++ b/vllm/xformers/LICENSE
@@ -0,0 +1,35 @@
+From xFormers:
+
+Copyright (c) Facebook, Inc. and its affiliates
+
+
+===
+
+BSD 3-Clause License
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
+ and IDIAP Research Institute nor the names of its contributors may be
+ used to endorse or promote products derived from this software without
+ specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
diff --git a/vllm/xformers/__init__.py b/vllm/xformers/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/vllm/xformers/_cpp_lib.py b/vllm/xformers/_cpp_lib.py
new file mode 100644
index 000000000000..4eb6fd98140e
--- /dev/null
+++ b/vllm/xformers/_cpp_lib.py
@@ -0,0 +1,144 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+
+import dataclasses
+import json
+import logging
+import os
+import platform
+from typing import Any, Dict, Optional
+
+import torch
+
+logger = logging.getLogger("xformers")
+
+UNAVAILABLE_FEATURES_MSG = (
+ " Memory-efficient attention, SwiGLU, sparse and more won't be available."
+)
+
+
+@dataclasses.dataclass
+class _BuildInfo:
+ metadata: Dict[str, Any]
+
+ @property
+ def cuda_version(self) -> Optional[int]:
+ return self.metadata["version"]["cuda"]
+
+ @property
+ def torch_version(self) -> str:
+ return self.metadata["version"]["torch"]
+
+ @property
+ def python_version(self) -> str:
+ return self.metadata["version"]["python"]
+
+ @property
+ def flash_version(self) -> str:
+ return self.metadata["version"].get("flash", "0.0.0")
+
+ @property
+ def build_env(self) -> Dict[str, Any]:
+ return self.metadata["env"]
+
+
+class xFormersWasNotBuiltException(Exception):
+ def __str__(self) -> str:
+ return (
+ "Need to compile C++ extensions to use all xFormers features.\n"
+ " Please install xformers properly "
+ "(see https://github.com/facebookresearch/xformers#installing-xformers)\n"
+ + UNAVAILABLE_FEATURES_MSG
+ )
+
+
+class xFormersInvalidLibException(Exception):
+ def __init__(self, build_info: Optional[_BuildInfo]) -> None:
+ self.build_info = build_info
+
+ def __str__(self) -> str:
+ if self.build_info is None:
+ msg = "xFormers was built for a different version of PyTorch or Python."
+ else:
+ msg = f"""xFormers was built for:
+ PyTorch {self.build_info.torch_version} with CUDA {self.build_info.cuda_version} (you have {torch.__version__})
+ Python {self.build_info.python_version} (you have {platform.python_version()})"""
+ return (
+ "xFormers can't load C++/CUDA extensions. "
+ + msg
+ + "\n Please reinstall xformers "
+ "(see https://github.com/facebookresearch/xformers#installing-xformers)\n"
+ + UNAVAILABLE_FEATURES_MSG
+ )
+
+
+def _register_extensions():
+ import importlib
+ import os
+
+ import torch
+
+ # load the custom_op_library and register the custom ops
+ lib_dir = os.path.dirname(__file__)
+ if os.name == "nt":
+ # Register the main torchvision library location on the default DLL path
+ import ctypes
+ import sys
+
+ kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
+ with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
+ prev_error_mode = kernel32.SetErrorMode(0x0001)
+
+ if with_load_library_flags:
+ kernel32.AddDllDirectory.restype = ctypes.c_void_p
+
+ if sys.version_info >= (3, 8):
+ os.add_dll_directory(lib_dir)
+ elif with_load_library_flags:
+ res = kernel32.AddDllDirectory(lib_dir)
+ if res is None:
+ err = ctypes.WinError(ctypes.get_last_error())
+ err.strerror += f' Error adding "{lib_dir}" to the DLL directories.'
+ raise err
+
+ kernel32.SetErrorMode(prev_error_mode)
+
+ loader_details = (
+ importlib.machinery.ExtensionFileLoader,
+ importlib.machinery.EXTENSION_SUFFIXES,
+ )
+
+ extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
+ ext_specs = extfinder.find_spec("_C")
+ if ext_specs is None:
+ raise xFormersWasNotBuiltException()
+ cpp_lib_json = os.path.join(lib_dir, "cpp_lib.json")
+ with open(cpp_lib_json, "r") as fp:
+ build_metadata = _BuildInfo(json.load(fp))
+ try:
+ torch.ops.load_library(ext_specs.origin)
+ except OSError as exc:
+ raise xFormersInvalidLibException(build_metadata) from exc
+ return build_metadata
+
+
+_cpp_library_load_exception = None
+_build_metadata: Optional[_BuildInfo] = None
+
+try:
+ _build_metadata = _register_extensions()
+except (xFormersInvalidLibException, xFormersWasNotBuiltException) as e:
+ ENV_VAR_FOR_DETAILS = "XFORMERS_MORE_DETAILS"
+ if os.environ.get(ENV_VAR_FOR_DETAILS, False):
+ logger.warning(f"WARNING[XFORMERS]: {e}", exc_info=e)
+ else:
+ logger.warning(
+ f"WARNING[XFORMERS]: {e}\n Set {ENV_VAR_FOR_DETAILS}=1 for more details"
+ )
+ _cpp_library_load_exception = e
+
+_built_with_cuda = (
+ _build_metadata is not None and _build_metadata.cuda_version is not None
+)
diff --git a/vllm/xformers/ops/__init__.py b/vllm/xformers/ops/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/vllm/xformers/ops/common.py b/vllm/xformers/ops/common.py
new file mode 100644
index 000000000000..e20cedabfab6
--- /dev/null
+++ b/vllm/xformers/ops/common.py
@@ -0,0 +1,23 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Any
+
+class BaseOperator:
+ OPERATOR: Any
+ NAME: str
+ OPERATOR_CATEGORY: str
+
+ @classmethod
+ def is_available(cls) -> bool:
+ if cls.OPERATOR is None or cls.OPERATOR.__name__ == "no_such_operator":
+ return False
+ return True
+
+ @classmethod
+ def operator_flop(cls, *inputs) -> int:
+ """Calculate number of FLOP given inputs to `OPERATOR`"""
+ return -1
+
\ No newline at end of file
diff --git a/vllm/xformers/ops/fmha/__init__.py b/vllm/xformers/ops/fmha/__init__.py
new file mode 100644
index 000000000000..449b026ce571
--- /dev/null
+++ b/vllm/xformers/ops/fmha/__init__.py
@@ -0,0 +1,51 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Any, Optional, Tuple, Type, Union
+
+import torch
+
+#from . import cutlass, decoder, flash, small_k, triton
+from .attn_bias import AttentionBias
+from .common import (
+ AttentionFwOpBase,
+ Inputs,
+)
+from .dispatch import _dispatch_fw, _ensure_op_supports_or_raise
+
+
+def memory_efficient_attention_forward(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
+ p: float = 0.0,
+ scale: Optional[float] = None,
+ *,
+ op: Optional[Type[AttentionFwOpBase]] = None,
+) -> torch.Tensor:
+ """
+ Calculates the forward pass of :attr:`xformers.ops.memory_efficient_attention`.
+ """
+ return _memory_efficient_attention_forward(
+ Inputs(
+ query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale
+ ),
+ op=op,
+ )
+
+def _memory_efficient_attention_forward(
+ inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
+) -> torch.Tensor:
+ inp.validate_inputs()
+ output_shape = inp.normalize_bmhk()
+ if op is None:
+ op = _dispatch_fw(inp, False)
+ else:
+ _ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
+
+ out, *_ = op.apply(inp, needs_gradient=False)
+ return out.reshape(output_shape)
+
diff --git a/vllm/xformers/ops/fmha/attn_bias.py b/vllm/xformers/ops/fmha/attn_bias.py
new file mode 100644
index 000000000000..4af444b92a90
--- /dev/null
+++ b/vllm/xformers/ops/fmha/attn_bias.py
@@ -0,0 +1,778 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import math
+from dataclasses import dataclass
+from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union
+
+import torch
+
+class AttentionBias:
+ """Base class for a custom bias that can be applied \
+ as the attn_bias argument in
+ :attr:`xformers.ops.memory_efficient_attention`.
+
+ That function has the ability to add a tensor, the
+ attention bias, to the QK^T matrix before it is used
+ in the softmax part of the attention calculation.
+ The attention bias tensor with shape
+ (B or 1, n_queries, number of keys)
+ can be given as the attn_bias input.
+ The most common use case is for an attention bias is
+ to contain only zeros and negative infinities, which forms
+ a mask so that some queries only attend to some keys.
+
+ Children of this class define alternative things which can
+ be used as the attn_bias input to define an attention bias which
+ forms such a mask, for some common cases.
+
+ When using an :attr:`xformers.ops.AttentionBias`
+ instead of a :attr:`torch.Tensor`, the mask matrix does
+ not need to be materialized, and can be
+ hardcoded into some kernels for better performance.
+
+ See:
+
+ - :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMask`
+ - :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias`
+ - :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`
+ - :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`
+
+ """
+
+ def materialize(
+ self,
+ shape: Tuple[int, ...],
+ dtype: torch.dtype = torch.float32,
+ device: Union[str, torch.device] = "cpu",
+ ) -> torch.Tensor:
+ """
+ Materializes the bias as a `torch.Tensor`. This is very slow
+ and we don't attempt to make it fast. Only use for debugging/testing.
+
+ Shape should be like `[*, q_seqlen, k_seqlen]`
+ """
+ raise NotImplementedError()
+
+
+class LowerTriangularMask(AttentionBias):
+ """
+ A lower-triangular (aka causal) mask
+
+ A query Q cannot attend to a key which is farther from the
+ initial key than Q is from the initial query.
+ """
+
+ def __init__(self, *tensor_args, **tensor_kwargs) -> None:
+ # NOTE: Unused arguments, we keep them for backward compatibility
+ super().__init__()
+
+ def materialize(
+ self,
+ shape: Tuple[int, ...],
+ dtype: torch.dtype = torch.float32,
+ device: Union[str, torch.device] = "cpu",
+ ) -> torch.Tensor:
+ create_as = dtype if dtype is not torch.bfloat16 else torch.float32
+ tensor = torch.full( # type: ignore
+ shape,
+ dtype=create_as,
+ fill_value=float("-inf"),
+ device=device,
+ )
+ return torch.triu(tensor, diagonal=1).to(dtype) # type: ignore
+
+ def add_bias(self, bias: torch.Tensor) -> "LowerTriangularMaskWithTensorBias":
+ return LowerTriangularMaskWithTensorBias(bias)
+
+
+class LowerTriangularMaskWithTensorBias(LowerTriangularMask):
+ """A lower-triangular (aka causal) mask with an additive bias"""
+
+ def __init__(self, bias: torch.Tensor) -> None:
+ self._bias = bias
+
+ def materialize(
+ self,
+ shape: Tuple[int, ...],
+ dtype: torch.dtype = torch.float32,
+ device: Union[str, torch.device] = "cpu",
+ ) -> torch.Tensor:
+ return super().materialize(shape, dtype=dtype, device=device) + self._bias
+
+
+@dataclass
+class _SeqLenInfo:
+ """
+ (Internal) Represents the division of a dimension into blocks.
+
+ For example, to represents a dimension of length 7 divided into
+ three blocks of lengths 2, 3 and 2, use `from_seqlength([2, 3, 2])`.
+ The members will be:
+ max_seqlen: 3
+ min_seqlen: 2
+ seqstart_py: [0, 2, 5, 7]
+ seqstart: torch.IntTensor([0, 2, 5, 7])
+ """
+
+ seqstart: torch.Tensor
+ max_seqlen: int
+ min_seqlen: int
+ seqstart_py: List[int]
+
+ def to(self, device: torch.device) -> None:
+ self.seqstart = self.seqstart.to(device, non_blocking=True)
+
+ def intervals(self) -> Iterable[Tuple[int, int]]:
+ yield from zip(self.seqstart_py, self.seqstart_py[1:])
+
+ @classmethod
+ def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo":
+ """
+ Input tensors are assumed to be in shape [B, M, *]
+ """
+ assert not isinstance(seqlens, torch.Tensor)
+ seqstart_py = [0]
+ max_seqlen = -1
+ min_seqlen = -1
+ for seqlen in seqlens:
+ min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen
+ max_seqlen = max(max_seqlen, seqlen)
+ seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen)
+ seqstart = torch.tensor(seqstart_py, dtype=torch.int32)
+ return cls(
+ max_seqlen=max_seqlen,
+ min_seqlen=min_seqlen,
+ seqstart=seqstart,
+ seqstart_py=seqstart_py,
+ )
+
+ def split(
+ self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None
+ ) -> List[torch.Tensor]:
+ if self.seqstart_py[-1] != x.shape[1] or x.shape[0] != 1:
+ raise ValueError(
+ f"Invalid `torch.Tensor` of shape {x.shape}, expected format "
+ f"(B, M, *) with B=1 and M={self.seqstart_py[-1]}\n"
+ f" seqstart: {self.seqstart_py}"
+ )
+ if batch_sizes is None:
+ batch_sizes = [1] * (len(self.seqstart_py) - 1)
+ split_chunks = []
+ it = 0
+ for batch_size in batch_sizes:
+ split_chunks.append(
+ self.seqstart_py[it + batch_size] - self.seqstart_py[it]
+ )
+ it += batch_size
+ return [
+ tensor.reshape([bs, -1, *tensor.shape[2:]])
+ for bs, tensor in zip(batch_sizes, x.split(split_chunks, dim=1))
+ ]
+
+
+@dataclass
+class _PaddedSeqLenInfo(_SeqLenInfo):
+ """
+ (Internal) Represents the division of a dimension into blocks which are
+ padded out to the same total length.
+
+ For example, to represent a dimension of length 12 with space for
+ three blocks of length 4, but where the occupied lengths are
+ 2, 3 and 2, use `from_seqlens_padded([2, 3, 2], 4)`.
+
+ The layout along the dimension is
+
+ 0 ─► block 0
+ block 0
+
+
+ 4 ─► block 1
+ block 1
+ block 1
+
+ 8 ─► block 2
+ block 2
+
+
+ 12 ─►
+
+ The members will be:
+ max_seqlen: 3
+ min_seqlen: 2
+ seqstart_py: [0, 4, 8, 12]
+ seqstart: torch.IntTensor([0, 4, 8, 12])
+ seqlen_py: [2, 3, 2]
+ seqlen: torch.IntTensor([2, 3, 2])
+ padding: 4
+ """
+
+ seqlen: torch.Tensor
+ seqlen_py: Sequence[int]
+ padding: int
+ # From parent: seqstart[i] contains the start position
+ # of the i-th sequence
+ # seqstart: torch.Tensor
+
+ def __post_init__(self) -> None:
+ assert len(self.seqstart_py) == len(self.seqlen_py) + 1
+
+ def to(self, device: torch.device) -> None:
+ self.seqlen = self.seqlen.to(device, non_blocking=True)
+ super().to(device)
+
+ def intervals(self) -> Iterable[Tuple[int, int]]:
+ for (start, _), length in zip(super().intervals(), self.seqlen_py):
+ yield start, start + length
+
+ @classmethod
+ def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo":
+ raise RuntimeError(
+ "Use either `_SeqLenInfo.from_seqlens` or `_PaddedSeqLenInfo.from_seqlens_padded`"
+ )
+
+ @classmethod
+ def from_seqlens_padded(
+ cls, seqlens: Sequence[int], padding: int
+ ) -> "_PaddedSeqLenInfo":
+ """
+ Input tensors are assumed to be in shape [B, M, *]
+ seqstart = padding * torch.arange(batch_size)
+ """
+ assert not isinstance(seqlens, torch.Tensor)
+ assert all(seqlen <= padding for seqlen in seqlens)
+ seqstart_py = list(range(0, len(seqlens) * padding + 1, padding))
+ return cls(
+ seqlen=torch.tensor(seqlens, dtype=torch.int32),
+ seqlen_py=seqlens,
+ max_seqlen=max(seqlens),
+ min_seqlen=min(seqlens),
+ seqstart=torch.tensor(seqstart_py, dtype=torch.int32),
+ seqstart_py=seqstart_py,
+ padding=padding,
+ )
+
+ def split(
+ self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None
+ ) -> List[torch.Tensor]:
+ raise NotImplementedError("_PaddedSeqLenInfo.split")
+
+
+@dataclass
+class BlockDiagonalMask(AttentionBias):
+ """
+ A block-diagonal mask that can be passed as ``attn_bias``
+ argument to :attr:`xformers.ops.memory_efficient_attention`.
+
+ Queries and Keys are each divided into the same number of blocks.
+ Queries in block i only attend to keys in block i.
+
+ .. figure:: /_static/block_diag_bias.png
+
+ This bias can be used to handle a batch of sequences of
+ different lengths, via :attr:`BlockDiagonalMask.from_tensor_list`
+
+ :Example:
+
+ .. code-block:: python
+
+ import torch
+ from xformers.ops import fmha
+
+ K = 16
+ dtype = torch.float16
+ device = "cuda"
+ list_x = [
+ torch.randn([1, 3, 1, K], dtype=dtype, device=device),
+ torch.randn([1, 6, 1, K], dtype=dtype, device=device),
+ torch.randn([1, 2, 1, K], dtype=dtype, device=device),
+ ]
+ attn_bias, x = fmha.BlockDiagonalMask.from_tensor_list(list_x)
+ linear = torch.nn.Linear(K, K * 3).to(device=device, dtype=dtype)
+
+ q, k, v = linear(x).reshape([1, -1, 1, 3, K]).unbind(-2)
+ out = fmha.memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ list_out = attn_bias.split(out)
+ print(list_out[0].shape) # [1, 3, 1, K]
+ assert tuple(list_out[0].shape) == (1, 3, 1, K)
+
+ """
+
+ q_seqinfo: _SeqLenInfo
+ k_seqinfo: _SeqLenInfo
+ _batch_sizes: Optional[Sequence[int]] = None
+
+ def _create_block_mask(
+ self,
+ shape: Tuple[int, ...],
+ dtype: torch.dtype = torch.float32,
+ device: Union[str, torch.device] = "cpu",
+ ) -> torch.Tensor:
+ return torch.zeros(
+ shape,
+ dtype=dtype,
+ device=device,
+ )
+
+ def materialize(
+ self,
+ shape: Tuple[int, ...],
+ dtype: torch.dtype = torch.float32,
+ device: Union[str, torch.device] = "cpu",
+ ) -> torch.Tensor:
+ """Materialize the attention bias - for debugging & testing"""
+ assert shape[-1] == self.k_seqinfo.seqstart_py[-1], (
+ shape[-1],
+ self.k_seqinfo.seqstart_py[-1],
+ )
+ assert shape[-2] == self.q_seqinfo.seqstart_py[-1], (
+ shape[-2],
+ self.q_seqinfo.seqstart_py[-1],
+ )
+ mask = torch.empty(shape[-2:], dtype=dtype, device=device)
+ mask.fill_(-math.inf)
+ for i, ((q_start, q_end), (k_start, k_end)) in enumerate(
+ zip(
+ self.q_seqinfo.intervals(),
+ self.k_seqinfo.intervals(),
+ )
+ ):
+ mask[q_start:q_end, k_start:k_end] = self._create_block_mask(
+ (q_end - q_start, k_end - k_start),
+ dtype=dtype,
+ device=device,
+ )
+ for _ in range(len(shape) - 2):
+ mask = mask.unsqueeze(0)
+ return mask.expand(shape)
+
+ @classmethod
+ def from_seqlens(
+ cls,
+ q_seqlen: Sequence[int],
+ kv_seqlen: Optional[Sequence[int]] = None,
+ ) -> "BlockDiagonalMask":
+ """Creates a :attr:`BlockDiagonalMask` from a list of tensors lengths for query and key/value.
+
+ Args:
+ q_seqlen (Union[Sequence[int], torch.Tensor]): List or tensor of sequence lengths for query tensors
+ kv_seqlen (Union[Sequence[int], torch.Tensor], optional): List or tensor of sequence lengths for key/value.
+ (Defaults to ``q_seqlen``.)
+ Returns:
+ BlockDiagonalMask
+ """
+ assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen)
+ q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen)
+ if kv_seqlen is None or q_seqlen == kv_seqlen:
+ k_seqinfo = q_seqinfo
+ else:
+ k_seqinfo = _SeqLenInfo.from_seqlens(kv_seqlen)
+ return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo)
+
+ @classmethod
+ def from_tensor_list(
+ cls,
+ tensors: Sequence[torch.Tensor],
+ ) -> Tuple["BlockDiagonalMask", torch.Tensor]:
+ """Creates a :attr:`BlockDiagonalMask` from a list of tensors, and returns the tensors
+ concatenated on the sequence length dimension
+
+ .. figure:: /_static/block_diag_cat_split.png
+
+ See also :attr:`BlockDiagonalMask.split` to split the returned
+ :attr:`torch.Tensor` back to a list of tensors of varying sequence length
+
+ Args:
+ tensors (Sequence[torch.Tensor]): A list of tensors of shape ``[B, M_i, *]``.
+ All tensors should have the same dimension and the same batch size ``B``, but
+ they can have different sequence length ``M``.
+
+ Returns:
+ Tuple[BlockDiagonalMask, torch.Tensor]: The corresponding bias for the attention
+ along with `tensors` concatenated on the sequence length dimension, with shape ``[1, sum_i{M_i}, *]``
+ """
+ batch_sizes = [tensor.shape[0] for tensor in tensors]
+ seqlens = []
+ for x in tensors:
+ for _ in range(x.shape[0]):
+ seqlens.append(x.shape[1])
+ block_diag = cls.from_seqlens(seqlens)
+ block_diag._batch_sizes = batch_sizes
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in tensors)
+ concat_tensors = torch.cat(tensors_bs1, dim=1)
+ return block_diag, concat_tensors
+
+ @classmethod
+ def from_tensor_lists_qkv(
+ cls,
+ tensors_q: Sequence[torch.Tensor],
+ tensors_k: Sequence[torch.Tensor],
+ tensors_v: Optional[Sequence[torch.Tensor]] = None,
+ ) -> Tuple["BlockDiagonalMask", torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ assert len(tensors_q) == len(tensors_k)
+ assert tensors_v is None or len(tensors_v) == len(tensors_q)
+ batch_sizes = [tensor.shape[0] for tensor in tensors_q]
+ q_seqlens, kv_seqlens = [], []
+ for i, (q, k) in enumerate(zip(tensors_q, tensors_k)):
+ assert q.shape[0] == k.shape[0]
+ q_seqlens += [q.shape[1]] * q.shape[0]
+ kv_seqlens += [k.shape[1]] * k.shape[0]
+ assert tensors_v is None or tensors_v[i].shape[:2] == k.shape[:2]
+ block_diag = cls.from_seqlens(q_seqlens, kv_seqlens)
+ block_diag._batch_sizes = batch_sizes
+ return (
+ block_diag,
+ torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_q], dim=1),
+ torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_k], dim=1),
+ torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_v], dim=1)
+ if tensors_v is not None
+ else None,
+ )
+
+ def split_queries(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
+ return self.q_seqinfo.split(tensor, self._batch_sizes)
+
+ def split_kv(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
+ return self.k_seqinfo.split(tensor, self._batch_sizes)
+
+ def split(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
+ """The inverse operation of :attr:`BlockDiagonalCausalMask.from_tensor_list`
+
+ Args:
+ tensor (torch.Tensor): Tensor of tokens of shape ``[1, sum_i{M_i}, *]``
+
+ Returns:
+ Sequence[torch.Tensor]: A list of tokens with possibly different sequence lengths
+ """
+ assert self.q_seqinfo is self.k_seqinfo
+ return self.q_seqinfo.split(tensor, self._batch_sizes)
+
+ def make_causal(self) -> "BlockDiagonalCausalMask":
+ """Makes each block causal"""
+ return BlockDiagonalCausalMask(
+ q_seqinfo=self.q_seqinfo,
+ k_seqinfo=self.k_seqinfo,
+ _batch_sizes=self._batch_sizes,
+ )
+
+ def make_causal_from_bottomright(self) -> "BlockDiagonalCausalFromBottomRightMask":
+ """Makes each block causal with a possible non-causal prefix"""
+ return BlockDiagonalCausalFromBottomRightMask(
+ q_seqinfo=self.q_seqinfo,
+ k_seqinfo=self.k_seqinfo,
+ _batch_sizes=self._batch_sizes,
+ )
+
+ def make_local_attention(
+ self, window_size: int
+ ) -> "BlockDiagonalCausalLocalAttentionMask":
+ """Experimental: Makes each block causal with local attention"""
+ return BlockDiagonalCausalLocalAttentionMask(
+ q_seqinfo=self.q_seqinfo,
+ k_seqinfo=self.k_seqinfo,
+ _batch_sizes=self._batch_sizes,
+ _window_size=window_size,
+ )
+
+ def make_local_attention_from_bottomright(
+ self, window_size: int
+ ) -> "BlockDiagonalCausalLocalAttentionFromBottomRightMask":
+ """Experimental: Makes each block causal with local attention, start from bottom right"""
+ return BlockDiagonalCausalLocalAttentionFromBottomRightMask(
+ q_seqinfo=self.q_seqinfo,
+ k_seqinfo=self.k_seqinfo,
+ _batch_sizes=self._batch_sizes,
+ _window_size=window_size,
+ )
+
+
+@dataclass
+class BlockDiagonalCausalMask(BlockDiagonalMask):
+ """
+ Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal.
+
+ Queries and Keys are each divided into the same number of blocks.
+ A query Q in block i cannot attend to a key which is not in block i,
+ nor one which is farther from the initial key in block i than Q
+ is from the initial query in block i.
+ """
+
+ def _create_block_mask(
+ self,
+ shape: Tuple[int, ...],
+ dtype: torch.dtype = torch.float32,
+ device: Union[str, torch.device] = "cpu",
+ ) -> torch.Tensor:
+ return LowerTriangularMask().materialize(
+ shape,
+ dtype=dtype,
+ device=device,
+ )
+
+
+@dataclass
+class BlockDiagonalCausalFromBottomRightMask(BlockDiagonalMask):
+ """
+ Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal.
+ This mask allows for a non-causal prefix
+ NOTE: Each block should have `num_keys >= num_queries` otherwise the forward pass is not
+ defined (softmax of vector of `-inf` in the attention)
+
+ Queries and keys are each divided into the same number of blocks.
+ A query Q in block i cannot attend to a key which is not in block i,
+ nor one which nearer the final key in block i than Q is to the
+ final query in block i.
+ """
+
+ def __post_init__(self) -> None:
+ for i, ((q_start, q_end), (k_start, k_end)) in enumerate(
+ zip(
+ self.q_seqinfo.intervals(),
+ self.k_seqinfo.intervals(),
+ )
+ ):
+ num_queries = q_end - q_start
+ num_keys = k_end - k_start
+ if num_keys < num_queries:
+ raise ValueError(
+ f"Block #{i} has num_keys={num_keys} and num_queries={num_queries}."
+ " Expected `num_keys >= num_queries`"
+ )
+
+ def _create_block_mask(
+ self,
+ shape: Tuple[int, ...],
+ dtype: torch.dtype = torch.float32,
+ device: Union[str, torch.device] = "cpu",
+ ) -> torch.Tensor:
+ create_as = dtype if dtype is not torch.bfloat16 else torch.float32
+ tensor = torch.full( # type: ignore
+ shape,
+ dtype=create_as,
+ fill_value=float("-inf"),
+ device=device,
+ )
+ num_queries, num_keys = shape[-2:]
+ return torch.triu(tensor, diagonal=num_keys - num_queries + 1).to(dtype) # type: ignore
+
+
+@dataclass
+class BlockDiagonalCausalWithOffsetPaddedKeysMask(AttentionBias):
+ """
+ Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`,
+ except an offset on causality is allowed for each block and we support padding for k/v
+
+ The keys and values are divided into blocks which are padded out to
+ the same total length.
+ For example, if there is space for 12 keys, for three blocks of
+ max length 4, but we only want to use the first 2, 3 and 2
+ of each block, use `kv_padding=4` and `kv_seqlens=[2, 3, 2]`.
+ The queries are divided into blocks, without padding, of lengths given by
+ q_seqlen.
+
+ A query Q in block i cannot attend to a key which is not in block i,
+ nor one which is not in use (i.e. in the padded area),
+ nor one which is nearer to the final key in block i
+ than Q is to the final query in block i.
+ """
+
+ q_seqinfo: _SeqLenInfo
+ k_seqinfo: _PaddedSeqLenInfo
+ causal_diagonal: Any = None # unused. Exists for BC only.
+
+ def _create_block_mask(
+ self,
+ shape: Tuple[int, ...],
+ dtype: torch.dtype = torch.float32,
+ device: Union[str, torch.device] = "cpu",
+ ) -> torch.Tensor:
+ create_as = dtype if dtype is not torch.bfloat16 else torch.float32
+ tensor = torch.full( # type: ignore
+ shape,
+ dtype=create_as,
+ fill_value=float("-inf"),
+ device=device,
+ )
+ num_queries, num_keys = shape[-2:]
+ return torch.triu(tensor, diagonal=1 + num_keys - num_queries).to(dtype) # type: ignore
+
+ def materialize(
+ self,
+ shape: Tuple[int, ...],
+ dtype: torch.dtype = torch.float32,
+ device: Union[str, torch.device] = "cpu",
+ ) -> torch.Tensor:
+ """Materialize the attention bias - for debugging & testing"""
+ if shape[-1] != self.k_seqinfo.seqstart_py[-1]:
+ raise ValueError("k shapes wrong")
+ if shape[-2] != self.q_seqinfo.seqstart_py[-1]:
+ raise ValueError("q shapes wrong")
+ mask = torch.empty(shape[-2:], dtype=dtype, device=device)
+ mask.fill_(-math.inf)
+ for i, ((q_start, q_end), (k_start, k_end)) in enumerate(
+ zip(
+ self.q_seqinfo.intervals(),
+ self.k_seqinfo.intervals(),
+ )
+ ):
+ mask[q_start:q_end, k_start:k_end] = self._create_block_mask(
+ (q_end - q_start, k_end - k_start),
+ dtype=dtype,
+ device=device,
+ )
+ for _ in range(len(shape) - 2):
+ mask = mask.unsqueeze(0)
+ return mask.expand(shape)
+
+ @classmethod
+ def from_seqlens(
+ cls,
+ q_seqlen: Sequence[int],
+ kv_padding: int,
+ kv_seqlen: Sequence[int],
+ causal_diagonal: Any = None,
+ ) -> "BlockDiagonalCausalWithOffsetPaddedKeysMask":
+ """Creates a :attr:`BlockDiagonalCausalWithOffsetPaddedKeysMask` from a list of tensor
+ lengths for query and key/value.
+
+ Args:
+ q_seqlen (Sequence[int]): List or tensor of sequence lengths for query tensors
+ kv_padding (int): Padding for k/v - also an upperbound on each individual key length
+ kv_seqlen (Sequence[int]): List or tensor of sequence lengths for key/value.
+ causal_diagonal: unused, for BC only
+ Returns:
+ BlockDiagonalCausalWithOffsetPaddedKeysMask
+ """
+ assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen), (
+ q_seqlen,
+ kv_seqlen,
+ )
+ q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen)
+ k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded(kv_seqlen, kv_padding)
+ return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo)
+
+
+@dataclass
+class BlockDiagonalCausalLocalAttentionMask(BlockDiagonalCausalMask):
+ """
+ (Experimental feature)
+ Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`.
+ This makes the mask "local" and the attention pattern banded.
+
+ Query i only attends to keys in its block and cannot attend keys further than "window_size"
+ from it.
+ """
+
+ _window_size: int = 0 # forced due to inheritance and default arguments
+
+ def __post_init__(self):
+ if self._window_size <= 0:
+ raise ValueError(
+ f"Expected `window_size > 0`, but window_size={self._window_size}"
+ )
+ q_seqlen = [
+ y - x
+ for x, y in zip(
+ self.q_seqinfo.seqstart_py[:-1], self.q_seqinfo.seqstart_py[1:]
+ )
+ ]
+ kv_seqlen = [
+ y - x
+ for x, y in zip(
+ self.k_seqinfo.seqstart_py[:-1], self.k_seqinfo.seqstart_py[1:]
+ )
+ ]
+ for q, k in zip(q_seqlen, kv_seqlen):
+ if q - self._window_size >= k:
+ raise RuntimeError(
+ f"No keys are attended in q_seqlen {q} k_seqlen {k} with sliding window {self._window_size}"
+ )
+
+ def _create_block_mask(
+ self,
+ shape: Tuple[int, ...],
+ dtype: torch.dtype = torch.float32,
+ device: Union[str, torch.device] = "cpu",
+ ) -> torch.Tensor:
+ create_as = dtype if dtype is not torch.bfloat16 else torch.float32
+ tensor = torch.full( # type: ignore
+ shape,
+ dtype=create_as,
+ fill_value=1,
+ device=device,
+ )
+
+ num_queries, num_keys = shape[-2:]
+ mask = torch.tril(tensor, diagonal=0).to(dtype) # type: ignore
+ if self._window_size is not None and self._window_size > 0:
+ mask = torch.triu(mask, diagonal=-self._window_size + 1)
+ mask = torch.log(mask)
+ return mask.to(dtype)
+
+
+@dataclass
+class BlockDiagonalCausalLocalAttentionFromBottomRightMask(
+ BlockDiagonalCausalFromBottomRightMask
+):
+ """
+ (Experimental feature)
+ Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`.
+ This makes the mask "local" and the attention pattern banded.
+
+ Query i only attends to keys in its block and cannot attend keys further than "window_size"
+ from it.
+ """
+
+ _window_size: int = 0 # forced due to inheritance and default arguments
+
+ def __post_init__(self):
+ super().__post_init__()
+ if self._window_size <= 0:
+ raise ValueError(
+ f"Expected `window_size > 0`, but window_size={self._window_size}"
+ )
+ q_seqlen = [
+ y - x
+ for x, y in zip(
+ self.q_seqinfo.seqstart_py[:-1], self.q_seqinfo.seqstart_py[1:]
+ )
+ ]
+ kv_seqlen = [
+ y - x
+ for x, y in zip(
+ self.k_seqinfo.seqstart_py[:-1], self.k_seqinfo.seqstart_py[1:]
+ )
+ ]
+ for q, k in zip(q_seqlen, kv_seqlen):
+ if q + (q - k) - self._window_size >= k:
+ raise RuntimeError(
+ f"No keys are attended in q_seqlen {q} k_seqlen {k} with sliding window {self._window_size}"
+ )
+ materialized = self.materialize((sum(q_seqlen), sum(kv_seqlen)))
+ if torch.max(materialized, dim=1).values.min() == -float("inf"):
+ raise RuntimeError("FUCKING FUCK FUCK")
+
+ def _create_block_mask(
+ self,
+ shape: Tuple[int, ...],
+ dtype: torch.dtype = torch.float32,
+ device: Union[str, torch.device] = "cpu",
+ ) -> torch.Tensor:
+ create_as = dtype if dtype is not torch.bfloat16 else torch.float32
+ tensor = torch.full( # type: ignore
+ shape,
+ dtype=create_as,
+ fill_value=1,
+ device=device,
+ )
+ num_queries, num_keys = shape[-2:]
+ mask = torch.tril(tensor, diagonal=num_keys - num_queries).to(dtype) # type: ignore
+ if self._window_size is not None:
+ mask = torch.triu(
+ mask, diagonal=num_keys - num_queries - self._window_size + 1
+ )
+ mask = torch.log(mask)
+ return mask.to(dtype)
diff --git a/vllm/xformers/ops/fmha/common.py b/vllm/xformers/ops/fmha/common.py
new file mode 100644
index 000000000000..c87cdd1ba8ec
--- /dev/null
+++ b/vllm/xformers/ops/fmha/common.py
@@ -0,0 +1,527 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from dataclasses import dataclass
+from typing import Any, List, Mapping, Optional, Set, Tuple, Type, Union
+
+import torch
+
+from ..._cpp_lib import _built_with_cuda
+from ..common import BaseOperator
+from .attn_bias import (
+ AttentionBias,
+ BlockDiagonalMask,
+ LowerTriangularMask,
+ LowerTriangularMaskWithTensorBias,
+)
+
+
+def _is_bias_type_supported_in_BMK(attn_bias_type: Any) -> bool:
+ # NoneType
+ if isinstance(None, attn_bias_type):
+ return True
+ if attn_bias_type in [LowerTriangularMask, torch.Tensor]:
+ return True
+ return False
+
+
+@dataclass
+class Inputs:
+ """
+ Stores inputs to the `memory_efficient_attention` operators
+ """
+
+ query: torch.Tensor
+ key: torch.Tensor
+ value: torch.Tensor
+ attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None
+ p: float = 0.0
+ scale: Optional[float] = None
+
+ @property
+ def device(self) -> torch.device:
+ return self.query.device
+
+ @property
+ def scale_float(self) -> float:
+ return self.query.shape[-1] ** (-0.5) if self.scale is None else self.scale
+
+ def get_qkv_in_bmghk(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if self.query.ndim == 5:
+ return self.query, self.key, self.value
+ if self.query.ndim == 4:
+ return (
+ self.query.unsqueeze(2),
+ self.key.unsqueeze(2),
+ self.value.unsqueeze(2),
+ )
+ if self.value.ndim == 3:
+ return (
+ self.query[:, :, None, None],
+ self.key[:, :, None, None],
+ self.value[:, :, None, None],
+ )
+ assert False
+
+ def normalize_bmhk(self) -> Tuple[int, ...]:
+ if self.query.ndim not in [3, 4, 5]:
+ raise ValueError(
+ f"Invalid shape for query: {self.query.shape}. "
+ "Expected shape [batch, seqlen, head_groups, num_heads_per_group, K]"
+ ", [batch, seqlen, num_heads, K], or [batch, seqlen, K]."
+ )
+ if self.value.dtype == torch.int32:
+ # Quantized K/V case, in which the last dims of Q and K are different.
+ # NB we currently don't have any implementations for quantized KV with
+ # SUPPORTS_DIFFERENT_VALUE_EMBED.
+ output_shape = tuple(self.query.shape)
+ else:
+ output_shape = (self.query.shape[:-1]) + (self.value.shape[-1],)
+ # Convert from legacy format
+ if self.query.ndim == 3:
+ self.query = self.query.unsqueeze(2)
+ self.key = self.key.unsqueeze(2)
+ self.value = self.value.unsqueeze(2)
+ if isinstance(self.attn_bias, torch.Tensor):
+ if self.attn_bias.ndim != 3:
+ raise ValueError(
+ f"Expected BMK format for attn_bias, but got {self.attn_bias.shape}"
+ )
+ self.attn_bias = self.attn_bias.unsqueeze(1)
+ return output_shape
+
+ def validate_inputs(self) -> None:
+ qkv = (self.query, self.key, self.value)
+ if self.query.ndim not in (3, 4, 5) or any(
+ x.ndim != self.query.ndim for x in qkv
+ ):
+ raise ValueError(
+ f"Query/Key/Value should all have BMGHK, BMHK, or BMK shape.\n"
+ f" query.shape: {self.query.shape}\n"
+ f" key.shape : {self.key.shape}\n"
+ f" value.shape: {self.value.shape}"
+ )
+ if any(x.device != self.query.device for x in qkv):
+ raise ValueError("Query/Key/Value should all be on the same device")
+ quantized_dtypes = self.key.dtype == self.value.dtype == torch.int32
+ non_quantized_dtypes = all(x.dtype == self.query.dtype for x in qkv)
+ if not (quantized_dtypes or non_quantized_dtypes):
+ raise ValueError(
+ "Query/Key/Value should either all have the same dtype, or "
+ "(in the quantized case) Key/Value should have dtype torch.int32\n"
+ f" query.dtype: {self.query.dtype}\n"
+ f" key.dtype : {self.key.dtype}\n"
+ f" value.dtype: {self.value.dtype}"
+ )
+ # Biases with tensors attached are meant to be in BMHK format
+ # This would require to permute biases/gradients which can be expensive,
+ # so let's just forbid it - BMK is a legacy format anyway
+ if self.query.ndim == 3 and not _is_bias_type_supported_in_BMK(
+ type(self.attn_bias)
+ ):
+ raise ValueError(
+ f"Please provide inputs in BMHK format rather "
+ f"than BMK when using bias type `{type(self.attn_bias).__name__}`"
+ )
+ attn_bias_t: Optional[torch.Tensor] = None
+ if isinstance(self.attn_bias, torch.Tensor):
+ attn_bias_t = self.attn_bias
+ if isinstance(self.attn_bias, LowerTriangularMaskWithTensorBias):
+ attn_bias_t = self.attn_bias._bias
+ if self.query.ndim == 4 and attn_bias_t is not None:
+ expected_shape = (
+ self.query.shape[0],
+ self.query.shape[2],
+ self.query.shape[1],
+ self.key.shape[1],
+ )
+ if attn_bias_t.shape != expected_shape:
+ raise ValueError(
+ f"Invalid shape for attention bias: {attn_bias_t.shape} (expected {expected_shape})\n"
+ f" query.shape: {self.query.shape}\n"
+ f" key.shape : {self.key.shape}\n"
+ f" value.shape: {self.value.shape}"
+ )
+ if isinstance(self.attn_bias, BlockDiagonalMask):
+ if any(x.shape[0] != 1 for x in qkv):
+ raise ValueError(
+ f"Expected batch_size=1 when using block-diagonal bias\n"
+ f" query.shape: {self.query.shape}\n"
+ f" key.shape : {self.key.shape}\n"
+ f" value.shape: {self.value.shape}"
+ )
+ if self.p < 0.0 or self.p > 1.0:
+ raise ValueError(f"Invalid dropout probability: p={self.p}")
+ # Check that shapes match between inputs
+ B, Mq = self.query.shape[:2]
+ K = self.query.shape[-1]
+ B, Mkv = self.key.shape[:2]
+ Kv = self.value.shape[-1]
+
+ valid_shapes = True
+ if self.query.ndim == 3: # BMK
+ valid_shapes = (
+ self.query.shape == (B, Mq, K)
+ and self.key.shape == (B, Mkv, K)
+ and self.value.shape == (B, Mkv, Kv)
+ )
+ H = self.query.shape[-2]
+ if self.query.ndim == 4: # BMHK
+ quantized_kv_cache = self.value.dtype == torch.int32
+ key_embed_dim = Kv if quantized_kv_cache else K
+ valid_shapes = (
+ self.query.shape == (B, Mq, H, K)
+ and self.key.shape == (B, Mkv, H, key_embed_dim)
+ and self.value.shape == (B, Mkv, H, Kv)
+ )
+ G = self.query.shape[2]
+ if self.query.ndim == 5: # BMNHK
+ valid_shapes = (
+ self.query.shape == (B, Mq, G, H, K)
+ and self.key.shape == (B, Mkv, G, H, K)
+ and self.value.shape == (B, Mkv, G, H, Kv)
+ )
+ if not valid_shapes:
+ raise ValueError(
+ f"Incompatible shapes for attention inputs:\n"
+ f" query.shape: {self.query.shape}\n"
+ f" key.shape : {self.key.shape}\n"
+ f" value.shape: {self.value.shape}\n"
+ "HINT: We don't support broadcasting, please use `expand` "
+ "yourself before calling `memory_efficient_attention` if you need to"
+ )
+
+
+@dataclass
+class Context:
+ lse: torch.Tensor
+ out: torch.Tensor
+ op_bw: Optional[Type["AttentionBwOpBase"]] = None
+ rng_state: Optional[torch.Tensor] = None
+
+ def get_padded_lse(self, pad_to: int, force_pad_inf: bool = False) -> torch.Tensor:
+ pad_amount = (pad_to - (self.lse.shape[2] % pad_to)) % pad_to
+ lse = self.lse
+ if pad_amount > 0:
+ if force_pad_inf:
+ lse = lse[:, :, : self.out.shape[1]]
+ pad_amount = (pad_to - (lse.shape[2] % pad_to)) % pad_to
+ lse = torch.nn.functional.pad(lse, [0, pad_amount], value=math.inf)
+ elif force_pad_inf and self.out.shape[1] != lse.shape[2]:
+ lse[:, :, self.out.shape[1] :].fill_(math.inf)
+ return lse
+
+
+@dataclass
+class Gradients:
+ dq: torch.Tensor
+ dk: torch.Tensor
+ dv: torch.Tensor
+ # bias gradient. None if there is no tensor bias or if it doesn't require grad
+ db: Optional[torch.Tensor] = None
+
+
+class AttentionOpBase(BaseOperator):
+ """Base class for any attention operator in xFormers
+
+ See:
+
+ - :attr:`xformers.ops.fmha.cutlass.FwOp`
+ - :attr:`xformers.ops.fmha.cutlass.BwOp`
+ - :attr:`xformers.ops.fmha.flash.FwOp`
+ - :attr:`xformers.ops.fmha.flash.BwOp`
+ - :attr:`xformers.ops.fmha.triton.FwOp`
+ - :attr:`xformers.ops.fmha.triton.BwOp`
+ - :attr:`xformers.ops.fmha.small_k.FwOp`
+ - :attr:`xformers.ops.fmha.small_k.BwOp`
+ """
+
+ OPERATOR: Any
+ SUPPORTED_DEVICES: Set[str]
+ CUDA_MINIMUM_COMPUTE_CAPABILITY: Tuple[int, int] = (5, 0)
+ SUPPORTED_DTYPES: Set[torch.dtype]
+ SUPPORTED_MAX_K: float
+ SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {type(None)}
+ SUPPORTS_DROPOUT: bool
+ SUPPORTS_CUSTOM_SCALE: bool = False
+ SUPPORTS_DIFFERENT_VALUE_EMBED: bool = False
+ IS_DETERMINISTIC: bool = True
+ SUPPORTS_BMGHK: bool = False
+ NAME: str
+ OPERATOR_CATEGORY = "memory_efficient_attention"
+
+ _TEST_BATCH_SIZES: List[int] = [1, 300]
+ _TEST_K: List[int] = [32, 128]
+
+ @classmethod
+ def supports(cls, d: Inputs) -> bool:
+ return not cls.not_supported_reasons(d)
+
+ @classmethod
+ def shape_not_supported_reasons(
+ cls, Mq: int, Mkv: int, K: int, Kv: int
+ ) -> List[str]:
+ reasons = []
+ if not cls.SUPPORTS_DIFFERENT_VALUE_EMBED and K != Kv:
+ reasons.append("query.shape[-1] != value.shape[-1]")
+ if max(K, Kv) > cls.SUPPORTED_MAX_K:
+ reasons.append(
+ f"max(query.shape[-1] != value.shape[-1]) > {cls.SUPPORTED_MAX_K}"
+ )
+ return reasons
+
+ @classmethod
+ def not_supported_reasons(cls, d: Inputs) -> List[str]:
+ """
+ Returns a list of reasons why this is not supported.
+ The kernel can run these inputs only if the returned list is empty
+ """
+ reasons = cls.shape_not_supported_reasons(
+ Mq=d.query.shape[1],
+ Mkv=d.key.shape[1],
+ K=d.query.shape[-1],
+ Kv=d.query.shape[-1],
+ )
+ device_type = d.query.device.type
+ dtype = d.query.dtype
+ if device_type not in cls.SUPPORTED_DEVICES:
+ reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})")
+ if dtype not in cls.SUPPORTED_DTYPES:
+ reasons.append(f"dtype={dtype} (supported: {cls.SUPPORTED_DTYPES})")
+ if type(d.attn_bias) not in cls.SUPPORTED_ATTN_BIAS_TYPES:
+ reasons.append(f"attn_bias type is {type(d.attn_bias)}")
+ if (d.p != 0.0) and not cls.SUPPORTS_DROPOUT:
+ reasons.append("dropout > 0.0")
+ if d.scale is not None and not cls.SUPPORTS_CUSTOM_SCALE:
+ reasons.append("has custom scale")
+ # bfloat16 is only supported on A100+
+ # ... although the kernels can still run and give the
+ # correct result
+ if dtype is torch.bfloat16 and (
+ not device_type.startswith("cuda")
+ or torch.cuda.get_device_capability(d.query.device)[0] < 8
+ ):
+ reasons.append("bf16 is only supported on A100+ GPUs")
+ if not cls.is_available():
+ reasons.append(
+ "operator wasn't built - see `python -m xformers.info` for more info"
+ )
+ if not cls.IS_DETERMINISTIC and torch.are_deterministic_algorithms_enabled():
+ reasons.append(
+ "operator is non-deterministic, but `torch.use_deterministic_algorithms` is set"
+ )
+ if not cls.SUPPORTS_BMGHK and d.query.ndim == 5:
+ reasons.append("operator does not support BMGHK format")
+ return reasons
+
+
+class AttentionFwOpBase(AttentionOpBase):
+ ERROR_ATOL: Mapping[torch.dtype, float] = {
+ torch.float: 3e-4,
+ torch.half: 4e-3,
+ torch.bfloat16: 2e-2,
+ }
+ ERROR_RTOL: Mapping[torch.dtype, float] = {
+ torch.float: 2e-5,
+ torch.half: 4e-4,
+ torch.bfloat16: 5e-3,
+ }
+
+ @classmethod
+ def apply(
+ cls, inp: Inputs, needs_gradient: bool
+ ) -> Tuple[torch.Tensor, Optional[Context]]:
+ raise NotImplementedError()
+
+ @classmethod
+ def attn_operator_flop(
+ cls,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ causal: bool = False,
+ seqstart_k: Optional[torch.Tensor] = None,
+ seqstart_q: Optional[torch.Tensor] = None,
+ ) -> int:
+ """
+ Computes total flops for the attention
+ Assumes inputs in format BMHK
+ """
+ assert query.ndim == 4
+
+ if seqstart_q is not None:
+ seqstart_q_py = seqstart_q.tolist()
+ else:
+ seqstart_q_py = [0, query.shape[1]]
+ if seqstart_k is not None:
+ seqstart_k_py = seqstart_k.tolist()
+ else:
+ seqstart_k_py = [0, key.shape[1]]
+
+ total_flop = 0
+ for q_start, q_end, k_start, k_end in zip(
+ seqstart_q_py, seqstart_q_py[1:], seqstart_k_py, seqstart_k_py[1:]
+ ):
+ num_q = q_end - q_start
+ num_kv = k_end - k_start
+ # (M,K) @ (K,N) GEMM needs M*N*K*2 flop
+ # Q @ K.transpose
+ total_flop += num_q * num_kv * query.shape[-1] * 2
+ # (ignore softmax)
+ # attn @ V
+ total_flop += num_q * key.shape[-1] * num_kv * 2
+ # Multiply by num_heads and batches
+ total_flop = total_flop * value.shape[2] * value.shape[0]
+ if causal:
+ total_flop //= 2
+ return total_flop
+
+
+class AttentionBwOpBase(AttentionOpBase):
+ ERROR_ATOL: Mapping[torch.dtype, float] = {
+ torch.float: 5e-4,
+ torch.half: 9e-2,
+ torch.bfloat16: 0.7,
+ }
+ ERROR_RTOL: Mapping[torch.dtype, float] = {
+ torch.float: 1e-4,
+ torch.half: 2e-2,
+ torch.bfloat16: 0.1,
+ }
+ SUPPORTS_ATTN_BIAS_GRAD = False
+
+ @classmethod
+ def not_supported_reasons(cls, d: Inputs) -> List[str]:
+ reasons = super(AttentionBwOpBase, cls).not_supported_reasons(d)
+ if (
+ isinstance(d.attn_bias, torch.Tensor)
+ and d.attn_bias.requires_grad
+ and not cls.SUPPORTS_ATTN_BIAS_GRAD
+ ):
+ reasons.append(
+ "Computing the bias gradient is not supported (attn_bias.requires_grad = True)"
+ )
+
+ return reasons
+
+ @classmethod
+ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
+ raise NotImplementedError()
+
+ @classmethod
+ def attn_operator_flop(
+ cls,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ causal: bool = False,
+ seqstart_k: Optional[torch.Tensor] = None,
+ seqstart_q: Optional[torch.Tensor] = None,
+ ) -> int:
+ """
+ Computes total flops for the attention
+ Assumes inputs in format BMHK
+ """
+ assert query.ndim == 4
+
+ if seqstart_q is not None:
+ seqstart_q_py = seqstart_q.tolist()
+ else:
+ seqstart_q_py = [0, query.shape[1]]
+ if seqstart_k is not None:
+ seqstart_k_py = seqstart_k.tolist()
+ else:
+ seqstart_k_py = [0, key.shape[1]]
+
+ total_flop = 0
+ for q_start, q_end, k_start, k_end in zip(
+ seqstart_q_py, seqstart_q_py[1:], seqstart_k_py, seqstart_k_py[1:]
+ ):
+ num_q = q_end - q_start
+ num_kv = k_end - k_start
+ Kqk = query.shape[-1]
+ Kv = value.shape[-1]
+ # (M,K) @ (K,N) GEMM needs M*N*K*2 flop
+ # att = Q @ K.transpose
+ total_flop += num_q * num_kv * Kqk * 2
+ # att @ dO
+ total_flop += num_kv * num_q * Kv * 2
+ # dov = dO @ V
+ total_flop += num_q * Kv * num_kv * 2
+ # dov @ K
+ total_flop += num_q * Kqk * num_kv * 2
+ # dov @ Q
+ total_flop += num_q * Kqk * num_kv * 2
+ # Multiply by num_heads and batches
+ total_flop = total_flop * value.shape[2] * value.shape[0]
+ if causal:
+ total_flop //= 2
+ return total_flop
+
+
+AttentionOp = Tuple[
+ Optional[Type[AttentionFwOpBase]], Optional[Type[AttentionBwOpBase]]
+]
+
+
+@dataclass
+class AttentionOpDispatch:
+ """Dispatcher to automatically select
+ the best operator to run memory-efficient attention.
+
+ :Deprecated:
+
+ This class is deprecated and will be removed in a later version
+ """
+
+ op: AttentionOp
+
+ @classmethod
+ def from_arguments(
+ cls,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
+ p: float = 0.0,
+ scale: Optional[float] = None,
+ ) -> "AttentionOpDispatch":
+ """Here for backward compatibility"""
+ from .dispatch import _dispatch_fw
+
+ inp = Inputs(
+ query=query,
+ key=key,
+ value=value,
+ attn_bias=attn_bias,
+ p=p,
+ scale=scale,
+ )
+ return AttentionOpDispatch(op=(_dispatch_fw(inp, True), ))
+
+
+def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor:
+ if tensor.ndim == 4:
+ return tensor
+ return tensor.reshape([-1, num_heads, tensor.shape[1], tensor.shape[2]]).permute(
+ (0, 2, 1, 3)
+ )
+
+
+def check_lastdim_alignment_stride1(
+ reasons: List[str], name: str, x: torch.Tensor, alignment: int
+) -> None:
+ if x.shape[-1] % alignment != 0:
+ reasons.append(f"{name}.shape[-1] % {alignment} != 0")
+ elif x.stride(-2) % alignment != 0:
+ reasons.append(
+ f"{name}.stride(-2) % {alignment} != 0 ({name}.stride() = {x.stride()})"
+ )
+ # We can have stride=0 sometimes if dimension=1
+ if x.stride(-1) > 1:
+ reasons.append(
+ f"{name}.stride(-1) > 1 ({name}.stride() = {x.stride()}) - you should call `.contiguous()` on the input"
+ )
diff --git a/vllm/xformers/ops/fmha/dispatch.py b/vllm/xformers/ops/fmha/dispatch.py
new file mode 100644
index 000000000000..d9120e2dd5ff
--- /dev/null
+++ b/vllm/xformers/ops/fmha/dispatch.py
@@ -0,0 +1,81 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import textwrap
+from collections import deque
+from typing import List, Sequence, Type, TypeVar
+
+from . import flash
+from .common import AttentionBwOpBase, AttentionFwOpBase, Inputs
+
+
+T = TypeVar("T", Type[AttentionFwOpBase], Type[AttentionBwOpBase])
+
+
+def _format_inputs_description(inp: Inputs) -> str:
+ return f"""query : shape={tuple(inp.query.shape)} ({inp.query.dtype})
+key : shape={tuple(inp.key.shape)} ({inp.key.dtype})
+value : shape={tuple(inp.value.shape)} ({inp.value.dtype})
+attn_bias : {type(inp.attn_bias)}
+p : {inp.p}"""
+
+
+def _ensure_op_supports_or_raise(exc_type, name: str, op, inp: Inputs) -> None:
+ reasons = op.not_supported_reasons(inp)
+ if not reasons:
+ return
+ raise exc_type(
+ f"""Operator `{name}` does not support inputs:
+{textwrap.indent(_format_inputs_description(inp), ' ')}
+{_format_not_supported_reasons(op, reasons)}"""
+ )
+
+
+def _format_not_supported_reasons(op, reasons: List[str]) -> str:
+ return f"`{op.NAME}` is not supported because:\n " + "\n ".join(reasons)
+
+
+def _run_priority_list(name: str, priority_list: Sequence[T], inp: Inputs) -> T:
+ not_supported_reasons: List[List[str]] = []
+ for op in priority_list:
+ not_supported = op.not_supported_reasons(inp)
+ if not not_supported:
+ return op
+ not_supported_reasons.append(not_supported)
+
+ # Let's write a nice message explaining what we tried and why it's not supported
+ msg = f"""No operator found for `{name}` with inputs:
+{textwrap.indent(_format_inputs_description(inp), ' ')}"""
+ for op, not_supported in zip(priority_list, not_supported_reasons):
+ msg += "\n" + _format_not_supported_reasons(op, not_supported)
+ raise NotImplementedError(msg)
+
+
+def _dispatch_fw_priority_list(
+ inp: Inputs, needs_gradient: bool
+) -> Sequence[Type[AttentionFwOpBase]]:
+ priority_list_ops = deque(
+ [
+ flash.FwOp,
+ ]
+ )
+ return priority_list_ops
+
+
+def _dispatch_fw(inp: Inputs, needs_gradient: bool) -> Type[AttentionFwOpBase]:
+ """Computes the best operator for forward
+
+ Raises:
+ NotImplementedError: if not operator was found
+
+ Returns:
+ AttentionOp: The best operator for the configuration
+ """
+ return _run_priority_list(
+ "memory_efficient_attention_forward",
+ _dispatch_fw_priority_list(inp, needs_gradient),
+ inp,
+ )
diff --git a/vllm/xformers/ops/fmha/flash.py b/vllm/xformers/ops/fmha/flash.py
new file mode 100644
index 000000000000..3c2520c6dceb
--- /dev/null
+++ b/vllm/xformers/ops/fmha/flash.py
@@ -0,0 +1,426 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+from dataclasses import replace
+from itertools import zip_longest
+from typing import Any, List, Optional, Set, Tuple, Union
+
+import torch
+
+from .attn_bias import (
+ AttentionBias,
+ BlockDiagonalCausalFromBottomRightMask,
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
+ BlockDiagonalCausalLocalAttentionMask,
+ BlockDiagonalCausalMask,
+ BlockDiagonalMask,
+ LowerTriangularMask,
+ LowerTriangularMaskWithTensorBias,
+)
+from .common import (
+ AttentionFwOpBase,
+ Context,
+ Inputs,
+ check_lastdim_alignment_stride1,
+)
+
+FLASH_VERSION = "0.0.0"
+try:
+ import flash_attn
+ from flash_attn.flash_attn_interface import (_flash_attn_forward,
+ _flash_attn_backward,
+ _flash_attn_varlen_forward,
+ _flash_attn_varlen_backward
+ )
+
+ FLASH_VERSION = flash_attn.__version__
+
+ # create library so that flash-attn goes through the PyTorch Dispatcher
+ _flash_lib = torch.library.Library("xformers_flash", "DEF")
+
+ _flash_lib.define(
+ "flash_fwd(Tensor query, Tensor key, Tensor value, "
+ "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
+ "int max_seqlen_q, int max_seqlen_k, "
+ "float p, float softmax_scale, "
+ "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
+ )
+
+ _flash_lib.define(
+ "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
+ "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
+ "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
+ "int max_seqlen_q, int max_seqlen_k, "
+ "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
+ )
+
+ def _flash_fwd(
+ query,
+ key,
+ value,
+ cu_seq_lens_q,
+ cu_seq_lens_k,
+ max_seq_len_q,
+ max_seq_len_k,
+ p,
+ softmax_scale,
+ is_causal,
+ window_size,
+ return_softmax,
+ ):
+ if cu_seq_lens_q is None:
+ assert cu_seq_lens_k is None
+ (
+ out,
+ q_padded,
+ k_padded,
+ v_padded,
+ out_padded,
+ softmax_lse,
+ S_dmask,
+ rng_state,
+ ) = _flash_attn_forward(
+ query,
+ key,
+ value,
+ p,
+ softmax_scale,
+ is_causal,
+ return_softmax
+ )
+ else:
+ out = query.new_empty(query.shape[0], query.shape[1], value.shape[2])
+ (
+ out,
+ q_padded,
+ k_padded,
+ v_padded,
+ out_padded,
+ softmax_lse,
+ S_dmask,
+ rng_state,
+ ) = _flash_attn_varlen_forward(
+ query,
+ key,
+ value,
+ cu_seq_lens_q,
+ cu_seq_lens_k,
+ max_seq_len_q,
+ max_seq_len_k,
+ p,
+ softmax_scale,
+ is_causal,
+ return_softmax,
+ )
+ return out, softmax_lse, rng_state
+
+ def _flash_bwd(
+ grad,
+ query,
+ key,
+ value,
+ out,
+ lse,
+ dq,
+ dk,
+ dv,
+ cu_seq_lens_q,
+ cu_seq_lens_k,
+ max_seq_len_q,
+ max_seq_len_k,
+ p,
+ softmax_scale,
+ is_causal,
+ window_size,
+ rng_state,
+ ):
+ if cu_seq_lens_k is None:
+ assert cu_seq_lens_q is None
+ (
+ dq, dk, dv, softmax_d
+ ) = _flash_attn_backward(
+ grad,
+ query,
+ key,
+ value,
+ out,
+ lse,
+ dq,
+ dk,
+ dv,
+ p,
+ softmax_scale,
+ is_causal,
+ rng_state,
+ )
+ else:
+ (
+ dq, dk, dv, softmax_d
+ ) = _flash_attn_varlen_backward(
+ grad,
+ query,
+ key,
+ value,
+ out,
+ lse,
+ dq,
+ dk,
+ dv,
+ cu_seq_lens_q,
+ cu_seq_lens_k,
+ max_seq_len_q,
+ max_seq_len_k,
+ p,
+ softmax_scale,
+ is_causal,
+ rng_state,
+ )
+ return dq, dk, dv
+
+ _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
+ _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
+except ImportError:
+ pass
+
+
+def _convert_input_format(
+ inp: Inputs,
+) -> Tuple[Inputs, Optional[torch.Tensor], int, Optional[torch.Tensor], int]:
+ assert inp.query.ndim in [4, 5]
+ query, key, value = inp.query, inp.key, inp.value
+ batch = query.shape[0]
+ seqlen_q = query.shape[1]
+ seqlen_kv = key.shape[1]
+ head_dim_q = query.shape[-1]
+ head_dim_v = value.shape[-1]
+
+ attn_bias = inp.attn_bias
+ if isinstance(attn_bias, BlockDiagonalMask):
+ attn_bias.k_seqinfo.seqstart = attn_bias.k_seqinfo.seqstart.to(
+ inp.query.device, non_blocking=True
+ )
+ attn_bias.q_seqinfo.seqstart = attn_bias.q_seqinfo.seqstart.to(
+ inp.query.device, non_blocking=True
+ )
+
+ cu_seqlen_k = attn_bias.k_seqinfo.seqstart
+ cu_seqlen_q = attn_bias.q_seqinfo.seqstart
+ max_seqlen_q = attn_bias.q_seqinfo.max_seqlen
+ max_seqlen_k = attn_bias.k_seqinfo.max_seqlen
+ else:
+ cu_seqlen_k = None
+ cu_seqlen_q = None
+ max_seqlen_q = inp.query.shape[1]
+ max_seqlen_k = inp.key.shape[1]
+
+ if query.ndim == 5: # QGA
+ # Fold the group/head_in_group dimensions together
+ def fold(x):
+ # Either the head is replicated
+ if x.stride(3) == 0:
+ return x[:, :, :, 0]
+ # Or we reshape
+ return x.reshape(
+ [
+ x.shape[0],
+ x.shape[1],
+ -1,
+ x.shape[4],
+ ]
+ )
+
+ query = fold(query)
+ key = fold(key)
+ value = fold(value)
+ # Optimize for MHA
+ if key.ndim == 4 and key.stride(2) == 0 and value.stride(2) == 0:
+ key = key[:, :, :1]
+ value = value[:, :, :1]
+ # Initially we have `query.shape = [batch, seqlen, head_dim_q]`
+ # We want format `[batch * seqlen, num_heads, head_dim_q]`
+ if cu_seqlen_k is not None:
+ query = query.reshape([batch * seqlen_q, -1, head_dim_q])
+ key = key.reshape([batch * seqlen_kv, -1, head_dim_q])
+ value = value.reshape([batch * seqlen_kv, -1, head_dim_v])
+ if query.is_contiguous() or key.is_contiguous() or value.is_contiguous():
+ query = query.contiguous()
+ key = key.contiguous()
+ value = value.contiguous()
+ new_inp = replace(
+ inp,
+ query=query,
+ key=key,
+ value=value,
+ )
+ return new_inp, cu_seqlen_q, max_seqlen_q, cu_seqlen_k, max_seqlen_k
+
+
+def _is_causal(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> bool:
+ return isinstance(
+ attn_bias,
+ (
+ LowerTriangularMask,
+ BlockDiagonalCausalMask,
+ BlockDiagonalCausalLocalAttentionMask,
+ BlockDiagonalCausalFromBottomRightMask,
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
+ ),
+ )
+
+
+def _window_size(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int:
+ if isinstance(
+ attn_bias,
+ (BlockDiagonalCausalLocalAttentionMask,),
+ ):
+ return attn_bias._window_size or 0
+ if isinstance(attn_bias, BlockDiagonalCausalLocalAttentionFromBottomRightMask):
+ return attn_bias._window_size
+ return 0
+
+
+def _check_needs_no_topleft(d: Inputs, reasons: List[str]) -> None:
+ # Flash does not support TopLeft, so only allow causal masks with TopLeft
+ # if each batch element has equal number of queries and keys.
+ if isinstance(d.attn_bias, BlockDiagonalCausalMask):
+ # Flash does not support TopLeft, so only allow BlockDiagonalCausalMask
+ # if each batch element has equal number of queries and keys.
+ for k_start, q_start in zip_longest(
+ d.attn_bias.k_seqinfo.seqstart_py, d.attn_bias.q_seqinfo.seqstart_py
+ ):
+ if k_start != q_start:
+ reasons.append(
+ "Only support BlockDiagonalCausalMask if equal"
+ " numbers of keys and queries"
+ )
+ break
+ elif isinstance(d.attn_bias, LowerTriangularMask):
+ if d.query.shape[1] != d.key.shape[1]:
+ reasons.append(
+ "Only support LowerTriangularMask if equal number of" "keys and queries"
+ )
+
+
+def _check_strides_for_bmghk(x: torch.Tensor, name: str, reasons: List[str]) -> None:
+ """
+ We want to be able to collapse the G/H dimensions together
+ """
+ if x.ndim == 5:
+ stride_g, stride_h = x.stride(2), x.stride(3)
+ if x.shape[2] == 1:
+ return
+ if x.shape[3] == 1 or stride_h == 0:
+ return
+ if stride_g != stride_h * x.shape[-2]:
+ reasons.append(
+ f"GQA is only supported when the G/H dimensions are contiguous\n"
+ f" {name}.stride: {x.stride()}\n"
+ f" {name}.shape : {list(x.shape)}"
+ )
+
+
+class FwOp(AttentionFwOpBase):
+ """Operator that computes memory-efficient attention using \
+ `Flash-Attention `_ \
+ implementation.
+ """
+
+ OPERATOR = _flash_fwd
+ SUPPORTED_DEVICES: Set[str] = {"cuda"}
+ #CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
+ CUDA_MINIMUM_COMPUTE_CAPABILITY = (7, 5)
+ SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}
+ SUPPORTED_MAX_K = 256
+ SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
+ type(None),
+ LowerTriangularMask,
+ BlockDiagonalMask,
+ BlockDiagonalCausalMask,
+ BlockDiagonalCausalLocalAttentionMask,
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
+ BlockDiagonalCausalFromBottomRightMask,
+ LowerTriangularMaskWithTensorBias,
+ }
+ SUPPORTS_DROPOUT = True
+ SUPPORTS_CUSTOM_SCALE = True
+ SUPPORTS_DIFFERENT_VALUE_EMBED = False
+ SUPPORTS_BMGHK = True
+ NAME = f"flshattF@{FLASH_VERSION}"
+ VERSION = FLASH_VERSION
+
+ @classmethod
+ def not_supported_reasons(cls, d: Inputs) -> List[str]:
+ reasons = super(FwOp, cls).not_supported_reasons(d)
+ check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
+ _check_needs_no_topleft(d, reasons)
+ _check_strides_for_bmghk(d.query, "query", reasons)
+ _check_strides_for_bmghk(d.key, "key", reasons)
+ _check_strides_for_bmghk(d.value, "value", reasons)
+ return reasons
+
+ @classmethod
+ def apply(
+ cls, inp: Inputs, needs_gradient: bool
+ ) -> Tuple[torch.Tensor, Optional[Context]]:
+ return_softmax = False
+ out_shape = [
+ *inp.query.shape[:-1],
+ inp.value.shape[-1],
+ ]
+ # no cumulative seqlen
+ (
+ inp,
+ cu_seqlens_q,
+ max_seqlen_q,
+ cu_seqlens_k,
+ max_seqlen_k,
+ ) = _convert_input_format(inp)
+
+ softmax_scale = inp.query.shape[-1] ** (-0.5) if inp.scale is None else inp.scale
+
+ ret = cls.OPERATOR(
+ inp.query,
+ inp.key,
+ inp.value,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q,
+ max_seqlen_k,
+ inp.p,
+ softmax_scale,
+ _is_causal(inp.attn_bias),
+ _window_size(inp.attn_bias),
+ return_softmax,
+ )
+
+ out = ret[0].reshape(out_shape)
+ ctx = Context(out=out, lse=ret[1])
+ return (out, ctx)
+
+ @classmethod
+ # type: ignore
+ def operator_flop(
+ cls,
+ query,
+ key,
+ value,
+ cu_seq_lens_q,
+ cu_seq_lens_k,
+ max_seq_len_q,
+ max_seq_len_k,
+ p,
+ softmax_scale,
+ causal,
+ return_softmax,
+ ) -> int:
+ return cls.attn_operator_flop(
+ query.unsqueeze(0),
+ key.unsqueeze(0),
+ value.unsqueeze(0),
+ causal=causal,
+ seqstart_k=cu_seq_lens_k,
+ seqstart_q=cu_seq_lens_q,
+ )