-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
KeyError when executing quantization #273
Comments
Is it OK to modify the line 2879 of from node = graph.nodes_map[name] to try:
node = graph.nodes_map[name]
except KeyError:
continue |
@iksooman No, I think the correct thing to do is to remove all the things in the specific directory. since it usually mean that the model is changed or updated. I think the following logic would be better. if name in graph.nodes_map:
node = graph.nodes_map[name]
else:
log.error(f'Node name {name} not found in configuration file, it probably means that your model has been updated. Please remove the old yaml file and try again')
assert False |
@peterjc123 Could you please give me a little more detailed explanation? I just replaced the model in the def main_worker(args):
print("###### TinyNeuralNetwork quick start for expert ######")
# If you encounter any problems, please set the global log level to `DEBUG`, which may make it easier to debug.
# set_global_log_level("DEBUG")
model = mobilenet.Mobilenet()
model.load_state_dict(torch.load(mobilenet.DEFAULT_STATE_DICT))
device = get_device()
model.to(device=device)
if args.distillation:
teacher = copy.deepcopy(model)
if args.parallel:
model = nn.DataParallel(model)
# Provide a viable input for the model
dummy_input = torch.rand((1, 3, 224, 224))
context = DLContext()
context.device = device
context.train_loader, context.val_loader = get_dataloader(args.data_path, 224, args.batch_size, args.workers)
print("Validation accuracy of the original model")
validate(model, context)
print("Start pruning the model")
# If you need to set the sparsity of a single operator, then you may refer to the examples in `examples/pruner`.
pruner = OneShotChannelPruner(model, dummy_input, {"sparsity": 0.75, "metrics": "l2_norm"})
st_flops = pruner.calc_flops()
pruner.prune() # Get the pruned model
print("Validation accuracy of the pruned model")
validate(model, context)
ed_flops = pruner.calc_flops()
print(f"Pruning over, reduced FLOPS {100 * (st_flops - ed_flops) / st_flops:.2f}% ({st_flops} -> {ed_flops})")
print("Start finetune the pruned model")
# In our experiments, using the same learning rate configuration as the one used during training from scratch
# leads to a higher final model accuracy.
context.max_epoch = 220
context.criterion = nn.BCEWithLogitsLoss()
context.optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
context.scheduler = CosineAnnealingLR(context.optimizer, T_max=context.max_epoch + 1, eta_min=0)
if args.warmup:
print("Use warmup")
context.warmup_iteration = len(context.train_loader) * 10 # warmup 10 epoch
context.warmup_scheduler = CyclicLR(
context.optimizer, base_lr=0, max_lr=0.1, step_size_up=context.warmup_iteration
)
if args.distillation:
# The utilization of distillation may leads to better accuracy at the price of longer training time.
print("Use distillation")
context.custom_args = {'distill_A': 0.3, 'distill_T': 6, 'distill_teacher': teacher}
train(model, context, train_one_epoch_distill, validate)
else:
train(model, context, train_one_epoch, validate)
print("Start preparing the model for quantization")
# We provides a QATQuantizer class that may rewrite the graph for and perform model fusion for quantization
# The model returned by the `quantize` function is ready for QAT training
quantizer = QATQuantizer(model, dummy_input, work_dir='out')
qat_model = quantizer.quantize() |
@iksooman you just have to do |
@peterjc123 problem solved. Thank you! |
Hi!
I'm trying to execute quick_start_for_expert.py of example with MobileNetV3.
But I got an Error like the below during quantization:
Could you give me an idea for fixing it?
The text was updated successfully, but these errors were encountered: