Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gptqmodel/nn_modules/qlinear/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def forward(self, x: torch.Tensor):
is_k_full=self.is_k_full,
bias=self.bias,
use_fp32_reduce=self.fp32,
use_atomics=False, # reduces accuracy with slightly faster performance
)

if self.adapter:
Expand Down
31 changes: 10 additions & 21 deletions gptqmodel/utils/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,21 +145,13 @@ def get_scale_perms():
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single


# Whether to use atomicAdd reduce in gptq/awq marlin kernel. experimental
GPTQMODEL_MARLIN_USE_ATOMIC_ADD = True


def maybe_warn_marlin_atomic_add_env():
if torch.compiler.is_dynamo_compiling():
return
if GPTQMODEL_MARLIN_USE_ATOMIC_ADD:
return
log.info_once(
"Marlin kernel can achieve better performance for small size_n "
"with experimental use_atomic_add feature. "
"You can consider set environment variable "
"GPTQMODEL_MARLIN_USE_ATOMIC_ADD to 1 if possible.")

# log.info_once(
# "Marlin kernel can achieve better performance for small size_n "
# "with experimental use_atomic_add feature.")


def maybe_warn_marlin_atomic_add(device, dtype):
Expand All @@ -180,12 +172,6 @@ def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device,
if n >= 2048 or k < 2048 or device.type != "cuda":
return False

# disable atomicAdd reduce by default,
# one can enable it with GPTQMODEL_MARLIN_USE_ATOMIC_ADD=1
if not GPTQMODEL_MARLIN_USE_ATOMIC_ADD:
maybe_warn_marlin_atomic_add_env()
return False

# sm8x doesn't support atomicAdd + bfloat16 natively
device_capability = torch.cuda.get_device_capability(device)
if device_capability[0] < 9 and dtype == torch.bfloat16:
Expand All @@ -208,11 +194,14 @@ def apply_gptq_marlin_linear(
input_size_per_partition: int,
is_k_full: bool,
bias: Optional[torch.Tensor] = None,
use_fp32_reduce: bool = True) -> torch.Tensor:
use_fp32_reduce: bool = True,
use_atomics: bool = False,

) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,)

use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
use_atomics = use_atomics and should_use_atomic_add_reduce(m=reshaped_x.size(0),
n=output_size_per_partition,
k=reshaped_x.size(1),
device=input.device,
Expand All @@ -233,7 +222,7 @@ def apply_gptq_marlin_linear(
size_n=output_size_per_partition,
size_k=input_size_per_partition,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_atomic_add=use_atomics,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False)

Expand Down