Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 2 additions & 16 deletions gptqmodel/looper/awq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@
class AWQProcessor(LoopProcessor):
def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration, prepare_dataset_func,
calibration_concat_size: Optional[int], calibration_sort: Optional[str], batch_size: int, gptq_model, model,
logger_board: str = "", require_fwd: bool = True, calculate_w_wq_diff: bool = False):
require_fwd: bool = True, calculate_w_wq_diff: bool = False):

super().__init__(tokenizer=tokenizer, qcfg=qcfg, calibration=calibration,
calibration_concat_size=calibration_concat_size, calibration_sort=calibration_sort,
prepare_dataset_func=prepare_dataset_func, batch_size=batch_size,
logger_board=logger_board, require_fwd=require_fwd)
require_fwd=require_fwd)

self.calculate_w_wq_diff = calculate_w_wq_diff
self.avg_losses = []
Expand Down Expand Up @@ -77,20 +77,6 @@ def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration, prepare_dataset

self.modules, self.module_kwargs, self.inps = self.init_quant()

def log_plotly(self):
task = self.logger_task
if task is not None:
from ..utils.plotly import create_plotly
x = list(range(self.layer_count))
gpu_fig = create_plotly(x=x, y=self.gpu_memorys, xaxis_title="layer", yaxis_title="GPU usage (GB)")
cpu_fig = create_plotly(x=x, y=self.cpu_memorys, xaxis_title="layer", yaxis_title="CPU usage (GB)")
loss_fig = create_plotly(x=self.module_names, y=self.avg_losses, xaxis_title="layer", yaxis_title="loss")
time_fig = create_plotly(x=self.module_names, y=self.durations, xaxis_title="layer", yaxis_title="time")
task.get_logger().report_plotly('GPU Memory', 'GPU Memory', gpu_fig)
task.get_logger().report_plotly('CPU Memory', 'CPU Memory', cpu_fig)
task.get_logger().report_plotly('avg_loss', 'avg_loss', loss_fig)
task.get_logger().report_plotly('quant_time', 'quant_time', time_fig)

def set_calibration_dataset(self, calibration_dataset):
raise NotImplementedError("AWQProcessor's calibration_dataset cannot be modified")

Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/looper/dequantize_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class DequantizeProcessor(LoopProcessor):
def __init__(self, quantized_modules: Dict[str, TorchQuantLinear]):
super().__init__(tokenizer=None, qcfg=None, calibration=None, calibration_concat_size=None,
prepare_dataset_func=None, batch_size=1,
logger_board="", require_fwd=False)
require_fwd=False)

self.quantized_modules = quantized_modules

Expand Down
16 changes: 2 additions & 14 deletions gptqmodel/looper/eora_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@
class EoraProcessor(LoopProcessor):
def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration, prepare_dataset_func,
calibration_concat_size: Optional[int], calibration_sort: Optional[str], batch_size: int,
logger_board: str = "", require_fwd: bool = True
require_fwd: bool = True
):
super().__init__(tokenizer=tokenizer, qcfg=qcfg, calibration=calibration,
calibration_concat_size=calibration_concat_size,
calibration_sort=calibration_sort,
prepare_dataset_func=prepare_dataset_func, batch_size=batch_size,
logger_board=logger_board, require_fwd=require_fwd)
require_fwd=require_fwd)

# dict: key is module name, value is the accumulated eigen_scaling_diag_matrix
self.eigen_scaling_diag_matrix: Dict[str, torch.Tensor] = {}
Expand All @@ -55,18 +55,6 @@ def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration, prepare_dataset
self.eora_compute_lora = eora_compute_lora
self.eora_process_input = eora_process_input

def log_plotly(self):
task = self.logger_task
if task is not None:
from ..utils.plotly import create_plotly
x = list(range(self.layer_count))
gpu_fig = create_plotly(x=x, y=self.gpu_memorys, xaxis_title="layer", yaxis_title="GPU usage (GB)")
cpu_fig = create_plotly(x=x, y=self.cpu_memorys, xaxis_title="layer", yaxis_title="CPU usage (GB)")
time_fig = create_plotly(x=self.module_names, y=self.durations, xaxis_title="layer", yaxis_title="time")
task.get_logger().report_plotly('GPU Memory', 'GPU Memory', gpu_fig)
task.get_logger().report_plotly('CPU Memory', 'CPU Memory', cpu_fig)
task.get_logger().report_plotly('quant_time', 'quant_time', time_fig)

