File tree Expand file tree Collapse file tree 1 file changed +14
-3
lines changed Expand file tree Collapse file tree 1 file changed +14
-3
lines changed Original file line number Diff line number Diff line change @@ -479,13 +479,24 @@ def process_module(name, m):
479479 additional_layer_inputs ["kv_last_layer" ] = shared_kv_cache_dict .get (layer_index - 1 )
480480
481481 # log.info(f"MODULE Last forward: {module}")
482+ module_output = None
483+ if is_lm_head_module :
484+ module_output = module (* layer_input )
485+ else :
486+ module_output = module (* layer_input , ** additional_layer_inputs )
487+
488+ # after transformers 4.54, some model's DecodeLayer.forward() no longer returns tuple
489+ if isinstance (module_output , tuple ):
490+ layer_output = module_output [0 ]
491+ else :
492+ layer_output = module_output
493+
482494 layer_output = move_to (
483- module (* layer_input )[0 ] if is_lm_head_module else
484- module (* layer_input , ** additional_layer_inputs )[0 ],
495+ layer_output ,
485496 device = cur_layer_device if calibration_enable_gpu_cache else CPU ,
486497 # stream=True,
487498 )
488-
499+
489500 layer_outputs .append ([layer_output ])
490501
491502 del layer_input
You can’t perform that action at this time.
0 commit comments