Skip to content

Commit

Permalink
[Mix Precision] Fix few bugs. (#181)
Browse files Browse the repository at this point in the history
Co-authored-by: zhangqi3 <zhangqi3@sensetime.com>
  • Loading branch information
Tracin and zhangqi3 committed Aug 29, 2022
1 parent 846f9ce commit 80345e5
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 17 deletions.
15 changes: 5 additions & 10 deletions mqbench/mix_precision/hessian_per_layer.py
Expand Up @@ -2,7 +2,7 @@

import torch
import numpy as np
from pyhessian import hessian, hessian_vector_product, group_product, orthnormal, normalization
from pyhessian import hessian, hessian_vector_product, group_product, normalization


class hessian_per_layer(hessian):
Expand All @@ -15,29 +15,24 @@ def __init__(self, *args, **kwargs):

def layer_eigenvalues(self, maxIter=100, tol=1e-3) -> Dict:
"""
compute the top_n eigenvalues in one model by layer.
compute the max eigenvalues in one model by layer.
"""
device = self.device
max_eigenvalues_dict = {}

for name, mod in self.model.named_modules():
if isinstance(mod, (torch.nn.Conv2d, torch.nn.Linear)):
weight = mod.weight
eigenvectors = []
eigenvalue = None
v = [torch.randn(weight.size()).to(device)]
v = normalization(v)
first_order_grad = self.first_order_grad_dict[name]

for i in range(maxIter):
v = orthnormal(v, eigenvectors)
self.model.zero_grad()

if self.full_dataset:
tmp_eigenvalue, Hv = self.dataloader_hv_product(v)
else:
Hv = hessian_vector_product(first_order_grad, weight, v)
tmp_eigenvalue = group_product(Hv, v).cpu().item()
Hv = hessian_vector_product(first_order_grad, weight, v)
tmp_eigenvalue = group_product(Hv, v).cpu().item()

v = normalization(Hv)

Expand Down Expand Up @@ -73,7 +68,7 @@ def layer_trace(self, maxIter=100, tol=1e-3) -> Dict:

Hv = hessian_vector_product(first_order_grad, weight, v)
trace_vhv.append(group_product(Hv, v).cpu().item())
if abs(np.mean(trace_vhv) - trace) / (trace + 1e-6) < tol:
if abs(np.mean(trace_vhv) - trace) / (abs(trace) + 1e-6) < tol:
break
else:
trace = np.mean(trace_vhv)
Expand Down
14 changes: 7 additions & 7 deletions mqbench/mix_precision/mix_precision.py
Expand Up @@ -144,7 +144,7 @@ def get_new_qrange(bits, qscheme):
with torch.no_grad():
output_data = quantized_model(input_data)
loss = creterion(output_data, label_data)
sensetive_dict[name].append(loss)
sensetive_dict[name].append(loss - fp_loss)
logger.info("Layer {} under bit {} with sensetive {}".format(name, bits, loss - fp_loss))
mod.weight_fake_quant.disable_observer()
mod.weight_fake_quant.disable_fake_quant()
Expand Down Expand Up @@ -246,16 +246,16 @@ def ILP_bit_selection(bitwidth_list, sensetive_dict, layer_parameters_dict, mode
layer_parameters_dict = model_size_analysis(model)
model_size = sum(list(layer_parameters_dict.values())) * 32 / 8 / 1024 / 1024
logger.info("FP model size: {:.2f} MB".format(model_size))
naive_sensetive_dict = mixprecision_profiling(model, quantized_model, test_bitwidth_list,
data=(inputs, targets), criterion=torch.nn.CrossEntropyLoss(), algo='naive')
# naive_sensetive_dict = mixprecision_profiling(model, quantized_model, test_bitwidth_list,
# data=(inputs, targets), criterion=torch.nn.CrossEntropyLoss(), algo='naive')
# maxeigen_sensetive_dict = mixprecision_profiling(model, quantized_model, test_bitwidth_list,
# data=(inputs, targets), criterion=torch.nn.CrossEntropyLoss(), algo='hawq_eigen')
# trace_sensetive_dict = mixprecision_profiling(model, quantized_model, test_bitwidth_list,
# data=(inputs, targets), criterion=torch.nn.CrossEntropyLoss(), algo='hawq_trace')
trace_sensetive_dict = mixprecision_profiling(model, quantized_model, test_bitwidth_list,
data=(inputs, targets), criterion=torch.nn.CrossEntropyLoss(), algo='hawq_trace')

mixprecision_bit_selection(test_bitwidth_list,
naive_sensetive_dict,
# naive_sensetive_dict,
# maxeigen_sensetive_dict,
# trace_sensetive_dict,
trace_sensetive_dict,
layer_parameters_dict,
model_size_constraints=3, latency_constraints=None)

0 comments on commit 80345e5

Please sign in to comment.