Skip to content

Commit

Permalink
Rework CUDA setup and diagnostics
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Mar 8, 2024
1 parent 17681f6 commit 174f575
Show file tree
Hide file tree
Showing 13 changed files with 435 additions and 595 deletions.
9 changes: 2 additions & 7 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from . import cuda_setup, research, utils
from . import research, utils
from .autograd._functions import (
MatmulLtState,
bmm_cublas,
Expand All @@ -12,11 +12,8 @@
matmul_cublas,
mm_cublas,
)
from .cextension import COMPILED_WITH_CUDA
from .nn import modules

if COMPILED_WITH_CUDA:
from .optim import adam
from .optim import adam

__pdoc__ = {
"libbitsandbytes": False,
Expand All @@ -25,5 +22,3 @@
}

__version__ = "0.44.0.dev"

PACKAGE_GITHUB_URL = "https://github.com/TimDettmers/bitsandbytes"
108 changes: 2 additions & 106 deletions bitsandbytes/__main__.py
Original file line number Diff line number Diff line change
@@ -1,108 +1,4 @@
import glob
import os
import sys
from warnings import warn

import torch

HEADER_WIDTH = 60


def find_dynamic_library(folder, filename):
for ext in ("so", "dll", "dylib"):
yield from glob.glob(os.path.join(folder, "**", filename + ext))


def generate_bug_report_information():
print_header("")
print_header("BUG REPORT INFORMATION")
print_header("")
print('')

path_sources = [
("ANACONDA CUDA PATHS", os.environ.get("CONDA_PREFIX")),
("/usr/local CUDA PATHS", "/usr/local"),
("CUDA PATHS", os.environ.get("CUDA_PATH")),
("WORKING DIRECTORY CUDA PATHS", os.getcwd()),
]
try:
ld_library_path = os.environ.get("LD_LIBRARY_PATH")
if ld_library_path:
for path in set(ld_library_path.strip().split(os.pathsep)):
path_sources.append((f"LD_LIBRARY_PATH {path} CUDA PATHS", path))
except Exception as e:
print(f"Could not parse LD_LIBRARY_PATH: {e}")

for name, path in path_sources:
if path and os.path.isdir(path):
print_header(name)
print(list(find_dynamic_library(path, '*cuda*')))
print("")


def print_header(
txt: str, width: int = HEADER_WIDTH, filler: str = "+"
) -> None:
txt = f" {txt} " if txt else ""
print(txt.center(width, filler))


def print_debug_info() -> None:
from . import PACKAGE_GITHUB_URL
print(
"\nAbove we output some debug information. Please provide this info when "
f"creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose ...\n"
)


def main():
generate_bug_report_information()

from . import COMPILED_WITH_CUDA
from .cuda_setup.main import get_compute_capabilities

print_header("OTHER")
print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}")
print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities()}")
print_header("")
print_header("DEBUG INFO END")
print_header("")
print("Checking that the library is importable and CUDA is callable...")
print("\nWARNING: Please be sure to sanitize sensitive info from any such env vars!\n")

try:
from bitsandbytes.optim import Adam

p = torch.nn.Parameter(torch.rand(10, 10).cuda())
a = torch.rand(10, 10).cuda()

p1 = p.data.sum().item()

adam = Adam([p])

out = a * p
loss = out.sum()
loss.backward()
adam.step()

p2 = p.data.sum().item()

assert p1 != p2
print("SUCCESS!")
print("Installation was successful!")
except ImportError:
print()
warn(
f"WARNING: {__package__} is currently running as CPU-only!\n"
"Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n"
f"If you think that this is so erroneously,\nplease report an issue!"
)
print_debug_info()
except Exception as e:
print(e)
print_debug_info()
sys.exit(1)


if __name__ == "__main__":
from bitsandbytes.diagnostics.main import main

main()
148 changes: 116 additions & 32 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,123 @@
"""
extract factors the build is dependent on:
[X] compute capability
[ ] TODO: Q - What if we have multiple GPUs of different makes?
- CUDA version
- Software:
- CPU-only: only CPU quantization functions (no optimizer, no matrix multiple)
- CuBLAS-LT: full-build 8-bit optimizer
- no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`)
evaluation:
- if paths faulty, return meaningful error
- else:
- determine CUDA version
- determine capabilities
- based on that set the default path
"""

import ctypes as ct
from warnings import warn
import logging
import os
from pathlib import Path

import torch

from bitsandbytes.cuda_setup.main import CUDASetup
from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR
from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs

logger = logging.getLogger(__name__)


def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
"""
Get the disk path to the CUDA BNB native library specified by the
given CUDA specs, taking into account the `BNB_CUDA_VERSION` override environment variable.
The library is not guaranteed to exist at the returned path.
"""
library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}"
if not cuda_specs.has_cublaslt:
# if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt
library_name += "_nocublaslt"
library_name = f"{library_name}{DYNAMIC_LIBRARY_SUFFIX}"

override_value = os.environ.get("BNB_CUDA_VERSION")
if override_value:
binary_name_stem, _, binary_name_ext = library_name.rpartition(".")
# `binary_name_stem` will now be e.g. `libbitsandbytes_cuda118`;
# let's remove any trailing numbers:
binary_name_stem = binary_name_stem.rstrip("0123456789")
# `binary_name_stem` will now be e.g. `libbitsandbytes_cuda`;
# let's tack the new version number and the original extension back on.
binary_name = f"{binary_name_stem}{override_value}.{binary_name_ext}"
logger.warning(
f'WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {binary_name}.\n'
'This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n'
'If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n'
'If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n'
'For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64\n'
)
library_name = None

return PACKAGE_DIR / library_name


class BNBNativeLibrary:
_lib: ct.CDLL
compiled_with_cuda = False

def __init__(self, lib: ct.CDLL):
self._lib = lib

def __getattr__(self, item):
return getattr(self._lib, item)


class CudaBNBNativeLibrary(BNBNativeLibrary):
compiled_with_cuda = True

def __init__(self, lib: ct.CDLL):
super().__init__(lib)
lib.get_context.restype = ct.c_void_p
lib.get_cusparse.restype = ct.c_void_p
lib.cget_managed_ptr.restype = ct.c_void_p


def get_native_library() -> BNBNativeLibrary:
binary_path = PACKAGE_DIR / f"libbitsandbytes_cpu{DYNAMIC_LIBRARY_SUFFIX}"
cuda_specs = get_cuda_specs()
if cuda_specs:
cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)
if cuda_binary_path.exists():
binary_path = cuda_binary_path
else:
logger.warning("Could not find the bitsandbytes CUDA binary at %r", cuda_binary_path)
logger.debug(f"Loading bitsandbytes native library from: {binary_path}")
dll = ct.cdll.LoadLibrary(str(binary_path))

if hasattr(dll, "get_context"): # only a CUDA-built library exposes this
return CudaBNBNativeLibrary(dll)

logger.warning(
"The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable."
)
return BNBNativeLibrary(dll)

setup = CUDASetup.get_instance()
if setup.initialized != True:
setup.run_cuda_setup()

lib = setup.lib
try:
if lib is None and torch.cuda.is_available():
CUDASetup.get_instance().generate_instructions()
CUDASetup.get_instance().print_log_stack()
raise RuntimeError('''
CUDA Setup failed despite GPU being available. Please run the following command to get more information:
python -m bitsandbytes
Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''')
_ = lib.cadam32bit_grad_fp32 # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False
lib.get_context.restype = ct.c_void_p
lib.get_cusparse.restype = ct.c_void_p
lib.cget_managed_ptr.restype = ct.c_void_p
COMPILED_WITH_CUDA = True
except AttributeError as ex:
warn("The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.")
COMPILED_WITH_CUDA = False
print(str(ex))


# print the setup details after checking for errors so we do not print twice
#if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
#setup.print_log_stack()
lib = get_native_library()
except Exception as e:
lib = None
logger.error(f"Could not load bitsandbytes native library: {e}", exc_info=True)
if torch.cuda.is_available():
logger.warning("""
CUDA Setup failed despite CUDA being available. Please run the following command to get more information:
python -m bitsandbytes
Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues
""")
12 changes: 12 additions & 0 deletions bitsandbytes/consts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from pathlib import Path
import platform

DYNAMIC_LIBRARY_SUFFIX = {
'Darwin': '.dylib',
'Linux': '.so',
'Windows': '.dll',
}.get(platform.system(), '.so')

PACKAGE_DIR = Path(__file__).parent
PACKAGE_GITHUB_URL = "https://github.com/TimDettmers/bitsandbytes"
NONPYTORCH_DOC_URL = "https://github.com/TimDettmers/bitsandbytes/blob/main/docs/source/nonpytorchcuda.mdx"
53 changes: 0 additions & 53 deletions bitsandbytes/cuda_setup/env_vars.py

This file was deleted.

Loading

0 comments on commit 174f575

Please sign in to comment.