Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions gptqmodel/looper/eora_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def is_skipped(self, module: NamedModule) -> bool:
# dynamic override removed eora processing for this module
return module.adapter_cfg in [None, {}]

def preprocess_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]:
def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]:
def tmp(module, input: Tuple[torch.Tensor, ...], output: torch.Tensor):
self.eora_process_input(
input=input,
Expand All @@ -115,7 +115,7 @@ def tmp(module, input: Tuple[torch.Tensor, ...], output: torch.Tensor):
)
return tmp

def pre_process_stream_hook(self, module: NamedModule):
def pre_process_streaming(self, module: NamedModule):
eigen_matrix = self.eigen_scaling_diag_matrix[module.name]
with torch_streamCtx(module.target_device_stream):
if eigen_matrix is not None:
Expand Down
12 changes: 3 additions & 9 deletions gptqmodel/looper/gptq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ..quantization.config import QUANT_METHOD, QuantizeConfig
from ..utils.logger import setup_logger
from ..utils.model import move_to, pack_model
from ..utils.torch import CPU, DEVICE_0, torch_streamCtx, torch_sync
from ..utils.torch import CPU, DEVICE_0, DEVICE_1, torch_streamCtx, torch_sync

log = setup_logger()

Expand Down Expand Up @@ -113,26 +113,22 @@ def is_skipped(self, module: NamedModule) -> bool:
else:
return False

def preprocess_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]:
def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]:
def tmp(module, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
# gptq is mutable.
g = self.tasks[name] # noqa: F821
g.add_batch(inp[0].data, out.data) # noqa: F821
del inp, out
return tmp

def pre_process_stream_hook(self, module: NamedModule):
def pre_process_streaming(self, module: NamedModule):
g = self.tasks[module.name]
with torch_streamCtx(module.target_device_stream):
if g.H is not None:
g.H = g.H.to(device=module.target_device, non_blocking=True)
module.weight.data = module.weight.data.to(device=module.target_device, non_blocking=True)

def process(self, module: NamedModule, auto_gc: bool = True):
# need to sync stream copies
# if torch.cuda.device_count() > 1:
# torch.cuda.synchronize()

# Reset peak memory stats
#torch.cuda.reset_peak_memory_stats()
self.pb.title(f"Quantizing {module.name} in layer ").draw()
Expand Down Expand Up @@ -239,9 +235,7 @@ def process(self, module: NamedModule, auto_gc: bool = True):
"wq": wq, # fp16, quantized weight but not int4 (packed qweight)
})

old = module.weight.data # TODO HACK since we cannot delete weight.data directly
module.weight.data = wq
del old

# if auto_gc:
# torch_empty_cache()
Expand Down
7 changes: 4 additions & 3 deletions gptqmodel/looper/loop_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,11 +290,12 @@ def clear_cache_data(self):
self.tasks = {}
self.inputs_cache.layer_inputs = []

def preprocess_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]:
def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]:
pass

# do work right before process where stream async/weight copies may happen
def pre_process_stream_hook(self, module: NamedModule):
# only called when more than 1 gpu devices are active
# do work right before process starts and after all fwd_hook ends where stream async/weight copies may happen
def pre_process_streaming(self, module: NamedModule):
pass