def set_calibration_dataset(self, calibration_dataset):
self.calibration_dataset = calibration_dataset
self.num_batches = len(calibration_dataset)
Expand Down
19 changes: 2 additions & 17 deletions gptqmodel/looper/gptq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,32 +32,17 @@
class GPTQProcessor(LoopProcessor):
def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration, prepare_dataset_func,
calibration_concat_size: Optional[int], calibration_sort: Optional[str], batch_size: int,
logger_board: str = "", require_fwd: bool = True, calculate_w_wq_diff: bool = False):
require_fwd: bool = True, calculate_w_wq_diff: bool = False):

super().__init__(tokenizer=tokenizer, qcfg=qcfg, calibration=calibration,
calibration_concat_size=calibration_concat_size,
calibration_sort=calibration_sort,
prepare_dataset_func=prepare_dataset_func, batch_size=batch_size,
logger_board=logger_board, require_fwd=require_fwd)
require_fwd=require_fwd)

self.calculate_w_wq_diff = calculate_w_wq_diff
self.avg_losses = []


def log_plotly(self):
task = self.logger_task
if task is not None:
from ..utils.plotly import create_plotly
x = list(range(self.layer_count))
gpu_fig = create_plotly(x=x, y=self.gpu_memorys, xaxis_title="layer", yaxis_title="GPU usage (GB)")
cpu_fig = create_plotly(x=x, y=self.cpu_memorys, xaxis_title="layer", yaxis_title="CPU usage (GB)")
loss_fig = create_plotly(x=self.module_names, y=self.avg_losses, xaxis_title="layer", yaxis_title="loss")
time_fig = create_plotly(x=self.module_names, y=self.durations, xaxis_title="layer", yaxis_title="time")
task.get_logger().report_plotly('GPU Memory', 'GPU Memory', gpu_fig)
task.get_logger().report_plotly('CPU Memory', 'CPU Memory', cpu_fig)
task.get_logger().report_plotly('avg_loss', 'avg_loss', loss_fig)
task.get_logger().report_plotly('quant_time', 'quant_time', time_fig)

def set_calibration_dataset(self, calibration_dataset):
raise NotImplementedError("GPTQProcessor's calibration_dataset cannot be modified")

