Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Fix H100 crash/compat with Marlin #654

Merged
merged 6 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions auto_gptq/modeling/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,13 +833,9 @@ def from_quantized(
# format marlin requires marlin kernel
use_marlin = True

marlin_compatible, marlin_optimized = _validate_marlin_device_support()
if use_marlin and (not MARLIN_AVAILABLE or not marlin_compatible):
raise TypeError("use_marlin is true but Marlin is not availble due to cuda/device support.")
elif use_marlin and not marlin_optimized:
logger.info(
"use_marlin is true and your gpu device is supported but not optimized for Marlin."
)
marlin_compatible = _validate_marlin_device_support()
if use_marlin and not MARLIN_AVAILABLE:
raise TypeError("use_marlin is true but Marlin is not available due to cuda/device support.")

if not use_marlin and MARLIN_AVAILABLE:
unsupported_reason = _validate_marlin_compatibility(quantize_config)
Expand Down
21 changes: 4 additions & 17 deletions auto_gptq/utils/marlin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,28 +87,15 @@ def prepare_model_for_marlin_load(


# Validate marlin support
def _validate_marlin_device_support() -> Tuple[bool, bool]:
def _validate_marlin_device_support() -> bool:
"""
Validates if the current device is compatible and optimized for Marlin.
Validates if the current device is compatible for Marlin.
ref: https://github.com/IST-DASLab/marlin?tab=readme-ov-file#requirements

Returns:
Tuple[bool, bool]: The first indicates if CUDA device is compatible for Marlin,
the second indicates if CUDA device is optimized for Marlin.
bool: indicates if CUDA device is compatible for Marlin
"""
supported = False
optimized = False

# >=hopper is compatible but not optimized
if torch.cuda.get_device_capability()[0] >= 9:
supported = True
optimized = False
# ampere and ada are supported and optimized
elif torch.cuda.get_device_capability()[0] >= 8:
supported = True
optimized = True

return supported, optimized
return torch.cuda.get_device_capability()[0] >= 8


# Adapted from https://github.com/rib-2/marlin/tree/conversion
Expand Down
22 changes: 8 additions & 14 deletions autogptq_extension/marlin/marlin_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,14 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool
#endif
}

// Asynchronous global->shared copy with a chache hint indicating that the values may be evicted immediately; used for
// quantized weights B, which are only accessed precisely once and should thus not pollute the L2 cache which we need
// for inputs A and outputs C.
__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) {
// Asynchronous global->shared copy
__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .b64 p;\n"
" createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
" cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
"}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES)
);
asm volatile("{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES));
#else
assert(0);
#endif
Expand Down Expand Up @@ -431,14 +425,14 @@ __global__ void Marlin(
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
B_ptr[i] += b_gl_rd_delta_o;
}
// Only fetch scales if this tile starts a new group
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred)
cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta;
}
}
Expand Down Expand Up @@ -692,7 +686,7 @@ __global__ void Marlin(
// For per-column scales, we only fetch them here in the final step before write-out
if (group_blocks == -1 && last) {
if (s_sh_wr_pred)
cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async_fence();
}
thread_block_reduce();
Expand Down
Loading