Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove (zeors -= 1) #559

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 69 additions & 3 deletions auto_gptq/modeling/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
preprocess_checkpoint_qigen,
simple_dispatch_model,
unpack_awq,
convert_new_checkpoint_format,
)


Expand Down Expand Up @@ -90,6 +91,7 @@ class BaseQuantizeConfig(PushToHubMixin):
model_name_or_path: Optional[str] = field(default=None)
model_file_base_name: Optional[str] = field(default=None)
awq_gemm_checkpoint: Optional[bool] = field(default=False)
new_checkpoint_format: Optional[bool] = field(default=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we name this is_legacy_format with the default being True when loading models that don't have the config attribute, and default to False when quantizing new models?


def __post_init__(self):
fields_info = fields(self)
Expand All @@ -100,6 +102,9 @@ def __post_init__(self):
raise ValueError("unless equal to -1, group_size must greater then 0.")
if not (0 < self.damp_percent < 1):
raise ValueError("damp_percent must between 0 and 1.")
if self.sym == False:
self.new_checkpoint_format = True
logger.warning("sym is False, will use new_checkpoint_format. because sym=False is not supported in old checkpoint format.")

def save_pretrained(self, save_dir: str, **kwargs):
with open(join(save_dir, "quantize_config.json"), "w", encoding="utf-8") as f:
Expand Down Expand Up @@ -194,6 +199,7 @@ def to_dict(self):
"model_file_base_name": self.model_file_base_name,
"is_marlin_format": self.is_marlin_format,
"quant_method": "gptq",
"new_checkpoint_format": self.new_checkpoint_format,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's call this is_legacy_format

}


Expand All @@ -216,6 +222,8 @@ def __init__(
injected_fused_attention: bool = False,
injected_fused_mlp: bool = False,
trainable: bool = False,
kerenl_backend_type: Optional[str] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe call this qlinear_kernel_name

now_format: Optional = None,
):
super().__init__()