Expand Down
61 changes: 17 additions & 44 deletions gptqmodel/looper/loop_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,10 @@ def __init__(
self,
tokenizer, qcfg: QuantizeConfig,
calibration,
prepare_dataset_func,
calibration_concat_size: Optional[int],
calibration_sort: Optional[str],
prepare_dataset_func: Optional[Callable] = None,
calibration_concat_size: Optional[int] = None,
calibration_sort: Optional[str] = None,
batch_size: int = 1,
logger_board: str = "",
require_fwd: bool = True,
fwd_after_process: bool = True,
fwd_all_modules_in_single_pass: bool = False,
Expand Down Expand Up @@ -95,7 +94,6 @@ def __init__(
self.tasks = {}

self.pb = None
self.logger_task = None
self.fwd_time = None
self.layer_count = None

Expand All @@ -107,7 +105,6 @@ def __init__(

# logging
self.log = []
self.logger_board = logger_board
self.log_call_count = 0
self._log_column_labels: List[str] = []
self._log_columns = None
Expand All @@ -118,22 +115,6 @@ def __init__(
self._cpu_device_smi = self._init_cpu_device_handle()
self._device_metric_failures: Set[str] = set()

if self.logger_board == "clearml":
try:
from clearml import Task

from ..utils.plotly import create_plotly
except ImportError as _:
raise ImportError(
"The logger_board is set to 'clearml', but required dependencies are missing. "
"Please install them by running: pip install gptqmodel[logger]"
)
self.logger_task = Task.init(project_name='GPTQModel',
task_name=f'{self.__class__.__name__}-{RandomWords().get_random_word()}',
task_type=Task.TaskTypes.optimizer)
else:
self.logger_task = None


# prepare dataset
if calibration is not None:
Expand All @@ -146,6 +127,9 @@ def __init__(
log.warn(f"Calibration dataset size should be more than {min_calibration_dataset_size}. "
f"Current: {len(calibration)}.")

if prepare_dataset_func is None:
raise ValueError("prepare_dataset_func must be provided when calibration data is supplied.")

calibration = prepare_dataset_func(calibration_dataset=calibration,
calibration_dataset_concat_size=calibration_concat_size,
calibration_dataset_sort=calibration_sort,
Expand Down Expand Up @@ -175,9 +159,11 @@ def __init__(
log.warn(f"The average length of input_ids of calibration_dataset should be greater than "
f"{min_calibration_dataset_input_ids_avg_length}: actual avg: {avg}.")

self.num_batches = len(calibration)

self.calibration_dataset = calibration
self.num_batches = len(calibration)
self.calibration_dataset = calibration
else:
self.num_batches = 0
self.calibration_dataset = []

# Track the current calibration batch index on a per-thread basis so
# processors can retrieve deterministic ordering information (e.g.
Expand Down Expand Up @@ -456,26 +442,13 @@ def results(self):
return self._results

def collect_memory_info(self, layer_index: int):
if self.logger_task is not None:
device_snapshot = self._snapshot_device_memory_gib()
total_gpu_memory = sum(device_snapshot.values()) if device_snapshot else 0.0

self.logger_task.get_logger().report_scalar(
title='GPU Memory',
series='GPU Memory',
value=total_gpu_memory,
iteration=layer_index,
)

cpu_memory = self._snapshot_cpu_memory_gib() or 0.0

self.logger_task.get_logger().report_scalar(
title='CPU Memory',
series='CPU Memory',
value=cpu_memory,
iteration=layer_index,
)
device_snapshot = self._snapshot_device_memory_gib()
if device_snapshot:
total_gpu_memory = sum(device_snapshot.values())
self.gpu_memorys.append(total_gpu_memory)

cpu_memory = self._snapshot_cpu_memory_gib()
if cpu_memory is not None:
self.cpu_memorys.append(cpu_memory)

def log_plotly(self):
Expand Down
18 changes: 2 additions & 16 deletions gptqmodel/looper/native_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,17 @@
class NativeProcessor(LoopProcessor):
def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration, prepare_dataset_func,
calibration_concat_size: Optional[int], calibration_sort: Optional[str], batch_size: int,
logger_board: str = "", require_fwd: bool = True):
require_fwd: bool = True):

super().__init__(tokenizer=tokenizer, qcfg=qcfg, calibration=calibration,
calibration_concat_size=calibration_concat_size,
calibration_sort=calibration_sort,
prepare_dataset_func=prepare_dataset_func, batch_size=batch_size,
logger_board=logger_board, require_fwd=require_fwd, fwd_after_process=False,
require_fwd=require_fwd, fwd_after_process=False,
fwd_all_modules_in_single_pass=True)

self.native_inp_caches = {}

def log_plotly(self):
task = self.logger_task
if task is not None:
from ..utils.plotly import create_plotly
x = list(range(self.layer_count))
gpu_fig = create_plotly(x=x, y=self.gpu_memorys, xaxis_title="layer", yaxis_title="GPU usage (GB)")
cpu_fig = create_plotly(x=x, y=self.cpu_memorys, xaxis_title="layer", yaxis_title="CPU usage (GB)")
loss_fig = create_plotly(x=self.module_names, y=self.avg_losses, xaxis_title="layer", yaxis_title="loss")
time_fig = create_plotly(x=self.module_names, y=self.durations, xaxis_title="layer", yaxis_title="time")
task.get_logger().report_plotly('GPU Memory', 'GPU Memory', gpu_fig)
task.get_logger().report_plotly('CPU Memory', 'CPU Memory', cpu_fig)
task.get_logger().report_plotly('avg_loss', 'avg_loss', loss_fig)
task.get_logger().report_plotly('quant_time', 'quant_time', time_fig)

def set_calibration_dataset(self, calibration_dataset):
raise NotImplementedError("NativeProcessor's calibration_dataset cannot be modified")

Expand Down
18 changes: 2 additions & 16 deletions gptqmodel/looper/qqq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,30 +28,16 @@
class QQQProcessor(LoopProcessor):
def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration, prepare_dataset_func,
calibration_concat_size: Optional[int], calibration_sort: Optional[str], batch_size: int,
logger_board: str = "", require_fwd: bool = True, calculate_w_wq_diff: bool = False):
require_fwd: bool = True, calculate_w_wq_diff: bool = False):

super().__init__(tokenizer=tokenizer, qcfg=qcfg, calibration=calibration,
calibration_concat_size=calibration_concat_size, calibration_sort=calibration_sort,
prepare_dataset_func=prepare_dataset_func, batch_size=batch_size,
logger_board=logger_board, require_fwd=require_fwd)
require_fwd=require_fwd)

self.calculate_w_wq_diff = calculate_w_wq_diff
self.avg_losses = []

def log_plotly(self):
task = self.logger_task
if task is not None:
from ..utils.plotly import create_plotly
x = list(range(self.layer_count))
gpu_fig = create_plotly(x=x, y=self.gpu_memorys, xaxis_title="layer", yaxis_title="GPU usage (GB)")
cpu_fig = create_plotly(x=x, y=self.cpu_memorys, xaxis_title="layer", yaxis_title="CPU usage (GB)")
loss_fig = create_plotly(x=self.module_names, y=self.avg_losses, xaxis_title="layer", yaxis_title="loss")
time_fig = create_plotly(x=self.module_names, y=self.durations, xaxis_title="layer", yaxis_title="time")
task.get_logger().report_plotly('GPU Memory', 'GPU Memory', gpu_fig)
task.get_logger().report_plotly('CPU Memory', 'CPU Memory', cpu_fig)
task.get_logger().report_plotly('avg_loss', 'avg_loss', loss_fig)
task.get_logger().report_plotly('quant_time', 'quant_time', time_fig)

def set_calibration_dataset(self, calibration_dataset):
raise NotImplementedError("QQQProcessor's calibration_dataset cannot be modified")

Expand Down
2 changes: 0 additions & 2 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,6 @@ def generate(
calibration_dataset_sort: Optional[str] = None,
batch_size: Optional[int] = 1,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
logger_board: Optional[str] = None,
# pass-through vars for load()
trust_remote_code: bool = False,
dtype: Optional[Union[str, torch.dtype]] = None,
Expand Down Expand Up @@ -664,6 +663,5 @@ def generate(
calibration_dataset_sort=calibration_dataset_sort,
batch_size=batch_size,
tokenizer=tokenizer,
logger_board=logger_board,
)
return
5 changes: 0 additions & 5 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,6 @@ def quantize(
calibration_sort: Optional[str] = "desc", # valid values are asc, desc, shuffle
batch_size: int = 1,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
logger_board: Optional[str] = None,
backend: Optional[BACKEND] = BACKEND.AUTO,
# eora adapter generation needs config Lora(rank=1, path='lora.safetensors')
adapter: Adapter = None,
Expand Down Expand Up @@ -871,7 +870,6 @@ def quantize(
"calibration_concat_size": calibration_concat_size,
"calibration_sort": calibration_sort,
"batch_size": batch_size,
"logger_board": logger_board,
"calculate_w_wq_diff": needs_lora, # lora needs original w - wq delta
}

Expand Down Expand Up @@ -962,7 +960,6 @@ def quantize(
calibration_concat_size=calibration_concat_size,
calibration_sort=calibration_sort,
batch_size=batch_size,
logger_board=logger_board,
)
)

Expand Down Expand Up @@ -990,7 +987,6 @@ def _eora_generate(
calibration_dataset_sort: Optional[str] = None,
batch_size: int = 1,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
logger_board: Optional[str] = None,
):
if self.quantized:
raise EnvironmentError("eora_generate() is called a model that is already quantized")
Expand Down Expand Up @@ -1026,7 +1022,6 @@ def _eora_generate(
calibration_concat_size=calibration_dataset_concat_size,
calibration_sort=calibration_dataset_sort,
batch_size=batch_size,
logger_board=logger_board,
),
]

Expand Down
23 changes: 0 additions & 23 deletions gptqmodel/utils/plotly.py

This file was deleted.

5 changes: 0 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,6 @@ bitblas = [
hf = [
"optimum>=1.21.2",
]
logger = [
"clearml",
"random_word",
"plotly",
]
eval = [
"lm_eval>=0.4.7",
"evalplus>=0.3.1",
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,6 @@ def run(self):
"sglang": ["sglang[srt]>=0.4.6", "flashinfer-python>=0.2.1"],
"bitblas": ["bitblas==0.0.1-dev13"],
"hf": ["optimum>=1.21.2"],
"logger": ["clearml", "random_word", "plotly"],
"eval": ["lm_eval>=0.4.7", "evalplus>=0.3.1"],
"triton": ["triton>=3.4.0"],
"openai": ["uvicorn", "fastapi", "pydantic"],
Expand Down
Loading