Skip to content

Commit

Permalink
Multi args rebase (#79)
Browse files Browse the repository at this point in the history
* [Feature] enhance advanced ptq with multi-type input.
Previous advanced ptq only support `torch.tensor`, but
sometimes `dict` or `list` are alse needed.

* [Fix] getitem should not be quantized twice.

* [Feature] Add multi args cache.
Note that the code find the placeholder rather than the input module
now. So cache the output of placeholder but the input of module.

* [Feature] support multiple inputs to a graph

* [Fix] prune extra node in a block

* [Fix] fix `keep_gpu` flag for non-tensor input

* [Feature] assign node prefix in config to exclude some certain nodes.
Sometimes, there is no need to quantize all nodes in the network.
Ignore these nodes and keep them float.

Co-authored-by: fanyunqian <fanyunqian@sensetime.com>
  • Loading branch information
PannenetsF and fanyunqian committed Apr 18, 2022
1 parent 05c915e commit 72eebeb
Show file tree
Hide file tree
Showing 4 changed files with 318 additions and 57 deletions.
255 changes: 199 additions & 56 deletions mqbench/advanced_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

from mqbench.utils.logger import logger
from mqbench.utils.hook import DataSaverHook, StopForwardException
from mqbench.utils import deepcopy_graphmodule
from mqbench.utils import deepcopy_graphmodule, topology_order, getitem2node
from mqbench.utils.utils import _fix_succ_recursivly
from mqbench.utils.state import enable_quantization, disable_all

_ADAROUND_SUPPORT_TYPE = (torch.nn.Conv2d, torch.nn.Linear)
Expand All @@ -37,6 +38,34 @@ def lp_loss(pred, tgt, p=2.0):
return (pred - tgt).abs().pow(p).sum(1).mean()


def to_device(data, device='cpu'):
if isinstance(data, torch.Tensor):
return data.to(device)
elif isinstance(data, dict):
for key in data:
data[key] = to_device(data[key], device)
return data
elif isinstance(data, list):
for idx, _ in enumerate(data):
data[idx] = to_device(data[idx], device)
return data
else:
return data


def tensor_detach(data):
if isinstance(data, torch.Tensor):
return data.detach()
elif isinstance(data, dict):
for key in data:
data[key] = tensor_detach(data[key])
return data
elif isinstance(data, list):
data = [tensor_detach(dat) for dat in data]
else:
return data


def save_inp_oup_data(model: GraphModule, inp_module: Module, oup_module: Module, cali_data: list, store_inp=True, store_oup=True,
keep_gpu: bool = True):
"""
Expand All @@ -60,19 +89,19 @@ def save_inp_oup_data(model: GraphModule, inp_module: Module, oup_module: Module
with torch.no_grad():
for batch in cali_data:
try:
_ = model(batch.to(device))
_ = model(to_device(batch, device))
except StopForwardException:
pass
if store_inp:
if keep_gpu:
cached[0].append([inp.detach() for inp in inp_saver.input_store])
cached[0].append([tensor_detach(inp) for inp in inp_saver.input_store])
else:
cached[0].append([inp.detach().cpu() for inp in inp_saver.input_store]) # tuple/list one
cached[0].append([to_device(tensor_detach(inp), 'cpu') for inp in inp_saver.input_store]) # tuple/list one
if store_oup:
if keep_gpu:
cached[1].append(oup_saver.output_store.detach())
cached[1].append(tensor_detach(oup_saver.output_store))
else:
cached[1].append(oup_saver.output_store.detach().cpu())
cached[1].append(to_device(tensor_detach(oup_saver.output_store), 'cpu'))
if store_inp:
inp_handle.remove()
if store_oup:
Expand Down Expand Up @@ -119,7 +148,6 @@ class LossFunction:
r'''loss function to calculate mse reconstruction loss and relaxation loss
use some tempdecay to balance the two losses.
'''

def __init__(self,
subgraph: Module,
weight: float = 1.,
Expand Down Expand Up @@ -180,26 +208,51 @@ def _flatten_args(node):
return flattned_args


def append_extra_inputs(nodes, layer_node_list):
# there are some nodes in the block which are used but not in the list.
# e.g. a global dict used in UP or EOD.
extra_inputs = []
for node in layer_node_list:
for arg in _flatten_args(node.args):
if isinstance(arg, torch.fx.Node):
if arg not in layer_node_list:
extra_inputs.append(arg)
return extra_inputs
def find_used_times(nodes, target):
used = len([_node for _node in target.users if _node in nodes])
return used




def find_cur_node(layer_node_list):
for node in reversed(layer_node_list):
if node.target == 'update':
continue
if isinstance(node.target, str) and 'const' in node.target:
node_list = []
used_later = []
for idx, node in enumerate(layer_node_list):
for _node in layer_node_list[idx + 1:]:
if node in _flatten_args(_node.args):
used_later.append(node)
break
not_used_later = [node for node in layer_node_list if node not in used_later]
single_branch = dict()
for node in not_used_later:
single_branch[node] = set([node])
q = [node]
while True:
now_args = sum([_flatten_args(_node.args) for _node in q], [])
p = [_node for _node in now_args if isinstance(_node, torch.fx.Node) and find_used_times(layer_node_list, _node) == 1]
single_branch[node] = single_branch[node].union(set(p))
if len(p) == 0:
break
else:
q = p
for node in layer_node_list:
if node.op == 'call_function' or node.op == 'call_method':
continue
if node not in used_later:
break
unwanted = set()
for key in single_branch:
if key is node:
continue
return node
raise ValueError('Bad layer node list provided.')
else:
unwanted = unwanted.union(single_branch[key])
layer_node_list = [_node for _node in layer_node_list if _node not in unwanted]
for _node in layer_node_list:
node_list.append(_node)
if _node is node:
return node_list


def subgraph_reconstruction(subgraph, cached_inps, cached_oups, config):
global USE_LINK
Expand Down Expand Up @@ -239,23 +292,30 @@ def subgraph_reconstruction(subgraph, cached_inps, cached_oups, config):
'''start training'''
logger.info('start tuning by adaround')
if config.prob < 1.0:
sz = len(cached_inps[0])
# cache inps: drop x args x batch x data
sz = len(cached_inps[0][0])
num_args = len(cached_inps[0])
else:
sz = len(cached_inps)
# cache inps: args x batch x data
sz = len(cached_inps[0])
num_args = len(cached_inps)
for i in range(config.max_count):
idx = np.random.randint(0, sz)
if config.prob < 1.0:
cur_inp = [inp.to(device) for inp in cached_inps[0][idx]]
cur_sym = [sym.to(device) for sym in cached_inps[1][idx]]
cur_inp = [torch.where(torch.rand_like(inp) < config.prob, inp, sym) for inp, sym in zip(cur_inp, cur_sym)]
else:
cur_inp = cached_inps[idx]
cur_inp = [inp.to(device) for inp in cur_inp]
cur_out = cached_oups[idx].to(device)
cur_args = []
for a in range(num_args):
if config.prob < 1.0:
cur_inp = to_device(cached_inps[0][a][idx], device)
cur_sym = to_device(cached_inps[1][a][idx], device)
cur_inp = torch.where(torch.rand_like(cur_inp) < config.prob, cur_inp, cur_sym)
else:
cur_inp = to_device(cached_inps[a][idx], device)
cur_args.append(cur_inp)
cur_args = tuple(cur_args)
cur_out = to_device(cached_oups[idx], device)
if a_opt:
a_opt.zero_grad()
w_opt.zero_grad()
out_quant = subgraph(*cur_inp)
out_quant = subgraph(*cur_args)
err = loss_func(out_quant, cur_out)
err /= world_size
err.backward()
Expand Down Expand Up @@ -291,24 +351,37 @@ def subgraph_reconstruction(subgraph, cached_inps, cached_oups, config):
layer.prob = 1.0 # recover to promise that drop activation quantization only occurs at reconstruction phase


def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[fx.Node], extra_inputs: List[fx.Node], output: fx.Node):
def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], output: fx.Node, g2node: dict):
"""
Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph.
"""
new_graph = fx.Graph()
env = dict()
for input in set(inputs + extra_inputs):
new_node = new_graph.placeholder(input.name)
env[input] = new_node
inp_lst = []
for node in nodes:
for arg in _flatten_args(node.args):
if isinstance(arg, torch.fx.Node):
if arg not in nodes and arg not in inp_lst:
inp_lst.append(node)
if node in g2node:
arg_name = g2node[node].name
else:
arg_name = node.name
new_node = new_graph.placeholder(arg_name)
env[node] = new_node
break
for node in nodes:
if node in inp_lst:
continue
if node in g2node:
node = g2node[node]
new_node = new_graph.node_copy(node, lambda x: env[x])
env[node] = new_node
# create this or there will not be return value
new_graph.output(env[output])
new_graph.lint()
return fx.GraphModule(orig_module, new_graph)


def find_num_nodes(nodes):
num = 0
for node in nodes:
Expand All @@ -328,11 +401,15 @@ def extract_layer(node, fp32_modules):
layer_node_list.append(cur_node) # valid node here
stop = (len(cur_node.users) == 0)
for user in cur_node.users:
if user.op == 'call_module' and isinstance(fp32_modules[user.target], _ADAROUND_SUPPORT_TYPE):
if user.target == 'update':
continue
if user.op == 'call_module' and isinstance(
fp32_modules[user.target], _ADAROUND_SUPPORT_TYPE):
stop = True
# TODO: only short-cut here, consider more here
# TODO: can also use un/completed to check here.
if ('add' in user.name and user.op in ['call_function', 'call_method']):
if ('add' in user.name
and user.op in ['call_function', 'call_method']):
stop = True
if user.op == 'output':
is_next_block, stop = True, True
Expand All @@ -354,6 +431,7 @@ def extract_block(input_nodes, fp32_modules, depth=0):
is_block = False
cnt = dict()
q, p = [], [] # q records the completed node, p records the uncompleted nodes
cur_node = None
for input in input_nodes:
for user in input.users:
if user not in cnt:
Expand All @@ -368,6 +446,8 @@ def extract_block(input_nodes, fp32_modules, depth=0):
while len(q) != 0:
cur_node = q.pop(0) # valid node here
logger.debug('cur node is {}'.format(cur_node))
if cur_node.target == 'update':
continue
if len(p) == 0 and len(q) == 0:
break
layer_node_list.append(cur_node)
Expand All @@ -382,11 +462,14 @@ def extract_block(input_nodes, fp32_modules, depth=0):
q.append(user)
p.remove(user)
logger.debug('uncompleted nodes are {}'.format(p))
if not cur_node:
return layer_node_list
exp_nodes, is_next_block = extract_layer(cur_node, fp32_modules)
if is_block or is_next_block:
return layer_node_list + exp_nodes
else:
return layer_node_list + exp_nodes + extract_block([exp_nodes[-1]], fp32_modules, depth + 1)
return layer_node_list + exp_nodes + extract_block(
[exp_nodes[-1]], fp32_modules, depth + 1)


def ptq_reconstruction(model: GraphModule, cali_data: list, config: dict):
Expand Down Expand Up @@ -422,7 +505,7 @@ def ptq_reconstruction(model: GraphModule, cali_data: list, config: dict):
"""
# assert model is on cuda
if not config.keep_gpu:
cali_data = [inp.cpu() for inp in cali_data]
cali_data = [to_device(inp, 'cpu') for inp in cali_data]
'''set state first'''

fp32_model = model
Expand All @@ -435,8 +518,18 @@ def ptq_reconstruction(model: GraphModule, cali_data: list, config: dict):
nodes = list(quant_model.graph.nodes)
fp32_modules = dict(fp32_model.named_modules())
quant_modules = dict(quant_model.named_modules())
g2node = getitem2node(quant_model)
checked_nodes = dict()
for node in nodes:
if 'exclude_node_prefix' in config:
cont = False
for prefix in config['exclude_node']:
if node.name.startswith(prefix):
cont = True
break
if cont:
logger.info(f'Exclude node {node}')
continue
if node in checked_nodes:
continue
if node.op == "call_module" and isinstance(fp32_modules[node.target], _ADAROUND_SUPPORT_TYPE):
Expand All @@ -447,22 +540,72 @@ def ptq_reconstruction(model: GraphModule, cali_data: list, config: dict):
layer_node_list = extract_block(node.all_input_nodes, fp32_modules)
else:
raise NotImplementedError
extra_inputs = append_extra_inputs(nodes, layer_node_list)
cur_node = find_cur_node(layer_node_list)
fp32_module = fp32_modules[cur_node.target]
fp32_inp_module = fp32_modules[node.target]
quant_module = quant_modules[node.target]
fp32_inps, fp32_oups = save_inp_oup_data(fp32_model, fp32_inp_module, fp32_module, cali_data,
store_inp=(config.prob < 1.0), store_oup=True, keep_gpu=config.keep_gpu)
quant_inps, _ = save_inp_oup_data(quant_model, quant_module, None, cali_data, store_inp=True,
store_oup=False, keep_gpu=config.keep_gpu)
# if the update is not used in the block, remove it
if not all([n.target != 'update' for n in layer_node_list]):
remove_nodes = []
for idx, n in enumerate(layer_node_list):
if n.target == 'update':
src = n.args[0]
remove = True
for _idx in range(idx + 1, len(layer_node_list)):
if src in _flatten_args(
layer_node_list[_idx].args):
remove = False
break
if remove:
remove_nodes.append(n)
layer_node_list = [n for n in layer_node_list if n not in remove_nodes]
missing_inputs = []
for _node in layer_node_list:
for arg in _flatten_args(_node.args):
if isinstance(arg, torch.fx.Node):
if arg not in layer_node_list and arg not in missing_inputs:
missing_inputs.append(arg)
layer_node_list.extend(missing_inputs)
# replace getitem nodes into its source node
layer_node_list = [n if n not in g2node else g2node[n] for n in layer_node_list]
for _node in layer_node_list:
src = [arg for arg in _flatten_args(_node.args) if arg in g2node]
for arg in src:
_node.args = _fix_succ_recursivly(_node.args, arg, g2node[arg])
layer_node_list = sorted(layer_node_list, key=lambda x: topology_order(quant_model)[x])
layer_node_list = find_cur_node(layer_node_list)
logger.info('the node list is below!')
logger.info(layer_node_list)
subgraph = extract_subgraph(quant_modules, layer_node_list, node.all_input_nodes, extra_inputs, cur_node)
logger.info(subgraph)
cached_inps = (quant_inps, fp32_inps) if config.prob < 1.0 else quant_inps
cached_oups = fp32_oups
fp32_module = fp32_modules[layer_node_list[-1].target]
fp32_all_inps = []
quant_all_inps = []
fp32_final_oups = None
out_is_cached = False
for _node in layer_node_list:
if all([arg in layer_node_list for arg in _flatten_args(_node.args) if isinstance(arg, torch.fx.Node)]):
continue
else:
fp32_inp_module = fp32_modules[_node.target]
quant_module = quant_modules[_node.target]
# fp32 inps: [out_b1, out_b2, ...]
_, fp32_inps = save_inp_oup_data(fp32_model, None, fp32_inp_module, cali_data,
store_inp=False, store_oup=(config.prob < 1.0), keep_gpu=config.keep_gpu)
_, fp32_oups = save_inp_oup_data(fp32_model, None, fp32_module, cali_data,
store_inp=False, store_oup=(not out_is_cached), keep_gpu=config.keep_gpu)
_, quant_inps = save_inp_oup_data(quant_model, None, quant_module, cali_data,
store_inp=False, store_oup=True, keep_gpu=config.keep_gpu)
fp32_all_inps.append(fp32_inps)
quant_all_inps.append(quant_inps)
if not out_is_cached:
fp32_final_oups = fp32_oups
out_is_cached = True
cached_inps = (quant_all_inps, fp32_all_inps) if config.prob < 1.0 else quant_all_inps
cached_oups = fp32_final_oups
subgraph = extract_subgraph(quant_modules, layer_node_list,
layer_node_list[-1], g2node)
logger.info(subgraph.code)
subgraph_reconstruction(subgraph, cached_inps, cached_oups, config)
for x in layer_node_list:
checked_nodes[x] = True
disable_all(quant_model)
for node in checked_nodes:
if node.op == 'call_module':
enable_quantization(quant_modules[node.target])
logger.info(f'set the node {node.target} in quant')
return quant_model

0 comments on commit 72eebeb

Please sign in to comment.