Expand All @@ -229,6 +237,8 @@ def __init__(
self.injected_fused_attention = injected_fused_attention
self.injected_fused_mlp = injected_fused_mlp
self.trainable = trainable
self.kerenl_backend_type = kerenl_backend_type
self.now_format = now_format

@property
def quantized(self):
Expand Down Expand Up @@ -491,7 +501,7 @@ def tmp(_, inp, out):
layer_inputs, layer_outputs = layer_outputs, []
torch.cuda.empty_cache()

pack_model(
self.kerenl_backend_type = pack_model(
model=self.model,
quantizers=quantizers,
bits=self.quantize_config.bits,
Expand All @@ -502,12 +512,14 @@ def tmp(_, inp, out):
warmup_triton=autotune_warmup_after_quantized,
force_layer_back_to_cpu=force_layer_back_to_cpu,
)

if device_map:
self.model = remove_hook_from_module(self.model, recurse=True)
self.model = simple_dispatch_model(self.model, device_map)
self.model.config.use_cache = forward_pass_use_cache

self._quantized = True
self.now_format = "new"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be handled in the config no?


torch.cuda.empty_cache()

Expand All @@ -522,17 +534,40 @@ def device(self):
def to(self, device: Union[str, torch.device]):
self.model.to(device)
return self

def forward(self, *args, **kwargs):
if self.now_format == "old":
self.model = convert_new_checkpoint_format(
self.model,
True,
self.quantize_config,
self.kerenl_backend_type
)
self.now_format = "new"
Comment on lines +538 to +545
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in the from_quantized method, not in forward or generate. Basically, there should be three cases:

  • is_legacy_format missing from the config or argument: assume is_legacy_format=True and update accordingly.
  • is_legacy_format=True (either in config or argument): update accordingly.
  • is_legacy_format=False: the new default, do nothing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this is already in from_quantized, I am not sure why it is as well here?

return self.model(*args, **kwargs)

def generate(self, **kwargs):
"""shortcut for model.generate"""
if self.now_format == "old":
self.model = convert_new_checkpoint_format(
self.model,
True,
self.quantize_config,
self.kerenl_backend_type
)
self.now_format = "new"
with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type):
return self.model.generate(**kwargs)

def prepare_inputs_for_generation(self, *args, **kwargs):
"""shortcut for model.prepare_inputs_for_generation"""
if self.now_format == "old":
self.model = convert_new_checkpoint_format(
self.model,
True,
self.quantize_config,
self.kerenl_backend_type
)
self.now_format = "new"
return self.model.prepare_inputs_for_generation(*args, **kwargs)

def push_to_hub(
Expand Down Expand Up @@ -624,6 +659,17 @@ def save_quantized(

if not self.quantized:
raise EnvironmentError("can only save quantized model, please execute .quantize first.")
if self.quantize_config.new_checkpoint_format:
logger.warning("New checkpoint format is enabled, the saved model is not supported by older versions of AutoGPTQ(<= 0.7.0).")

if not self.quantize_config.new_checkpoint_format and self.now_format == "new":
self.model = convert_new_checkpoint_format(
self.model,
False,
self.quantize_config,
self.kerenl_backend_type
)
self.now_format = "old"
Comment on lines +662 to +672
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point we should always have self.quantize_config.is_legacy_format to be False (if we always convert to the new format).


self.model.to(CPU)

Expand Down Expand Up @@ -1299,6 +1345,25 @@ def skip(*args, **kwargs):
)
model.load_state_dict(checkpoint)

kerenl_backend_type = dynamically_import_QuantLinear(
use_triton=use_triton,
desc_act=quantize_config.desc_act,
group_size=quantize_config.group_size,
bits=quantize_config.bits,
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
use_qigen=use_qigen,
disable_marlin=not use_marlin,
)

if not quantize_config.new_checkpoint_format:
model = convert_new_checkpoint_format(
model,
True,
quantize_config,
kerenl_backend_type
)
Comment on lines +1359 to +1365
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add the argument is_legacy_format: Optional[bool] = None to from_quantized that allows to specify it in case not in the config, defaulting first to the quantization config value if any, and else to the user specified if any, and else to True?

For example, if somebody is quantizing models with an external library, the quantization config may not contain a is_legacy_format.


# == step4: set seqlen == #
model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
Expand Down Expand Up @@ -1337,7 +1402,6 @@ def skip(*args, **kwargs):

# Any post-initialization that require device information, for example buffers initialization on device.
model = autogptq_post_init(model, use_act_order=quantize_config.desc_act)

model.eval()

# == step6: (optional) warmup triton == #
Expand Down Expand Up @@ -1370,6 +1434,8 @@ def skip(*args, **kwargs):
injected_fused_attention=inject_fused_attention,
injected_fused_mlp=inject_fused_mlp and use_triton,
trainable=trainable,
kerenl_backend_type=kerenl_backend_type,
now_format="new",
)

def warmup_triton(self, enabled: bool = True):
Expand Down
54 changes: 51 additions & 3 deletions auto_gptq/modeling/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,55 @@ def make_quant(
)
new_layer.device = ori_layer_device
recurse_setattr(module, name, new_layer.to(ori_layer_device))

def convert_new_checkpoint_format(
model,
to_new_format,
quantize_config,
QuantLinear,
):
use_qigen = QuantLinear.QUANT_TYPE == "qigen"
use_marlin = QuantLinear.QUANT_TYPE == "marlin"

for name, submodule in model.named_modules():
if isinstance(submodule, QuantLinear):
if to_new_format:
if use_qigen:
submodule.zeros.data += 1
elif use_marlin:
pass
else:
if quantize_config.bits == 2:
submodule.qzeros.data += 0b01010101010101010101010101010101
elif quantize_config.bits == 3:
submodule.qzeros.data[:,range(0,submodule.qzeros.data.shape[1],3)] += 0b00100100100100100100100100100100
submodule.qzeros.data[:,range(1,submodule.qzeros.data.shape[1],3)] += 0b10010010010010010010010010010010
submodule.qzeros.data[:,range(2,submodule.qzeros.data.shape[1],3)] += 0b01001001001001001001001001001001
elif quantize_config.bits == 4:
submodule.qzeros.data += 0b00010001000100010001000100010001
elif quantize_config.bits == 8:
submodule.qzeros.data += 0b00000001000000010000000100000001
Comment on lines +151 to +160
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not check for overflows, this is an issue. We used to check for overflows in the kernels.

11111 will become 0000.

else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
else:
if use_qigen:
submodule.zeros.data -= 1
elif use_marlin:
pass
else:
if quantize_config.bits == 2:
submodule.qzeros.data -= 0b01010101010101010101010101010101
elif quantize_config.bits == 3:
submodule.qzeros.data[:,range(0,submodule.qzeros.data.shape[1],3)] -= 0b00100100100100100100100100100100
submodule.qzeros.data[:,range(1,submodule.qzeros.data.shape[1],3)] -= 0b10010010010010010010010010010010
submodule.qzeros.data[:,range(2,submodule.qzeros.data.shape[1],3)] -= 0b01001001001001001001001001001001
elif quantize_config.bits == 4:
submodule.qzeros.data -= 0b00010001000100010001000100010001
elif quantize_config.bits == 8:
submodule.qzeros.data -= 0b00000001000000010000000100000001
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
Comment on lines +163 to +180
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this case? Inference should always be done with is_legacy_format=False.

return model

def preprocess_checkpoint_qigen(
module,
Expand Down Expand Up @@ -297,6 +346,7 @@ def pack_model(
"using autotune_warmup will move model to GPU, make sure you have enough VRAM to load the whole model."
)
QuantLinear.warmup(model.to(CUDA_0), seqlen=model.seqlen)
return QuantLinear
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe return the name directly here



def check_and_get_model_type(model_dir, trust_remote_code=False):
Expand Down Expand Up @@ -551,8 +601,6 @@ def unpack_awq(
torch.int16 if bits == 8 else torch.int8
)

# zeros = zeros + 1

torch.bitwise_and(zeros, (2**bits) - 1, out=zeros)

zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])
Expand Down Expand Up @@ -648,7 +696,6 @@ def pack_from_tensors(
qweight = qweight.astype(np.int32)
qweight = torch.from_numpy(qweight)

unpacked_qzeros = unpacked_qzeros - 1
torch.bitwise_and(unpacked_qzeros, (2**bits) - 1, out=unpacked_qzeros)

unpacked_qzeros = unpacked_qzeros.numpy().astype(np.uint32)
Expand Down Expand Up @@ -750,4 +797,5 @@ def get_checkpoints(model_name_or_path: str, extensions: List[str], possible_mod
"check_and_get_model_type",
"simple_dispatch_model",
"make_sure_no_tensor_in_meta_device",
"convert_new_checkpoint_format",
]
3 changes: 0 additions & 3 deletions auto_gptq/nn_modules/qlinear/qlinear_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ def pack(self, linear, scales, zeros, g_idx=None):
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)

zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
Expand Down Expand Up @@ -261,7 +260,6 @@ def forward(self, x: torch.Tensor):
).to(torch.int16 if self.bits == 8 else torch.int8)
zeros = torch.bitwise_and(zeros, (2**self.bits) - 1)

zeros = zeros + 1
zeros = zeros.reshape(self.scales.shape)

weight = torch.bitwise_right_shift(
Expand All @@ -282,7 +280,6 @@ def forward(self, x: torch.Tensor):
dim=2,
)

zeros = zeros + 1
zeros = zeros.reshape(self.scales.shape)

weight = self.qweight.reshape(self.qweight.shape[0] // 3, 3, 1, self.qweight.shape[1]).expand(
Expand Down
3 changes: 0 additions & 3 deletions auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ def pack(self, linear, scales, zeros, g_idx):
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)

zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
Expand Down Expand Up @@ -298,7 +297,6 @@ def forward(self, x):
self.wf.unsqueeze(0),
).to(torch.int16 if self.bits == 8 else torch.int8)

zeros = zeros + 1
zeros = torch.bitwise_and(
zeros, (2**self.bits) - 1
) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important.
Expand Down Expand Up @@ -327,7 +325,6 @@ def forward(self, x):
dim=2,
)

zeros = zeros + 1
zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])

scales = self.scales
Expand Down
1 change: 0 additions & 1 deletion auto_gptq/nn_modules/qlinear/qlinear_exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def pack(self, linear, scales, zeros, g_idx=None):
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)

zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
Expand Down
2 changes: 1 addition & 1 deletion auto_gptq/nn_modules/qlinear/qlinear_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def unpack_qzeros(qzeros):
i = col % 8
unpacked_zeros[:, col] = (qzeros[:, col // 8] >> (4 * i)) & 0xF

return unpacked_zeros + 1
return unpacked_zeros


# Copied from https://github.com/IST-DASLab/marlin/pull/1
Expand Down
1 change: 1 addition & 0 deletions auto_gptq/nn_modules/qlinear/qlinear_qigen.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def __init__(
hint=1,
p=8,
l1=2**18,
**kwargs,
):
super().__init__()
if bits not in [2, 4]:
Expand Down
1 change: 0 additions & 1 deletion auto_gptq/nn_modules/qlinear/qlinear_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def pack(self, linear, scales, zeros, g_idx=None):
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)

zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
Expand Down
2 changes: 0 additions & 2 deletions auto_gptq/nn_modules/triton_utils/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def quant_matmul_248_kernel(
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)

zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = zeros + 1

a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
Expand Down Expand Up @@ -318,7 +317,6 @@ def transpose_quant_matmul_248_kernel(
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)

zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = zeros + 1

a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
Expand Down
1 change: 0 additions & 1 deletion auto_gptq/utils/marlin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

logger = getLogger(__name__)


def prepare_model_for_marlin_load(
model_name_or_path,
model,
Expand Down
Loading
Loading