-
Notifications
You must be signed in to change notification settings - Fork 429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
remove (zeors -= 1) #559
remove (zeors -= 1) #559
Changes from all commits
b33e1d0
6419a2d
5ea98bb
0b07292
b015ae9
b7d9ade
83e510e
a89cb77
639e66a
15ecb0f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,6 +61,7 @@ | |
preprocess_checkpoint_qigen, | ||
simple_dispatch_model, | ||
unpack_awq, | ||
convert_new_checkpoint_format, | ||
) | ||
|
||
|
||
|
@@ -90,6 +91,7 @@ class BaseQuantizeConfig(PushToHubMixin): | |
model_name_or_path: Optional[str] = field(default=None) | ||
model_file_base_name: Optional[str] = field(default=None) | ||
awq_gemm_checkpoint: Optional[bool] = field(default=False) | ||
new_checkpoint_format: Optional[bool] = field(default=False) | ||
|
||
def __post_init__(self): | ||
fields_info = fields(self) | ||
|
@@ -100,6 +102,9 @@ def __post_init__(self): | |
raise ValueError("unless equal to -1, group_size must greater then 0.") | ||
if not (0 < self.damp_percent < 1): | ||
raise ValueError("damp_percent must between 0 and 1.") | ||
if self.sym == False: | ||
self.new_checkpoint_format = True | ||
logger.warning("sym is False, will use new_checkpoint_format. because sym=False is not supported in old checkpoint format.") | ||
|
||
def save_pretrained(self, save_dir: str, **kwargs): | ||
with open(join(save_dir, "quantize_config.json"), "w", encoding="utf-8") as f: | ||
|
@@ -194,6 +199,7 @@ def to_dict(self): | |
"model_file_base_name": self.model_file_base_name, | ||
"is_marlin_format": self.is_marlin_format, | ||
"quant_method": "gptq", | ||
"new_checkpoint_format": self.new_checkpoint_format, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's call this |
||
} | ||
|
||
|
||
|
@@ -216,6 +222,8 @@ def __init__( | |
injected_fused_attention: bool = False, | ||
injected_fused_mlp: bool = False, | ||
trainable: bool = False, | ||
kerenl_backend_type: Optional[str] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe call this |
||
now_format: Optional = None, | ||
): | ||
super().__init__() | ||
|
||
|
@@ -229,6 +237,8 @@ def __init__( | |
self.injected_fused_attention = injected_fused_attention | ||
self.injected_fused_mlp = injected_fused_mlp | ||
self.trainable = trainable | ||
self.kerenl_backend_type = kerenl_backend_type | ||
self.now_format = now_format | ||
|
||
@property | ||
def quantized(self): | ||
|
@@ -491,7 +501,7 @@ def tmp(_, inp, out): | |
layer_inputs, layer_outputs = layer_outputs, [] | ||
torch.cuda.empty_cache() | ||
|
||
pack_model( | ||
self.kerenl_backend_type = pack_model( | ||
model=self.model, | ||
quantizers=quantizers, | ||
bits=self.quantize_config.bits, | ||
|
@@ -502,12 +512,14 @@ def tmp(_, inp, out): | |
warmup_triton=autotune_warmup_after_quantized, | ||
force_layer_back_to_cpu=force_layer_back_to_cpu, | ||
) | ||
|
||
if device_map: | ||
self.model = remove_hook_from_module(self.model, recurse=True) | ||
self.model = simple_dispatch_model(self.model, device_map) | ||
self.model.config.use_cache = forward_pass_use_cache | ||
|
||
self._quantized = True | ||
self.now_format = "new" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should be handled in the config no? |
||
|
||
torch.cuda.empty_cache() | ||
|
||
|
@@ -522,17 +534,40 @@ def device(self): | |
def to(self, device: Union[str, torch.device]): | ||
self.model.to(device) | ||
return self | ||
|
||
def forward(self, *args, **kwargs): | ||
if self.now_format == "old": | ||
self.model = convert_new_checkpoint_format( | ||
self.model, | ||
True, | ||
self.quantize_config, | ||
self.kerenl_backend_type | ||
) | ||
self.now_format = "new" | ||
Comment on lines
+538
to
+545
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be in the
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As this is already in |
||
return self.model(*args, **kwargs) | ||
|
||
def generate(self, **kwargs): | ||
"""shortcut for model.generate""" | ||
if self.now_format == "old": | ||
self.model = convert_new_checkpoint_format( | ||
self.model, | ||
True, | ||
self.quantize_config, | ||
self.kerenl_backend_type | ||
) | ||
self.now_format = "new" | ||
with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type): | ||
return self.model.generate(**kwargs) | ||
|
||
def prepare_inputs_for_generation(self, *args, **kwargs): | ||
"""shortcut for model.prepare_inputs_for_generation""" | ||
if self.now_format == "old": | ||
self.model = convert_new_checkpoint_format( | ||
self.model, | ||
True, | ||
self.quantize_config, | ||
self.kerenl_backend_type | ||
) | ||
self.now_format = "new" | ||
return self.model.prepare_inputs_for_generation(*args, **kwargs) | ||
|
||
def push_to_hub( | ||
|
@@ -624,6 +659,17 @@ def save_quantized( | |
|
||
if not self.quantized: | ||
raise EnvironmentError("can only save quantized model, please execute .quantize first.") | ||
if self.quantize_config.new_checkpoint_format: | ||
logger.warning("New checkpoint format is enabled, the saved model is not supported by older versions of AutoGPTQ(<= 0.7.0).") | ||
|
||
if not self.quantize_config.new_checkpoint_format and self.now_format == "new": | ||
self.model = convert_new_checkpoint_format( | ||
self.model, | ||
False, | ||
self.quantize_config, | ||
self.kerenl_backend_type | ||
) | ||
self.now_format = "old" | ||
Comment on lines
+662
to
+672
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At this point we should always have |
||
|
||
self.model.to(CPU) | ||
|
||
|
@@ -1299,6 +1345,25 @@ def skip(*args, **kwargs): | |
) | ||
model.load_state_dict(checkpoint) | ||
|
||
kerenl_backend_type = dynamically_import_QuantLinear( | ||
use_triton=use_triton, | ||
desc_act=quantize_config.desc_act, | ||
group_size=quantize_config.group_size, | ||
bits=quantize_config.bits, | ||
disable_exllama=disable_exllama, | ||
disable_exllamav2=disable_exllamav2, | ||
use_qigen=use_qigen, | ||
disable_marlin=not use_marlin, | ||
) | ||
|
||
if not quantize_config.new_checkpoint_format: | ||
model = convert_new_checkpoint_format( | ||
model, | ||
True, | ||
quantize_config, | ||
kerenl_backend_type | ||
) | ||
Comment on lines
+1359
to
+1365
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you also add the argument For example, if somebody is quantizing models with an external library, the quantization config may not contain a |
||
|
||
# == step4: set seqlen == # | ||
model_config = model.config.to_dict() | ||
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"] | ||
|
@@ -1337,7 +1402,6 @@ def skip(*args, **kwargs): | |
|
||
# Any post-initialization that require device information, for example buffers initialization on device. | ||
model = autogptq_post_init(model, use_act_order=quantize_config.desc_act) | ||
|
||
model.eval() | ||
|
||
# == step6: (optional) warmup triton == # | ||
|
@@ -1370,6 +1434,8 @@ def skip(*args, **kwargs): | |
injected_fused_attention=inject_fused_attention, | ||
injected_fused_mlp=inject_fused_mlp and use_triton, | ||
trainable=trainable, | ||
kerenl_backend_type=kerenl_backend_type, | ||
now_format="new", | ||
) | ||
|
||
def warmup_triton(self, enabled: bool = True): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -130,6 +130,55 @@ def make_quant( | |
) | ||
new_layer.device = ori_layer_device | ||
recurse_setattr(module, name, new_layer.to(ori_layer_device)) | ||
|
||
def convert_new_checkpoint_format( | ||
model, | ||
to_new_format, | ||
quantize_config, | ||
QuantLinear, | ||
): | ||
use_qigen = QuantLinear.QUANT_TYPE == "qigen" | ||
use_marlin = QuantLinear.QUANT_TYPE == "marlin" | ||
|
||
for name, submodule in model.named_modules(): | ||
if isinstance(submodule, QuantLinear): | ||
if to_new_format: | ||
if use_qigen: | ||
submodule.zeros.data += 1 | ||
elif use_marlin: | ||
pass | ||
else: | ||
if quantize_config.bits == 2: | ||
submodule.qzeros.data += 0b01010101010101010101010101010101 | ||
elif quantize_config.bits == 3: | ||
submodule.qzeros.data[:,range(0,submodule.qzeros.data.shape[1],3)] += 0b00100100100100100100100100100100 | ||
submodule.qzeros.data[:,range(1,submodule.qzeros.data.shape[1],3)] += 0b10010010010010010010010010010010 | ||
submodule.qzeros.data[:,range(2,submodule.qzeros.data.shape[1],3)] += 0b01001001001001001001001001001001 | ||
elif quantize_config.bits == 4: | ||
submodule.qzeros.data += 0b00010001000100010001000100010001 | ||
elif quantize_config.bits == 8: | ||
submodule.qzeros.data += 0b00000001000000010000000100000001 | ||
Comment on lines
+151
to
+160
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This does not check for overflows, this is an issue. We used to check for overflows in the kernels. 11111 will become 0000. |
||
else: | ||
raise NotImplementedError("Only 2,3,4,8 bits are supported.") | ||
else: | ||
if use_qigen: | ||
submodule.zeros.data -= 1 | ||
elif use_marlin: | ||
pass | ||
else: | ||
if quantize_config.bits == 2: | ||
submodule.qzeros.data -= 0b01010101010101010101010101010101 | ||
elif quantize_config.bits == 3: | ||
submodule.qzeros.data[:,range(0,submodule.qzeros.data.shape[1],3)] -= 0b00100100100100100100100100100100 | ||
submodule.qzeros.data[:,range(1,submodule.qzeros.data.shape[1],3)] -= 0b10010010010010010010010010010010 | ||
submodule.qzeros.data[:,range(2,submodule.qzeros.data.shape[1],3)] -= 0b01001001001001001001001001001001 | ||
elif quantize_config.bits == 4: | ||
submodule.qzeros.data -= 0b00010001000100010001000100010001 | ||
elif quantize_config.bits == 8: | ||
submodule.qzeros.data -= 0b00000001000000010000000100000001 | ||
else: | ||
raise NotImplementedError("Only 2,3,4,8 bits are supported.") | ||
Comment on lines
+163
to
+180
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need this case? Inference should always be done with |
||
return model | ||
|
||
def preprocess_checkpoint_qigen( | ||
module, | ||
|
@@ -297,6 +346,7 @@ def pack_model( | |
"using autotune_warmup will move model to GPU, make sure you have enough VRAM to load the whole model." | ||
) | ||
QuantLinear.warmup(model.to(CUDA_0), seqlen=model.seqlen) | ||
return QuantLinear | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe return the name directly here |
||
|
||
|
||
def check_and_get_model_type(model_dir, trust_remote_code=False): | ||
|
@@ -551,8 +601,6 @@ def unpack_awq( | |
torch.int16 if bits == 8 else torch.int8 | ||
) | ||
|
||
# zeros = zeros + 1 | ||
|
||
torch.bitwise_and(zeros, (2**bits) - 1, out=zeros) | ||
|
||
zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2]) | ||
|
@@ -648,7 +696,6 @@ def pack_from_tensors( | |
qweight = qweight.astype(np.int32) | ||
qweight = torch.from_numpy(qweight) | ||
|
||
unpacked_qzeros = unpacked_qzeros - 1 | ||
torch.bitwise_and(unpacked_qzeros, (2**bits) - 1, out=unpacked_qzeros) | ||
|
||
unpacked_qzeros = unpacked_qzeros.numpy().astype(np.uint32) | ||
|
@@ -750,4 +797,5 @@ def get_checkpoints(model_name_or_path: str, extensions: List[str], possible_mod | |
"check_and_get_model_type", | ||
"simple_dispatch_model", | ||
"make_sure_no_tensor_in_meta_device", | ||
"convert_new_checkpoint_format", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -165,6 +165,7 @@ def __init__( | |
hint=1, | ||
p=8, | ||
l1=2**18, | ||
**kwargs, | ||
): | ||
super().__init__() | ||
if bits not in [2, 4]: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,6 @@ | |
|
||
logger = getLogger(__name__) | ||
|
||
|
||
def prepare_model_for_marlin_load( | ||
model_name_or_path, | ||
model, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we name this
is_legacy_format
with the default beingTrue
when loading models that don't have the config attribute, and default toFalse
when quantizing new models?