Skip to content

Commit 0da2a69

Browse files
[FIX] transformers compat (#1687)
* Update module_looper.py * Update module_looper.py
1 parent 558449b commit 0da2a69

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

gptqmodel/looper/module_looper.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -479,13 +479,24 @@ def process_module(name, m):
479479
additional_layer_inputs["kv_last_layer"] = shared_kv_cache_dict.get(layer_index - 1)
480480

481481
# log.info(f"MODULE Last forward: {module}")
482+
module_output = None
483+
if is_lm_head_module:
484+
module_output = module(*layer_input)
485+
else:
486+
module_output = module(*layer_input, **additional_layer_inputs)
487+
488+
# after transformers 4.54, some model's DecodeLayer.forward() no longer returns tuple
489+
if isinstance(module_output, tuple):
490+
layer_output = module_output[0]
491+
else:
492+
layer_output = module_output
493+
482494
layer_output = move_to(
483-
module(*layer_input)[0] if is_lm_head_module else
484-
module(*layer_input, **additional_layer_inputs)[0],
495+
layer_output,
485496
device=cur_layer_device if calibration_enable_gpu_cache else CPU,
486497
# stream=True,
487498
)
488-
499+
489500
layer_outputs.append([layer_output])
490501

491502
del layer_input

0 commit comments

Comments
 (0)