# do work and return processor.self state which will updated/merged
Expand Down
13 changes: 8 additions & 5 deletions gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,12 +293,13 @@ def loop(self, auto_gc=True, calibration_enable_gpu_cache=True, buffered_fwd=Fal
for name in subset:
# log.info(f"Loop name = {name}")
if hasattr(subset[name], 'forward_hook'):
subset[name].forward_hook = processor.preprocess_fwd_hook(name)
subset[name].forward_hook = processor.pre_process_fwd_hook(name)
else:
# TODO FIXME: do we even need to hook into modules that are not quantizable?
assert (f"forward_hook missing for module name: `{name}`, layer name: {layer_name}")
handle.append(subset[name].register_forward_hook(processor.preprocess_fwd_hook(name)))
handle.append(subset[name].register_forward_hook(processor.pre_process_fwd_hook(name)))

# ---- Start Pre-Quantized Forward ----
# logger.info(f"layer-{i}: Begin Forward() Pass")
fwd_start = time.time()

Expand Down Expand Up @@ -367,6 +368,7 @@ def loop(self, auto_gc=True, calibration_enable_gpu_cache=True, buffered_fwd=Fal
if hasattr(subset[name], 'forward_hook'):
subset[name].forward_hook = None


# TODO FIXME: MoE modules forward() may not trigger if dataset is too small
# and moe gating logic does not trigger some moes
if isinstance(processor, GPTQProcessor):
Expand All @@ -378,7 +380,9 @@ def loop(self, auto_gc=True, calibration_enable_gpu_cache=True, buffered_fwd=Fal

for name in moe_skip_modules:
subset.pop(name)
# ---- END Pre-Quantized Forward ----

# ---- Start Proceess Hook ----
if len(ALL_DEVICES) <= 1:
for name_index, name in enumerate(subset):
m = subset[name]
Expand All @@ -387,12 +391,10 @@ def loop(self, auto_gc=True, calibration_enable_gpu_cache=True, buffered_fwd=Fal
else:
for name in subset:
m = subset[name]
processor.pre_process_stream_hook(module=m)
processor.pre_process_streaming(module=m)

torch_sync()

# log.info("streams synced")

# Use ThreadPoolExecutor with 3 threads
max_workers = len(ALL_DEVICES) if DEFAULT_BALANCE_STRATEGY == BalanceStrategy.GPU else len(ALL_DEVICES) - 1
with ThreadPoolExecutor(max_workers=max_workers) as executor:
Expand All @@ -414,6 +416,7 @@ def process_module(name, m):
processed_subset[name] = m

torch_sync()
# ---- End Process Hook ----

if index == len(modules) - 1:
if auto_gc:
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/looper/native_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def is_skipped(self, module: NamedModule) -> bool:
# TODO: Add skipping certain modules
return False

def preprocess_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]:
def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]:
def tmp(module, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
# gptq is mutable.
inp = inp[0].detach()
Expand Down
9 changes: 3 additions & 6 deletions gptqmodel/looper/qqq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ..quantization.qqq import QQQ
from ..utils.logger import setup_logger
from ..utils.model import move_to, pack_model
from ..utils.torch import CPU, torch_sync, torch_streamCtx, DEVICE_0
from ..utils.torch import CPU, DEVICE_0, torch_streamCtx, torch_sync

log = setup_logger()

Expand Down Expand Up @@ -106,14 +106,14 @@ def is_skipped(self, module: NamedModule) -> bool:
else:
return False

def preprocess_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]:
def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]:
def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
# gptq is mutable.
q = self.tasks[name] # noqa: F821
q.add_batch(inp[0].data, out.data) # noqa: F821
return tmp

def pre_process_stream_hook(self, module: NamedModule):
def pre_process_streaming(self, module: NamedModule):
q = self.tasks[module.name]
with torch_streamCtx(module.target_device_stream):
if q.H is not None:
Expand All @@ -122,9 +122,6 @@ def pre_process_stream_hook(self, module: NamedModule):


def process(self, module: NamedModule, auto_gc: bool = True):
# need to sync stream copies
torch_sync()

self.pb.title(f"Quantizing {module.name} in layer ").draw()
qqq = self.tasks

Expand Down
2 changes: 1 addition & 1 deletion tests/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, torch_dtype="aut
is_ovis_model = model.__class__.__name__ == "OvisGPTQ"
need_create_processor = is_image_to_text_model and not is_ovis_model
if not is_quantized:
model.quantize(calibration_dataset, backend=self.QUANT_BACKEND, batch_size=batch_size, buffered_fwd=True)
model.quantize(calibration_dataset, backend=self.QUANT_BACKEND, batch_size=batch_size, buffered_fwd=False)

self.check_kernel(model, self.KERNEL_QUANT)

Expand Down