diff --git a/.github/scripts/build-rocm.sh b/.github/scripts/build-rocm.sh new file mode 100644 index 000000000..fc7515aa7 --- /dev/null +++ b/.github/scripts/build-rocm.sh @@ -0,0 +1,19 @@ +#!/bin/bash +declare build_arch +declare build_os + +set -xeuo pipefail +if [ "${build_os:0:6}" == ubuntu ]; then + image=rocm/dev-ubuntu-22.04:6.1-complete + echo "Using image $image" + docker run --rm --platform "linux/$build_arch" -i \ + -w /src -v "$PWD:/src" "$image" sh -c \ + "apt-get update \ + && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \ + && cmake -DCOMPUTE_BACKEND=hip . \ + && cmake --build ." +fi + +#output_dir="output/${build_os}/${build_arch}" +#mkdir -p "${output_dir}" +#(shopt -s nullglob && cp bitsandbytes/*.{so,dylib,dll} "${output_dir}") diff --git a/.github/workflows/build_documentation.yml b/.github/workflows/build_documentation.yml index 10272be87..a19e7511d 100644 --- a/.github/workflows/build_documentation.yml +++ b/.github/workflows/build_documentation.yml @@ -14,5 +14,6 @@ jobs: commit_sha: ${{ github.sha }} package: bitsandbytes repo_owner: TimDettmers + custom_container: huggingface/transformers-doc-builder secrets: hf_token: ${{ secrets.HUGGINGFACE_PUSH }} diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml index d6455fd11..cc833df5d 100644 --- a/.github/workflows/build_pr_documentation.yml +++ b/.github/workflows/build_pr_documentation.yml @@ -16,3 +16,4 @@ jobs: pr_number: ${{ github.event.number }} package: bitsandbytes repo_owner: TimDettmers + custom_container: huggingface/transformers-doc-builder diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 72e1b099a..78bc747c3 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -103,6 +103,28 @@ jobs: name: shared_library_cuda_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.cuda_version }} path: output/* retention-days: 7 + build-shared-libs-rocm: + strategy: + matrix: + os: [ubuntu-latest] + arch: [x86_64] + runs-on: ${{ matrix.os }} # One day, we could run them on native agents. Azure supports this now but it's planned only for Q3 2023 for hosted agents + steps: + - uses: actions/checkout@v4 + - name: Set up Docker multiarch + if: startsWith(matrix.os, 'ubuntu') + uses: docker/setup-qemu-action@v2 + - name: Clean up disk space + run: | + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf "/usr/local/share/boost" + sudo rm -rf "$AGENT_TOOLSDIRECTORY" + - name: Build C++ + run: bash .github/scripts/build-rocm.sh + env: + build_os: ${{ matrix.os }} + build_arch: ${{ matrix.arch }} build-wheels: needs: - build-shared-libs diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 0fae0ace5..76d7327a8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -9,23 +9,12 @@ We actively welcome your pull requests. 2. If you've added code that should be tested, add tests. 3. If you've changed APIs, update the documentation. 4. Ensure the test suite passes. -5. Make sure your code lints. -6. If you haven't already, complete the Contributor License Agreement ("CLA"). - -## Contributor License Agreement ("CLA") -In order to accept your pull request, we need you to submit a CLA. You only need -to do this once to work on any of Facebook's open source projects. - -Complete your CLA here: +5. Make sure your code lints, install the [pre-commit hooks as documented here](https://huggingface.co/docs/bitsandbytes/main/en/contributing#setup-pre-commit-hooks). ## Issues We use GitHub issues to track public bugs. Please ensure your description is clear and has sufficient instructions to be able to reproduce the issue. -Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe -disclosure of security bugs. In those cases, please go through the process -outlined on that page and do not file a public issue. - ## License By contributing to bitsandbytes, you agree that your contributions will be licensed under the LICENSE file in the root directory of this source tree. diff --git a/README.md b/README.md index 2cf630dcb..7823168ac 100644 --- a/README.md +++ b/README.md @@ -12,8 +12,18 @@ There are ongoing efforts to support further hardware backends, i.e. Intel CPU + **[https://huggingface.co/docs/bitsandbytes/main](https://huggingface.co/docs/bitsandbytes/main)** +## ALPHA TESTERS WANTED: `multi-backend-refactor` AMD GPU + Intel CPU/GPU specific BNB backend implementations + +We're in the process of a complex refactor in order to allow the support of additional hardware backends, other than CUDA, in BNB. The efforts around this are already quite far along and there's plenty of functionality already in place that is in need for users to take a hands-on approach! Mac support will likely soon also see progress. However, I recommend waiting 2 weeks until the device abstraction has further consolidated (**breaking changes upcoming**). + +Currently, you still need to compile from source, after checking out the `multi-backend-refactor` branch (instructions WIP, but [the current docs on the compilation from source](https://huggingface.co/docs/bitsandbytes/main/en/installation#compile-from-source) are a good starting point; [feel free to share tips / input in this Github discussion](https://github.com/TimDettmers/bitsandbytes/discussions/1219). We'll soon enable nightly releases to make this much easier for you! + +Please give feedback to us in [this dedicated Github Discussion space](https://github.com/TimDettmers/bitsandbytes/discussions/categories/catch-all-alpha-testing-the-multi-backend-refactor)! + +We're super excited about these recent developments and grateful for any constructive input or support that you can give to help us make this a reality. BNB is a community project and we're excited for your collaboration 🤗 + ## License -The majority of bitsandbytes is licensed under MIT, however small portions of the project are available under separate license terms, as the parts adapted from Pytorch are licensed under the BSD license. +`bitsandbytes` is MIT licensed. We thank Fabio Cannizzo for his work on [FastBinarySearch](https://github.com/fabiocannizzo/FastBinarySearch) which we use for CPU quantization. diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 760a8eda4..eff7fc686 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -16,6 +16,7 @@ ) from .backends import register_backend from .backends.cpu import CPUBackend +from .backends.npu import NPUBackend from .cextension import lib from .nn import modules @@ -49,11 +50,14 @@ register_backend("xpu", XPUBackend()) +# Register Ascend NPU backend, if available. +if hasattr(torch, "npu") and torch.npu.is_available(): + register_backend("npu", NPUBackend()) + # TODO: Other potential backends: # XLA - Google TPU / PJRT runtime # HPU - Habana / Intel Gaudi # IPU - Graphcore -# NPU - Ascend # Note that we may not map 1:1 with a device type, e.g. SYCL, XLA # In this case, it will be up to each backend to dispatch as needed diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 71943915b..8e296a8ee 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -575,7 +575,8 @@ def matmul_4bit( bias=None, ): assert quant_state is not None - if A.numel() == A.shape[-1] and A.requires_grad == False: + if (A.numel() == A.shape[-1] or A.device.type == "cpu") and A.requires_grad == False: + # CPU backend does not require A to be a vector if A.shape[-1] % quant_state.blocksize != 0: warn( f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}", diff --git a/bitsandbytes/backends/cpu.py b/bitsandbytes/backends/cpu.py index d6a9192e4..5d38171d5 100644 --- a/bitsandbytes/backends/cpu.py +++ b/bitsandbytes/backends/cpu.py @@ -6,9 +6,12 @@ from .base import Backend from .cpu_xpu_common import ( + dequantize_4bit_impl, double_quant_impl, + gemm_4bit_impl, igemmlt_impl, mm_dequant_impl, + quantize_4bit_impl, ) Tensor = torch.Tensor @@ -132,7 +135,11 @@ def quantize_4bit( quant_type: Literal["fp4", "nf4"] = "fp4", quant_storage=torch.uint8, ) -> Tuple[torch.Tensor, QuantState]: - raise NotImplementedError("Not yet implemented for CPU backend") + if blocksize is None: + blocksize = 64 + assert_on_cpu([A, absmax, out]) + assert quant_storage == torch.uint8, "CPU backend only supports uint8 quant_storage" + return quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type) def dequantize_4bit( self, @@ -143,7 +150,10 @@ def dequantize_4bit( blocksize: int = 64, quant_type: Literal["fp4", "nf4"] = "fp4", ) -> torch.Tensor: - raise NotImplementedError("Not yet implemented for CPU backend") + if blocksize is None: + blocksize = 64 + assert_on_cpu([A, absmax, out]) + return dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type) def gemv_4bit( self, @@ -154,7 +164,11 @@ def gemv_4bit( transposed_B=False, state: QuantState = None, ) -> torch.Tensor: - raise NotImplementedError("Not yet implemented for CPU backend") + assert_on_cpu([A, B, out]) + if state is None: + raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()") + + return gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state) def dequantize_blockwise( self, diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index f4e5ed3ec..396234853 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -1,7 +1,13 @@ +from typing import Optional import warnings import torch +from bitsandbytes.functional import ( + QuantState, + get_4bit_type, +) + try: # to support Intel CPU/GPU (XPU) backend import intel_extension_for_pytorch as ipex @@ -49,7 +55,7 @@ def _maybe_torch_compile(func): return func -# Don't use torch.compile for now due to PyTorch issue https://github.com/pytorch/pytorch/issues/124382 +@_maybe_torch_compile def double_quant_impl(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): """ Find absolute max values of each row/column of a tensor, and symmetrically quantize it to int8. @@ -228,3 +234,290 @@ def mm_dequant_impl( out = out + bias.to(compute_dtype) out = out.to(output_dtype) return out + + +NF4_QUANT_TABLE = [ + -1.0 - 1e-2, # 0b0000 + -0.8480964004993439, # 0b0001 + -0.6106329262256622, # 0b0010 + -0.4599952697753906, # 0b0011 + -0.33967943489551544, # 0b0100 + -0.23460740596055984, # 0b0101 + -0.13791173323988914, # 0b0110 + -0.045525018125772476, # 0b0111 + 0.03979014977812767, # 0b1000 + 0.1202552504837513, # 0b1001 + 0.2035212516784668, # 0b1010 + 0.2920137718319893, # 0b1011 + 0.3893125355243683, # 0b1100 + 0.5016634166240692, # 0b1101 + 0.6427869200706482, # 0b1110 + 0.8614784181118011, # 0b1111 +] + + +FP4_QUANT_TABLE = { + 0 - 1e-2: 0, # 0b0000 + 0.00260417: 1, # 0b0001 + 0.0859375: 6, # 0b0110 + 0.20833333: 7, # 0b0111 + 0.29166667: 4, # 0b0100 + 0.4166667: 5, # 0b0101 + 0.583333: 2, # 0b0010 + 0.8333333: 3, # 0b0011 +} + + +@_maybe_torch_compile +def quantize_4bit_impl( + A: Tensor, + absmax: Tensor = None, + out: Tensor = None, + blocksize=64, + compress_statistics=False, + quant_type="nf4", +) -> Tensor: + """ + Quantize tensor A in blocks of 4-bit values. + + Quantizes tensor A by dividing it into blocks which are independently quantized to FP4. + + Parameters + ---------- + A : torch.Tensor + The input tensor. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + The output tensor (8-bit). + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4}, only nf4 is supported now + + Returns + ------- + torch.Tensor: + The 8-bit tensor with packed 4-bit values. + tuple(torch.Tensor, torch.Size, torch.dtype, int): + The quantization state to undo the quantization. + """ + if quant_type not in ["nf4", "fp4"]: + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented for CPU/XPU.") + if quant_type == "fp4": + warnings.warn("fp4 quantization is currently slow on CPU/XPU. Please Use nf4 instead for better performance.") + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + n = A.numel() + input_shape = A.shape + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + + if absmax is None: + absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) + + if out is None: + out = torch.zeros(((n + 1) // 2), dtype=torch.uint8, device=A.device) + + rem = n % blocksize + has_rem = rem > 0 + + # Scale tensor to [-1, 1] + A_reshaped = A.reshape(n) + A_com = A_reshaped[: n - rem] + A_com_reshaped = A_com.reshape(n // blocksize, blocksize) + absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] + scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) + scaled_A = scaled_A.reshape(-1) + if has_rem: + absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() + scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) + scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) + # map [-1, 1] to nf4/fp4 + out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8) + if quant_type == "nf4": + for i in range(len(NF4_QUANT_TABLE)): + out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i + elif quant_type == "fp4": + sign = scaled_A < 0 + abs_scaled_A = torch.abs(scaled_A) + for key, val in FP4_QUANT_TABLE.items(): + out_uint8[abs_scaled_A > key] = val + out_uint8 += sign.to(torch.uint8) * 8 + if out_uint8.size(-1) % 2: + out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0) + out[:] = out_uint8[1::2].bitwise_left_shift(4).bitwise_or_(out_uint8[::2]) + + code = get_4bit_type(quant_type, device=A.device) + + if compress_statistics: + raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU") + else: + state = QuantState( + absmax=absmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + ) + + if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and input_shape[1] % blocksize == 0 and quant_type == "nf4": + # lowp_mode: lowest precision for computation + lowp_mode = ipex_cpu.quantization.WoqLowpMode.BF16 + state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack( + out.reshape([input_shape[0], input_shape[1] // 2]), + ipex_cpu.quantization.WoqWeightDtype.NF4, + input_shape, # weight shape + absmax.view(input_shape[0], input_shape[1] // blocksize), # scales + None, # zero_points + None, # bias + None, # g_idx + None, # batch_size + blocksize, + int(lowp_mode), + -1, # act_quant_mode. -1 means don't quant activation + ) + state.absmax = torch.Tensor() + return torch.Tensor(), state + + return out, state + + +@_maybe_torch_compile +def dequantize_4bit_impl( + A: Tensor, + quant_state=None, + absmax: Tensor = None, + out: Tensor = None, + blocksize: int = 64, + quant_type="nf4", +) -> Tensor: + """ + Dequantizes FP4 blockwise quantized values. + + Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize. + + Parameters + ---------- + A : torch.Tensor + The input 8-bit tensor (packed 4-bit values). + quant_state : QuantState + object with quantisation stats, incl. absmax values, original tensor shape and original dtype. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + Dequantized output tensor. + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4}, only nf4 is supported now + + + Returns + ------- + torch.Tensor: + Dequantized tensor. + """ + + if quant_state is None: + assert absmax is not None and out is not None + + quant_state = QuantState( + absmax=absmax, + shape=out.shape, + dtype=out.dtype, + blocksize=blocksize, + quant_type=quant_type, + ) + + else: + absmax = quant_state.absmax + + if quant_type not in ["nf4", "fp4"]: + raise NotImplementedError( + f"4-bit quantization data type {quant_state.quant_type} is not implemented for CPU/XPU." + ) + + if quant_state.nested: + raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU") + + if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(quant_state, "op_context"): + assert quant_state.op_context is not None + A = quant_state.op_context.to_public(quant_state.op_context.get_weight()) + A = A.reshape(-1) + absmax = quant_state.op_context.get_scales().reshape(-1) + + if out is None: + out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) + + n = out.numel() + # Map nf4 to [-1, 1] + out_uint8 = torch.empty(A.size(0) * 2, dtype=torch.uint8, device=A.device) + out_uint8[::2] = A.bitwise_and(0xF) + out_uint8[1::2] = A.bitwise_right_shift(4) + out_dq = torch.empty(out_uint8.shape).to(quant_state.dtype) + for i in range(len(quant_state.code)): + out_dq[out_uint8 == i] = quant_state.code[i] + + # Apply scales + if out_dq.numel() != n: + assert out_dq.numel() == n + 1 + out_dq = torch.narrow(out_dq, 0, 0, n) + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + rem = n % blocksize + has_rem = rem > 0 + out_reshaped = out.reshape(-1) + out_reshaped[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape( + -1 + ) + if has_rem: + out_reshaped[n - rem :] = out_dq[n - rem :] * absmax[-1] + + # take transpose here because weight is transposed (again) for computation + return out.t() + + +# Do not need torch.compile here as we are calling torch/ipex kernel +def gemm_4bit_impl( + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + transposed_A=False, + transposed_B=False, + state: QuantState = None, +) -> torch.Tensor: + """ + Matrix-matrix multiplication with 4-bit quantization. + + Parameters + ---------- + A : torch.Tensor + The first input tensor. Usually the activation tensor. + B : torch.Tensor + The second input tensor. Usually the weight tensor. + out : torch.Tensor + The output tensor. + transposed_A : bool + Whether A is transposed + transposed_B : bool + Whether B is transposed + state : QuantState + Contains quantization info, such as blocksize and dtype + + Returns + ------- + torch.Tensor: + GEMM output tensor. + """ + if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(state, "op_context"): + assert state.op_context is not None + output = torch.ops.torch_ipex.ipex_woq_linear(A, state.op_context.get_data_handle()) + else: + dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize) + output = torch.matmul(A, dqB) + if out is not None: + out.copy_(output) + else: + out = output + return out diff --git a/bitsandbytes/backends/npu.py b/bitsandbytes/backends/npu.py new file mode 100644 index 000000000..1b3cb57d6 --- /dev/null +++ b/bitsandbytes/backends/npu.py @@ -0,0 +1,170 @@ +from typing import Literal, Optional, Tuple, Union + +import torch + +from bitsandbytes.utils import QuantState + +from .base import Backend + +try: + # to support Ascend NPU backend + import torch_npu # noqa: F401 +except ImportError: + pass + + +class NPUBackend(Backend): + def double_quant( + self, + A: torch.Tensor, + col_stats: Optional[torch.Tensor] = None, + row_stats: Optional[torch.Tensor] = None, + out_col: Optional[torch.Tensor] = None, + out_row: Optional[torch.Tensor] = None, + threshold=0.0, + ): + raise NotImplementedError + + def transform( + self, + A: torch.Tensor, + to_order: str, + from_order="row", + out: Optional[torch.Tensor] = None, + transpose=False, + state: Optional[Tuple[torch.Size, str]] = None, + ld=None, + ): + raise NotImplementedError + + def igemmlt( + self, + A: torch.Tensor, + B: torch.Tensor, + SA: Tuple[torch.Size, str], + SB: Tuple[torch.Size, str], + out: Optional[torch.Tensor] = None, + Sout: Optional[Tuple[torch.Size, str]] = None, + dtype=torch.int32, + ) -> Union[torch.Tensor, Tuple[Optional[Tuple[torch.Tensor, Tuple[torch.Size, str]]]]]: + raise NotImplementedError + + def mm_dequant( + self, + A: torch.Tensor, + quant_state: Tuple[torch.Size, str], + row_stats: torch.Tensor, + col_stats: torch.Tensor, + out: Optional[torch.Tensor] = None, + new_row_stats: Optional[torch.Tensor] = None, + new_col_stats: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError + + def extract_outliers( + self, + A: torch.Tensor, + SA: Tuple[torch.Size, str], + idx: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + def quantize_4bit( + self, + A: torch.Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=64, + compress_statistics=False, + quant_type: Literal["fp4", "nf4"] = "fp4", + quant_storage=torch.uint8, + ) -> Tuple[torch.Tensor, QuantState]: + raise NotImplementedError + + def dequantize_4bit( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 64, + quant_type: Literal["fp4", "nf4"] = "fp4", + ) -> torch.Tensor: + raise NotImplementedError + + def gemv_4bit( + self, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + transposed_A=False, + transposed_B=False, + state: QuantState = None, + ) -> torch.Tensor: + raise NotImplementedError + + def dequantize_blockwise( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 4096, + nested=False, + ) -> torch.Tensor: + raise NotImplementedError + + def quantize_blockwise( + self, + A: torch.Tensor, + code: Optional[torch.Tensor] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=4096, + nested=False, + ) -> Tuple[torch.Tensor, QuantState]: + raise NotImplementedError + + def optimizer_update_8bit_blockwise( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, + ) -> None: + raise NotImplementedError + + def optimizer_update_32bit( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + beta1: float, + eps: float, + step: int, + lr: float, + state2: Optional[torch.Tensor] = None, + beta2: float = 0.0, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Optional[torch.Tensor] = None, + max_unorm: float = 0.0, + skip_zeros=False, + ) -> None: + raise NotImplementedError diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 79b31f51f..7ab070785 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -289,7 +289,7 @@ def from_prequantized( return self def _quantize(self, device): - w = self.data.contiguous().cuda(device) + w = self.data.contiguous().to(device) w_4bit, quant_state = bnb.functional.quantize_4bit( w, blocksize=self.blocksize, @@ -307,6 +307,9 @@ def _quantize(self, device): def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): return self.to(device="cuda" if device is None else device, non_blocking=non_blocking) + def cpu(self, non_blocking: bool = False): + return self.to(device="cpu", non_blocking=non_blocking) + @overload def to( self: T, @@ -324,7 +327,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if device is not None and device.type == "cuda" and not self.bnb_quantized: + if device is not None and device.type in ["cuda", "cpu"] and not self.bnb_quantized: return self._quantize(device) else: if self.quant_state is not None: diff --git a/docs/source/fsdp_qlora.md b/docs/source/fsdp_qlora.md index 47922cfcc..11e169ffb 100644 --- a/docs/source/fsdp_qlora.md +++ b/docs/source/fsdp_qlora.md @@ -9,25 +9,40 @@ This guide provides a brief guide on how bitsandbytes supports storing quantized ## Quantized data storage -FSDP only supports sharding float data types which can be problematic because quantized weights are typically stored as integer data types (uint8). bitsandbytes doesn't have this problem because it uses `StoreChar` to read and write quantized weights regardless of the data type storage. This makes it simple to add a `quant_storage` parameter to the [`~nn.Linear4bit`] and [`~nn.Params4bit`] classes and set it to `torch.uint8` to maintain backward compatibility with the codebase. +FSDP only supports sharding float data types which can be problematic because quantized weights are typically stored as integer data types (uint8). bitsandbytes doesn't have this problem because it uses `StoreChar` to read and write quantized weights regardless of the data type storage. This makes it simple to add a `quant_storage` parameter to the [`~nn.Linear4bit`] and [`~nn.Params4bit`] classes and set it to `torch.uint8` to maintain backward compatibility with the codebase. With the `quant_storage` parameter, you can select any of the FSDP supported data types to shard [`~nn.Linear4bit`] with such as bfloat16, float16 or float32. + +You'll typically access and configure this option from [`transformers.BitsAndBytesConfig`] by setting the `bnb_4bit_quant_storage` parameter. It is very **important** the `quant_storage` data type matches the data types used throughout the model because FSDP can only wrap layers and modules that have the *same floating data type*. Making sure the data types are aligned will ensure the model is correctly sharded. + +> [!TIP] +> The `compute_dtype` is the data type used for computation inside the CUDA kernel, where the 4-bit quantized weights are unpacked from the data type in `quant_storage` and dequantized to `compute_dtype`. We recommend using torch.bfloat16 (if available on your hardware) for better numerical stability. ```py -import torch -import bitsandbytes as bnb - -model = bnb.nn.Linear4bit( - input_features, - output_features, - quant_type="fp4", - quant_storage=torch.uint8, +from transformers import BitsAndBytesConfig, AutoModelForCausalLM + +bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_quant_storage=torch.bfloat16, +) + +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-70b", + quantization_config=bnb_config, + torch_dtype=torch.bfloat16, ) ``` -With the `quant_storage` parameter, you can select any of the FSDP supported data types to shard [`~nn.Linear4bit`] with such as bfloat16, float16 or float32. +Check out this [section](https://hf.co/docs/peft/main/en/accelerate/fsdp#use-peft-qlora-and-fsdp-for-finetuning-large-models-on-multiple-gpus) of the PEFT documentation for the config file and training code to run FSDP-QLoRA training. ## Training -bitsandbytes is deeply integrated with the Hugging Face ecosystem, making it easy to use with libraries like [Transformers](https://hf/co/docs/transformers), [PEFT](https://hf/co/docs/peft), and [TRL](https://hf/co/docs/trl). +> [!TIP] +> FSDP is a distributed training framework that needs to be launched as a distributed training job with a library like [Accelerate](https://hf.co/docs/accelerate/index) or [torchrun](https://pytorch.org/docs/stable/elastic/run.html). The launch command provided in this section uses Accelerate to launch the training script. + +bitsandbytes is deeply integrated with the Hugging Face ecosystem, making it easy to use with libraries like [Transformers](https://hf.co/docs/transformers), [PEFT](https://hf.co/docs/peft), and [TRL](https://hf.co/docs/trl). + +PEFT provides a configuration file ([fsdp_config_qlora.yaml](https://github.com/huggingface/peft/blob/main/examples/sft/configs/fsdp_config_qlora.yaml)), launch command ([run_peft_qlora_fsdp.sh](https://github.com/huggingface/peft/blob/main/examples/sft/run_peft_qlora_fsdp.sh)), and training script ([train.py](https://github.com/huggingface/peft/blob/main/examples/sft/train.py)) for running FSDP-QLoRA. To learn more, check out the [Use PEFT QLoRA and FSDP for finetuning large models on multiple GPUs](https://huggingface.co/docs/peft/main/en/accelerate/fsdp#use-peft-qlora-and-fsdp-for-finetuning-large-models-on-multiple-gpus) documentation. This section briefly covers the steps to run FSDP-QLoRA training. Before you begin, make sure you have the latest libraries installed. @@ -35,9 +50,6 @@ Before you begin, make sure you have the latest libraries installed. pip install -U bitsandbytes accelerate transformers peft trl ``` -> [!TIP] -> PEFT provides a configuration file ([fsdp_config_qlora.yaml](https://github.com/huggingface/peft/blob/main/examples/sft/configs/fsdp_config_qlora.yaml)), launch command ([run_peft_qlora_fsdp.sh](https://github.com/huggingface/peft/blob/main/examples/sft/run_peft_qlora_fsdp.sh)), and training script ([train.py](https://github.com/huggingface/peft/blob/main/examples/sft/train.py)) for FSDP-QLoRA. To learn more, check out the [Use PEFT QLoRA and FSDP for finetuning large models on multiple GPUs](https://huggingface.co/docs/peft/main/en/accelerate/fsdp#use-peft-qlora-and-fsdp-for-finetuning-large-models-on-multiple-gpus) documentation. - The important change that enables FSDP-QLoRA training is the `bnb_4bit_quant_storage` parameter in the [`~transformers.BitsAndBytesConfig`] class. This allows you to set the storage data type of the quantized weights to a float data type. ```py diff --git a/requirements-ci.txt b/requirements-ci.txt index 24e2db324..0e9dd2407 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -1,6 +1,6 @@ # Requirements used for GitHub actions -pytest==8.2.0 +pytest==8.2.1 einops==0.8.0 lion-pytorch==0.1.4 scipy==1.10.1; python_version < "3.9" -scipy==1.13.0; python_version >= "3.9" +scipy==1.13.1; python_version >= "3.9" diff --git a/requirements-dev.txt b/requirements-dev.txt index 0334896be..de7adce94 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,9 +1,9 @@ # Requirements used for local development setuptools>=63 -pytest~=8.2.0 +pytest~=8.2.1 einops~=0.8.0 wheel~=0.43.0 lion-pytorch~=0.1.4 -scipy~=1.13.0 +scipy~=1.13.1 pandas~=2.2.2 -matplotlib~=3.8.4 +matplotlib~=3.9.0 diff --git a/tests/test_functional.py b/tests/test_functional.py index 8ddee9f9a..4e82c530a 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2017,7 +2017,8 @@ def test_bench_dequantization(): @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096]) -def test_4bit_quant(dtype, quant_type, blocksize): +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +def test_4bit_quant(dtype, quant_type, blocksize, device): vals = list(product([0, 1], repeat=4)) code = {} @@ -2041,9 +2042,11 @@ def test_4bit_quant(dtype, quant_type, blocksize): result = sign * exp * frac code[idx] = result - A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype) + A1 = torch.randn(1024, 1024, device=device, dtype=dtype) qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) + if device == "cpu": + A2 = A2.t() err = (A1 - A2).abs().float() relerr = (err / (A1.abs().float() + 1e-8)).mean() @@ -2297,6 +2300,49 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): assert maxratio < 1.02 and maxratio > 0.98 +@pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"]) +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) +def test_gemv_4bit_cpu(dtype, quant_type, kind): + """ + Test 4bit GEMV for CPU. It is simplified a lot from the cuda version, since + the CPU backend does not support double_quant or quant_storage other than uint8. + Also, the CPU backend has different numeric accuracy from that of CUDA + """ + for dim in [128, 256, 512, 1024]: + for i in range(10): + if kind == "fc1": + A = torch.randn(1, dim, dtype=dtype, device="cpu") + B = torch.randn(dim * 4, dim, dtype=dtype, device="cpu") / math.sqrt(dim) + elif kind == "fc2": + A = torch.randn(1, 4 * dim, dtype=dtype, device="cpu") + B = torch.randn(dim, 4 * dim, dtype=dtype, device="cpu") / math.sqrt(dim) + elif kind == "attn": + A = torch.randn(1, dim, dtype=dtype, device="cpu") + B = torch.randn(dim, dim, dtype=dtype, device="cpu") / math.sqrt(dim) + elif kind == "attn_packed": + A = torch.randn(1, dim, dtype=dtype, device="cpu") + B = torch.randn(dim * 3, dim, dtype=dtype, device="cpu") / math.sqrt(dim) + + qB, state = F.quantize_4bit( + B, + quant_type=quant_type, + compress_statistics=False, + quant_storage=torch.uint8, + ) + dqB = F.dequantize_4bit(qB, state) + C3 = torch.matmul(A, dqB) + C2 = F.gemv_4bit(A, qB.t(), state=state) + A.requires_grad = True + C1 = bnb.matmul_4bit(A, qB.t(), state) + + c = int(C1.numel() * 0.0014 * (dim / 256)) + 1 + rtol = 1e-3 if dtype != torch.bfloat16 else 1e-2 + atol = 1e-2 if dtype != torch.bfloat16 else 5e-2 + assert_all_approx_close(C1, C2, rtol, atol, count=c) + assert_all_approx_close(C3, C2, rtol, atol, count=c) + + @pytest.mark.skip("Row scale has some bugs for ampere") def test_managed(): n = 32 * 10