Skip to content

Commit

Permalink
[ParamManager][Redo] Use BundleModelParams for transform_dequantize
Browse files Browse the repository at this point in the history
Prior to this commit, `ParamManager.transform_quantize` function took
as input functions with separate parameters for each weight tensor,
and produced output functions with a tuple parameter for all weights.
Because `LiftTransformParams` had the same convention, neither could
be applied as part of the same build flow.

This commit updates `ParamManager.transform_quantize` pass to produce
outputs with separate tensor parameters, using the `BundleModelParams`
transform to later combine them into a single tuple parameter.  The
analogous change was also performed for `LiftTransformParams` as part
of apache/tvm#15657.

In addition, prior to this commit, the
`ParamManager.transform_dequantize` function operated directly on a
`IRModule` object.  As a result, any debug instrumentation
(e.g. before/after printouts for each pass, before/after verification
with `relax.analysis.well_formed`, etc.) did not apply to this
`transform_dequantize`.  This commit updates
`ParamManager.transform_dequantize` to return a `ir.transform.Pass`.

This commit is a repeat of the reverted PR
mlc-ai#1056.  This PR resolves the bug
in the earlier implementation by removing the call to
`.without_attr("num_input")` in `ParamReplacer.rewrite_func`.  This
follows an analogous update in `LiftTransformParams`, preserving the
`"num_input"` attribute for use in `BundleModelParams`.
  • Loading branch information
