Skip to content

Commit

Permalink
Merge pull request #354 from PanQiWei/revert-325-main
Browse files Browse the repository at this point in the history
Reverts #325 for it may breaks exllama kernels
  • Loading branch information
PanQiWei committed Sep 27, 2023
2 parents ac23d6b + 3de7fbb commit 3b81fb5
Show file tree
Hide file tree
Showing 14 changed files with 325 additions and 237 deletions.
23 changes: 1 addition & 22 deletions auto_gptq/modeling/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class BaseQuantizeConfig(PushToHubMixin):
damp_percent: float = field(default=0.01)
desc_act: bool = field(default=True)
static_groups: bool = field(default=False)
sym: bool = field(default=False)
sym: bool = field(default=True)
true_sequential: bool = field(default=True)
model_name_or_path: Optional[str] = field(default=None)
model_file_base_name: Optional[str] = field(default=None)
Expand Down Expand Up @@ -967,27 +967,6 @@ def skip(*args, **kwargs):
checkpoint
)
model.load_state_dict(checkpoint)
# Preprocessing for backward compatibility
if quantize_config.sym:
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, disable_exllama=disable_exllama, use_qigen=use_qigen,
desc_act=quantize_config.desc_act, group_size=quantize_config.group_size, bits=quantize_config.bits)
for name, submodule in model.named_modules():
if isinstance(submodule, QuantLinear):
if use_qigen:
submodule.zeros.data = torch.full_like(submodule.zeros.data, (torch.tensor(2 ** quantize_config.bits - 1) + 1) / 2)
else:
if quantize_config.bits == 2:
submodule.qzeros.data = torch.full_like(submodule.qzeros.data, -1431655766)
elif quantize_config.bits == 3:
submodule.qzeros.data[:,range(0,submodule.qzeros.data.shape[1],3)] = 613566756
submodule.qzeros.data[:,range(1,submodule.qzeros.data.shape[1],3)] = 1227133513
submodule.qzeros.data[:,range(2,submodule.qzeros.data.shape[1],3)] = -1840700270
elif quantize_config.bits == 4:
submodule.qzeros.data = torch.full_like(submodule.qzeros.data, -2004318072)
elif quantize_config.bits == 8:
submodule.qzeros.data = torch.full_like(submodule.qzeros.data, -2139062144)
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
# == step4: set seqlen == #
model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
Expand Down
17 changes: 7 additions & 10 deletions auto_gptq/nn_modules/fused_gptj_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from ._fused_base import FusedBaseAttentionModule
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear

from logging import getLogger
logger = getLogger(__name__)

def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
dim = x.shape[-1]
Expand Down Expand Up @@ -242,13 +240,8 @@ def inject_to_model(
**kwargs
):
config = model.config

QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2)
if QuantLinear.QUANT_TYPE in ["exllama", "exllamav2"] and desc_act:
# See fused_llama_attn.py comment
logger.warning(f"Exllama kernel does not support query/key/value fusion with act-order. Because of this, Fused attention is automatically disabled.")
return False


for name, m in model.named_modules():
if not isinstance(m, GPTJAttention):
continue
Expand All @@ -264,7 +257,11 @@ def inject_to_model(
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)

if QuantLinear.QUANT_TYPE == "exllama":
g_idx = None
if desc_act:
# See fused_llama_attn.py comment
raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.")
else:
g_idx = None
else:
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)

Expand Down Expand Up @@ -301,6 +298,6 @@ def inject_to_model(

setattr(parent, child_name, attn)
del m
return True


__all__ = ["FusedGPTJAttentionForQuantizedModel"]
20 changes: 8 additions & 12 deletions auto_gptq/nn_modules/fused_llama_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from ._fused_base import FusedBaseAttentionModule
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear

from logging import getLogger
logger = getLogger(__name__)

class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
Expand Down Expand Up @@ -144,15 +142,8 @@ def inject_to_model(
"""
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
"""

QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2)
if QuantLinear.QUANT_TYPE in ["exllama", "exllamav2"] and desc_act:
# TODO: support it. The issue lies maybe in the line:
# int groups = qzeros.size(0);
# in exllama_ext.cpp
logger.warning(f"Exllama kernel does not support query/key/value fusion with act-order. Because of this, Fused attention is automatically disabled.")
return False


for name, m in model.named_modules():
if not isinstance(m, LlamaAttention):
continue
Expand All @@ -166,7 +157,13 @@ def inject_to_model(
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)

if QuantLinear.QUANT_TYPE == "exllama":
g_idx = None
if desc_act:
# TODO: support it. The issue lies maybe in the line:
# int groups = qzeros.size(0);
# in exllama_ext.cpp
raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.")
else:
g_idx = None
else:
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)

Expand Down Expand Up @@ -201,7 +198,6 @@ def inject_to_model(
child_name = name

setattr(parent, child_name, attn)
return True


__all__ = ["FusedLlamaAttentionForQuantizedModel"]
3 changes: 3 additions & 0 deletions auto_gptq/nn_modules/qlinear/qlinear_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def pack(self, linear, scales, zeros, g_idx=None):
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)

zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
Expand Down Expand Up @@ -220,6 +221,7 @@ def forward(self, x: torch.Tensor):
).to(torch.int16 if self.bits == 8 else torch.int8)
torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros)

zeros = zeros + 1
zeros = zeros.reshape(self.scales.shape)

weight = torch.bitwise_right_shift(
Expand All @@ -237,6 +239,7 @@ def forward(self, x: torch.Tensor):
zeros = zeros & 0x7
zeros = torch.cat([zeros[:, :, 0, :11], zeros[:, :, 1, 1:12], zeros[:, :, 2, 1:11]], dim=2)

zeros = zeros + 1
zeros = zeros.reshape(self.scales.shape)

weight = self.qweight.reshape(
Expand Down
3 changes: 3 additions & 0 deletions auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def pack(self, linear, scales, zeros, g_idx):
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)

zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
Expand Down Expand Up @@ -230,6 +231,7 @@ def forward(self, x):
zeros = torch.bitwise_right_shift(torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), self.wf.unsqueeze(0)).to(torch.int16 if self.bits == 8 else torch.int8)
torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros)

zeros = zeros + 1
zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])

scales = self.scales
Expand All @@ -246,6 +248,7 @@ def forward(self, x):
zeros = zeros & 0x7
zeros = torch.cat([zeros[:,:,0,:11], zeros[:,:,1,1:12], zeros[:,:,2,1:11]], dim=2)

zeros = zeros + 1
zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])

scales = self.scales
Expand Down
1 change: 1 addition & 0 deletions auto_gptq/nn_modules/qlinear/qlinear_exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def pack(self, linear, scales, zeros, g_idx=None):
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)

zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
Expand Down
1 change: 1 addition & 0 deletions auto_gptq/nn_modules/qlinear/qlinear_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def pack(self, linear, scales, zeros, g_idx=None):
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)

zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
Expand Down
2 changes: 2 additions & 0 deletions auto_gptq/nn_modules/triton_utils/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def quant_matmul_248_kernel(
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)

zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1)

a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
Expand Down Expand Up @@ -289,6 +290,7 @@ def transpose_quant_matmul_248_kernel(
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)

zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1)

a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
Expand Down
Loading

0 comments on commit 3b81fb5

Please sign in to comment.