Lunderberg committed Oct 24, 2023
1 parent 9cb8e8e commit b762ee4
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 63 deletions.
3 changes: 2 additions & 1 deletion mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,8 @@ def mod_transform_before_build(
if args.model.lower().startswith("rwkv-"):
model_names += ["reset_kv_cache"]

mod = param_manager.transform_dequantize(mod)
mod = param_manager.transform_dequantize()(mod)
mod = relax.transform.BundleModelParams()(mod)

use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft"]
mod = mlc_llm.transform.FuseDecodeTranspose(skip_gemm=not use_ft_quant)(mod)
Expand Down
116 changes: 54 additions & 62 deletions mlc_llm/relax_model/param_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def set_param_loading_func(
else:
self.pidx2pname = dict()

def transform_dequantize(self, mod: tvm.IRModule) -> tvm.IRModule:
def transform_dequantize(self) -> tvm.ir.transform.Pass:
"""Apply dequantization to the input IRModule.
Parameters
Expand All @@ -386,38 +386,48 @@ def transform_dequantize(self, mod: tvm.IRModule) -> tvm.IRModule:
The IRModule updated with the dequantization computation.
"""

# For each Relax function in the input IRModule (e.g., "prefill"),
# we create its input relax.Var of all the quantized data, and
# store the mapping from function name to the var.
func2param_var: Dict[str, relax.Var] = {}
for gv, func in mod.functions.items():
if not isinstance(func, relax.Function):
continue
if func.attrs is None or not "num_input" in func.attrs:
continue
func2param_var[gv.name_hint] = relax.Var(
"params", self.get_quantized_param_info(gv.name_hint)
)
@tvm.ir.transform.module_pass(opt_level=0, name="ParamManager.transform_dequantize")
def transform_func(mod: tvm.IRModule, _context) -> tvm.IRModule:
# For each Relax function in the input IRModule (e.g., "prefill"),
# we create its input relax.Var of all the quantized data, and
# store the mapping from function name to the var.
func_name_to_quantized_params: Dict[str, List[relax.Var]] = {}

# Cache mapping to avoid duplicate dequantization.
dequantized_cache: Dict[relax.Var, relax.Var] = {}
for gv, func in mod.functions.items():
if isinstance(func, relax.Function) and func.attrs and "num_input" in func.attrs:
quantized_param_info = self.get_quantized_param_info(gv.name_hint)
param_vars = [
relax.Var(f"param_{i}", info)
for i, info in enumerate(quantized_param_info.fields)
]
func_name_to_quantized_params[gv.name_hint] = param_vars

# Define a var replacement function for applying dequantization.
def f_replace(var: relax.Var, bb: relax.BlockBuilder, func_name: str) -> relax.Var:
if var in dequantized_cache:
return dequantized_cache[var]
assert var in self.func_raw_param_map
func_name, param = self.func_raw_param_map[var]
dequantized = self._dequantize(param, func2param_var[func_name], bb, func_name)
dequantized_cache[var] = dequantized
return dequantized
# Cache mapping to avoid duplicate dequantization.
dequantized_cache: Dict[relax.Var, relax.Var] = {}

# Create the function mutator for applying dequantization.
replacer = ParamReplacer(mod, func2param_var, f_replace)
# Update the input IRModule with dequantization.
mod = replacer.transform()
# Define a var replacement function for applying dequantization.
def f_replace(var: relax.Var, bb: relax.BlockBuilder) -> relax.Var:
if var in dequantized_cache:
return dequantized_cache[var]
assert var in self.func_raw_param_map

return mod
func_name, param = self.func_raw_param_map[var]
quantized_params = func_name_to_quantized_params[func_name]
relevant_quantized_params = [quantized_params[i] for i in self.param2qrange[param]]

dequantized = self._dequantize(param, relevant_quantized_params, bb, func_name)

dequantized_cache[var] = dequantized
return dequantized

# Create the function mutator for applying dequantization.
replacer = ParamReplacer(mod, func_name_to_quantized_params, f_replace)
# Update the input IRModule with dequantization.
mod = replacer.transform()

return mod

return transform_func

def get_quantized_param_info(self, func_name: str) -> List[relax.TensorStructInfo]:
bb = relax.BlockBuilder()
Expand Down Expand Up @@ -697,10 +707,9 @@ def _register_param(
def _dequantize(
self,
param: Parameter,
quantized_tuple: relax.Var,
qparams: List[relax.Var],
bb: relax.BlockBuilder,
func_name: str,
qparams: List[relax.Var] = None,
) -> relax.Var:
"""Applying dequantization to the input parameter.
This method is called by `transform_module` below, and is not
Expand All @@ -711,30 +720,13 @@ def _dequantize(
param : Parameter
The parameter whose quantized tensors are to be dequantized.
quantized_tuple : relax.Var
The relax.Var of the quantized tensors of all parameters in the model.
bb : relax.BlockBuilder
The Relax BlockBuilder used for inserting the dequantization computations.
func_name : str
The name of the function which dequantization is applied to.
qparams : List[relax.Var]
The quantized parts of the parameter.
By default it is `None`, in which case we will get the quantized parts
from `quantized_tuple`.
The relax.Var of the quantized tensors of all parameters in the model.
Returns
-------
The dequantized parameter, in the form of a relax.Var.
"""
if not qparams:
# Get the corresponding Relax vars of the quantized tensors of this parameter.
qparams: List[relax.Var] = []
for qparam_idx in self.param2qrange[param]:
qparams.append(bb.emit(relax.TupleGetItem(quantized_tuple, qparam_idx)))

# Get the dequantization function of this parameter.
f_dequantize = param.quant_spec.get_dequantize_func(
param_info=param.param_info_dict[func_name],
Expand Down Expand Up @@ -789,7 +781,7 @@ class ParamReplacer(PyExprMutator):
mod : tvm.IRModule
The IRModule of the model to be updated.
func2param_var : Dict[str, relax.Var]
func_name_to_quantized_params : Dict[str, List[relax.Var]]
The mapping from each function name to its input var of quantized data tuple.
f_replace : Callable[[relax.Var, relax.BlockBuilder], relax.Var]
Expand All @@ -801,7 +793,7 @@ class ParamReplacer(PyExprMutator):
"""

mod: tvm.IRModule
func2param_var: Dict[str, relax.Var]
func_name_to_quantized_params: Dict[str, List[relax.Var]]
f_replace: Callable[[relax.Var, relax.BlockBuilder], relax.Var]
param_set: Set[relax.Var]

Expand All @@ -810,12 +802,12 @@ class ParamReplacer(PyExprMutator):
def __init__(
self,
mod: tvm.IRModule,
func2param_var: Dict[str, relax.Var],
func_name_to_quantized_params: Dict[str, relax.Var],
f_replace: Callable[[relax.Var, relax.BlockBuilder], relax.Var],
):
super().__init__(mod)
self.mod = mod
self.func2param_var = func2param_var
self.func_name_to_quantized_params = func_name_to_quantized_params
self.f_replace = f_replace
self.cur_func_name = ""

Expand All @@ -827,31 +819,31 @@ def transform(self) -> tvm.IRModule:
continue

assert (
gv.name_hint in self.func2param_var
), f"{gv.name_hint} not in {self.func2param_var}"
self.cur_func_name = gv.name_hint
updated_func = self.rewrite_func(func, self.func2param_var[gv.name_hint])
gv.name_hint in self.func_name_to_quantized_params
), f"{gv.name_hint} not in {self.func_name_to_quantized_params}"
updated_func = self.rewrite_func(func, self.func_name_to_quantized_params[gv.name_hint])
updated_func = remove_all_unused(updated_func)
self.builder_.update_func(gv, updated_func)
return self.builder_.get()

def rewrite_func(self, func: Function, param_var: relax.Var) -> relax.Function:
def rewrite_func(self, func: Function, quantized_params: List[relax.Var]) -> relax.Function:
num_input = int(func.attrs["num_input"])
self.param_set = set(func.params[num_input:])

body = self.visit_expr(func.body)
return relax.Function(
params=func.params[:num_input] + [param_var],
params=func.params[:num_input] + quantized_params,
body=body,
ret_struct_info=func.ret_struct_info,
is_pure=func.is_pure,
attrs=func.attrs,
).without_attr("num_input")
)

def visit_var_(self, var: Var) -> Expr:
if var not in self.param_set:
if var in self.param_set:
return self.f_replace(var, self.builder_)
else:
return super().visit_var_(var)
return self.f_replace(var, self.builder_, self.cur_func_name)


##################################################################
Expand Down

0 comments on commit b762ee4

Please sign in to comment.